Spaces:
Sleeping
Sleeping
Update shadow_generator.py
Browse files- shadow_generator.py +9 -1
shadow_generator.py
CHANGED
|
@@ -341,10 +341,17 @@ class SSNWrapper:
|
|
| 341 |
_validate_assets()
|
| 342 |
ssn_model_dir = _find_ssn_model_file()
|
| 343 |
sys.path.insert(0, str(ssn_model_dir.resolve()))
|
|
|
|
| 344 |
try:
|
| 345 |
from SSN_Model import SSN_Model
|
| 346 |
self.model = SSN_Model()
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
if isinstance(state, dict):
|
| 349 |
if "model_state_dict" in state:
|
| 350 |
sd = state["model_state_dict"]
|
|
@@ -356,6 +363,7 @@ class SSNWrapper:
|
|
| 356 |
sd = state
|
| 357 |
else:
|
| 358 |
sd = state
|
|
|
|
| 359 |
self.model.load_state_dict(sd, strict=False)
|
| 360 |
self.model.eval().to(self.device)
|
| 361 |
_log(f"✅ SSN model loaded on {self.device}")
|
|
|
|
| 341 |
_validate_assets()
|
| 342 |
ssn_model_dir = _find_ssn_model_file()
|
| 343 |
sys.path.insert(0, str(ssn_model_dir.resolve()))
|
| 344 |
+
|
| 345 |
try:
|
| 346 |
from SSN_Model import SSN_Model
|
| 347 |
self.model = SSN_Model()
|
| 348 |
+
|
| 349 |
+
# ✅ Patch for numpy scalar safety
|
| 350 |
+
torch.serialization.add_safe_globals([np.core.multiarray.scalar])
|
| 351 |
+
|
| 352 |
+
# ✅ Allow full checkpoint load
|
| 353 |
+
state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False)
|
| 354 |
+
|
| 355 |
if isinstance(state, dict):
|
| 356 |
if "model_state_dict" in state:
|
| 357 |
sd = state["model_state_dict"]
|
|
|
|
| 363 |
sd = state
|
| 364 |
else:
|
| 365 |
sd = state
|
| 366 |
+
|
| 367 |
self.model.load_state_dict(sd, strict=False)
|
| 368 |
self.model.eval().to(self.device)
|
| 369 |
_log(f"✅ SSN model loaded on {self.device}")
|