Spaces:
Sleeping
Sleeping
Commit ·
14399c4
1
Parent(s): 5a125c6
Use correct deepfake checkpoint and reject incompatible weights
Browse files- main.py +26 -1
- models/best_model.pth +3 -0
main.py
CHANGED
|
@@ -247,9 +247,34 @@ def load_checkpoint_model():
|
|
| 247 |
try:
|
| 248 |
checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
|
| 249 |
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
loaded_any = True
|
| 252 |
print(f"Loaded checkpoint: {path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
break
|
| 254 |
except Exception as ex:
|
| 255 |
print(f"Failed loading checkpoint {path}: {ex}")
|
|
|
|
| 247 |
try:
|
| 248 |
checkpoint = torch.load(path, map_location=DEVICE, weights_only=False)
|
| 249 |
state_dict = checkpoint.get("model_state_dict", checkpoint)
|
| 250 |
+
|
| 251 |
+
# Only keep keys that belong to this architecture and match tensor shapes.
|
| 252 |
+
model_state = model.state_dict()
|
| 253 |
+
filtered_state = {}
|
| 254 |
+
for key, value in state_dict.items():
|
| 255 |
+
if key in model_state and hasattr(value, "shape") and model_state[key].shape == value.shape:
|
| 256 |
+
filtered_state[key] = value
|
| 257 |
+
|
| 258 |
+
if not filtered_state:
|
| 259 |
+
print(f"Rejected checkpoint (no compatible keys): {path}")
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
load_result = model.load_state_dict(filtered_state, strict=False)
|
| 263 |
+
|
| 264 |
+
# Guardrail: require substantial overlap so unrelated checkpoints don't load.
|
| 265 |
+
loaded_ratio = len(filtered_state) / max(len(model_state), 1)
|
| 266 |
+
if loaded_ratio < 0.7:
|
| 267 |
+
print(
|
| 268 |
+
f"Rejected checkpoint (too few compatible keys: {len(filtered_state)}/{len(model_state)} = {loaded_ratio:.2%}): {path}"
|
| 269 |
+
)
|
| 270 |
+
continue
|
| 271 |
+
|
| 272 |
loaded_any = True
|
| 273 |
print(f"Loaded checkpoint: {path}")
|
| 274 |
+
print(
|
| 275 |
+
f"Compatible keys: {len(filtered_state)}/{len(model_state)} | "
|
| 276 |
+
f"Missing: {len(load_result.missing_keys)} | Unexpected ignored: {len(load_result.unexpected_keys)}"
|
| 277 |
+
)
|
| 278 |
break
|
| 279 |
except Exception as ex:
|
| 280 |
print(f"Failed loading checkpoint {path}: {ex}")
|
models/best_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b302642edfb3dfc5986fae645ba3538a97ecb1ac5e2b3218ce3bfc8e30cef9b4
|
| 3 |
+
size 19494087
|