basimazam commited on
Commit
a2cf62b
·
verified ·
1 Parent(s): 1932ce5

Upload SDG pipeline + classifier weights

Browse files
Files changed (1) hide show
  1. safe_diffusion_guidance.py +61 -66
safe_diffusion_guidance.py CHANGED
@@ -7,93 +7,87 @@ import torch.nn as nn
7
  from diffusers import DiffusionPipeline, StableDiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
 
 
 
10
 
11
- # ----------------------------- Classifier ------------------------------------
12
- CLASS_NAMES = ["gore", "hate", "medical", "safe", "sexual"]
13
 
14
- class SafetyClassifier1280(nn.Module):
15
  """
16
- Safety classifier for mid-UNet features of shape (B, 1280, H, W).
17
- Robust to HxW via AdaptiveAvgPool2d((8,8)) before the head.
18
  """
19
  def __init__(self, num_classes: int = 5):
20
  super().__init__()
21
  self.pre = nn.AdaptiveAvgPool2d((8, 8))
22
- self.net = nn.Sequential(
23
- nn.Conv2d(1280, 512, 3, padding=1),
24
- nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 512x4x4
25
- nn.Conv2d(512, 256, 3, padding=1),
26
- nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # 256x2x2
27
- nn.AdaptiveAvgPool2d(1), nn.Flatten(), # 256
28
- nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
 
 
 
 
 
 
29
  nn.Linear(128, num_classes)
30
  )
31
- self.apply(self._init_weights)
32
 
33
  @staticmethod
34
- def _init_weights(m):
35
  if isinstance(m, nn.Linear):
36
- nn.init.xavier_uniform_(m.weight)
37
- if m.bias is not None: nn.init.zeros_(m.bias)
38
  elif isinstance(m, nn.Conv2d):
39
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
40
  if m.bias is not None: nn.init.zeros_(m.bias)
41
  elif isinstance(m, nn.BatchNorm2d):
42
  nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
43
 
44
- def forward(self, x: torch.Tensor) -> torch.Tensor:
45
- x = self.pre(x) # (B, 1280, 8, 8)
46
- return self.net(x)
47
-
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def load_classifier_1280(
50
- weights_path: str,
51
- device=None,
52
  dtype: torch.dtype = torch.float32
53
- ) -> SafetyClassifier1280:
54
- """
55
- Robustly load classifier weights saved either as a pure state dict or as a
56
- training checkpoint (with 'model_state_dict'), handling key prefixes and
57
- PyTorch 2.6's weights_only behavior.
58
- """
59
- model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
60
-
61
- # --- load robustly (PyTorch 2.6 & older) ---
62
- try:
63
- state = torch.load(weights_path, map_location="cpu", weights_only=False)
64
- except TypeError:
65
- state = torch.load(weights_path, map_location="cpu")
66
- except Exception:
67
- # Allowlist common numpy globals if needed for old pickles
68
- try:
69
- import numpy as np
70
- from torch.serialization import add_safe_globals
71
- add_safe_globals([np.dtype, np.number])
72
- except Exception:
73
- pass
74
- state = torch.load(weights_path, map_location="cpu", weights_only=False)
75
-
76
- # --- unwrap training checkpoint ---
77
- if isinstance(state, dict) and "model_state_dict" in state:
78
- state = state["model_state_dict"]
79
-
80
- # --- normalize keys: 'module.' -> '', 'model.' -> 'net.' ---
81
- norm_state = {}
82
- for k, v in state.items():
83
- k = k.replace("module.", "")
84
- k = k.replace("model.", "net.")
85
- norm_state[k] = v
86
-
87
- # --- load tolerating tiny diffs (e.g., extra keys from older heads) ---
88
- missing, unexpected = model.load_state_dict(norm_state, strict=False)
89
  if missing or unexpected:
90
- print(f"[SDG] Loaded classifier with relaxed strictness. "
91
- f"Missing: {len(missing)}, Unexpected: {len(unexpected)}")
92
-
93
  model.eval()
94
  return model
95
 
96
-
97
  def _here(*paths: str) -> str:
98
  return os.path.join(os.path.dirname(__file__), *paths)
99
 
@@ -197,8 +191,9 @@ class SafeDiffusionGuidance(DiffusionPipeline):
197
  timesteps = base.scheduler.timesteps
198
 
199
  # 4) classifier (run in fp32)
200
- weights_file = classifier_weights or pick_weights_path()
201
- clf = load_classifier_1280(weights_file, device=device, dtype=torch.float32)
 
202
 
203
  # 5) mid-block hook
204
  mid = {}
@@ -222,7 +217,7 @@ class SafeDiffusionGuidance(DiffusionPipeline):
222
  lcat = torch.cat([lin, lin], dim=0)
223
 
224
  _ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
