""" Exact architecture for Coherence Detection Model. Uses key matching to validate the safetensors file. """ import sys import torch import torch.nn as nn # ============================================================================ # AdaptiveConcatPool2d for FastAI model compatitbility # ============================================================================ class AdaptiveConcatPool2d(nn.Module): """FastAI-style adaptive concatenation pooling.""" def __init__(self, sz=None): super().__init__() self.ap = nn.AdaptiveAvgPool2d(sz or 1) self.mp = nn.AdaptiveMaxPool2d(sz or 1) def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) # Also register in __main__ if it exists if '__main__' in sys.modules: main_module = sys.modules['__main__'] if not hasattr(main_module, 'AdaptiveConcatPool2d'): setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d) # Also register under common FastAI names if '__main__' in sys.modules: main_module = sys.modules['__main__'] # Some FastAI models might expect this if not hasattr(main_module, 'AdaptiveConcatPool'): setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d) # ============================================================================ # Utility function to check torchvision version # ============================================================================ def _get_resnet_backbone(): """Helper to get ResNet backbone with version-appropriate API.""" from torchvision.models import resnet34 import torchvision # Parse version to determine API version = torchvision.__version__.split('.') major = int(version[0]) if version[0].isdigit() else 0 minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0 if major >= 0 and minor >= 13: # Use new weights API return resnet34(weights=None) else: # Use old pretrained API return resnet34(pretrained=False) # ============================================================================ # Clean model with version detection # ============================================================================ class CoherenceDetectionModel(nn.Sequential): """ Clean version that handles torchvision API changes properly. """ def __init__(self, num_classes=3): # Get backbone using version-appropriate API backbone = _get_resnet_backbone() backbone = nn.Sequential(*list(backbone.children())[:-2]) # Classifier head classifier = nn.Sequential( AdaptiveConcatPool2d(), nn.Flatten(start_dim=1, end_dim=-1), nn.BatchNorm1d(1024), nn.Dropout(p=0.25, inplace=False), nn.Linear(1024, 512, bias=True), nn.ReLU(inplace=True), nn.BatchNorm1d(512), nn.Dropout(p=0.5, inplace=False), nn.Linear(512, num_classes, bias=True) ) super().__init__(backbone, classifier) # ============================================================================ # Loading function # ============================================================================ def load_coherence_model(safetensors_path, device='auto'): """ Load safetensors weights with automatic key remapping. Args: safetensors_path: Path to .safetensors file device: 'auto', 'cuda', or 'cpu' """ import safetensors.torch # Determine device if device == 'auto': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: device = torch.device(device) # Create model instance model = CoherenceDetectionModel(num_classes=3) # Load safetensors try: state_dict = safetensors.torch.load_file(safetensors_path, device='cpu') except FileNotFoundError: print(f"Error: File '{safetensors_path}' not found.") print("Testing with sample model structure...") state_dict = model.state_dict() # Load directly (keys should match exactly) missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if missing_keys: print(f"Warning: Missing keys: {missing_keys}") if unexpected_keys: print(f"Warning: Unexpected keys: {unexpected_keys}") if not missing_keys and not unexpected_keys: print(f"✓ CoherenceDetectionModel loaded successfully (exact match)") else: print(f"⚠ CoherenceDetectionModel loaded with key mismatches") model = model.to(device) model.eval() print(f" Device: {device}") print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") return model # ============================================================================ # Test functions # ============================================================================ def test_key_matching(safetensors_path="coherence_model.safetensors"): """Test that keys match between model and safetensors.""" import safetensors.torch print("\nTesting key matching...") try: state_dict = safetensors.torch.load_file(safetensors_path, device='cpu') except FileNotFoundError: print(f" ⚠ File '{safetensors_path}' not found, using model weights") model = CoherenceDetectionModel(num_classes=3) state_dict = model.state_dict() print("\nTesting CoherenceDetectionModel:") model_clean = CoherenceDetectionModel(num_classes=3) missing, unexpected = model_clean.load_state_dict(state_dict, strict=False) if not missing and not unexpected: print(" ✅ Load successful (exact key match)") else: print(f" ⚠ Load completed with issues") if missing: print(f" Missing keys: {len(missing)}") if unexpected: print(f" Unexpected keys: {len(unexpected)}") return model_clean if not missing and not unexpected else None def print_key_samples(safetensors_path="coherence_model.safetensors"): """Print sample keys for debugging.""" import safetensors.torch print("\nKey samples:") try: state_dict = safetensors.torch.load_file(safetensors_path, device='cpu') print("From safetensors file (first 5 keys):") for i, key in enumerate(sorted(state_dict.keys())[:5]): print(f" {i}: {key}") except FileNotFoundError: print(f"Safetensors file '{safetensors_path}' not found") print("Showing model structure keys instead:") state_dict = None print("\nFrom CoherenceDetectionModel:") model = CoherenceDetectionModel(num_classes=3) for i, key in enumerate(sorted(model.state_dict().keys())[:5]): print(f" {i}: {key}") return state_dict # ============================================================================ # Version compatibility info # ============================================================================ def print_version_info(): """Print version information for debugging.""" import torch import torchvision print("\n" + "=" * 60) print("Version Information") print("=" * 60) print(f"Torch: {torch.__version__}") print(f"Torchvision: {torchvision.__version__}") print(f"CUDA Available: {torch.cuda.is_available()}") # Check API compatibility version = torchvision.__version__.split('.') major = int(version[0]) if version[0].isdigit() else 0 minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0 if major >= 0 and minor >= 13: print("✓ Using modern torchvision API (weights parameter)") else: print("⚠ Using legacy torchvision API (pretrained parameter)") print("=" * 60) if __name__ == "__main__": print("=" * 60) print("Coherence Detection Model Architecture") print("=" * 60) print_version_info() state_dict = print_key_samples() print("\n" + "=" * 60) model = test_key_matching() if model: print("\nModel summary:") print(f" Backbone layers: {len(model[0])}") print(f" Classifier layers: {len(model[1])}") print(f" Total sequential blocks: {len(model)}")