Neylton commited on
Commit
a77ccbe
·
verified ·
1 Parent(s): 6e4ff59

Upload model_utils.py

Browse files
Files changed (1) hide show
  1. utils/model_utils.py +90 -1
utils/model_utils.py CHANGED
@@ -1 +1,90 @@
1
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model utilities for telecom site classification
3
+ Handles ConvNeXt model loading and adaptation for transfer learning
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import timm
9
+ import os
10
+ from typing import Dict, Any, Optional, Tuple
11
+
12
+ class TelecomClassifier(nn.Module):
13
+ """
14
+ ConvNeXt-based telecom site classifier
15
+ Uses transfer learning from food detection model
16
+ """
17
+ def __init__(self, num_classes: int = 3, pretrained: bool = True):
18
+ super(TelecomClassifier, self).__init__()
19
+ self.backbone = timm.create_model(
20
+ 'convnext_large.fb_in22k_ft_in1k',
21
+ pretrained=pretrained,
22
+ num_classes=0 # Remove classification head
23
+ )
24
+ self.feature_dim = self.backbone.num_features
25
+ self.classifier = nn.Sequential(
26
+ nn.LayerNorm(self.feature_dim),
27
+ nn.Linear(self.feature_dim, 512),
28
+ nn.ReLU(inplace=True),
29
+ nn.Dropout(0.3),
30
+ nn.Linear(512, 128),
31
+ nn.ReLU(inplace=True),
32
+ nn.Dropout(0.2),
33
+ nn.Linear(128, num_classes)
34
+ )
35
+ self._init_classifier_weights()
36
+ def _init_classifier_weights(self):
37
+ for module in self.classifier.modules():
38
+ if isinstance(module, nn.Linear):
39
+ nn.init.xavier_uniform_(module.weight)
40
+ nn.init.constant_(module.bias, 0)
41
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
42
+ features = self.backbone(x)
43
+ output = self.classifier(features)
44
+ return output
45
+ def freeze_backbone(self):
46
+ for param in self.backbone.parameters():
47
+ param.requires_grad = False
48
+ print("🔒 Backbone frozen for transfer learning")
49
+ def unfreeze_backbone(self):
50
+ for param in self.backbone.parameters():
51
+ param.requires_grad = True
52
+ print("🔓 Backbone unfrozen for fine-tuning")
53
+ def get_parameter_count(self) -> Dict[str, int]:
54
+ backbone_params = sum(p.numel() for p in self.backbone.parameters())
55
+ classifier_params = sum(p.numel() for p in self.classifier.parameters())
56
+ total_params = backbone_params + classifier_params
57
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
58
+ return {
59
+ 'backbone': backbone_params,
60
+ 'classifier': classifier_params,
61
+ 'total': total_params,
62
+ 'trainable': trainable_params
63
+ }
64
+
65
+ def load_model(
66
+ model_path: str,
67
+ num_classes: int = 3,
68
+ device: str = 'cpu'
69
+ ) -> Tuple[TelecomClassifier, Dict[str, Any]]:
70
+ """
71
+ Load trained telecom classifier model
72
+ Args:
73
+ model_path: Path to saved model
74
+ num_classes: Number of output classes
75
+ device: Device to load model on
76
+ Returns:
77
+ Tuple of (model, model_info)
78
+ """
79
+ print(f"📂 Loading model from {model_path}")
80
+ model = TelecomClassifier(num_classes=num_classes, pretrained=False)
81
+ checkpoint = torch.load(model_path, map_location=device)
82
+ model.load_state_dict(checkpoint['model_state_dict'])
83
+ model.eval()
84
+ model_info = checkpoint.get('model_info', {})
85
+ model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
86
+ model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
87
+ print(f"✅ Model loaded successfully")
88
+ print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}")
89
+ print(f" Epoch: {model_info.get('epoch', 'Unknown')}")
90
+ return model, model_info