225
- feat = mid["feat"].detach().float() # (B*2, 1280, H, W)
226
  logits = clf(feat)
227
  probs = torch.softmax(logits, dim=-1)
228
  unsafe = 1.0 - probs[:, safe_class_index].mean() # encourage "safe"
 
7
  from diffusers import DiffusionPipeline, StableDiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
 
10
+ import torch, torch.nn as nn, os
11
+ from typing import Optional
12
 
13
+ CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
 
14
 
15
+ class AdaptiveClassifier1280(nn.Module):
16
  """
17
+ Same CNN topology you trained (keys start with 'model.*').
18
+ Input (B,1280,H,W) -> AdaptiveAvgPool2d(8,8) -> conv stack -> head
19
  """
20
  def __init__(self, num_classes: int = 5):
21
  super().__init__()
22
  self.pre = nn.AdaptiveAvgPool2d((8, 8))
23
+ # Keep the attribute name 'model' to match the checkpoint keys.
24
+ self.model = nn.Sequential(
25
+ nn.Conv2d(1280, 512, kernel_size=3, padding=1),
26
+ nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), # (512,4,4)
27
+ nn.Dropout2d(0.1),
28
+
29
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # (256,2,2)
31
+ nn.Dropout2d(0.1),
32
+
33
+ nn.AdaptiveAvgPool2d(1), # -> (256,1,1)
34
+ nn.Flatten(), # -> (256,)
35
+ nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
36
  nn.Linear(128, num_classes)
37
  )
38
+ self.apply(self._init)
39
 
40
  @staticmethod
41
+ def _init(m):
42
  if isinstance(m, nn.Linear):
43
+ nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
 
44
  elif isinstance(m, nn.Conv2d):
45
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
46
  if m.bias is not None: nn.init.zeros_(m.bias)
47
  elif isinstance(m, nn.BatchNorm2d):
48
  nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
49
 
50
+ def forward(self, x):
51
+ x = self.pre(x) # (B,1280,8,8)
52
+ return self.model(x)
53
+
54
+ def _find_weights_path() -> str:
55
+ # 1) explicit env; 2) repo root file; 3) classifiers/ subdir
56
+ env_p = os.getenv("SDG_CLASSIFIER_WEIGHTS")
57
+ if env_p and os.path.exists(env_p): return env_p
58
+ for p in ["safety_classifier_1280.pth", os.path.join("classifiers","safety_classifier_1280.pth")]:
59
+ if os.path.exists(p): return p
60
+ # If running from HF cache, these paths are relative to the cached repo folder.
61
+ raise FileNotFoundError(
62
+ "Safety-classifier weights not found. Provide via env SDG_CLASSIFIER_WEIGHTS, "
63
+ "place 'safety_classifier_1280.pth' at repo root or 'classifiers/', "
64
+ "or pass `classifier_weights=...` to the pipeline call."
65
+ )
66
 
67
  def load_classifier_1280(
68
+ weights_path: Optional[str],
69
+ device: torch.device,
70
  dtype: torch.dtype = torch.float32
71
+ ) -> AdaptiveClassifier1280:
72
+ path = weights_path or _find_weights_path()
73
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
74
+
75
+ # Extract actual state dict
76
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
77
+ state = ckpt["model_state_dict"]
78
+ elif isinstance(ckpt, dict) and any(k.startswith("model.") for k in ckpt.keys()):
79
+ state = ckpt
80
+ else:
81
+ # Fallback: allow whole-object saves (only if trusted)
82
+ state = ckpt
83
+
84
+ model = AdaptiveClassifier1280().to(device=device, dtype=torch.float32) # keep classifier in fp32
85
+ missing, unexpected = model.load_state_dict(state, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if missing or unexpected:
87
+ print(f"[SDG] load_state_dict: missing={missing[:4]}... ({len(missing)}), unexpected={unexpected[:4]}... ({len(unexpected)})")
 
 
88
  model.eval()
89
  return model
90
 
 
91
  def _here(*paths: str) -> str:
92
  return os.path.join(os.path.dirname(__file__), *paths)
93
 
 
191
  timesteps = base.scheduler.timesteps
192
 
193
  # 4) classifier (run in fp32)
194
+ weights = classifier_weights or pick_weights_for_pipe(base)
195
+ clf = load_classifier_1280(weights, device=device, dtype=torch.float32).eval()
196
+
197
 
198
  # 5) mid-block hook
199
  mid = {}
 
217
  lcat = torch.cat([lin, lin], dim=0)
218
 
219
  _ = base.unet(lcat, t, encoder_hidden_states=cond_embeds).sample
220
+ feat = mid["feat"].detach().to(torch.float32)
221
  logits = clf(feat)
222
  probs = torch.softmax(logits, dim=-1)
223
  unsafe = 1.0 - probs[:, safe_class_index].mean() # encourage "safe"