coherence_detection / model_architecture.py
reddesert's picture
Convert to safetensors format and add model architecture
c5a3ef9
"""
Exact architecture for Coherence Detection Model.
Uses key matching to validate the safetensors file.
"""
import sys
import torch
import torch.nn as nn
# ============================================================================
# AdaptiveConcatPool2d for FastAI model compatitbility
# ============================================================================
class AdaptiveConcatPool2d(nn.Module):
"""FastAI-style adaptive concatenation pooling."""
def __init__(self, sz=None):
super().__init__()
self.ap = nn.AdaptiveAvgPool2d(sz or 1)
self.mp = nn.AdaptiveMaxPool2d(sz or 1)
def forward(self, x):
return torch.cat([self.mp(x), self.ap(x)], 1)
# Also register in __main__ if it exists
if '__main__' in sys.modules:
main_module = sys.modules['__main__']
if not hasattr(main_module, 'AdaptiveConcatPool2d'):
setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d)
# Also register under common FastAI names
if '__main__' in sys.modules:
main_module = sys.modules['__main__']
# Some FastAI models might expect this
if not hasattr(main_module, 'AdaptiveConcatPool'):
setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d)
# ============================================================================
# Utility function to check torchvision version
# ============================================================================
def _get_resnet_backbone():
"""Helper to get ResNet backbone with version-appropriate API."""
from torchvision.models import resnet34
import torchvision
# Parse version to determine API
version = torchvision.__version__.split('.')
major = int(version[0]) if version[0].isdigit() else 0
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
if major >= 0 and minor >= 13:
# Use new weights API
return resnet34(weights=None)
else:
# Use old pretrained API
return resnet34(pretrained=False)
# ============================================================================
# Clean model with version detection
# ============================================================================
class CoherenceDetectionModel(nn.Sequential):
"""
Clean version that handles torchvision API changes properly.
"""
def __init__(self, num_classes=3):
# Get backbone using version-appropriate API
backbone = _get_resnet_backbone()
backbone = nn.Sequential(*list(backbone.children())[:-2])
# Classifier head
classifier = nn.Sequential(
AdaptiveConcatPool2d(),
nn.Flatten(start_dim=1, end_dim=-1),
nn.BatchNorm1d(1024),
nn.Dropout(p=0.25, inplace=False),
nn.Linear(1024, 512, bias=True),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(p=0.5, inplace=False),
nn.Linear(512, num_classes, bias=True)
)
super().__init__(backbone, classifier)
# ============================================================================
# Loading function
# ============================================================================
def load_coherence_model(safetensors_path, device='auto'):
"""
Load safetensors weights with automatic key remapping.
Args:
safetensors_path: Path to .safetensors file
device: 'auto', 'cuda', or 'cpu'
"""
import safetensors.torch
# Determine device
if device == 'auto':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device(device)
# Create model instance
model = CoherenceDetectionModel(num_classes=3)
# Load safetensors
try:
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
except FileNotFoundError:
print(f"Error: File '{safetensors_path}' not found.")
print("Testing with sample model structure...")
state_dict = model.state_dict()
# Load directly (keys should match exactly)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Warning: Missing keys: {missing_keys}")
if unexpected_keys:
print(f"Warning: Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print(f"✓ CoherenceDetectionModel loaded successfully (exact match)")
else:
print(f"⚠ CoherenceDetectionModel loaded with key mismatches")
model = model.to(device)
model.eval()
print(f" Device: {device}")
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
return model
# ============================================================================
# Test functions
# ============================================================================
def test_key_matching(safetensors_path="coherence_model.safetensors"):
"""Test that keys match between model and safetensors."""
import safetensors.torch
print("\nTesting key matching...")
try:
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
except FileNotFoundError:
print(f" ⚠ File '{safetensors_path}' not found, using model weights")
model = CoherenceDetectionModel(num_classes=3)
state_dict = model.state_dict()
print("\nTesting CoherenceDetectionModel:")
model_clean = CoherenceDetectionModel(num_classes=3)
missing, unexpected = model_clean.load_state_dict(state_dict, strict=False)
if not missing and not unexpected:
print(" ✅ Load successful (exact key match)")
else:
print(f" ⚠ Load completed with issues")
if missing:
print(f" Missing keys: {len(missing)}")
if unexpected:
print(f" Unexpected keys: {len(unexpected)}")
return model_clean if not missing and not unexpected else None
def print_key_samples(safetensors_path="coherence_model.safetensors"):
"""Print sample keys for debugging."""
import safetensors.torch
print("\nKey samples:")
try:
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu')
print("From safetensors file (first 5 keys):")
for i, key in enumerate(sorted(state_dict.keys())[:5]):
print(f" {i}: {key}")
except FileNotFoundError:
print(f"Safetensors file '{safetensors_path}' not found")
print("Showing model structure keys instead:")
state_dict = None
print("\nFrom CoherenceDetectionModel:")
model = CoherenceDetectionModel(num_classes=3)
for i, key in enumerate(sorted(model.state_dict().keys())[:5]):
print(f" {i}: {key}")
return state_dict
# ============================================================================
# Version compatibility info
# ============================================================================
def print_version_info():
"""Print version information for debugging."""
import torch
import torchvision
print("\n" + "=" * 60)
print("Version Information")
print("=" * 60)
print(f"Torch: {torch.__version__}")
print(f"Torchvision: {torchvision.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
# Check API compatibility
version = torchvision.__version__.split('.')
major = int(version[0]) if version[0].isdigit() else 0
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0
if major >= 0 and minor >= 13:
print("✓ Using modern torchvision API (weights parameter)")
else:
print("⚠ Using legacy torchvision API (pretrained parameter)")
print("=" * 60)
if __name__ == "__main__":
print("=" * 60)
print("Coherence Detection Model Architecture")
print("=" * 60)
print_version_info()
state_dict = print_key_samples()
print("\n" + "=" * 60)
model = test_key_matching()
if model:
print("\nModel summary:")
print(f" Backbone layers: {len(model[0])}")
print(f" Classifier layers: {len(model[1])}")
print(f" Total sequential blocks: {len(model)}")