Upload SDG pipeline + classifier weights
Browse files- safe_diffusion_guidance.py +20 -54
safe_diffusion_guidance.py
CHANGED
|
@@ -12,44 +12,25 @@ from typing import Optional
|
|
| 12 |
|
| 13 |
CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
|
| 14 |
|
| 15 |
-
class
|
| 16 |
-
"""
|
| 17 |
-
Same CNN topology you trained (keys start with 'model.*').
|
| 18 |
-
Input (B,1280,H,W) -> AdaptiveAvgPool2d(8,8) -> conv stack -> head
|
| 19 |
-
"""
|
| 20 |
def __init__(self, num_classes: int = 5):
|
| 21 |
super().__init__()
|
| 22 |
self.pre = nn.AdaptiveAvgPool2d((8, 8))
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
nn.
|
| 26 |
-
nn.
|
| 27 |
-
nn.
|
| 28 |
-
|
| 29 |
-
nn.
|
| 30 |
-
nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2), # (256,2,2)
|
| 31 |
-
nn.Dropout2d(0.1),
|
| 32 |
-
|
| 33 |
-
nn.AdaptiveAvgPool2d(1), # -> (256,1,1)
|
| 34 |
-
nn.Flatten(), # -> (256,)
|
| 35 |
-
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.5),
|
| 36 |
nn.Linear(128, num_classes)
|
| 37 |
)
|
| 38 |
-
self.apply(self.
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
elif isinstance(m, nn.Conv2d):
|
| 45 |
-
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 46 |
-
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 47 |
-
elif isinstance(m, nn.BatchNorm2d):
|
| 48 |
-
nn.init.ones_(m.weight); nn.init.zeros_(m.bias)
|
| 49 |
-
|
| 50 |
-
def forward(self, x):
|
| 51 |
-
x = self.pre(x) # (B,1280,8,8)
|
| 52 |
-
return self.model(x)
|
| 53 |
|
| 54 |
def _find_weights_path() -> str:
|
| 55 |
# 1) explicit env; 2) repo root file; 3) classifiers/ subdir
|
|
@@ -64,27 +45,12 @@ def _find_weights_path() -> str:
|
|
| 64 |
"or pass `classifier_weights=...` to the pipeline call."
|
| 65 |
)
|
| 66 |
|
| 67 |
-
def load_classifier_1280(
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 74 |
-
|
| 75 |
-
# Extract actual state dict
|
| 76 |
-
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
| 77 |
-
state = ckpt["model_state_dict"]
|
| 78 |
-
elif isinstance(ckpt, dict) and any(k.startswith("model.") for k in ckpt.keys()):
|
| 79 |
-
state = ckpt
|
| 80 |
-
else:
|
| 81 |
-
# Fallback: allow whole-object saves (only if trusted)
|
| 82 |
-
state = ckpt
|
| 83 |
-
|
| 84 |
-
model = AdaptiveClassifier1280().to(device=device, dtype=torch.float32) # keep classifier in fp32
|
| 85 |
-
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 86 |
-
if missing or unexpected:
|
| 87 |
-
print(f"[SDG] load_state_dict: missing={missing[:4]}... ({len(missing)}), unexpected={unexpected[:4]}... ({len(unexpected)})")
|
| 88 |
model.eval()
|
| 89 |
return model
|
| 90 |
|
|
|
|
| 12 |
|
| 13 |
CLASS_NAMES = ['gore', 'hate', 'medical', 'safe', 'sexual']
|
| 14 |
|
| 15 |
+
class SafetyClassifier1280(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def __init__(self, num_classes: int = 5):
|
| 17 |
super().__init__()
|
| 18 |
self.pre = nn.AdaptiveAvgPool2d((8, 8))
|
| 19 |
+
self.model = nn.Sequential( # <--- use "model" to match checkpoint
|
| 20 |
+
nn.Conv2d(1280, 512, 3, padding=1),
|
| 21 |
+
nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.MaxPool2d(2),
|
| 22 |
+
nn.Conv2d(512, 256, 3, padding=1),
|
| 23 |
+
nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2),
|
| 24 |
+
nn.AdaptiveAvgPool2d(1), nn.Flatten(),
|
| 25 |
+
nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(0.3),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
nn.Linear(128, num_classes)
|
| 27 |
)
|
| 28 |
+
self.apply(self._init_weights)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
x = self.pre(x)
|
| 32 |
+
return self.model(x) # <--- forward through "model"
|
| 33 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
def _find_weights_path() -> str:
|
| 36 |
# 1) explicit env; 2) repo root file; 3) classifiers/ subdir
|
|
|
|
| 45 |
"or pass `classifier_weights=...` to the pipeline call."
|
| 46 |
)
|
| 47 |
|
| 48 |
+
def load_classifier_1280(weights_path: str, device=None, dtype=torch.float32):
|
| 49 |
+
model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
|
| 50 |
+
state = torch.load(weights_path, map_location="cpu", weights_only=False)
|
| 51 |
+
if isinstance(state, dict) and "model_state_dict" in state:
|
| 52 |
+
state = state["model_state_dict"]
|
| 53 |
+
model.load_state_dict(state, strict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
model.eval()
|
| 55 |
return model
|
| 56 |
|