|
|
""" |
|
|
Exact architecture for Coherence Detection Model. |
|
|
Uses key matching to validate the safetensors file. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaptiveConcatPool2d(nn.Module): |
|
|
"""FastAI-style adaptive concatenation pooling.""" |
|
|
def __init__(self, sz=None): |
|
|
super().__init__() |
|
|
self.ap = nn.AdaptiveAvgPool2d(sz or 1) |
|
|
self.mp = nn.AdaptiveMaxPool2d(sz or 1) |
|
|
|
|
|
def forward(self, x): |
|
|
return torch.cat([self.mp(x), self.ap(x)], 1) |
|
|
|
|
|
|
|
|
if '__main__' in sys.modules: |
|
|
main_module = sys.modules['__main__'] |
|
|
if not hasattr(main_module, 'AdaptiveConcatPool2d'): |
|
|
setattr(main_module, 'AdaptiveConcatPool2d', AdaptiveConcatPool2d) |
|
|
|
|
|
|
|
|
if '__main__' in sys.modules: |
|
|
main_module = sys.modules['__main__'] |
|
|
|
|
|
if not hasattr(main_module, 'AdaptiveConcatPool'): |
|
|
setattr(main_module, 'AdaptiveConcatPool', AdaptiveConcatPool2d) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_resnet_backbone(): |
|
|
"""Helper to get ResNet backbone with version-appropriate API.""" |
|
|
from torchvision.models import resnet34 |
|
|
import torchvision |
|
|
|
|
|
|
|
|
version = torchvision.__version__.split('.') |
|
|
major = int(version[0]) if version[0].isdigit() else 0 |
|
|
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0 |
|
|
|
|
|
if major >= 0 and minor >= 13: |
|
|
|
|
|
return resnet34(weights=None) |
|
|
else: |
|
|
|
|
|
return resnet34(pretrained=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CoherenceDetectionModel(nn.Sequential): |
|
|
""" |
|
|
Clean version that handles torchvision API changes properly. |
|
|
""" |
|
|
def __init__(self, num_classes=3): |
|
|
|
|
|
backbone = _get_resnet_backbone() |
|
|
backbone = nn.Sequential(*list(backbone.children())[:-2]) |
|
|
|
|
|
|
|
|
classifier = nn.Sequential( |
|
|
AdaptiveConcatPool2d(), |
|
|
nn.Flatten(start_dim=1, end_dim=-1), |
|
|
nn.BatchNorm1d(1024), |
|
|
nn.Dropout(p=0.25, inplace=False), |
|
|
nn.Linear(1024, 512, bias=True), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.Dropout(p=0.5, inplace=False), |
|
|
nn.Linear(512, num_classes, bias=True) |
|
|
) |
|
|
|
|
|
super().__init__(backbone, classifier) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_coherence_model(safetensors_path, device='auto'): |
|
|
""" |
|
|
Load safetensors weights with automatic key remapping. |
|
|
|
|
|
Args: |
|
|
safetensors_path: Path to .safetensors file |
|
|
device: 'auto', 'cuda', or 'cpu' |
|
|
""" |
|
|
import safetensors.torch |
|
|
|
|
|
|
|
|
if device == 'auto': |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
else: |
|
|
device = torch.device(device) |
|
|
|
|
|
|
|
|
model = CoherenceDetectionModel(num_classes=3) |
|
|
|
|
|
|
|
|
try: |
|
|
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu') |
|
|
except FileNotFoundError: |
|
|
print(f"Error: File '{safetensors_path}' not found.") |
|
|
print("Testing with sample model structure...") |
|
|
state_dict = model.state_dict() |
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
if missing_keys: |
|
|
print(f"Warning: Missing keys: {missing_keys}") |
|
|
|
|
|
if unexpected_keys: |
|
|
print(f"Warning: Unexpected keys: {unexpected_keys}") |
|
|
|
|
|
if not missing_keys and not unexpected_keys: |
|
|
print(f"✓ CoherenceDetectionModel loaded successfully (exact match)") |
|
|
else: |
|
|
print(f"⚠ CoherenceDetectionModel loaded with key mismatches") |
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print(f" Device: {device}") |
|
|
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_key_matching(safetensors_path="coherence_model.safetensors"): |
|
|
"""Test that keys match between model and safetensors.""" |
|
|
import safetensors.torch |
|
|
|
|
|
print("\nTesting key matching...") |
|
|
|
|
|
try: |
|
|
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu') |
|
|
except FileNotFoundError: |
|
|
print(f" ⚠ File '{safetensors_path}' not found, using model weights") |
|
|
model = CoherenceDetectionModel(num_classes=3) |
|
|
state_dict = model.state_dict() |
|
|
|
|
|
print("\nTesting CoherenceDetectionModel:") |
|
|
model_clean = CoherenceDetectionModel(num_classes=3) |
|
|
missing, unexpected = model_clean.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
if not missing and not unexpected: |
|
|
print(" ✅ Load successful (exact key match)") |
|
|
else: |
|
|
print(f" ⚠ Load completed with issues") |
|
|
if missing: |
|
|
print(f" Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f" Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
return model_clean if not missing and not unexpected else None |
|
|
|
|
|
|
|
|
def print_key_samples(safetensors_path="coherence_model.safetensors"): |
|
|
"""Print sample keys for debugging.""" |
|
|
import safetensors.torch |
|
|
|
|
|
print("\nKey samples:") |
|
|
|
|
|
try: |
|
|
state_dict = safetensors.torch.load_file(safetensors_path, device='cpu') |
|
|
print("From safetensors file (first 5 keys):") |
|
|
for i, key in enumerate(sorted(state_dict.keys())[:5]): |
|
|
print(f" {i}: {key}") |
|
|
except FileNotFoundError: |
|
|
print(f"Safetensors file '{safetensors_path}' not found") |
|
|
print("Showing model structure keys instead:") |
|
|
state_dict = None |
|
|
|
|
|
print("\nFrom CoherenceDetectionModel:") |
|
|
model = CoherenceDetectionModel(num_classes=3) |
|
|
for i, key in enumerate(sorted(model.state_dict().keys())[:5]): |
|
|
print(f" {i}: {key}") |
|
|
|
|
|
return state_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_version_info(): |
|
|
"""Print version information for debugging.""" |
|
|
import torch |
|
|
import torchvision |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Version Information") |
|
|
print("=" * 60) |
|
|
print(f"Torch: {torch.__version__}") |
|
|
print(f"Torchvision: {torchvision.__version__}") |
|
|
print(f"CUDA Available: {torch.cuda.is_available()}") |
|
|
|
|
|
|
|
|
version = torchvision.__version__.split('.') |
|
|
major = int(version[0]) if version[0].isdigit() else 0 |
|
|
minor = int(version[1]) if len(version) > 1 and version[1].isdigit() else 0 |
|
|
|
|
|
if major >= 0 and minor >= 13: |
|
|
print("✓ Using modern torchvision API (weights parameter)") |
|
|
else: |
|
|
print("⚠ Using legacy torchvision API (pretrained parameter)") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 60) |
|
|
print("Coherence Detection Model Architecture") |
|
|
print("=" * 60) |
|
|
|
|
|
print_version_info() |
|
|
state_dict = print_key_samples() |
|
|
print("\n" + "=" * 60) |
|
|
model = test_key_matching() |
|
|
|
|
|
if model: |
|
|
print("\nModel summary:") |
|
|
print(f" Backbone layers: {len(model[0])}") |
|
|
print(f" Classifier layers: {len(model[1])}") |
|
|
print(f" Total sequential blocks: {len(model)}") |
|
|
|