|
|
""" |
|
|
Model loading utilities with compatibility fixes |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, Optional |
|
|
|
|
|
def load_model_weights(model: nn.Module, model_path: str) -> bool: |
|
|
""" |
|
|
Load model weights with compatibility handling |
|
|
|
|
|
Args: |
|
|
model: Model instance |
|
|
model_path: Path to model file |
|
|
|
|
|
Returns: |
|
|
True if successful, False otherwise |
|
|
""" |
|
|
try: |
|
|
if not Path(model_path).exists(): |
|
|
print(f"Model file not found: {model_path}") |
|
|
return False |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'state_dict' in checkpoint: |
|
|
|
|
|
state_dict = checkpoint['state_dict'] |
|
|
|
|
|
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
|
model.load_state_dict(state_dict) |
|
|
print(f"Loaded model from checkpoint with metadata") |
|
|
return True |
|
|
elif 'model_state_dict' in checkpoint: |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
print(f"Loaded model from checkpoint with model_state_dict") |
|
|
return True |
|
|
else: |
|
|
|
|
|
try: |
|
|
model.load_state_dict(checkpoint) |
|
|
print(f"Loaded model from state dict") |
|
|
return True |
|
|
except: |
|
|
|
|
|
model.load_state_dict(checkpoint, strict=False) |
|
|
print(f"Loaded model with strict=False (some keys missing)") |
|
|
return True |
|
|
else: |
|
|
|
|
|
model.load_state_dict(checkpoint) |
|
|
print(f"Loaded model directly") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading model from {model_path}: {e}") |
|
|
return False |
|
|
|
|
|
def load_model_with_flexibility(model: nn.Module, model_path: str) -> bool: |
|
|
""" |
|
|
Load model weights with flexibility for size mismatches |
|
|
|
|
|
Args: |
|
|
model: Model instance |
|
|
model_path: Path to model file |
|
|
|
|
|
Returns: |
|
|
True if successful (with warnings), False if failed |
|
|
""" |
|
|
try: |
|
|
if not Path(model_path).exists(): |
|
|
print(f"Model file not found: {model_path}") |
|
|
return False |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
|
|
state_dict = checkpoint['state_dict'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
|
|
|
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
|
|
|
|
|
|
|
model_dict = model.state_dict() |
|
|
|
|
|
|
|
|
filtered_state_dict = {} |
|
|
missing_keys = [] |
|
|
unexpected_keys = [] |
|
|
size_mismatches = [] |
|
|
|
|
|
for k, v in state_dict.items(): |
|
|
if k in model_dict: |
|
|
if v.size() == model_dict[k].size(): |
|
|
filtered_state_dict[k] = v |
|
|
else: |
|
|
size_mismatches.append((k, v.size(), model_dict[k].size())) |
|
|
else: |
|
|
unexpected_keys.append(k) |
|
|
|
|
|
|
|
|
for k in model_dict.keys(): |
|
|
if k not in state_dict: |
|
|
missing_keys.append(k) |
|
|
|
|
|
|
|
|
model_dict.update(filtered_state_dict) |
|
|
model.load_state_dict(model_dict, strict=False) |
|
|
|
|
|
|
|
|
if size_mismatches: |
|
|
print(f"⚠️ Size mismatches ({len(size_mismatches)}):") |
|
|
for k, saved_size, current_size in size_mismatches[:3]: |
|
|
print(f" {k}: saved {saved_size} != current {current_size}") |
|
|
if len(size_mismatches) > 3: |
|
|
print(f" ... and {len(size_mismatches) - 3} more") |
|
|
|
|
|
if missing_keys: |
|
|
print(f"⚠️ Missing keys ({len(missing_keys)}): {missing_keys[:5]}") |
|
|
if len(missing_keys) > 5: |
|
|
print(f" ... and {len(missing_keys) - 5} more") |
|
|
|
|
|
if unexpected_keys: |
|
|
print(f"⚠️ Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:5]}") |
|
|
if len(unexpected_keys) > 5: |
|
|
print(f" ... and {len(unexpected_keys) - 5} more") |
|
|
|
|
|
if filtered_state_dict: |
|
|
print(f"✅ Loaded {len(filtered_state_dict)}/{len(model_dict)} parameters") |
|
|
return True |
|
|
else: |
|
|
print("❌ No parameters loaded") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
return False |
|
|
|
|
|
def create_and_load_model(model_class, model_path: str, **kwargs) -> Optional[nn.Module]: |
|
|
""" |
|
|
Create model and load weights |
|
|
|
|
|
Args: |
|
|
model_class: Model class to instantiate |
|
|
model_path: Path to model weights |
|
|
**kwargs: Arguments for model constructor |
|
|
|
|
|
Returns: |
|
|
Loaded model or None |
|
|
""" |
|
|
try: |
|
|
model = model_class(**kwargs) |
|
|
if load_model_with_flexibility(model, model_path): |
|
|
model.eval() |
|
|
return model |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error creating model: {e}") |
|
|
return None |
|
|
|
|
|
def save_model_with_metadata(model: nn.Module, model_path: str, metadata: Dict[str, Any] = None): |
|
|
""" |
|
|
Save model with metadata |
|
|
|
|
|
Args: |
|
|
model: Model to save |
|
|
model_path: Path to save to |
|
|
metadata: Additional metadata |
|
|
""" |
|
|
checkpoint = { |
|
|
'state_dict': model.state_dict(), |
|
|
'model_class': model.__class__.__name__, |
|
|
'metadata': metadata or {} |
|
|
} |
|
|
|
|
|
Path(model_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
torch.save(checkpoint, model_path) |
|
|
print(f"Model saved to {model_path} with metadata") |
|
|
|