File size: 8,357 Bytes
c5a3ef9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
"""
Exact architecture for Coherence Detection Model.
Uses key matching to validate the safetensors file.
"""
import sys
import torch
import torch.nn as nn
# ============================================================================
# AdaptiveConcatPool2d for FastAI model compatitbility
# ============================================================================
class AdaptiveConcatPool2d(nn.Module):
"""FastAI-style adaptive concatenation pooling."""
def __init__(self, sz=None):
super().__init__()
self.ap = nn.AdaptiveAvgPool2d(sz or 1)
self.mp = nn.AdaptiveMaxPool2d(sz or 1)
def forward(self, x):
return torch.cat([self.mp(x), self.ap(x)], 1)
# Also register in __main__ if it exists
if '__main__' in sys.modules:
main_module = sys.modules['__main__']
if not hasattr(main_module, 'AdaptiveConcatPool2d'):
setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
# Also register under common FastAI names
if '__main__' in sys.modules:
main_module = sys.modules['__main__']
# Some FastAI models might expect this
if not hasattr(main_module, 'AdaptiveConcatPool'):
setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d)
# ============================================================================
# Utility function to check torchvision version
# ============================================================================
def _get_resnet_backbone():
"""Helper to get ResNet backbone with version-appropriate API."""
from torchvision.models import resnet34
import torchvision
# Parse version to determine API
version = torchvision.__version__.split('.')
major = int(version[0]) if version[0].isdigit() else 0
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
if major >= 0 and minor >= 13:
# Use new weights API
return resnet34(weights=None)
else:
# Use old pretrained API
return resnet34(pretrained=False)
# ============================================================================
# Clean model with version detection
# ============================================================================
class CoherenceDetectionModel(nn.Sequential):
"""
Clean version that handles torchvision API changes properly.
"""
def __init__(self, num_classes=3):
# Get backbone using version-appropriate API
backbone = _get_resnet_backbone()
backbone = nn.Sequential(*list(backbone.children())[:-2])
# Classifier head
classifier = nn.Sequential(
AdaptiveConcatPool2d(),
nn.Flatten(start_dim=1, end_dim=-1),
nn.BatchNorm1d(1024),
nn.Dropout(p=0.25, inplace=False),
nn.Linear(1024, 512, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(512, num_classes, bias=True)
)
super().__init__(backbone, classifier)
# ============================================================================
# Loading function
# ============================================================================
def load_coherence_model(safetensors_path, device='auto'):
"""
Load safetensors weights with automatic key remapping.
Args:
safetensors_path: Path to .safetensors file
device: 'auto', 'cuda', or 'cpu'
"""
import safetensors.torch
# Determine device
if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device)
# Create model instance
model = CoherenceDetectionModel(num_classes=3)
# Load safetensors
try:
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
except FileNotFoundError:
print(f"Error: File '{safetensors_path}' not found.")
print("Testing with sample model structure...")
state_dict = model.state_dict()
# Load directly (keys should match exactly)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Warning: Missing keys: {missing_keys}")
if unexpected_keys:
print(f"Warning: Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print(f"✓ CoherenceDetectionModel loaded successfully (exact match)")
else:
print(f"⚠ CoherenceDetectionModel loaded with key mismatches")
model = model.to(device)
model.eval()
print(f" Device: {device}")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
return model
# ============================================================================
# Test functions
# ============================================================================
def test_key_matching(safetensors_path="coherence_model.safetensors"):
"""Test that keys match between model and safetensors."""
import safetensors.torch
print("\nTesting key matching...")
try:
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
except FileNotFoundError:
print(f" ⚠ File '{safetensors_path}' not found, using model weights")
model = CoherenceDetectionModel(num_classes=3)
state_dict = model.state_dict()
print("\nTesting CoherenceDetectionModel:")
model_clean = CoherenceDetectionModel(num_classes=3)
missing, unexpected = model_clean.load_state_dict(state_dict, strict=False)
if not missing and not unexpected:
print(" ✅ Load successful (exact key match)")
else:
print(f" ⚠ Load completed with issues")
if missing:
print(f" Missing keys: {len(missing)}")
if unexpected:
print(f" Unexpected keys: {len(unexpected)}")
return model_clean if not missing and not unexpected else None
def print_key_samples(safetensors_path="coherence_model.safetensors"):
"""Print sample keys for debugging."""
import safetensors.torch
print("\nKey samples:")
try:
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
print("From safetensors file (first 5 keys):")
for i, key in enumerate(sorted(state_dict.keys())[:5]):
print(f" {i}: {key}")
except FileNotFoundError:
print(f"Safetensors file '{safetensors_path}' not found")
print("Showing model structure keys instead:")
state_dict = None
print("\nFrom CoherenceDetectionModel:")
model = CoherenceDetectionModel(num_classes=3)
for i, key in enumerate(sorted(model.state_dict().keys())[:5]):
print(f" {i}: {key}")
return state_dict
# ============================================================================
# Version compatibility info
# ============================================================================
def print_version_info():
"""Print version information for debugging."""
import torch
import torchvision
print("\n" + "=" * 60)
print("Version Information")
print("=" * 60)
print(f"Torch: {torch.__version__}")
print(f"Torchvision: {torchvision.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
# Check API compatibility
version = torchvision.__version__.split('.')
major = int(version[0]) if version[0].isdigit() else 0
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
if major >= 0 and minor >= 13:
print("✓ Using modern torchvision API (weights parameter)")
else:
print("⚠ Using legacy torchvision API (pretrained parameter)")
print("=" * 60)
if __name__ == "__main__":
print("=" * 60)
print("Coherence Detection Model Architecture")
print("=" * 60)
print_version_info()
state_dict = print_key_samples()
print("\n" + "=" * 60)
model = test_key_matching()
if model:
print("\nModel summary:")
print(f" Backbone layers: {len(model[0])}")
print(f" Classifier layers: {len(model[1])}")
print(f" Total sequential blocks: {len(model)}")
|