Spaces:
Sleeping
Sleeping
| """ | |
| Model utilities for telecom site classification | |
| Handles ConvNeXt model loading and adaptation for transfer learning | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import os | |
| from typing import Dict, Any, Optional, Tuple | |
| class TelecomClassifier(nn.Module): | |
| """ | |
| ConvNeXt-based telecom site classifier | |
| Uses transfer learning from food detection model | |
| """ | |
| def __init__(self, num_classes: int = 3, pretrained: bool = True): | |
| super(TelecomClassifier, self).__init__() | |
| self.backbone = timm.create_model( | |
| 'convnext_large.fb_in22k_ft_in1k', | |
| pretrained=pretrained, | |
| num_classes=0 # Remove classification head | |
| ) | |
| self.feature_dim = self.backbone.num_features | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(self.feature_dim), | |
| nn.Linear(self.feature_dim, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_classes) | |
| ) | |
| self._init_classifier_weights() | |
| def _init_classifier_weights(self): | |
| for module in self.classifier.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| nn.init.constant_(module.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| features = self.backbone(x) | |
| output = self.classifier(features) | |
| return output | |
| def freeze_backbone(self): | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| print("π Backbone frozen for transfer learning") | |
| def unfreeze_backbone(self): | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = True | |
| print("π Backbone unfrozen for fine-tuning") | |
| def get_parameter_count(self) -> Dict[str, int]: | |
| backbone_params = sum(p.numel() for p in self.backbone.parameters()) | |
| classifier_params = sum(p.numel() for p in self.classifier.parameters()) | |
| total_params = backbone_params + classifier_params | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return { | |
| 'backbone': backbone_params, | |
| 'classifier': classifier_params, | |
| 'total': total_params, | |
| 'trainable': trainable_params | |
| } | |
| def load_model( | |
| model_path: str, | |
| num_classes: int = 3, | |
| device: str = 'cpu' | |
| ) -> Tuple[TelecomClassifier, Dict[str, Any]]: | |
| """ | |
| Load trained telecom classifier model | |
| Args: | |
| model_path: Path to saved model | |
| num_classes: Number of output classes | |
| device: Device to load model on | |
| Returns: | |
| Tuple of (model, model_info) | |
| """ | |
| print(f"π Loading model from {model_path}") | |
| model = TelecomClassifier(num_classes=num_classes, pretrained=False) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| model_info = checkpoint.get('model_info', {}) | |
| model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown') | |
| model_info['epoch'] = checkpoint.get('epoch', 'Unknown') | |
| print(f"β Model loaded successfully") | |
| print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}") | |
| print(f" Epoch: {model_info.get('epoch', 'Unknown')}") | |
| return model, model_info |