Upload SDG pipeline + classifier weights
Browse files- safe_diffusion_guidance.py +37 -3
safe_diffusion_guidance.py
CHANGED
|
@@ -48,14 +48,48 @@ class SafetyClassifier1280(nn.Module):
|
|
| 48 |
|
| 49 |
def load_classifier_1280(
|
| 50 |
weights_path: str,
|
| 51 |
-
device
|
| 52 |
dtype: torch.dtype = torch.float32
|
| 53 |
) -> SafetyClassifier1280:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
if isinstance(state, dict) and "model_state_dict" in state:
|
| 57 |
state = state["model_state_dict"]
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
model.eval()
|
| 60 |
return model
|
| 61 |
|
|
|
|
| 48 |
|
| 49 |
def load_classifier_1280(
|
| 50 |
weights_path: str,
|
| 51 |
+
device=None,
|
| 52 |
dtype: torch.dtype = torch.float32
|
| 53 |
) -> SafetyClassifier1280:
|
| 54 |
+
"""
|
| 55 |
+
Robustly load classifier weights saved either as a pure state dict or as a
|
| 56 |
+
training checkpoint (with 'model_state_dict'), handling key prefixes and
|
| 57 |
+
PyTorch 2.6's weights_only behavior.
|
| 58 |
+
"""
|
| 59 |
model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
|
| 60 |
+
|
| 61 |
+
# --- load robustly (PyTorch 2.6 & older) ---
|
| 62 |
+
try:
|
| 63 |
+
state = torch.load(weights_path, map_location="cpu", weights_only=False)
|
| 64 |
+
except TypeError:
|
| 65 |
+
state = torch.load(weights_path, map_location="cpu")
|
| 66 |
+
except Exception:
|
| 67 |
+
# Allowlist common numpy globals if needed for old pickles
|
| 68 |
+
try:
|
| 69 |
+
import numpy as np
|
| 70 |
+
from torch.serialization import add_safe_globals
|
| 71 |
+
add_safe_globals([np.dtype, np.number])
|
| 72 |
+
except Exception:
|
| 73 |
+
pass
|
| 74 |
+
state = torch.load(weights_path, map_location="cpu", weights_only=False)
|
| 75 |
+
|
| 76 |
+
# --- unwrap training checkpoint ---
|
| 77 |
if isinstance(state, dict) and "model_state_dict" in state:
|
| 78 |
state = state["model_state_dict"]
|
| 79 |
+
|
| 80 |
+
# --- normalize keys: 'module.' -> '', 'model.' -> 'net.' ---
|
| 81 |
+
norm_state = {}
|
| 82 |
+
for k, v in state.items():
|
| 83 |
+
k = k.replace("module.", "")
|
| 84 |
+
k = k.replace("model.", "net.")
|
| 85 |
+
norm_state[k] = v
|
| 86 |
+
|
| 87 |
+
# --- load tolerating tiny diffs (e.g., extra keys from older heads) ---
|
| 88 |
+
missing, unexpected = model.load_state_dict(norm_state, strict=False)
|
| 89 |
+
if missing or unexpected:
|
| 90 |
+
print(f"[SDG] Loaded classifier with relaxed strictness. "
|
| 91 |
+
f"Missing: {len(missing)}, Unexpected: {len(unexpected)}")
|
| 92 |
+
|
| 93 |
model.eval()
|
| 94 |
return model
|
| 95 |
|