Spaces:
Runtime error
Runtime error
Update networks.py
Browse files- networks.py +27 -4
networks.py
CHANGED
|
@@ -12,7 +12,7 @@ class Options:
|
|
| 12 |
# Default values
|
| 13 |
self.fine_height = 256
|
| 14 |
self.fine_width = 192
|
| 15 |
-
self.grid_size =
|
| 16 |
self.use_dropout = False
|
| 17 |
|
| 18 |
def weights_init_normal(m):
|
|
@@ -499,7 +499,30 @@ def save_checkpoint(model, save_path):
|
|
| 499 |
os.makedirs(os.path.dirname(save_path))
|
| 500 |
torch.save(model.state_dict(), save_path)
|
| 501 |
|
| 502 |
-
def load_checkpoint(model, checkpoint_path):
|
| 503 |
if not os.path.exists(checkpoint_path):
|
| 504 |
-
|
| 505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
# Default values
|
| 13 |
self.fine_height = 256
|
| 14 |
self.fine_width = 192
|
| 15 |
+
self.grid_size = 5
|
| 16 |
self.use_dropout = False
|
| 17 |
|
| 18 |
def weights_init_normal(m):
|
|
|
|
| 499 |
os.makedirs(os.path.dirname(save_path))
|
| 500 |
torch.save(model.state_dict(), save_path)
|
| 501 |
|
| 502 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
| 503 |
if not os.path.exists(checkpoint_path):
|
| 504 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
| 505 |
+
|
| 506 |
+
# Load checkpoint with strict=False to ignore size mismatches
|
| 507 |
+
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
| 508 |
+
|
| 509 |
+
# Filter out size-mismatched keys
|
| 510 |
+
model_state_dict = model.state_dict()
|
| 511 |
+
filtered_state_dict = {k: v for k, v in state_dict.items()
|
| 512 |
+
if k in model_state_dict and v.size() == model_state_dict[k].size()}
|
| 513 |
+
|
| 514 |
+
# Load the filtered state dict
|
| 515 |
+
model.load_state_dict(filtered_state_dict, strict=strict)
|
| 516 |
+
|
| 517 |
+
# Print warnings for mismatched keys
|
| 518 |
+
missing_keys = [k for k in model_state_dict.keys() if k not in state_dict]
|
| 519 |
+
unexpected_keys = [k for k in state_dict.keys() if k not in model_state_dict]
|
| 520 |
+
size_mismatch_keys = [k for k in state_dict.keys()
|
| 521 |
+
if k in model_state_dict and state_dict[k].size() != model_state_dict[k].size()]
|
| 522 |
+
|
| 523 |
+
if missing_keys:
|
| 524 |
+
print(f"Missing keys in checkpoint: {missing_keys}")
|
| 525 |
+
if unexpected_keys:
|
| 526 |
+
print(f"Unexpected keys in checkpoint: {unexpected_keys}")
|
| 527 |
+
if size_mismatch_keys:
|
| 528 |
+
print(f"Size mismatch for keys: {size_mismatch_keys}")
|