Spaces:
Sleeping
Sleeping
fix: remap bare Sequential keys to net.* when loading ddg_da.pt
Browse filesCheckpoint was saved when DriftPredictorMLP used a bare nn.Sequential,
producing keys like '0.weight'. Current class wraps it as self.net,
expecting 'net.0.weight'. Remap on load so old .pt files load correctly.
- app/models/ddg_da.py +4 -0
app/models/ddg_da.py
CHANGED
|
@@ -96,6 +96,10 @@ class DDGDAPredictor:
|
|
| 96 |
if os.path.exists(model_path):
|
| 97 |
try:
|
| 98 |
state = torch.load(model_path, map_location="cpu", weights_only=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
self.mlp.load_state_dict(state)
|
| 100 |
self.mlp.eval()
|
| 101 |
except Exception as e:
|
|
|
|
| 96 |
if os.path.exists(model_path):
|
| 97 |
try:
|
| 98 |
state = torch.load(model_path, map_location="cpu", weights_only=True)
|
| 99 |
+
# Remap bare Sequential keys ("0.weight") → nested ("net.0.weight")
|
| 100 |
+
# to handle checkpoints saved before the self.net wrapper was added.
|
| 101 |
+
if any(k.startswith("0.") or k.startswith("3.") for k in state):
|
| 102 |
+
state = {f"net.{k}": v for k, v in state.items()}
|
| 103 |
self.mlp.load_state_dict(state)
|
| 104 |
self.mlp.eval()
|
| 105 |
except Exception as e:
|