File size: 8,357 Bytes
c5a3ef9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""
Exact architecture for Coherence Detection Model.
Uses key matching to validate the safetensors file.
"""

import sys
import torch
import torch.nn as nn

# ============================================================================
# AdaptiveConcatPool2d for FastAI model compatitbility
# ============================================================================
class AdaptiveConcatPool2d(nn.Module):
    """FastAI-style adaptive concatenation pooling."""
    def __init__(self, sz=None):
        super().__init__()
        self.ap = nn.AdaptiveAvgPool2d(sz or 1)
        self.mp = nn.AdaptiveMaxPool2d(sz or 1)
        
    def forward(self, x):
        return torch.cat([self.mp(x), self.ap(x)], 1)

# Also register in __main__ if it exists
if '__main__' in sys.modules:
    main_module = sys.modules['__main__']
    if not hasattr(main_module, 'AdaptiveConcatPool2d'):
        setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
        
# Also register under common FastAI names
if '__main__' in sys.modules:
    main_module = sys.modules['__main__']
    # Some FastAI models might expect this
    if not hasattr(main_module, 'AdaptiveConcatPool'):
        setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d)

# ============================================================================
# Utility function to check torchvision version
# ============================================================================
def _get_resnet_backbone():
    """Helper to get ResNet backbone with version-appropriate API."""
    from torchvision.models import resnet34
    import torchvision
    
    # Parse version to determine API
    version = torchvision.__version__.split('.')
    major = int(version[0]) if version[0].isdigit() else 0
    minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
    
    if major >= 0 and minor >= 13:
        # Use new weights API
        return resnet34(weights=None)
    else:
        # Use old pretrained API
        return resnet34(pretrained=False)


# ============================================================================
# Clean model with version detection
# ============================================================================
class CoherenceDetectionModel(nn.Sequential):
    """
    Clean version that handles torchvision API changes properly.
    """
    def __init__(self, num_classes=3):
        # Get backbone using version-appropriate API
        backbone = _get_resnet_backbone()
        backbone = nn.Sequential(*list(backbone.children())[:-2])
        
        # Classifier head
        classifier = nn.Sequential(
            AdaptiveConcatPool2d(),
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.BatchNorm1d(1024),
            nn.Dropout(p=0.25, inplace=False),
            nn.Linear(1024, 512, bias=True),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(512),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(512, num_classes, bias=True)
        )
        
        super().__init__(backbone, classifier)


# ============================================================================
# Loading function
# ============================================================================
def load_coherence_model(safetensors_path, device='auto'):
    """
    Load safetensors weights with automatic key remapping.
    
    Args:
        safetensors_path: Path to .safetensors file
        device: 'auto', 'cuda', or 'cpu'
    """
    import safetensors.torch
    
    # Determine device
    if device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device)
    
    # Create model instance
    model = CoherenceDetectionModel(num_classes=3)
    
    # Load safetensors
    try:
        state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
    except FileNotFoundError:
        print(f"Error: File '{safetensors_path}' not found.")
        print("Testing with sample model structure...")
        state_dict = model.state_dict()
    
    # Load directly (keys should match exactly)
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
    
    if missing_keys:
        print(f"Warning: Missing keys: {missing_keys}")
    
    if unexpected_keys:
        print(f"Warning: Unexpected keys: {unexpected_keys}")
    
    if not missing_keys and not unexpected_keys:
        print(f"✓ CoherenceDetectionModel loaded successfully (exact match)")
    else:
        print(f"⚠ CoherenceDetectionModel loaded with key mismatches")
    
    model = model.to(device)
    model.eval()
    
    print(f"  Device: {device}")
    print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    return model


# ============================================================================
# Test functions
# ============================================================================
def test_key_matching(safetensors_path="coherence_model.safetensors"):
    """Test that keys match between model and safetensors."""
    import safetensors.torch
    
    print("\nTesting key matching...")
    
    try:
        state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
    except FileNotFoundError:
        print(f"  ⚠ File '{safetensors_path}' not found, using model weights")
        model = CoherenceDetectionModel(num_classes=3)
        state_dict = model.state_dict()
    
    print("\nTesting CoherenceDetectionModel:")
    model_clean = CoherenceDetectionModel(num_classes=3)
    missing, unexpected = model_clean.load_state_dict(state_dict, strict=False)
    
    if not missing and not unexpected:
        print("  ✅ Load successful (exact key match)")
    else:
        print(f"  ⚠ Load completed with issues")
        if missing:
            print(f"    Missing keys: {len(missing)}")
        if unexpected:
            print(f"    Unexpected keys: {len(unexpected)}")
    
    return model_clean if not missing and not unexpected else None


def print_key_samples(safetensors_path="coherence_model.safetensors"):
    """Print sample keys for debugging."""
    import safetensors.torch
    
    print("\nKey samples:")
    
    try:
        state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
        print("From safetensors file (first 5 keys):")
        for i, key in enumerate(sorted(state_dict.keys())[:5]):
            print(f"  {i}: {key}")
    except FileNotFoundError:
        print(f"Safetensors file '{safetensors_path}' not found")
        print("Showing model structure keys instead:")
        state_dict = None
    
    print("\nFrom CoherenceDetectionModel:")
    model = CoherenceDetectionModel(num_classes=3)
    for i, key in enumerate(sorted(model.state_dict().keys())[:5]):
        print(f"  {i}: {key}")
    
    return state_dict


# ============================================================================
# Version compatibility info
# ============================================================================
def print_version_info():
    """Print version information for debugging."""
    import torch
    import torchvision
    
    print("\n" + "=" * 60)
    print("Version Information")
    print("=" * 60)
    print(f"Torch: {torch.__version__}")
    print(f"Torchvision: {torchvision.__version__}")
    print(f"CUDA Available: {torch.cuda.is_available()}")
    
    # Check API compatibility
    version = torchvision.__version__.split('.')
    major = int(version[0]) if version[0].isdigit() else 0
    minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
    
    if major >= 0 and minor >= 13:
        print("✓ Using modern torchvision API (weights parameter)")
    else:
        print("⚠ Using legacy torchvision API (pretrained parameter)")
    print("=" * 60)


if __name__ == "__main__":
    print("=" * 60)
    print("Coherence Detection Model Architecture")
    print("=" * 60)
    
    print_version_info()
    state_dict = print_key_samples()
    print("\n" + "=" * 60)
    model = test_key_matching()
    
    if model:
        print("\nModel summary:")
        print(f"  Backbone layers: {len(model[0])}")
        print(f"  Classifier layers: {len(model[1])}")
        print(f"  Total sequential blocks: {len(model)}")