File size: 6,452 Bytes
f4bee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""
Model loading utilities with compatibility fixes
"""
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Any, Optional

def load_model_weights(model: nn.Module, model_path: str) -> bool:
    """
    Load model weights with compatibility handling
    
    Args:
        model: Model instance
        model_path: Path to model file
    
    Returns:
        True if successful, False otherwise
    """
    try:
        if not Path(model_path).exists():
            print(f"Model file not found: {model_path}")
            return False
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Handle different checkpoint formats
        if isinstance(checkpoint, dict):
            if 'state_dict' in checkpoint:
                # New format with metadata
                state_dict = checkpoint['state_dict']
                # Remove 'module.' prefix if present (for DataParallel)
                state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
                model.load_state_dict(state_dict)
                print(f"Loaded model from checkpoint with metadata")
                return True
            elif 'model_state_dict' in checkpoint:
                # Alternative format
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"Loaded model from checkpoint with model_state_dict")
                return True
            else:
                # Assume it's a state dict
                try:
                    model.load_state_dict(checkpoint)
                    print(f"Loaded model from state dict")
                    return True
                except:
                    # Try with strict=False
                    model.load_state_dict(checkpoint, strict=False)
                    print(f"Loaded model with strict=False (some keys missing)")
                    return True
        else:
            # Assume it's a state dict
            model.load_state_dict(checkpoint)
            print(f"Loaded model directly")
            return True
            
    except Exception as e:
        print(f"Error loading model from {model_path}: {e}")
        return False

def load_model_with_flexibility(model: nn.Module, model_path: str) -> bool:
    """
    Load model weights with flexibility for size mismatches
    
    Args:
        model: Model instance
        model_path: Path to model file
    
    Returns:
        True if successful (with warnings), False if failed
    """
    try:
        if not Path(model_path).exists():
            print(f"Model file not found: {model_path}")
            return False
        
        # Load checkpoint
        checkpoint = torch.load(model_path, map_location='cpu')
        
        # Get state dict
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint
        
        # Remove 'module.' prefix if present
        state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        
        # Get current model state dict
        model_dict = model.state_dict()
        
        # Filter out incompatible keys
        filtered_state_dict = {}
        missing_keys = []
        unexpected_keys = []
        size_mismatches = []
        
        for k, v in state_dict.items():
            if k in model_dict:
                if v.size() == model_dict[k].size():
                    filtered_state_dict[k] = v
                else:
                    size_mismatches.append((k, v.size(), model_dict[k].size()))
            else:
                unexpected_keys.append(k)
        
        # Check for missing keys in state_dict
        for k in model_dict.keys():
            if k not in state_dict:
                missing_keys.append(k)
        
        # Load filtered state dict
        model_dict.update(filtered_state_dict)
        model.load_state_dict(model_dict, strict=False)
        
        # Print warnings
        if size_mismatches:
            print(f"⚠️  Size mismatches ({len(size_mismatches)}):")
            for k, saved_size, current_size in size_mismatches[:3]:  # Show first 3
                print(f"    {k}: saved {saved_size} != current {current_size}")
            if len(size_mismatches) > 3:
                print(f"    ... and {len(size_mismatches) - 3} more")
        
        if missing_keys:
            print(f"⚠️  Missing keys ({len(missing_keys)}): {missing_keys[:5]}")
            if len(missing_keys) > 5:
                print(f"    ... and {len(missing_keys) - 5} more")
        
        if unexpected_keys:
            print(f"⚠️  Unexpected keys ({len(unexpected_keys)}): {unexpected_keys[:5]}")
            if len(unexpected_keys) > 5:
                print(f"    ... and {len(unexpected_keys) - 5} more")
        
        if filtered_state_dict:
            print(f"✅ Loaded {len(filtered_state_dict)}/{len(model_dict)} parameters")
            return True
        else:
            print("❌ No parameters loaded")
            return False
            
    except Exception as e:
        print(f"❌ Error loading model: {e}")
        return False

def create_and_load_model(model_class, model_path: str, **kwargs) -> Optional[nn.Module]:
    """
    Create model and load weights
    
    Args:
        model_class: Model class to instantiate
        model_path: Path to model weights
        **kwargs: Arguments for model constructor
    
    Returns:
        Loaded model or None
    """
    try:
        model = model_class(**kwargs)
        if load_model_with_flexibility(model, model_path):
            model.eval()
            return model
        return None
    except Exception as e:
        print(f"Error creating model: {e}")
        return None

def save_model_with_metadata(model: nn.Module, model_path: str, metadata: Dict[str, Any] = None):
    """
    Save model with metadata
    
    Args:
        model: Model to save
        model_path: Path to save to
        metadata: Additional metadata
    """
    checkpoint = {
        'state_dict': model.state_dict(),
        'model_class': model.__class__.__name__,
        'metadata': metadata or {}
    }
    
    Path(model_path).parent.mkdir(parents=True, exist_ok=True)
    torch.save(checkpoint, model_path)
    print(f"Model saved to {model_path} with metadata")