karthikeya1212 commited on
Commit
f03d432
·
verified ·
1 Parent(s): 7bd9d8c

Update shadow_generator.py

Browse files
Files changed (1) hide show
  1. shadow_generator.py +9 -1
shadow_generator.py CHANGED
@@ -341,10 +341,17 @@ class SSNWrapper:
341
  _validate_assets()
342
  ssn_model_dir = _find_ssn_model_file()
343
  sys.path.insert(0, str(ssn_model_dir.resolve()))
 
344
  try:
345
  from SSN_Model import SSN_Model
346
  self.model = SSN_Model()
347
- state = torch.load(str(WEIGHT_FILE), map_location=self.device)
 
 
 
 
 
 
348
  if isinstance(state, dict):
349
  if "model_state_dict" in state:
350
  sd = state["model_state_dict"]
@@ -356,6 +363,7 @@ class SSNWrapper:
356
  sd = state
357
  else:
358
  sd = state
 
359
  self.model.load_state_dict(sd, strict=False)
360
  self.model.eval().to(self.device)
361
  _log(f"✅ SSN model loaded on {self.device}")
 
341
  _validate_assets()
342
  ssn_model_dir = _find_ssn_model_file()
343
  sys.path.insert(0, str(ssn_model_dir.resolve()))
344
+
345
  try:
346
  from SSN_Model import SSN_Model
347
  self.model = SSN_Model()
348
+
349
+ # ✅ Patch for numpy scalar safety
350
+ torch.serialization.add_safe_globals([np.core.multiarray.scalar])
351
+
352
+ # ✅ Allow full checkpoint load
353
+ state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False)
354
+
355
  if isinstance(state, dict):
356
  if "model_state_dict" in state:
357
  sd = state["model_state_dict"]
 
363
  sd = state
364
  else:
365
  sd = state
366
+
367
  self.model.load_state_dict(sd, strict=False)
368
  self.model.eval().to(self.device)
369
  _log(f"✅ SSN model loaded on {self.device}")