Spaces:
Sleeping
Sleeping
File size: 5,645 Bytes
fc6062a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
Model utilities for telecom site classification
Handles ConvNeXt model loading and adaptation for transfer learning
"""
import torch
import torch.nn as nn
import timm
import os
from typing import Dict, Any, Optional, Tuple
class TelecomClassifier(nn.Module):
"""
ConvNeXt-based telecom site classifier
Uses transfer learning from food detection model
"""
def __init__(self, num_classes: int = 3, pretrained: bool = True):
super(TelecomClassifier, self).__init__()
self.backbone = timm.create_model(
'convnext_large.fb_in22k_ft_in1k',
pretrained=pretrained,
num_classes=0 # Remove classification head
)
self.feature_dim = self.backbone.num_features
self.classifier = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Linear(self.feature_dim, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
self._init_classifier_weights()
def _init_classifier_weights(self):
for module in self.classifier.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
output = self.classifier(features)
return output
def freeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = False
print("π Backbone frozen for transfer learning")
def unfreeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = True
print("π Backbone unfrozen for fine-tuning")
def get_parameter_count(self) -> Dict[str, int]:
backbone_params = sum(p.numel() for p in self.backbone.parameters())
classifier_params = sum(p.numel() for p in self.classifier.parameters())
total_params = backbone_params + classifier_params
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
'backbone': backbone_params,
'classifier': classifier_params,
'total': total_params,
'trainable': trainable_params
}
def load_model(
model_path: str,
num_classes: int = 3,
device: str = 'cpu'
) -> Tuple[TelecomClassifier, Dict[str, Any]]:
"""
Load trained telecom classifier model
Args:
model_path: Path to saved model
num_classes: Number of output classes
device: Device to load model on
Returns:
Tuple of (model, model_info)
"""
print(f"π Loading model from {model_path}")
model = TelecomClassifier(num_classes=num_classes, pretrained=False)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model_info = checkpoint.get('model_info', {})
model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
print(f"β
Model loaded successfully")
print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}")
print(f" Epoch: {model_info.get('epoch', 'Unknown')}")
return model, model_info
def create_telecom_model(num_classes: int, food_model_path: str = None, freeze_backbone: bool = True) -> TelecomClassifier:
"""
Create and initialize the TelecomClassifier model.
Optionally load weights from a food detection model and freeze backbone.
"""
model = TelecomClassifier(num_classes=num_classes, pretrained=True)
if food_model_path and os.path.exists(food_model_path):
print(f"π Loading backbone weights from: {food_model_path}")
state_dict = torch.load(food_model_path, map_location='cpu')
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
# Remove classifier weights if present
backbone_state_dict = {k: v for k, v in state_dict.items() if k.startswith('backbone')}
model.backbone.load_state_dict({k.replace('backbone.', ''): v for k, v in backbone_state_dict.items()}, strict=False)
if freeze_backbone:
model.freeze_backbone()
else:
model.unfreeze_backbone()
return model
def save_model(model, path, epoch, val_acc, optimizer_state, extra_info=None):
"""
Save the model checkpoint.
Args:
model: The model to save.
path: Path to save the model.
epoch: Current epoch.
val_acc: Validation accuracy.
optimizer_state: Optimizer state dict.
extra_info: Any extra info to save (dict).
"""
checkpoint = {
'model_state_dict': model.state_dict(),
'epoch': epoch,
'best_acc': val_acc,
'optimizer_state_dict': optimizer_state,
'model_info': extra_info if extra_info is not None else {}
}
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(checkpoint, path)
print(f"πΎ Model saved to {path}")
def get_model_summary(model: TelecomClassifier) -> str:
"""
Return a string summary of the TelecomClassifier model.
"""
summary_lines = []
summary_lines.append(str(model))
param_counts = model.get_parameter_count()
summary_lines.append(f"\nParameter counts:")
for k, v in param_counts.items():
summary_lines.append(f" {k}: {v:,}")
return "\n".join(summary_lines) |