kevinlu4588 commited on
Commit
3e69fdc
·
verified ·
1 Parent(s): 464a103

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +103 -0
README.md ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ tags:
5
+ - latent-classifier
6
+ - stable-diffusion
7
+ - diffusion
8
+ - concept-probing
9
+ - classifier-guidance
10
+ - SD1.4
11
+ pipeline_tag: text-to-image
12
+ language:
13
+ - en
14
+ ---
15
+
16
+ # Church Latent Classifier (Stable Diffusion v1.4)
17
+
18
+ **Latent-space binary classifier** trained on **Stable Diffusion v1.4** VAE latents (shape `4×64×64`) with a simple MLP head and a timestep embedding (from the DDIM scheduler).
19
+ Intended for **concept probing** and **classifier guidance** in diffusion workflows.
20
+
21
+ - **Concept:** `church`
22
+ - **Input:** latent tensor `z ∈ ℝ^{4×64×64}` and a diffusion timestep `t`
23
+ - **Output:** logit/probability that `z` contains the concept at timestep `t`
24
+ - **Author/Org:** DiffusionConceptErasure
25
+ - **Date:** 2025-11-05
26
+
27
+ ## Usage (PyTorch)
28
+
29
+ ```python
30
+ import torch
31
+ from diffusers import DDIMScheduler
32
+
33
+ # ---- model definition (must match training) ----
34
+ import torch.nn as nn
35
+ class FixedTimestepEncoding(nn.Module):
36
+ def __init__(self, scheduler):
37
+ super().__init__()
38
+ self.register_buffer("alphas_cumprod", scheduler.alphas_cumprod)
39
+ def forward(self, t):
40
+ alpha_bar = self.alphas_cumprod[t]
41
+ return torch.stack([alpha_bar.sqrt(), (1 - alpha_bar).sqrt()], dim=-1)
42
+
43
+ class LatentClassifierT(nn.Module):
44
+ def __init__(self, latent_shape=(4, 64, 64), scheduler=None):
45
+ super().__init__()
46
+ c, h, w = latent_shape
47
+ flat_dim = c * h * w
48
+ self.t_embed = FixedTimestepEncoding(scheduler)
49
+ self.fc_t = nn.Linear(2, 1024)
50
+ self.fc_x = nn.Linear(flat_dim, 1024)
51
+ self.net = nn.Sequential(
52
+ nn.SiLU(),
53
+ nn.Dropout(0.3),
54
+ nn.Linear(1024, 512),
55
+ nn.SiLU(),
56
+ nn.Dropout(0.3),
57
+ nn.Linear(512, 1)
58
+ )
59
+ def forward(self, z, t):
60
+ z_flat = z.flatten(start_dim=1)
61
+ return self.net(self.fc_x(z_flat) + self.fc_t(self.t_embed(t)))
62
+
63
+ # ---- load weights ----
64
+ repo_id = "DiffusionConceptErasure/latent-classifier-church"
65
+ ckpt_name = "church.pt"
66
+
67
+ scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
68
+ model = LatentClassifierT(scheduler=scheduler)
69
+
70
+ state = torch.hub.load_state_dict_from_url(
71
+ f"https://huggingface.co/{repo_id}/resolve/main/{ckpt_name}",
72
+ map_location="cpu"
73
+ )
74
+ model.load_state_dict(state["model_state_dict"] if "model_state_dict" in state else state)
75
+ model.eval()
76
+
77
+ # Example inference:
78
+ z = torch.randn(1, 4, 64, 64) # latent
79
+ t = torch.randint(0, scheduler.config.num_train_timesteps, (1,)) # timestep
80
+ with torch.no_grad():
81
+ logit = model(z, t) # shape [1, 1]
82
+ prob = torch.sigmoid(logit)
83
+ print(prob.item())
84
+ ```
85
+
86
+ ## Notes
87
+
88
+ - Trained with DDIM power-law timestep sampling biased to noisier latents.
89
+ - For classifier guidance, average logits across a few noisy t samples if desired.
90
+ - Expectation: highest discriminability at moderate noise; extreme noise reduces signal.
91
+
92
+ ## Citation
93
+
94
+ If you use this, please cite:
95
+
96
+ ```bibtex
97
+ @inproceedings{lu2025concepts,
98
+ title={When Are Concepts Erased From Diffusion Models?},
99
+ author={Kevin Lu and Nicky Kriplani and Rohit Gandikota and Minh Pham and David Bau and Chinmay Hegde and Niv Cohen},
100
+ booktitle={NeurIPS},
101
+ year={2025}
102
+ }
103
+ ```