Neylton commited on
Commit
168169c
Β·
1 Parent(s): 5705ccc

Add utils folder with required modules

Browse files
utils/__pycache__/data_utils.cpython-311.pyc ADDED
Binary file (18.7 kB). View file
 
utils/__pycache__/model_utils.cpython-311.pyc ADDED
Binary file (14.2 kB). View file
 
utils/data_utils.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data utilities for telecom site classification
3
+ Handles data loading, transformations, and dataset management
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
9
+ from torchvision import transforms, datasets
10
+ from PIL import Image
11
+ import numpy as np
12
+ from typing import Tuple, Dict, List, Optional
13
+ from collections import Counter
14
+ import random
15
+
16
+ class TelecomSiteDataset(Dataset):
17
+ """
18
+ Custom dataset for telecom site images
19
+ Supports both training and validation modes with appropriate transforms
20
+ """
21
+
22
+ def __init__(self, data_dir: str, split: str = 'train', image_size: int = 224):
23
+ """
24
+ Initialize telecom site dataset
25
+
26
+ Args:
27
+ data_dir: Root directory containing train/val folders
28
+ split: 'train' or 'val'
29
+ image_size: Size to resize images to
30
+ """
31
+ self.data_dir = data_dir
32
+ self.split = split
33
+ self.image_size = image_size
34
+
35
+ # Define class mapping
36
+ self.classes = ['bad', 'good'] # 0: bad, 1: good
37
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
38
+
39
+ # Load image paths and labels
40
+ self.samples = self._load_samples()
41
+
42
+ # Define transforms
43
+ self.transform = self._get_transforms()
44
+
45
+ print(f"πŸ“Š {split.upper()} Dataset loaded:")
46
+ print(f" Total samples: {len(self.samples)}")
47
+ print(f" Classes: {self.classes}")
48
+ self._print_class_distribution()
49
+
50
+ def _load_samples(self) -> List[Tuple[str, int]]:
51
+ """Load image paths and corresponding labels"""
52
+ samples = []
53
+ split_dir = os.path.join(self.data_dir, self.split)
54
+
55
+ for class_name in self.classes:
56
+ class_dir = os.path.join(split_dir, class_name)
57
+ if not os.path.exists(class_dir):
58
+ print(f"⚠️ Warning: {class_dir} not found")
59
+ continue
60
+
61
+ class_idx = self.class_to_idx[class_name]
62
+
63
+ # Load all images from class directory
64
+ for img_name in os.listdir(class_dir):
65
+ if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
66
+ img_path = os.path.join(class_dir, img_name)
67
+ samples.append((img_path, class_idx))
68
+
69
+ return samples
70
+
71
+ def _print_class_distribution(self):
72
+ """Print class distribution for the dataset"""
73
+ class_counts = Counter([label for _, label in self.samples])
74
+ for class_name, class_idx in self.class_to_idx.items():
75
+ count = class_counts.get(class_idx, 0)
76
+ print(f" {class_name}: {count} samples")
77
+
78
+ def _get_transforms(self) -> transforms.Compose:
79
+ """Get appropriate transforms for the split"""
80
+ if self.split == 'train':
81
+ return transforms.Compose([
82
+ transforms.Resize((self.image_size + 32, self.image_size + 32)),
83
+ transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)),
84
+ transforms.RandomHorizontalFlip(p=0.5),
85
+ transforms.RandomRotation(degrees=10),
86
+ transforms.ColorJitter(
87
+ brightness=0.2,
88
+ contrast=0.2,
89
+ saturation=0.2,
90
+ hue=0.1
91
+ ),
92
+ transforms.ToTensor(),
93
+ transforms.Normalize(
94
+ mean=[0.485, 0.456, 0.406],
95
+ std=[0.229, 0.224, 0.225]
96
+ ),
97
+ transforms.RandomErasing(p=0.1, scale=(0.02, 0.08))
98
+ ])
99
+ else:
100
+ return transforms.Compose([
101
+ transforms.Resize((self.image_size, self.image_size)),
102
+ transforms.ToTensor(),
103
+ transforms.Normalize(
104
+ mean=[0.485, 0.456, 0.406],
105
+ std=[0.229, 0.224, 0.225]
106
+ )
107
+ ])
108
+
109
+ def __len__(self) -> int:
110
+ return len(self.samples)
111
+
112
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
113
+ """Get a sample from the dataset"""
114
+ img_path, label = self.samples[idx]
115
+
116
+ # Load image
117
+ try:
118
+ image = Image.open(img_path).convert('RGB')
119
+ except Exception as e:
120
+ print(f"⚠️ Error loading image {img_path}: {e}")
121
+ # Return a black image as fallback
122
+ image = Image.new('RGB', (self.image_size, self.image_size), color='black')
123
+
124
+ # Apply transforms
125
+ if self.transform:
126
+ image = self.transform(image)
127
+
128
+ return image, label
129
+
130
+ def create_data_loaders(
131
+ data_dir: str,
132
+ batch_size: int = 16,
133
+ num_workers: int = 4,
134
+ image_size: int = 224,
135
+ use_weighted_sampling: bool = True
136
+ ) -> Tuple[DataLoader, DataLoader]:
137
+ """
138
+ Create train and validation data loaders
139
+
140
+ Args:
141
+ data_dir: Root directory containing train/val folders
142
+ batch_size: Batch size for data loaders
143
+ num_workers: Number of worker processes
144
+ image_size: Size to resize images to
145
+ use_weighted_sampling: Whether to use weighted sampling for imbalanced data
146
+
147
+ Returns:
148
+ Tuple of (train_loader, val_loader)
149
+ """
150
+ # Create datasets
151
+ train_dataset = TelecomSiteDataset(data_dir, 'train', image_size)
152
+ val_dataset = TelecomSiteDataset(data_dir, 'val', image_size)
153
+
154
+ # Create samplers
155
+ train_sampler = None
156
+ if use_weighted_sampling and len(train_dataset) > 0:
157
+ train_sampler = create_weighted_sampler(train_dataset)
158
+
159
+ # Create data loaders
160
+ train_loader = DataLoader(
161
+ train_dataset,
162
+ batch_size=batch_size,
163
+ sampler=train_sampler,
164
+ shuffle=(train_sampler is None),
165
+ num_workers=num_workers,
166
+ pin_memory=torch.cuda.is_available(),
167
+ drop_last=True
168
+ )
169
+
170
+ val_loader = DataLoader(
171
+ val_dataset,
172
+ batch_size=batch_size,
173
+ shuffle=False,
174
+ num_workers=num_workers,
175
+ pin_memory=torch.cuda.is_available()
176
+ )
177
+
178
+ print(f"πŸ“¦ Data loaders created:")
179
+ print(f" Batch size: {batch_size}")
180
+ print(f" Num workers: {num_workers}")
181
+ print(f" Train batches: {len(train_loader)}")
182
+ print(f" Val batches: {len(val_loader)}")
183
+ print(f" Weighted sampling: {use_weighted_sampling}")
184
+
185
+ return train_loader, val_loader
186
+
187
+ def create_weighted_sampler(dataset: TelecomSiteDataset) -> WeightedRandomSampler:
188
+ """
189
+ Create weighted random sampler for imbalanced datasets
190
+
191
+ Args:
192
+ dataset: The dataset to create sampler for
193
+
194
+ Returns:
195
+ WeightedRandomSampler for balanced sampling
196
+ """
197
+ # Count samples per class
198
+ class_counts = Counter([label for _, label in dataset.samples])
199
+ total_samples = len(dataset.samples)
200
+
201
+ # Calculate weights (inverse frequency)
202
+ class_weights = {}
203
+ for class_idx in range(len(dataset.classes)):
204
+ class_weights[class_idx] = total_samples / (len(dataset.classes) * class_counts.get(class_idx, 1))
205
+
206
+ # Create sample weights
207
+ sample_weights = [class_weights[label] for _, label in dataset.samples]
208
+
209
+ sampler = WeightedRandomSampler(
210
+ weights=sample_weights,
211
+ num_samples=len(sample_weights),
212
+ replacement=True
213
+ )
214
+
215
+ print(f"βš–οΈ Weighted sampler created:")
216
+ for class_name, class_idx in dataset.class_to_idx.items():
217
+ print(f" {class_name}: weight={class_weights[class_idx]:.3f}")
218
+
219
+ return sampler
220
+
221
+ def get_inference_transform(image_size: int = 224) -> transforms.Compose:
222
+ """
223
+ Get transform for inference/prediction
224
+
225
+ Args:
226
+ image_size: Size to resize images to
227
+
228
+ Returns:
229
+ Transform pipeline for inference
230
+ """
231
+ return transforms.Compose([
232
+ transforms.Resize((image_size, image_size)),
233
+ transforms.ToTensor(),
234
+ transforms.Normalize(
235
+ mean=[0.485, 0.456, 0.406],
236
+ std=[0.229, 0.224, 0.225]
237
+ )
238
+ ])
239
+
240
+ def prepare_image_for_inference(image: Image.Image, transform: transforms.Compose) -> torch.Tensor:
241
+ """
242
+ Prepare a PIL image for model inference
243
+
244
+ Args:
245
+ image: PIL Image
246
+ transform: Transform pipeline
247
+
248
+ Returns:
249
+ Preprocessed tensor ready for model
250
+ """
251
+ if image.mode != 'RGB':
252
+ image = image.convert('RGB')
253
+
254
+ # Apply transforms and add batch dimension
255
+ tensor = transform(image).unsqueeze(0)
256
+ return tensor
257
+
258
+ def visualize_batch(data_loader: DataLoader, num_samples: int = 8) -> None:
259
+ """
260
+ Visualize a batch of images from the data loader
261
+
262
+ Args:
263
+ data_loader: DataLoader to sample from
264
+ num_samples: Number of samples to visualize
265
+ """
266
+ try:
267
+ import matplotlib.pyplot as plt
268
+
269
+ # Get a batch
270
+ batch_images, batch_labels = next(iter(data_loader))
271
+
272
+ # Denormalize images for visualization
273
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
274
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
275
+
276
+ # Create figure
277
+ fig, axes = plt.subplots(2, 4, figsize=(12, 6))
278
+ axes = axes.flatten()
279
+
280
+ class_names = ['Bad', 'Good']
281
+
282
+ for i in range(min(num_samples, len(batch_images))):
283
+ # Denormalize
284
+ img = batch_images[i] * std + mean
285
+ img = torch.clamp(img, 0, 1)
286
+
287
+ # Convert to numpy and transpose
288
+ img_np = img.permute(1, 2, 0).numpy()
289
+
290
+ # Plot
291
+ axes[i].imshow(img_np)
292
+ axes[i].set_title(f'Class: {class_names[batch_labels[i]]}')
293
+ axes[i].axis('off')
294
+
295
+ plt.tight_layout()
296
+ plt.show()
297
+
298
+ except ImportError:
299
+ print("⚠️ Matplotlib not available for visualization")
300
+
301
+ def check_data_directory(data_dir: str) -> Dict[str, int]:
302
+ """
303
+ Check the data directory structure and count samples
304
+
305
+ Args:
306
+ data_dir: Root directory to check
307
+
308
+ Returns:
309
+ Dictionary with sample counts
310
+ """
311
+ print(f"πŸ“‚ Checking data directory: {data_dir}")
312
+
313
+ if not os.path.exists(data_dir):
314
+ print(f"❌ Data directory not found: {data_dir}")
315
+ return {}
316
+
317
+ counts = {}
318
+
319
+ for split in ['train', 'val']:
320
+ split_dir = os.path.join(data_dir, split)
321
+ if not os.path.exists(split_dir):
322
+ print(f"⚠️ {split} directory not found")
323
+ continue
324
+
325
+ split_counts = {}
326
+ for class_name in ['good', 'bad']:
327
+ class_dir = os.path.join(split_dir, class_name)
328
+ if os.path.exists(class_dir):
329
+ image_files = [f for f in os.listdir(class_dir)
330
+ if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff'))]
331
+ split_counts[class_name] = len(image_files)
332
+ else:
333
+ split_counts[class_name] = 0
334
+
335
+ counts[split] = split_counts
336
+ print(f" {split.upper()}: Good={split_counts['good']}, Bad={split_counts['bad']}")
337
+
338
+ return counts
339
+
340
+ def create_sample_data_structure():
341
+ """
342
+ Create sample data directory structure with instructions
343
+ """
344
+ instructions = """
345
+ πŸ“ Data Directory Structure:
346
+
347
+ data/
348
+ β”œβ”€β”€ train/
349
+ β”‚ β”œβ”€β”€ good/ # Place good telecom site images here
350
+ β”‚ β”‚ β”œβ”€β”€ good_site_001.jpg
351
+ β”‚ β”‚ β”œβ”€β”€ good_site_002.jpg
352
+ β”‚ β”‚ └── ...
353
+ β”‚ └── bad/ # Place bad telecom site images here
354
+ β”‚ β”œβ”€β”€ bad_site_001.jpg
355
+ β”‚ β”œβ”€β”€ bad_site_002.jpg
356
+ β”‚ └── ...
357
+ └── val/
358
+ β”œβ”€β”€ good/ # Validation good images
359
+ β”‚ β”œβ”€β”€ val_good_001.jpg
360
+ β”‚ └── ...
361
+ └── bad/ # Validation bad images
362
+ β”œβ”€β”€ val_bad_001.jpg
363
+ └── ...
364
+
365
+ πŸ“‹ Data Requirements:
366
+ - Minimum 50 images per class for training
367
+ - 20% of data should be reserved for validation
368
+ - Images should be clear and well-lit
369
+ - Recommended resolution: 224x224 or higher
370
+ - Supported formats: JPG, PNG, JPEG, BMP, TIFF
371
+
372
+ πŸ“Š Good Site Criteria:
373
+ - Proper cable assembly and routing
374
+ - All cards correctly installed and labeled
375
+ - Clean and organized equipment layout
376
+ - Proper grounding and safety measures
377
+ - Clear and readable labels
378
+
379
+ πŸ“Š Bad Site Criteria:
380
+ - Messy or improper cable routing
381
+ - Missing or incorrectly installed cards
382
+ - Poor equipment organization
383
+ - Missing or unreadable labels
384
+ - Safety issues or violations
385
+ """
386
+
387
+ print(instructions)
388
+ return instructions
utils/model_utils.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model utilities for telecom site classification
3
+ Handles ConvNeXt model loading and adaptation for transfer learning
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import timm
9
+ import os
10
+ from typing import Dict, Any, Optional, Tuple
11
+
12
+ class TelecomClassifier(nn.Module):
13
+ """
14
+ ConvNeXt-based telecom site classifier
15
+ Uses transfer learning from food detection model
16
+ """
17
+
18
+ def __init__(self, num_classes: int = 2, pretrained: bool = True):
19
+ super(TelecomClassifier, self).__init__()
20
+
21
+ # Load ConvNeXt Large model (same as food detection)
22
+ self.backbone = timm.create_model(
23
+ 'convnext_large.fb_in22k_ft_in1k',
24
+ pretrained=pretrained,
25
+ num_classes=0 # Remove classification head
26
+ )
27
+
28
+ # Get feature dimensions
29
+ self.feature_dim = self.backbone.num_features
30
+
31
+ # Custom classification head for telecom sites
32
+ self.classifier = nn.Sequential(
33
+ nn.LayerNorm(self.feature_dim),
34
+ nn.Linear(self.feature_dim, 512),
35
+ nn.ReLU(inplace=True),
36
+ nn.Dropout(0.3),
37
+ nn.Linear(512, 128),
38
+ nn.ReLU(inplace=True),
39
+ nn.Dropout(0.2),
40
+ nn.Linear(128, num_classes)
41
+ )
42
+
43
+ # Initialize classifier weights
44
+ self._init_classifier_weights()
45
+
46
+ def _init_classifier_weights(self):
47
+ """Initialize classifier weights using Xavier initialization"""
48
+ for module in self.classifier.modules():
49
+ if isinstance(module, nn.Linear):
50
+ nn.init.xavier_uniform_(module.weight)
51
+ nn.init.constant_(module.bias, 0)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ """Forward pass through the model"""
55
+ # Extract features using ConvNeXt backbone
56
+ features = self.backbone(x)
57
+
58
+ # Classify using custom head
59
+ output = self.classifier(features)
60
+
61
+ return output
62
+
63
+ def freeze_backbone(self):
64
+ """Freeze backbone parameters for transfer learning"""
65
+ for param in self.backbone.parameters():
66
+ param.requires_grad = False
67
+ print("πŸ”’ Backbone frozen for transfer learning")
68
+
69
+ def unfreeze_backbone(self):
70
+ """Unfreeze backbone parameters for fine-tuning"""
71
+ for param in self.backbone.parameters():
72
+ param.requires_grad = True
73
+ print("πŸ”“ Backbone unfrozen for fine-tuning")
74
+
75
+ def get_parameter_count(self) -> Dict[str, int]:
76
+ """Get parameter counts for different parts of the model"""
77
+ backbone_params = sum(p.numel() for p in self.backbone.parameters())
78
+ classifier_params = sum(p.numel() for p in self.classifier.parameters())
79
+ total_params = backbone_params + classifier_params
80
+
81
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
82
+
83
+ return {
84
+ 'backbone': backbone_params,
85
+ 'classifier': classifier_params,
86
+ 'total': total_params,
87
+ 'trainable': trainable_params
88
+ }
89
+
90
+ def load_food_model_weights(model: TelecomClassifier, food_model_path: str) -> TelecomClassifier:
91
+ """
92
+ Load weights from the pre-trained food detection model
93
+ Only loads the backbone weights, ignoring the classification head
94
+ """
95
+ if not os.path.exists(food_model_path):
96
+ print(f"⚠️ Food model not found at {food_model_path}")
97
+ print("πŸš€ Using ImageNet pretrained weights instead")
98
+ return model
99
+
100
+ try:
101
+ print(f"πŸ“‚ Loading food model weights from {food_model_path}")
102
+
103
+ # Load the food model checkpoint
104
+ checkpoint = torch.load(food_model_path, map_location='cpu')
105
+
106
+ # Handle different checkpoint formats
107
+ if isinstance(checkpoint, dict):
108
+ if 'model_state_dict' in checkpoint:
109
+ food_state_dict = checkpoint['model_state_dict']
110
+ accuracy = checkpoint.get('best_acc', 'Unknown')
111
+ print(f"πŸ“Š Food model accuracy: {accuracy}%")
112
+ else:
113
+ food_state_dict = checkpoint
114
+ else:
115
+ food_state_dict = checkpoint
116
+
117
+ # Create a new state dict with only backbone weights
118
+ backbone_state_dict = {}
119
+ for key, value in food_state_dict.items():
120
+ # Only include backbone weights (exclude head/classifier)
121
+ if not key.startswith('head') and not key.startswith('classifier'):
122
+ backbone_state_dict[f"backbone.{key}"] = value
123
+
124
+ # Load backbone weights into our model
125
+ model_dict = model.state_dict()
126
+
127
+ # Filter out keys that don't match our model structure
128
+ filtered_dict = {}
129
+ for key, value in backbone_state_dict.items():
130
+ if key in model_dict and model_dict[key].shape == value.shape:
131
+ filtered_dict[key] = value
132
+
133
+ # Update model with filtered weights
134
+ model_dict.update(filtered_dict)
135
+ model.load_state_dict(model_dict)
136
+
137
+ print(f"βœ… Successfully loaded {len(filtered_dict)} backbone layers from food model")
138
+ print(f"🎯 Transfer learning ready: backbone initialized with food detection weights")
139
+
140
+ return model
141
+
142
+ except Exception as e:
143
+ print(f"❌ Error loading food model weights: {e}")
144
+ print("πŸš€ Using ImageNet pretrained weights instead")
145
+ return model
146
+
147
+ def create_telecom_model(
148
+ num_classes: int = 2,
149
+ food_model_path: Optional[str] = None,
150
+ freeze_backbone: bool = True
151
+ ) -> TelecomClassifier:
152
+ """
153
+ Create telecom classifier model with transfer learning from food detection
154
+
155
+ Args:
156
+ num_classes: Number of output classes (2 for good/bad)
157
+ food_model_path: Path to pre-trained food detection model
158
+ freeze_backbone: Whether to freeze backbone for transfer learning
159
+
160
+ Returns:
161
+ TelecomClassifier model ready for training
162
+ """
163
+ print("πŸ—οΈ Creating telecom site classifier...")
164
+
165
+ # Create the model
166
+ model = TelecomClassifier(num_classes=num_classes, pretrained=True)
167
+
168
+ # Load food model weights if available
169
+ if food_model_path:
170
+ model = load_food_model_weights(model, food_model_path)
171
+
172
+ # Freeze backbone if requested
173
+ if freeze_backbone:
174
+ model.freeze_backbone()
175
+
176
+ # Print model information
177
+ param_counts = model.get_parameter_count()
178
+ print(f"πŸ“Š Model Statistics:")
179
+ print(f" Backbone parameters: {param_counts['backbone']:,}")
180
+ print(f" Classifier parameters: {param_counts['classifier']:,}")
181
+ print(f" Total parameters: {param_counts['total']:,}")
182
+ print(f" Trainable parameters: {param_counts['trainable']:,}")
183
+ print(f" Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB")
184
+
185
+ return model
186
+
187
+ def save_model(
188
+ model: TelecomClassifier,
189
+ save_path: str,
190
+ epoch: int,
191
+ best_acc: float,
192
+ optimizer_state: Optional[Dict] = None,
193
+ additional_info: Optional[Dict] = None
194
+ ) -> None:
195
+ """
196
+ Save model checkpoint with training information
197
+
198
+ Args:
199
+ model: The model to save
200
+ save_path: Path to save the model
201
+ epoch: Current epoch number
202
+ best_acc: Best validation accuracy achieved
203
+ optimizer_state: Optimizer state dict
204
+ additional_info: Additional information to save
205
+ """
206
+ checkpoint = {
207
+ 'epoch': epoch,
208
+ 'model_state_dict': model.state_dict(),
209
+ 'best_acc': best_acc,
210
+ 'model_info': {
211
+ 'architecture': 'ConvNeXt Large',
212
+ 'num_classes': 2,
213
+ 'parameter_count': model.get_parameter_count(),
214
+ 'task': 'telecom_site_classification'
215
+ }
216
+ }
217
+
218
+ if optimizer_state:
219
+ checkpoint['optimizer_state_dict'] = optimizer_state
220
+
221
+ if additional_info:
222
+ checkpoint.update(additional_info)
223
+
224
+ torch.save(checkpoint, save_path)
225
+ print(f"πŸ’Ύ Model saved to {save_path}")
226
+
227
+ def load_model(
228
+ model_path: str,
229
+ num_classes: int = 2,
230
+ device: str = 'cpu'
231
+ ) -> Tuple[TelecomClassifier, Dict[str, Any]]:
232
+ """
233
+ Load trained telecom classifier model
234
+
235
+ Args:
236
+ model_path: Path to saved model
237
+ num_classes: Number of output classes
238
+ device: Device to load model on
239
+
240
+ Returns:
241
+ Tuple of (model, model_info)
242
+ """
243
+ print(f"πŸ“‚ Loading model from {model_path}")
244
+
245
+ # Create model architecture
246
+ model = TelecomClassifier(num_classes=num_classes, pretrained=False)
247
+
248
+ # Load checkpoint
249
+ checkpoint = torch.load(model_path, map_location=device)
250
+
251
+ # Load model weights
252
+ model.load_state_dict(checkpoint['model_state_dict'])
253
+ model.eval()
254
+
255
+ # Extract model information
256
+ model_info = checkpoint.get('model_info', {})
257
+ model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
258
+ model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
259
+
260
+ print(f"βœ… Model loaded successfully")
261
+ print(f" Best accuracy: {model_info.get('best_acc', 'Unknown')}")
262
+ print(f" Epoch: {model_info.get('epoch', 'Unknown')}")
263
+
264
+ return model, model_info
265
+
266
+ def get_model_summary(model: TelecomClassifier) -> str:
267
+ """
268
+ Get a formatted summary of the model
269
+
270
+ Args:
271
+ model: The model to summarize
272
+
273
+ Returns:
274
+ Formatted string with model information
275
+ """
276
+ param_counts = model.get_parameter_count()
277
+
278
+ summary = f"""
279
+ πŸ€– Telecom Site Classifier Model Summary
280
+ {'='*50}
281
+ Architecture: ConvNeXt Large + Custom Classifier
282
+ Task: Binary Classification (Good/Bad Sites)
283
+
284
+ Parameter Counts:
285
+ Backbone (ConvNeXt): {param_counts['backbone']:,}
286
+ Classifier Head: {param_counts['classifier']:,}
287
+ Total Parameters: {param_counts['total']:,}
288
+ Trainable Parameters: {param_counts['trainable']:,}
289
+
290
+ Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB
291
+ Transfer Learning: {'Enabled' if param_counts['trainable'] < param_counts['total'] else 'Disabled'}
292
+ """
293
+
294
+ return summary