""" Model utilities for telecom site classification Handles ConvNeXt model loading and adaptation for transfer learning """ import torch import torch.nn as nn import timm import os from typing import Dict, Any, Optional, Tuple class TelecomClassifier(nn.Module): """ ConvNeXt-based telecom site classifier Uses transfer learning from food detection model """ def __init__(self, num_classes: int = 3, pretrained: bool = True): super(TelecomClassifier, self).__init__() self.backbone = timm.create_model( 'convnext_large.fb_in22k_ft_in1k', pretrained=pretrained, num_classes=0 # Remove classification head ) self.feature_dim = self.backbone.num_features self.classifier = nn.Sequential( nn.LayerNorm(self.feature_dim), nn.Linear(self.feature_dim, 512), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(128, num_classes) ) self._init_classifier_weights() def _init_classifier_weights(self): for module in self.classifier.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.constant_(module.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.backbone(x) output = self.classifier(features) return output def freeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad = False print("🔒 Backbone frozen for transfer learning") def unfreeze_backbone(self): for param in self.backbone.parameters(): param.requires_grad = True print("🔓 Backbone unfrozen for fine-tuning") def get_parameter_count(self) -> Dict[str, int]: backbone_params = sum(p.numel() for p in self.backbone.parameters()) classifier_params = sum(p.numel() for p in self.classifier.parameters()) total_params = backbone_params + classifier_params trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) return { 'backbone': backbone_params, 'classifier': classifier_params, 'total': total_params, 'trainable': trainable_params } def load_model( model_path: str, num_classes: int = 3, device: str = 'cpu' ) -> Tuple[TelecomClassifier, Dict[str, Any]]: """ Load trained telecom classifier model Args: model_path: Path to saved model num_classes: Number of output classes device: Device to load model on Returns: Tuple of (model, model_info) """ print(f"📂 Loading model from {model_path}") model = TelecomClassifier(num_classes=num_classes, pretrained=False) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() model_info = checkpoint.get('model_info', {}) model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown') model_info['epoch'] = checkpoint.get('epoch', 'Unknown') print(f"✅ Model loaded successfully") print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}") print(f" Epoch: {model_info.get('epoch', 'Unknown')}") return model, model_info