File size: 3,541 Bytes
a77ccbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""

Model utilities for telecom site classification

Handles ConvNeXt model loading and adaptation for transfer learning

"""

import torch
import torch.nn as nn
import timm
import os
from typing import Dict, Any, Optional, Tuple

class TelecomClassifier(nn.Module):
    """

    ConvNeXt-based telecom site classifier

    Uses transfer learning from food detection model

    """
    def __init__(self, num_classes: int = 3, pretrained: bool = True):
        super(TelecomClassifier, self).__init__()
        self.backbone = timm.create_model(
            'convnext_large.fb_in22k_ft_in1k', 
            pretrained=pretrained, 
            num_classes=0  # Remove classification head
        )
        self.feature_dim = self.backbone.num_features
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
        self._init_classifier_weights()
    def _init_classifier_weights(self):
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = self.backbone(x)
        output = self.classifier(features)
        return output
    def freeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("πŸ”’ Backbone frozen for transfer learning")
    def unfreeze_backbone(self):
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("πŸ”“ Backbone unfrozen for fine-tuning")
    def get_parameter_count(self) -> Dict[str, int]:
        backbone_params = sum(p.numel() for p in self.backbone.parameters())
        classifier_params = sum(p.numel() for p in self.classifier.parameters())
        total_params = backbone_params + classifier_params
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return {
            'backbone': backbone_params,
            'classifier': classifier_params,
            'total': total_params,
            'trainable': trainable_params
        }

def load_model(

    model_path: str,

    num_classes: int = 3,

    device: str = 'cpu'

) -> Tuple[TelecomClassifier, Dict[str, Any]]:
    """

    Load trained telecom classifier model

    Args:

        model_path: Path to saved model

        num_classes: Number of output classes

        device: Device to load model on

    Returns:

        Tuple of (model, model_info)

    """
    print(f"πŸ“‚ Loading model from {model_path}")
    model = TelecomClassifier(num_classes=num_classes, pretrained=False)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    model_info = checkpoint.get('model_info', {})
    model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
    model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
    print(f"βœ… Model loaded successfully")
    print(f"   Best accuracy: {model_info.get('best_acc', 'Unknown')}")
    print(f"   Epoch: {model_info.get('epoch', 'Unknown')}")
    return model, model_info