File size: 5,645 Bytes
fc6062a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Model utilities for telecom site classification
Handles ConvNeXt model loading and adaptation for transfer learning
"""

import torch
import torch.nn as nn
import timm
import os
from typing import Dict, Any, Optional, Tuple

class TelecomClassifier(nn.Module):
    """
    ConvNeXt-based telecom site classifier
    Uses transfer learning from food detection model
    """
    def __init__(self, num_classes: int = 3, pretrained: bool = True):
        super(TelecomClassifier, self).__init__()
        self.backbone = timm.create_model(
            'convnext_large.fb_in22k_ft_in1k', 
            pretrained=pretrained, 
            num_classes=0  # Remove classification head
        )
        self.feature_dim = self.backbone.num_features
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
        self._init_classifier_weights()
    def _init_classifier_weights(self):
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        output = self.classifier(features)
        return output
    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("πŸ”’ Backbone frozen for transfer learning")
    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("πŸ”“ Backbone unfrozen for fine-tuning")
    def get_parameter_count(self) -> Dict[str, int]:
        backbone_params = sum(p.numel() for p in self.backbone.parameters())
        classifier_params = sum(p.numel() for p in self.classifier.parameters())
        total_params = backbone_params + classifier_params
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return {
            'backbone': backbone_params,
            'classifier': classifier_params,
            'total': total_params,
            'trainable': trainable_params
        }

def load_model(
    model_path: str,
    num_classes: int = 3,
    device: str = 'cpu'
) -> Tuple[TelecomClassifier, Dict[str, Any]]:
    """
    Load trained telecom classifier model
    Args:
        model_path: Path to saved model
        num_classes: Number of output classes
        device: Device to load model on
    Returns:
        Tuple of (model, model_info)
    """
    print(f"πŸ“‚ Loading model from {model_path}")
    model = TelecomClassifier(num_classes=num_classes, pretrained=False)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model_info = checkpoint.get('model_info', {})
    model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
    model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
    print(f"βœ… Model loaded successfully")
    print(f"   Best accuracy: {model_info.get('best_acc', 'Unknown')}")
    print(f"   Epoch: {model_info.get('epoch', 'Unknown')}")
    return model, model_info 


def create_telecom_model(num_classes: int, food_model_path: str = None, freeze_backbone: bool = True) -> TelecomClassifier:
    """
    Create and initialize the TelecomClassifier model.
    Optionally load weights from a food detection model and freeze backbone.
    """
    model = TelecomClassifier(num_classes=num_classes, pretrained=True)
    if food_model_path and os.path.exists(food_model_path):
        print(f"πŸ”„ Loading backbone weights from: {food_model_path}")
        state_dict = torch.load(food_model_path, map_location='cpu')
        if 'model_state_dict' in state_dict:
            state_dict = state_dict['model_state_dict']
        # Remove classifier weights if present
        backbone_state_dict = {k: v for k, v in state_dict.items() if k.startswith('backbone')}
        model.backbone.load_state_dict({k.replace('backbone.', ''): v for k, v in backbone_state_dict.items()}, strict=False)
    if freeze_backbone:
        model.freeze_backbone()
    else:
        model.unfreeze_backbone()
    return model 


def save_model(model, path, epoch, val_acc, optimizer_state, extra_info=None):
    """
    Save the model checkpoint.
    Args:
        model: The model to save.
        path: Path to save the model.
        epoch: Current epoch.
        val_acc: Validation accuracy.
        optimizer_state: Optimizer state dict.
        extra_info: Any extra info to save (dict).
    """
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
        'best_acc': val_acc,
        'optimizer_state_dict': optimizer_state,
        'model_info': extra_info if extra_info is not None else {}
    }
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(checkpoint, path)
    print(f"πŸ’Ύ Model saved to {path}") 


def get_model_summary(model: TelecomClassifier) -> str:
    """
    Return a string summary of the TelecomClassifier model.
    """
    summary_lines = []
    summary_lines.append(str(model))
    param_counts = model.get_parameter_count()
    summary_lines.append(f"\nParameter counts:")
    for k, v in param_counts.items():
        summary_lines.append(f"  {k}: {v:,}")
    return "\n".join(summary_lines)