|
|
--- |
|
|
license: mit |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- latent-classifier |
|
|
- stable-diffusion |
|
|
- diffusion |
|
|
- concept-probing |
|
|
- classifier-guidance |
|
|
- SD1.4 |
|
|
pipeline_tag: text-to-image |
|
|
language: |
|
|
- en |
|
|
--- |
|
|
|
|
|
# Airliner Latent Classifier (Stable Diffusion v1.4) |
|
|
|
|
|
**Latent-space binary classifier** trained on **Stable Diffusion v1.4** VAE latents (shape `4×64×64`) with a simple MLP head and a timestep embedding (from the DDIM scheduler). |
|
|
Intended for **concept probing** and **classifier guidance** in diffusion workflows. |
|
|
|
|
|
- **Concept:** `airliner` |
|
|
- **Input:** latent tensor `z ∈ ℝ^{4×64×64}` and a diffusion timestep `t` |
|
|
- **Output:** logit/probability that `z` contains the concept at timestep `t` |
|
|
- **Author/Org:** DiffusionConceptErasure |
|
|
- **Date:** 2025-11-05 |
|
|
|
|
|
## Usage (PyTorch) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from diffusers import DDIMScheduler |
|
|
|
|
|
# ---- model definition (must match training) ---- |
|
|
import torch.nn as nn |
|
|
class FixedTimestepEncoding(nn.Module): |
|
|
def __init__(self, scheduler): |
|
|
super().__init__() |
|
|
self.register_buffer("alphas_cumprod", scheduler.alphas_cumprod) |
|
|
def forward(self, t): |
|
|
alpha_bar = self.alphas_cumprod[t] |
|
|
return torch.stack([alpha_bar.sqrt(), (1 - alpha_bar).sqrt()], dim=-1) |
|
|
|
|
|
class LatentClassifierT(nn.Module): |
|
|
def __init__(self, latent_shape=(4, 64, 64), scheduler=None): |
|
|
super().__init__() |
|
|
c, h, w = latent_shape |
|
|
flat_dim = c * h * w |
|
|
self.t_embed = FixedTimestepEncoding(scheduler) |
|
|
self.fc_t = nn.Linear(2, 1024) |
|
|
self.fc_x = nn.Linear(flat_dim, 1024) |
|
|
self.net = nn.Sequential( |
|
|
nn.SiLU(), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(1024, 512), |
|
|
nn.SiLU(), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(512, 1) |
|
|
) |
|
|
def forward(self, z, t): |
|
|
z_flat = z.flatten(start_dim=1) |
|
|
return self.net(self.fc_x(z_flat) + self.fc_t(self.t_embed(t))) |
|
|
|
|
|
# ---- load weights ---- |
|
|
repo_id = "DiffusionConceptErasure/latent-classifier-airliner" |
|
|
ckpt_name = "airliner.pt" |
|
|
|
|
|
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") |
|
|
model = LatentClassifierT(scheduler=scheduler) |
|
|
|
|
|
state = torch.hub.load_state_dict_from_url( |
|
|
f"https://huggingface.co/{repo_id}/resolve/main/{ckpt_name}", |
|
|
map_location="cpu" |
|
|
) |
|
|
model.load_state_dict(state["model_state_dict"] if "model_state_dict" in state else state) |
|
|
model.eval() |
|
|
|
|
|
# Example inference: |
|
|
z = torch.randn(1, 4, 64, 64) # latent |
|
|
t = torch.randint(0, scheduler.config.num_train_timesteps, (1,)) # timestep |
|
|
with torch.no_grad(): |
|
|
logit = model(z, t) # shape [1, 1] |
|
|
prob = torch.sigmoid(logit) |
|
|
print(prob.item()) |
|
|
``` |
|
|
|
|
|
## Notes |
|
|
|
|
|
- Trained with DDIM power-law timestep sampling biased to noisier latents. |
|
|
- For classifier guidance, average logits across a few noisy t samples if desired. |
|
|
- Expectation: highest discriminability at moderate noise; extreme noise reduces signal. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this, please cite: |
|
|
|
|
|
```bibtex |
|
|
@inproceedings{lu2025concepts, |
|
|
title={When Are Concepts Erased From Diffusion Models?}, |
|
|
author={Kevin Lu and Nicky Kriplani and Rohit Gandikota and Minh Pham and David Bau and Chinmay Hegde and Niv Cohen}, |
|
|
booktitle={NeurIPS}, |
|
|
year={2025} |
|
|
} |
|
|
``` |
|
|
|