acceptIN-v2 / utils /model_utils.py
Neylton's picture
Upload model_utils.py
a77ccbe verified
"""
Model utilities for telecom site classification
Handles ConvNeXt model loading and adaptation for transfer learning
"""
import torch
import torch.nn as nn
import timm
import os
from typing import Dict, Any, Optional, Tuple
class TelecomClassifier(nn.Module):
"""
ConvNeXt-based telecom site classifier
Uses transfer learning from food detection model
"""
def __init__(self, num_classes: int = 3, pretrained: bool = True):
super(TelecomClassifier, self).__init__()
self.backbone = timm.create_model(
'convnext_large.fb_in22k_ft_in1k',
pretrained=pretrained,
num_classes=0 # Remove classification head
)
self.feature_dim = self.backbone.num_features
self.classifier = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Linear(self.feature_dim, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
self._init_classifier_weights()
def _init_classifier_weights(self):
for module in self.classifier.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
output = self.classifier(features)
return output
def freeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = False
print("πŸ”’ Backbone frozen for transfer learning")
def unfreeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = True
print("πŸ”“ Backbone unfrozen for fine-tuning")
def get_parameter_count(self) -> Dict[str, int]:
backbone_params = sum(p.numel() for p in self.backbone.parameters())
classifier_params = sum(p.numel() for p in self.classifier.parameters())
total_params = backbone_params + classifier_params
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
'backbone': backbone_params,
'classifier': classifier_params,
'total': total_params,
'trainable': trainable_params
}
def load_model(
model_path: str,
num_classes: int = 3,
device: str = 'cpu'
) -> Tuple[TelecomClassifier, Dict[str, Any]]:
"""
Load trained telecom classifier model
Args:
model_path: Path to saved model
num_classes: Number of output classes
device: Device to load model on
Returns:
Tuple of (model, model_info)
"""
print(f"πŸ“‚ Loading model from {model_path}")
model = TelecomClassifier(num_classes=num_classes, pretrained=False)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model_info = checkpoint.get('model_info', {})
model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
print(f"βœ… Model loaded successfully")
print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}")
print(f" Epoch: {model_info.get('epoch', 'Unknown')}")
return model, model_info