basimazam commited on
Commit
1932ce5
·
verified ·
1 Parent(s): fbfadc4

Upload SDG pipeline + classifier weights

Browse files
Files changed (1) hide show
  1. safe_diffusion_guidance.py +37 -3
safe_diffusion_guidance.py CHANGED
@@ -48,14 +48,48 @@ class SafetyClassifier1280(nn.Module):
48
 
49
  def load_classifier_1280(
50
  weights_path: str,
51
- device: Optional[torch.device] = None,
52
  dtype: torch.dtype = torch.float32
53
  ) -> SafetyClassifier1280:
 
 
 
 
 
54
  model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
55
- state = torch.load(weights_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if isinstance(state, dict) and "model_state_dict" in state:
57
  state = state["model_state_dict"]
58
- model.load_state_dict(state, strict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  model.eval()
60
  return model
61
 
 
48
 
49
  def load_classifier_1280(
50
  weights_path: str,
51
+ device=None,
52
  dtype: torch.dtype = torch.float32
53
  ) -> SafetyClassifier1280:
54
+ """
55
+ Robustly load classifier weights saved either as a pure state dict or as a
56
+ training checkpoint (with 'model_state_dict'), handling key prefixes and
57
+ PyTorch 2.6's weights_only behavior.
58
+ """
59
  model = SafetyClassifier1280().to(device or "cpu", dtype=dtype)
60
+
61
+ # --- load robustly (PyTorch 2.6 & older) ---
62
+ try:
63
+ state = torch.load(weights_path, map_location="cpu", weights_only=False)
64
+ except TypeError:
65
+ state = torch.load(weights_path, map_location="cpu")
66
+ except Exception:
67
+ # Allowlist common numpy globals if needed for old pickles
68
+ try:
69
+ import numpy as np
70
+ from torch.serialization import add_safe_globals
71
+ add_safe_globals([np.dtype, np.number])
72
+ except Exception:
73
+ pass
74
+ state = torch.load(weights_path, map_location="cpu", weights_only=False)
75
+
76
+ # --- unwrap training checkpoint ---
77
  if isinstance(state, dict) and "model_state_dict" in state:
78
  state = state["model_state_dict"]
79
+
80
+ # --- normalize keys: 'module.' -> '', 'model.' -> 'net.' ---
81
+ norm_state = {}
82
+ for k, v in state.items():
83
+ k = k.replace("module.", "")
84
+ k = k.replace("model.", "net.")
85
+ norm_state[k] = v
86
+
87
+ # --- load tolerating tiny diffs (e.g., extra keys from older heads) ---
88
+ missing, unexpected = model.load_state_dict(norm_state, strict=False)
89
+ if missing or unexpected:
90
+ print(f"[SDG] Loaded classifier with relaxed strictness. "
91
+ f"Missing: {len(missing)}, Unexpected: {len(unexpected)}")
92
+
93
  model.eval()
94
  return model
95