jenithjain commited on
Commit
14399c4
·
1 Parent(s): 5a125c6

Use correct deepfake checkpoint and reject incompatible weights

Browse files
Files changed (2) hide show
  1. main.py +26 -1
  2. models/best_model.pth +3 -0
main.py CHANGED
@@ -247,9 +247,34 @@ def load_checkpoint_model():
247
  try:
248
  checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
249
  state_dict = checkpoint.get("model_state_dict", checkpoint)
250
- model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  loaded_any = True
252
  print(f"Loaded checkpoint: {path}")
 
 
 
 
253
  break
254
  except Exception as ex:
255
  print(f"Failed loading checkpoint {path}: {ex}")
 
247
  try:
248
  checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
249
  state_dict = checkpoint.get("model_state_dict", checkpoint)
250
+
251
+ # Only keep keys that belong to this architecture and match tensor shapes.
252
+ model_state = model.state_dict()
253
+ filtered_state = {}
254
+ for key, value in state_dict.items():
255
+ if key in model_state and hasattr(value, "shape") and model_state[key].shape == value.shape:
256
+ filtered_state[key] = value
257
+
258
+ if not filtered_state:
259
+ print(f"Rejected checkpoint (no compatible keys): {path}")
260
+ continue
261
+
262
+ load_result = model.load_state_dict(filtered_state, strict=False)
263
+
264
+ # Guardrail: require substantial overlap so unrelated checkpoints don't load.
265
+ loaded_ratio = len(filtered_state) / max(len(model_state), 1)
266
+ if loaded_ratio < 0.7:
267
+ print(
268
+ f"Rejected checkpoint (too few compatible keys: {len(filtered_state)}/{len(model_state)} = {loaded_ratio:.2%}): {path}"
269
+ )
270
+ continue
271
+
272
  loaded_any = True
273
  print(f"Loaded checkpoint: {path}")
274
+ print(
275
+ f"Compatible keys: {len(filtered_state)}/{len(model_state)} | "
276
+ f"Missing: {len(load_result.missing_keys)} | Unexpected ignored: {len(load_result.unexpected_keys)}"
277
+ )
278
  break
279
  except Exception as ex:
280
  print(f"Failed loading checkpoint {path}: {ex}")
models/best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b302642edfb3dfc5986fae645ba3538a97ecb1ac5e2b3218ce3bfc8e30cef9b4
3
+ size 19494087