acceptIN-v3 / utils /model_utils.py
Neylton's picture
Initial commit with ACCEPTIN app
fc6062a
"""
Model utilities for telecom site classification
Handles ConvNeXt model loading and adaptation for transfer learning
"""
import torch
import torch.nn as nn
import timm
import os
from typing import Dict, Any, Optional, Tuple
class TelecomClassifier(nn.Module):
"""
ConvNeXt-based telecom site classifier
Uses transfer learning from food detection model
"""
def __init__(self, num_classes: int = 3, pretrained: bool = True):
super(TelecomClassifier, self).__init__()
self.backbone = timm.create_model(
'convnext_large.fb_in22k_ft_in1k',
pretrained=pretrained,
num_classes=0 # Remove classification head
)
self.feature_dim = self.backbone.num_features
self.classifier = nn.Sequential(
nn.LayerNorm(self.feature_dim),
nn.Linear(self.feature_dim, 512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.2),
nn.Linear(128, num_classes)
)
self._init_classifier_weights()
def _init_classifier_weights(self):
for module in self.classifier.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
nn.init.constant_(module.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
output = self.classifier(features)
return output
def freeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = False
print("πŸ”’ Backbone frozen for transfer learning")
def unfreeze_backbone(self):
for param in self.backbone.parameters():
param.requires_grad = True
print("πŸ”“ Backbone unfrozen for fine-tuning")
def get_parameter_count(self) -> Dict[str, int]:
backbone_params = sum(p.numel() for p in self.backbone.parameters())
classifier_params = sum(p.numel() for p in self.classifier.parameters())
total_params = backbone_params + classifier_params
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
'backbone': backbone_params,
'classifier': classifier_params,
'total': total_params,
'trainable': trainable_params
}
def load_model(
model_path: str,
num_classes: int = 3,
device: str = 'cpu'
) -> Tuple[TelecomClassifier, Dict[str, Any]]:
"""
Load trained telecom classifier model
Args:
model_path: Path to saved model
num_classes: Number of output classes
device: Device to load model on
Returns:
Tuple of (model, model_info)
"""
print(f"πŸ“‚ Loading model from {model_path}")
model = TelecomClassifier(num_classes=num_classes, pretrained=False)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model_info = checkpoint.get('model_info', {})
model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
print(f"βœ… Model loaded successfully")
print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}")
print(f" Epoch: {model_info.get('epoch', 'Unknown')}")
return model, model_info
def create_telecom_model(num_classes: int, food_model_path: str = None, freeze_backbone: bool = True) -> TelecomClassifier:
"""
Create and initialize the TelecomClassifier model.
Optionally load weights from a food detection model and freeze backbone.
"""
model = TelecomClassifier(num_classes=num_classes, pretrained=True)
if food_model_path and os.path.exists(food_model_path):
print(f"πŸ”„ Loading backbone weights from: {food_model_path}")
state_dict = torch.load(food_model_path, map_location='cpu')
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
# Remove classifier weights if present
backbone_state_dict = {k: v for k, v in state_dict.items() if k.startswith('backbone')}
model.backbone.load_state_dict({k.replace('backbone.', ''): v for k, v in backbone_state_dict.items()}, strict=False)
if freeze_backbone:
model.freeze_backbone()
else:
model.unfreeze_backbone()
return model
def save_model(model, path, epoch, val_acc, optimizer_state, extra_info=None):
"""
Save the model checkpoint.
Args:
model: The model to save.
path: Path to save the model.
epoch: Current epoch.
val_acc: Validation accuracy.
optimizer_state: Optimizer state dict.
extra_info: Any extra info to save (dict).
"""
checkpoint = {
'model_state_dict': model.state_dict(),
'epoch': epoch,
'best_acc': val_acc,
'optimizer_state_dict': optimizer_state,
'model_info': extra_info if extra_info is not None else {}
}
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(checkpoint, path)
print(f"πŸ’Ύ Model saved to {path}")
def get_model_summary(model: TelecomClassifier) -> str:
"""
Return a string summary of the TelecomClassifier model.
"""
summary_lines = []
summary_lines.append(str(model))
param_counts = model.get_parameter_count()
summary_lines.append(f"\nParameter counts:")
for k, v in param_counts.items():
summary_lines.append(f" {k}: {v:,}")
return "\n".join(summary_lines)