will702 commited on
Commit
ef2e4d9
·
1 Parent(s): ec4688a

fix: remap bare Sequential keys to net.* when loading ddg_da.pt

Browse files

Checkpoint was saved when DriftPredictorMLP used a bare nn.Sequential,
producing keys like '0.weight'. Current class wraps it as self.net,
expecting 'net.0.weight'. Remap on load so old .pt files load correctly.

Files changed (1) hide show
  1. app/models/ddg_da.py +4 -0
app/models/ddg_da.py CHANGED
@@ -96,6 +96,10 @@ class DDGDAPredictor:
96
  if os.path.exists(model_path):
97
  try:
98
  state = torch.load(model_path, map_location="cpu", weights_only=True)
 
 
 
 
99
  self.mlp.load_state_dict(state)
100
  self.mlp.eval()
101
  except Exception as e:
 
96
  if os.path.exists(model_path):
97
  try:
98
  state = torch.load(model_path, map_location="cpu", weights_only=True)
99
+ # Remap bare Sequential keys ("0.weight") → nested ("net.0.weight")
100
+ # to handle checkpoints saved before the self.net wrapper was added.
101
+ if any(k.startswith("0.") or k.startswith("3.") for k in state):
102
+ state = {f"net.{k}": v for k, v in state.items()}
103
  self.mlp.load_state_dict(state)
104
  self.mlp.eval()
105
  except Exception as e: