File size: 3,270 Bytes
c831263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
---
license: mit
library_name: pytorch
tags:
- latent-classifier
- stable-diffusion
- diffusion
- concept-probing
- classifier-guidance
- SD1.4
pipeline_tag: text-to-image
language:
- en
---

# Airliner Latent Classifier (Stable Diffusion v1.4)

**Latent-space binary classifier** trained on **Stable Diffusion v1.4** VAE latents (shape `4×64×64`) with a simple MLP head and a timestep embedding (from the DDIM scheduler).  
Intended for **concept probing** and **classifier guidance** in diffusion workflows.

- **Concept:** `airliner`
- **Input:** latent tensor `z ∈ ℝ^{4×64×64}` and a diffusion timestep `t`
- **Output:** logit/probability that `z` contains the concept at timestep `t`
- **Author/Org:** DiffusionConceptErasure
- **Date:** 2025-11-05

## Usage (PyTorch)

```python
import torch
from diffusers import DDIMScheduler

# ---- model definition (must match training) ----
import torch.nn as nn
class FixedTimestepEncoding(nn.Module):
    def __init__(self, scheduler):
        super().__init__()
        self.register_buffer("alphas_cumprod", scheduler.alphas_cumprod)
    def forward(self, t):
        alpha_bar = self.alphas_cumprod[t]
        return torch.stack([alpha_bar.sqrt(), (1 - alpha_bar).sqrt()], dim=-1)

class LatentClassifierT(nn.Module):
    def __init__(self, latent_shape=(4, 64, 64), scheduler=None):
        super().__init__()
        c, h, w = latent_shape
        flat_dim = c * h * w
        self.t_embed = FixedTimestepEncoding(scheduler)
        self.fc_t = nn.Linear(2, 1024)
        self.fc_x = nn.Linear(flat_dim, 1024)
        self.net = nn.Sequential(
            nn.SiLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.SiLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)
        )
    def forward(self, z, t):
        z_flat = z.flatten(start_dim=1)
        return self.net(self.fc_x(z_flat) + self.fc_t(self.t_embed(t)))

# ---- load weights ----
repo_id = "DiffusionConceptErasure/latent-classifier-airliner"
ckpt_name = "airliner.pt"

scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
model = LatentClassifierT(scheduler=scheduler)

state = torch.hub.load_state_dict_from_url(
    f"https://huggingface.co/{repo_id}/resolve/main/{ckpt_name}",
    map_location="cpu"
)
model.load_state_dict(state["model_state_dict"] if "model_state_dict" in state else state)
model.eval()

# Example inference:
z = torch.randn(1, 4, 64, 64)           # latent
t = torch.randint(0, scheduler.config.num_train_timesteps, (1,))  # timestep
with torch.no_grad():
    logit = model(z, t)                 # shape [1, 1]
    prob = torch.sigmoid(logit)
print(prob.item())
```

## Notes

- Trained with DDIM power-law timestep sampling biased to noisier latents.
- For classifier guidance, average logits across a few noisy t samples if desired.
- Expectation: highest discriminability at moderate noise; extreme noise reduces signal.

## Citation

If you use this, please cite:

```bibtex
@inproceedings{lu2025concepts,
  title={When Are Concepts Erased From Diffusion Models?},
  author={Kevin Lu and Nicky Kriplani and Rohit Gandikota and Minh Pham and David Bau and Chinmay Hegde and Niv Cohen},
  booktitle={NeurIPS},
  year={2025}
}
```