Upload README.md with huggingface_hub
Browse files
README.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
tags:
|
| 5 |
+
- latent-classifier
|
| 6 |
+
- stable-diffusion
|
| 7 |
+
- diffusion
|
| 8 |
+
- concept-probing
|
| 9 |
+
- classifier-guidance
|
| 10 |
+
- SD1.4
|
| 11 |
+
pipeline_tag: text-to-image
|
| 12 |
+
language:
|
| 13 |
+
- en
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
# Church Latent Classifier (Stable Diffusion v1.4)
|
| 17 |
+
|
| 18 |
+
**Latent-space binary classifier** trained on **Stable Diffusion v1.4** VAE latents (shape `4×64×64`) with a simple MLP head and a timestep embedding (from the DDIM scheduler).
|
| 19 |
+
Intended for **concept probing** and **classifier guidance** in diffusion workflows.
|
| 20 |
+
|
| 21 |
+
- **Concept:** `church`
|
| 22 |
+
- **Input:** latent tensor `z ∈ ℝ^{4×64×64}` and a diffusion timestep `t`
|
| 23 |
+
- **Output:** logit/probability that `z` contains the concept at timestep `t`
|
| 24 |
+
- **Author/Org:** DiffusionConceptErasure
|
| 25 |
+
- **Date:** 2025-11-05
|
| 26 |
+
|
| 27 |
+
## Usage (PyTorch)
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
import torch
|
| 31 |
+
from diffusers import DDIMScheduler
|
| 32 |
+
|
| 33 |
+
# ---- model definition (must match training) ----
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
class FixedTimestepEncoding(nn.Module):
|
| 36 |
+
def __init__(self, scheduler):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.register_buffer("alphas_cumprod", scheduler.alphas_cumprod)
|
| 39 |
+
def forward(self, t):
|
| 40 |
+
alpha_bar = self.alphas_cumprod[t]
|
| 41 |
+
return torch.stack([alpha_bar.sqrt(), (1 - alpha_bar).sqrt()], dim=-1)
|
| 42 |
+
|
| 43 |
+
class LatentClassifierT(nn.Module):
|
| 44 |
+
def __init__(self, latent_shape=(4, 64, 64), scheduler=None):
|
| 45 |
+
super().__init__()
|
| 46 |
+
c, h, w = latent_shape
|
| 47 |
+
flat_dim = c * h * w
|
| 48 |
+
self.t_embed = FixedTimestepEncoding(scheduler)
|
| 49 |
+
self.fc_t = nn.Linear(2, 1024)
|
| 50 |
+
self.fc_x = nn.Linear(flat_dim, 1024)
|
| 51 |
+
self.net = nn.Sequential(
|
| 52 |
+
nn.SiLU(),
|
| 53 |
+
nn.Dropout(0.3),
|
| 54 |
+
nn.Linear(1024, 512),
|
| 55 |
+
nn.SiLU(),
|
| 56 |
+
nn.Dropout(0.3),
|
| 57 |
+
nn.Linear(512, 1)
|
| 58 |
+
)
|
| 59 |
+
def forward(self, z, t):
|
| 60 |
+
z_flat = z.flatten(start_dim=1)
|
| 61 |
+
return self.net(self.fc_x(z_flat) + self.fc_t(self.t_embed(t)))
|
| 62 |
+
|
| 63 |
+
# ---- load weights ----
|
| 64 |
+
repo_id = "DiffusionConceptErasure/latent-classifier-church"
|
| 65 |
+
ckpt_name = "church.pt"
|
| 66 |
+
|
| 67 |
+
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
| 68 |
+
model = LatentClassifierT(scheduler=scheduler)
|
| 69 |
+
|
| 70 |
+
state = torch.hub.load_state_dict_from_url(
|
| 71 |
+
f"https://huggingface.co/{repo_id}/resolve/main/{ckpt_name}",
|
| 72 |
+
map_location="cpu"
|
| 73 |
+
)
|
| 74 |
+
model.load_state_dict(state["model_state_dict"] if "model_state_dict" in state else state)
|
| 75 |
+
model.eval()
|
| 76 |
+
|
| 77 |
+
# Example inference:
|
| 78 |
+
z = torch.randn(1, 4, 64, 64) # latent
|
| 79 |
+
t = torch.randint(0, scheduler.config.num_train_timesteps, (1,)) # timestep
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
logit = model(z, t) # shape [1, 1]
|
| 82 |
+
prob = torch.sigmoid(logit)
|
| 83 |
+
print(prob.item())
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Notes
|
| 87 |
+
|
| 88 |
+
- Trained with DDIM power-law timestep sampling biased to noisier latents.
|
| 89 |
+
- For classifier guidance, average logits across a few noisy t samples if desired.
|
| 90 |
+
- Expectation: highest discriminability at moderate noise; extreme noise reduces signal.
|
| 91 |
+
|
| 92 |
+
## Citation
|
| 93 |
+
|
| 94 |
+
If you use this, please cite:
|
| 95 |
+
|
| 96 |
+
```bibtex
|
| 97 |
+
@inproceedings{lu2025concepts,
|
| 98 |
+
title={When Are Concepts Erased From Diffusion Models?},
|
| 99 |
+
author={Kevin Lu and Nicky Kriplani and Rohit Gandikota and Minh Pham and David Bau and Chinmay Hegde and Niv Cohen},
|
| 100 |
+
booktitle={NeurIPS},
|
| 101 |
+
year={2025}
|
| 102 |
+
}
|
| 103 |
+
```
|