basimazam commited on
Commit
82a327a
·
verified ·
1 Parent(s): c44a3aa

Upload SDG pipeline + classifier weights

Browse files
Files changed (1) hide show
  1. safe_diffusion_guidance.py +20 -54
safe_diffusion_guidance.py CHANGED
@@ -12,44 +12,25 @@ from typing import Optional
12
 
13
  CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
14
 
15
- class AdaptiveClassifier1280(nn.Module):
16
- """
17
- Same CNN topology you trained (keys start with 'model.*').
18
- Input (B,1280,H,W) -> AdaptiveAvgPool2d(8,8) -> conv stack -> head
19
- """
20
  def __init__(self, num_classes: int = 5):
21
  super().__init__()
22
  self.pre = nn.AdaptiveAvgPool2d((8, 8))
23
- # Keep the attribute name 'model' to match the checkpoint keys.
24
- self.model = nn.Sequential(
25
- nn.Conv2d(1280, 512, kernel_size=3, padding=1),
26
- nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2), # (512,4,4)
27
- nn.Dropout2d(0.1),
28
-
29
- nn.Conv2d(512, 256, kernel_size=3, padding=1),
30
- nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # (256,2,2)
31
- nn.Dropout2d(0.1),
32
-
33
- nn.AdaptiveAvgPool2d(1), # -> (256,1,1)
34
- nn.Flatten(), # -> (256,)
35
- nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
36
  nn.Linear(128, num_classes)
37
  )
38
- self.apply(self._init)
39
-
40
- @staticmethod
41
- def _init(m):
42
- if isinstance(m, nn.Linear):
43
- nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)
44
- elif isinstance(m, nn.Conv2d):
45
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
46
- if m.bias is not None: nn.init.zeros_(m.bias)
47
- elif isinstance(m, nn.BatchNorm2d):
48
- nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
49
-
50
- def forward(self, x):
51
- x = self.pre(x) # (B,1280,8,8)
52
- return self.model(x)
53
 
54
  def _find_weights_path() -> str:
55
  # 1) explicit env; 2) repo root file; 3) classifiers/ subdir
@@ -64,27 +45,12 @@ def _find_weights_path() -> str:
64
  "or pass `classifier_weights=...` to the pipeline call."
65
  )
66
 
67
- def load_classifier_1280(
68
- weights_path: Optional[str],
69
- device: torch.device,
70
- dtype: torch.dtype = torch.float32
71
- ) -> AdaptiveClassifier1280:
72
- path = weights_path or _find_weights_path()
73
- ckpt = torch.load(path, map_location="cpu", weights_only=False)
74
-
75
- # Extract actual state dict
76
- if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
77
- state = ckpt["model_state_dict"]
78
- elif isinstance(ckpt, dict) and any(k.startswith("model.") for k in ckpt.keys()):
79
- state = ckpt
80
- else:
81
- # Fallback: allow whole-object saves (only if trusted)
82
- state = ckpt
83
-
84
- model = AdaptiveClassifier1280().to(device=device, dtype=torch.float32) # keep classifier in fp32
85
- missing, unexpected = model.load_state_dict(state, strict=False)
86
- if missing or unexpected:
87
- print(f"[SDG] load_state_dict: missing={missing[:4]}... ({len(missing)}), unexpected={unexpected[:4]}... ({len(unexpected)})")
88
  model.eval()
89
  return model
90
 
 
12
 
13
  CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
14
 
15
+ class SafetyClassifier1280(nn.Module):
 
 
 
 
16
  def __init__(self, num_classes: int = 5):
17
  super().__init__()
18
  self.pre = nn.AdaptiveAvgPool2d((8, 8))
19
+ self.model = nn.Sequential( # <--- use "model" to match checkpoint
20
+ nn.Conv2d(1280, 512, 3, padding=1),
21
+ nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
22
+ nn.Conv2d(512, 256, 3, padding=1),
23
+ nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
24
+ nn.AdaptiveAvgPool2d(1), nn.Flatten(),
25
+ nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
 
 
 
 
 
 
26
  nn.Linear(128, num_classes)
27
  )
28
+ self.apply(self._init_weights)
29
+
30
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
31
+ x = self.pre(x)
32
+ return self.model(x) # <--- forward through "model"
33
+
 
 
 
 
 
 
 
 
 
34
 
35
  def _find_weights_path() -> str:
36
  # 1) explicit env; 2) repo root file; 3) classifiers/ subdir
 
45
  "or pass `classifier_weights=...` to the pipeline call."
46
  )
47
 
48
+ def load_classifier_1280(weights_path: str, device=None, dtype=torch.float32):
49
+ model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
50
+ state = torch.load(weights_path, map_location="cpu", weights_only=False)
51
+ if isinstance(state, dict) and "model_state_dict" in state:
52
+ state = state["model_state_dict"]
53
+ model.load_state_dict(state, strict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  model.eval()
55
  return model
56