Spaces:
Sleeping
Sleeping
| """ | |
| Model utilities for telecom site classification | |
| Handles ConvNeXt model loading and adaptation for transfer learning | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import os | |
| from typing import Dict, Any, Optional, Tuple | |
| class TelecomClassifier(nn.Module): | |
| """ | |
| ConvNeXt-based telecom site classifier | |
| Uses transfer learning from food detection model | |
| """ | |
| def __init__(self, num_classes: int = 3, pretrained: bool = True): | |
| super(TelecomClassifier, self).__init__() | |
| self.backbone = timm.create_model( | |
| 'convnext_large.fb_in22k_ft_in1k', | |
| pretrained=pretrained, | |
| num_classes=0 # Remove classification head | |
| ) | |
| self.feature_dim = self.backbone.num_features | |
| self.classifier = nn.Sequential( | |
| nn.LayerNorm(self.feature_dim), | |
| nn.Linear(self.feature_dim, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.2), | |
| nn.Linear(128, num_classes) | |
| ) | |
| self._init_classifier_weights() | |
| def _init_classifier_weights(self): | |
| for module in self.classifier.modules(): | |
| if isinstance(module, nn.Linear): | |
| nn.init.xavier_uniform_(module.weight) | |
| nn.init.constant_(module.bias, 0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| features = self.backbone(x) | |
| output = self.classifier(features) | |
| return output | |
| def freeze_backbone(self): | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| print("π Backbone frozen for transfer learning") | |
| def unfreeze_backbone(self): | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = True | |
| print("π Backbone unfrozen for fine-tuning") | |
| def get_parameter_count(self) -> Dict[str, int]: | |
| backbone_params = sum(p.numel() for p in self.backbone.parameters()) | |
| classifier_params = sum(p.numel() for p in self.classifier.parameters()) | |
| total_params = backbone_params + classifier_params | |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return { | |
| 'backbone': backbone_params, | |
| 'classifier': classifier_params, | |
| 'total': total_params, | |
| 'trainable': trainable_params | |
| } | |
| def load_model( | |
| model_path: str, | |
| num_classes: int = 3, | |
| device: str = 'cpu' | |
| ) -> Tuple[TelecomClassifier, Dict[str, Any]]: | |
| """ | |
| Load trained telecom classifier model | |
| Args: | |
| model_path: Path to saved model | |
| num_classes: Number of output classes | |
| device: Device to load model on | |
| Returns: | |
| Tuple of (model, model_info) | |
| """ | |
| print(f"π Loading model from {model_path}") | |
| model = TelecomClassifier(num_classes=num_classes, pretrained=False) | |
| checkpoint = torch.load(model_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| model_info = checkpoint.get('model_info', {}) | |
| model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown') | |
| model_info['epoch'] = checkpoint.get('epoch', 'Unknown') | |
| print(f"β Model loaded successfully") | |
| print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}") | |
| print(f" Epoch: {model_info.get('epoch', 'Unknown')}") | |
| return model, model_info | |
| def create_telecom_model(num_classes: int, food_model_path: str = None, freeze_backbone: bool = True) -> TelecomClassifier: | |
| """ | |
| Create and initialize the TelecomClassifier model. | |
| Optionally load weights from a food detection model and freeze backbone. | |
| """ | |
| model = TelecomClassifier(num_classes=num_classes, pretrained=True) | |
| if food_model_path and os.path.exists(food_model_path): | |
| print(f"π Loading backbone weights from: {food_model_path}") | |
| state_dict = torch.load(food_model_path, map_location='cpu') | |
| if 'model_state_dict' in state_dict: | |
| state_dict = state_dict['model_state_dict'] | |
| # Remove classifier weights if present | |
| backbone_state_dict = {k: v for k, v in state_dict.items() if k.startswith('backbone')} | |
| model.backbone.load_state_dict({k.replace('backbone.', ''): v for k, v in backbone_state_dict.items()}, strict=False) | |
| if freeze_backbone: | |
| model.freeze_backbone() | |
| else: | |
| model.unfreeze_backbone() | |
| return model | |
| def save_model(model, path, epoch, val_acc, optimizer_state, extra_info=None): | |
| """ | |
| Save the model checkpoint. | |
| Args: | |
| model: The model to save. | |
| path: Path to save the model. | |
| epoch: Current epoch. | |
| val_acc: Validation accuracy. | |
| optimizer_state: Optimizer state dict. | |
| extra_info: Any extra info to save (dict). | |
| """ | |
| checkpoint = { | |
| 'model_state_dict': model.state_dict(), | |
| 'epoch': epoch, | |
| 'best_acc': val_acc, | |
| 'optimizer_state_dict': optimizer_state, | |
| 'model_info': extra_info if extra_info is not None else {} | |
| } | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| torch.save(checkpoint, path) | |
| print(f"πΎ Model saved to {path}") | |
| def get_model_summary(model: TelecomClassifier) -> str: | |
| """ | |
| Return a string summary of the TelecomClassifier model. | |
| """ | |
| summary_lines = [] | |
| summary_lines.append(str(model)) | |
| param_counts = model.get_parameter_count() | |
| summary_lines.append(f"\nParameter counts:") | |
| for k, v in param_counts.items(): | |
| summary_lines.append(f" {k}: {v:,}") | |
| return "\n".join(summary_lines) |