|
|
""" |
|
|
Model utilities for telecom site classification |
|
|
Handles ConvNeXt model loading and adaptation for transfer learning |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import timm |
|
|
import os |
|
|
from typing import Dict, Any, Optional, Tuple |
|
|
|
|
|
class TelecomClassifier(nn.Module): |
|
|
""" |
|
|
ConvNeXt-based telecom site classifier |
|
|
Uses transfer learning from food detection model |
|
|
""" |
|
|
|
|
|
def __init__(self, num_classes: int = 2, pretrained: bool = True): |
|
|
super(TelecomClassifier, self).__init__() |
|
|
|
|
|
|
|
|
self.backbone = timm.create_model( |
|
|
'convnext_large.fb_in22k_ft_in1k', |
|
|
pretrained=pretrained, |
|
|
num_classes=0 |
|
|
) |
|
|
|
|
|
|
|
|
self.feature_dim = self.backbone.num_features |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.LayerNorm(self.feature_dim), |
|
|
nn.Linear(self.feature_dim, 512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(512, 128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.2), |
|
|
nn.Linear(128, num_classes) |
|
|
) |
|
|
|
|
|
|
|
|
self._init_classifier_weights() |
|
|
|
|
|
def _init_classifier_weights(self): |
|
|
"""Initialize classifier weights using Xavier initialization""" |
|
|
for module in self.classifier.modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
nn.init.constant_(module.bias, 0) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass through the model""" |
|
|
|
|
|
features = self.backbone(x) |
|
|
|
|
|
|
|
|
output = self.classifier(features) |
|
|
|
|
|
return output |
|
|
|
|
|
def freeze_backbone(self): |
|
|
"""Freeze backbone parameters for transfer learning""" |
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
print("π Backbone frozen for transfer learning") |
|
|
|
|
|
def unfreeze_backbone(self): |
|
|
"""Unfreeze backbone parameters for fine-tuning""" |
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = True |
|
|
print("π Backbone unfrozen for fine-tuning") |
|
|
|
|
|
def get_parameter_count(self) -> Dict[str, int]: |
|
|
"""Get parameter counts for different parts of the model""" |
|
|
backbone_params = sum(p.numel() for p in self.backbone.parameters()) |
|
|
classifier_params = sum(p.numel() for p in self.classifier.parameters()) |
|
|
total_params = backbone_params + classifier_params |
|
|
|
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
return { |
|
|
'backbone': backbone_params, |
|
|
'classifier': classifier_params, |
|
|
'total': total_params, |
|
|
'trainable': trainable_params |
|
|
} |
|
|
|
|
|
def load_food_model_weights(model: TelecomClassifier, food_model_path: str) -> TelecomClassifier: |
|
|
""" |
|
|
Load weights from the pre-trained food detection model |
|
|
Only loads the backbone weights, ignoring the classification head |
|
|
""" |
|
|
if not os.path.exists(food_model_path): |
|
|
print(f"β οΈ Food model not found at {food_model_path}") |
|
|
print("π Using ImageNet pretrained weights instead") |
|
|
return model |
|
|
|
|
|
try: |
|
|
print(f"π Loading food model weights from {food_model_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(food_model_path, map_location='cpu') |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict): |
|
|
if 'model_state_dict' in checkpoint: |
|
|
food_state_dict = checkpoint['model_state_dict'] |
|
|
accuracy = checkpoint.get('best_acc', 'Unknown') |
|
|
print(f"π Food model accuracy: {accuracy}%") |
|
|
else: |
|
|
food_state_dict = checkpoint |
|
|
else: |
|
|
food_state_dict = checkpoint |
|
|
|
|
|
|
|
|
backbone_state_dict = {} |
|
|
for key, value in food_state_dict.items(): |
|
|
|
|
|
if not key.startswith('head') and not key.startswith('classifier'): |
|
|
backbone_state_dict[f"backbone.{key}"] = value |
|
|
|
|
|
|
|
|
model_dict = model.state_dict() |
|
|
|
|
|
|
|
|
filtered_dict = {} |
|
|
for key, value in backbone_state_dict.items(): |
|
|
if key in model_dict and model_dict[key].shape == value.shape: |
|
|
filtered_dict[key] = value |
|
|
|
|
|
|
|
|
model_dict.update(filtered_dict) |
|
|
model.load_state_dict(model_dict) |
|
|
|
|
|
print(f"β
Successfully loaded {len(filtered_dict)} backbone layers from food model") |
|
|
print(f"π― Transfer learning ready: backbone initialized with food detection weights") |
|
|
|
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error loading food model weights: {e}") |
|
|
print("π Using ImageNet pretrained weights instead") |
|
|
return model |
|
|
|
|
|
def create_telecom_model( |
|
|
num_classes: int = 2, |
|
|
food_model_path: Optional[str] = None, |
|
|
freeze_backbone: bool = True |
|
|
) -> TelecomClassifier: |
|
|
""" |
|
|
Create telecom classifier model with transfer learning from food detection |
|
|
|
|
|
Args: |
|
|
num_classes: Number of output classes (2 for good/bad) |
|
|
food_model_path: Path to pre-trained food detection model |
|
|
freeze_backbone: Whether to freeze backbone for transfer learning |
|
|
|
|
|
Returns: |
|
|
TelecomClassifier model ready for training |
|
|
""" |
|
|
print("ποΈ Creating telecom site classifier...") |
|
|
|
|
|
|
|
|
model = TelecomClassifier(num_classes=num_classes, pretrained=True) |
|
|
|
|
|
|
|
|
if food_model_path: |
|
|
model = load_food_model_weights(model, food_model_path) |
|
|
|
|
|
|
|
|
if freeze_backbone: |
|
|
model.freeze_backbone() |
|
|
|
|
|
|
|
|
param_counts = model.get_parameter_count() |
|
|
print(f"π Model Statistics:") |
|
|
print(f" Backbone parameters: {param_counts['backbone']:,}") |
|
|
print(f" Classifier parameters: {param_counts['classifier']:,}") |
|
|
print(f" Total parameters: {param_counts['total']:,}") |
|
|
print(f" Trainable parameters: {param_counts['trainable']:,}") |
|
|
print(f" Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB") |
|
|
|
|
|
return model |
|
|
|
|
|
def save_model( |
|
|
model: TelecomClassifier, |
|
|
save_path: str, |
|
|
epoch: int, |
|
|
best_acc: float, |
|
|
optimizer_state: Optional[Dict] = None, |
|
|
additional_info: Optional[Dict] = None |
|
|
) -> None: |
|
|
""" |
|
|
Save model checkpoint with training information |
|
|
|
|
|
Args: |
|
|
model: The model to save |
|
|
save_path: Path to save the model |
|
|
epoch: Current epoch number |
|
|
best_acc: Best validation accuracy achieved |
|
|
optimizer_state: Optimizer state dict |
|
|
additional_info: Additional information to save |
|
|
""" |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'best_acc': best_acc, |
|
|
'model_info': { |
|
|
'architecture': 'ConvNeXt Large', |
|
|
'num_classes': 2, |
|
|
'parameter_count': model.get_parameter_count(), |
|
|
'task': 'telecom_site_classification' |
|
|
} |
|
|
} |
|
|
|
|
|
if optimizer_state: |
|
|
checkpoint['optimizer_state_dict'] = optimizer_state |
|
|
|
|
|
if additional_info: |
|
|
checkpoint.update(additional_info) |
|
|
|
|
|
torch.save(checkpoint, save_path) |
|
|
print(f"πΎ Model saved to {save_path}") |
|
|
|
|
|
def load_model( |
|
|
model_path: str, |
|
|
num_classes: int = 2, |
|
|
device: str = 'cpu' |
|
|
) -> Tuple[TelecomClassifier, Dict[str, Any]]: |
|
|
""" |
|
|
Load trained telecom classifier model |
|
|
|
|
|
Args: |
|
|
model_path: Path to saved model |
|
|
num_classes: Number of output classes |
|
|
device: Device to load model on |
|
|
|
|
|
Returns: |
|
|
Tuple of (model, model_info) |
|
|
""" |
|
|
print(f"π Loading model from {model_path}") |
|
|
|
|
|
|
|
|
model = TelecomClassifier(num_classes=num_classes, pretrained=False) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
model_info = checkpoint.get('model_info', {}) |
|
|
model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown') |
|
|
model_info['epoch'] = checkpoint.get('epoch', 'Unknown') |
|
|
|
|
|
print(f"β
Model loaded successfully") |
|
|
print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}") |
|
|
print(f" Epoch: {model_info.get('epoch', 'Unknown')}") |
|
|
|
|
|
return model, model_info |
|
|
|
|
|
def get_model_summary(model: TelecomClassifier) -> str: |
|
|
""" |
|
|
Get a formatted summary of the model |
|
|
|
|
|
Args: |
|
|
model: The model to summarize |
|
|
|
|
|
Returns: |
|
|
Formatted string with model information |
|
|
""" |
|
|
param_counts = model.get_parameter_count() |
|
|
|
|
|
summary = f""" |
|
|
π€ Telecom Site Classifier Model Summary |
|
|
{'='*50} |
|
|
Architecture: ConvNeXt Large + Custom Classifier |
|
|
Task: Binary Classification (Good/Bad Sites) |
|
|
|
|
|
Parameter Counts: |
|
|
Backbone (ConvNeXt): {param_counts['backbone']:,} |
|
|
Classifier Head: {param_counts['classifier']:,} |
|
|
Total Parameters: {param_counts['total']:,} |
|
|
Trainable Parameters: {param_counts['trainable']:,} |
|
|
|
|
|
Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB |
|
|
Transfer Learning: {'Enabled' if param_counts['trainable'] < param_counts['total'] else 'Disabled'} |
|
|
""" |
|
|
|
|
|
return summary |