Neylton commited on
Commit
66a8df2
Β·
verified Β·
1 Parent(s): 59228b6

Upload 2 files

Browse files
Files changed (2) hide show
  1. data_utils.py +377 -0
  2. model_utils.py +239 -0
data_utils.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data utilities for fire detection classification
3
+ Handles data loading, transformations, and dataset management
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
9
+ from torchvision import transforms, datasets
10
+ from PIL import Image
11
+ import numpy as np
12
+ from typing import Tuple, Dict, List, Optional
13
+ from collections import Counter
14
+ import random
15
+
16
+ class FireDetectionDataset(Dataset):
17
+ """
18
+ Custom dataset for fire detection images
19
+ Supports both training and validation modes with appropriate transforms
20
+ """
21
+
22
+ def __init__(self, data_dir: str, split: str = 'train', image_size: int = 224):
23
+ """
24
+ Initialize fire detection dataset
25
+
26
+ Args:
27
+ data_dir: Root directory containing train/val folders
28
+ split: 'train' or 'val'
29
+ image_size: Size to resize images to
30
+ """
31
+ self.data_dir = data_dir
32
+ self.split = split
33
+ self.image_size = image_size
34
+
35
+ # Define class mapping
36
+ self.classes = ['fire', 'no_fire'] # 0: fire, 1: no_fire
37
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
38
+
39
+ # Load image paths and labels
40
+ self.samples = self._load_samples()
41
+
42
+ # Define transforms
43
+ self.transform = self._get_transforms()
44
+
45
+ print(f"πŸ”₯ {split.upper()} Dataset loaded:")
46
+ print(f" Total samples: {len(self.samples)}")
47
+ print(f" Classes: {self.classes}")
48
+ self._print_class_distribution()
49
+
50
+ def _load_samples(self) -> List[Tuple[str, int]]:
51
+ """Load image paths and corresponding labels"""
52
+ samples = []
53
+ split_dir = os.path.join(self.data_dir, self.split)
54
+
55
+ for class_name in self.classes:
56
+ class_dir = os.path.join(split_dir, class_name)
57
+ if not os.path.exists(class_dir):
58
+ print(f"⚠️ Warning: {class_dir} not found")
59
+ continue
60
+
61
+ class_idx = self.class_to_idx[class_name]
62
+
63
+ # Load all images from class directory and subdirectories
64
+ for root, dirs, files in os.walk(class_dir):
65
+ for img_name in files:
66
+ if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
67
+ img_path = os.path.join(root, img_name)
68
+ samples.append((img_path, class_idx))
69
+
70
+ return samples
71
+
72
+ def _print_class_distribution(self):
73
+ """Print class distribution for the dataset"""
74
+ class_counts = Counter([label for _, label in self.samples])
75
+ for class_name, class_idx in self.class_to_idx.items():
76
+ count = class_counts.get(class_idx, 0)
77
+ print(f" {class_name}: {count} samples")
78
+
79
+ def _get_transforms(self) -> transforms.Compose:
80
+ """Get appropriate transforms for the split"""
81
+ if self.split == 'train':
82
+ return transforms.Compose([
83
+ transforms.Resize((self.image_size + 32, self.image_size + 32)),
84
+ transforms.RandomResizedCrop(self.image_size, scale=(0.8, 1.0)),
85
+ transforms.RandomHorizontalFlip(p=0.5),
86
+ transforms.RandomRotation(degrees=10),
87
+ transforms.ColorJitter(
88
+ brightness=0.2,
89
+ contrast=0.2,
90
+ saturation=0.2,
91
+ hue=0.1
92
+ ),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(
95
+ mean=[0.485, 0.456, 0.406],
96
+ std=[0.229, 0.224, 0.225]
97
+ ),
98
+ transforms.RandomErasing(p=0.1, scale=(0.02, 0.08))
99
+ ])
100
+ else:
101
+ return transforms.Compose([
102
+ transforms.Resize((self.image_size, self.image_size)),
103
+ transforms.ToTensor(),
104
+ transforms.Normalize(
105
+ mean=[0.485, 0.456, 0.406],
106
+ std=[0.229, 0.224, 0.225]
107
+ )
108
+ ])
109
+
110
+ def __len__(self) -> int:
111
+ return len(self.samples)
112
+
113
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
114
+ """Get a sample from the dataset"""
115
+ img_path, label = self.samples[idx]
116
+
117
+ # Load image
118
+ try:
119
+ image = Image.open(img_path).convert('RGB')
120
+ except Exception as e:
121
+ print(f"⚠️ Error loading image {img_path}: {e}")
122
+ # Return a black image as fallback
123
+ image = Image.new('RGB', (self.image_size, self.image_size), color='black')
124
+
125
+ # Apply transforms
126
+ if self.transform:
127
+ image = self.transform(image)
128
+
129
+ return image, label
130
+
131
+ def create_data_loaders(
132
+ data_dir: str,
133
+ batch_size: int = 16,
134
+ num_workers: int = 4,
135
+ image_size: int = 224,
136
+ use_weighted_sampling: bool = True
137
+ ) -> Tuple[DataLoader, DataLoader]:
138
+ """
139
+ Create train and validation data loaders
140
+
141
+ Args:
142
+ data_dir: Root directory containing train/val folders
143
+ batch_size: Batch size for data loaders
144
+ num_workers: Number of worker processes
145
+ image_size: Size to resize images to
146
+ use_weighted_sampling: Whether to use weighted sampling for imbalanced data
147
+
148
+ Returns:
149
+ Tuple of (train_loader, val_loader)
150
+ """
151
+ # Create datasets
152
+ train_dataset = FireDetectionDataset(data_dir, 'train', image_size)
153
+ val_dataset = FireDetectionDataset(data_dir, 'val', image_size)
154
+
155
+ # Create samplers
156
+ train_sampler = None
157
+ if use_weighted_sampling and len(train_dataset) > 0:
158
+ train_sampler = create_weighted_sampler(train_dataset)
159
+
160
+ # Create data loaders
161
+ train_loader = DataLoader(
162
+ train_dataset,
163
+ batch_size=batch_size,
164
+ sampler=train_sampler,
165
+ shuffle=(train_sampler is None),
166
+ num_workers=num_workers,
167
+ pin_memory=torch.cuda.is_available(),
168
+ drop_last=True
169
+ )
170
+
171
+ val_loader = DataLoader(
172
+ val_dataset,
173
+ batch_size=batch_size,
174
+ shuffle=False,
175
+ num_workers=num_workers,
176
+ pin_memory=torch.cuda.is_available()
177
+ )
178
+
179
+ print(f"πŸ“¦ Data loaders created:")
180
+ print(f" Batch size: {batch_size}")
181
+ print(f" Num workers: {num_workers}")
182
+ print(f" Train batches: {len(train_loader)}")
183
+ print(f" Val batches: {len(val_loader)}")
184
+ print(f" Weighted sampling: {use_weighted_sampling}")
185
+
186
+ return train_loader, val_loader
187
+
188
+ def create_weighted_sampler(dataset: FireDetectionDataset) -> WeightedRandomSampler:
189
+ """
190
+ Create weighted random sampler for imbalanced datasets
191
+
192
+ Args:
193
+ dataset: The dataset to create sampler for
194
+
195
+ Returns:
196
+ WeightedRandomSampler for balanced sampling
197
+ """
198
+ # Count samples per class
199
+ class_counts = Counter([label for _, label in dataset.samples])
200
+ total_samples = len(dataset.samples)
201
+
202
+ # Calculate weights (inverse frequency)
203
+ class_weights = {}
204
+ for class_idx, count in class_counts.items():
205
+ class_weights[class_idx] = total_samples / count
206
+
207
+ # Create sample weights
208
+ sample_weights = [class_weights[label] for _, label in dataset.samples]
209
+
210
+ # Create sampler
211
+ sampler = WeightedRandomSampler(
212
+ weights=sample_weights,
213
+ num_samples=total_samples,
214
+ replacement=True
215
+ )
216
+
217
+ print(f"βš–οΈ Weighted sampler created:")
218
+ for class_name, class_idx in dataset.class_to_idx.items():
219
+ count = class_counts.get(class_idx, 0)
220
+ weight = class_weights.get(class_idx, 0)
221
+ print(f" {class_name}: {count} samples, weight: {weight:.2f}")
222
+
223
+ return sampler
224
+
225
+ def get_inference_transform(image_size: int = 224) -> transforms.Compose:
226
+ """
227
+ Get transforms for inference/prediction
228
+
229
+ Args:
230
+ image_size: Size to resize images to
231
+
232
+ Returns:
233
+ Transform pipeline for inference
234
+ """
235
+ return transforms.Compose([
236
+ transforms.Resize((image_size, image_size)),
237
+ transforms.ToTensor(),
238
+ transforms.Normalize(
239
+ mean=[0.485, 0.456, 0.406],
240
+ std=[0.229, 0.224, 0.225]
241
+ )
242
+ ])
243
+
244
+ def prepare_image_for_inference(image: Image.Image, transform: transforms.Compose) -> torch.Tensor:
245
+ """
246
+ Prepare an image for inference
247
+
248
+ Args:
249
+ image: PIL Image
250
+ transform: Transform pipeline
251
+
252
+ Returns:
253
+ Tensor ready for model inference
254
+ """
255
+ # Apply transforms
256
+ image_tensor = transform(image)
257
+
258
+ # Add batch dimension
259
+ image_tensor = image_tensor.unsqueeze(0)
260
+
261
+ return image_tensor
262
+
263
+ def visualize_batch(data_loader: DataLoader, num_samples: int = 8) -> None:
264
+ """
265
+ Visualize a batch of images from the data loader
266
+
267
+ Args:
268
+ data_loader: DataLoader to sample from
269
+ num_samples: Number of samples to visualize
270
+ """
271
+ import matplotlib.pyplot as plt
272
+
273
+ # Get a batch
274
+ images, labels = next(iter(data_loader))
275
+
276
+ # Denormalize images
277
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
278
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
279
+
280
+ # Create figure
281
+ fig, axes = plt.subplots(2, 4, figsize=(15, 8))
282
+ axes = axes.flatten()
283
+
284
+ class_names = ['Fire', 'No Fire']
285
+
286
+ for i in range(min(num_samples, len(images))):
287
+ # Denormalize
288
+ img = images[i] * std + mean
289
+ img = torch.clamp(img, 0, 1)
290
+
291
+ # Convert to numpy
292
+ img_np = img.permute(1, 2, 0).numpy()
293
+
294
+ # Plot
295
+ axes[i].imshow(img_np)
296
+ axes[i].set_title(f'{class_names[labels[i]]}')
297
+ axes[i].axis('off')
298
+
299
+ plt.tight_layout()
300
+ plt.show()
301
+
302
+ def check_data_directory(data_dir: str) -> Dict[str, int]:
303
+ """
304
+ Check data directory structure and count samples
305
+
306
+ Args:
307
+ data_dir: Directory to check
308
+
309
+ Returns:
310
+ Dictionary with data counts
311
+ """
312
+ data_counts = {}
313
+
314
+ if not os.path.exists(data_dir):
315
+ print(f"❌ Data directory not found: {data_dir}")
316
+ return data_counts
317
+
318
+ print(f"πŸ“Š Data Directory Analysis: {data_dir}")
319
+ print("=" * 50)
320
+
321
+ total_samples = 0
322
+
323
+ for split in ['train', 'val']:
324
+ split_dir = os.path.join(data_dir, split)
325
+ if not os.path.exists(split_dir):
326
+ continue
327
+
328
+ print(f"\n{split.upper()} SET:")
329
+ split_total = 0
330
+
331
+ for class_name in ['fire', 'no_fire']:
332
+ class_dir = os.path.join(split_dir, class_name)
333
+ if not os.path.exists(class_dir):
334
+ continue
335
+
336
+ # Count images recursively
337
+ count = 0
338
+ for root, dirs, files in os.walk(class_dir):
339
+ for file in files:
340
+ if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
341
+ count += 1
342
+
343
+ print(f" {class_name}: {count} images")
344
+ data_counts[f"{split}_{class_name}"] = count
345
+ split_total += count
346
+
347
+ print(f" Total {split}: {split_total}")
348
+ total_samples += split_total
349
+ data_counts[f"{split}_total"] = split_total
350
+
351
+ print(f"\nOVERALL TOTAL: {total_samples} images")
352
+ data_counts['total'] = total_samples
353
+ print("=" * 50)
354
+
355
+ return data_counts
356
+
357
+ def create_sample_data_structure():
358
+ """Create sample data structure for testing"""
359
+ print("πŸ”₯ Creating sample fire detection data structure...")
360
+
361
+ # Create directories
362
+ directories = [
363
+ 'data/train/fire',
364
+ 'data/train/no_fire',
365
+ 'data/val/fire',
366
+ 'data/val/no_fire'
367
+ ]
368
+
369
+ for directory in directories:
370
+ os.makedirs(directory, exist_ok=True)
371
+
372
+ print("βœ… Sample data structure created")
373
+ print(" Please add your fire detection images to the appropriate directories")
374
+ print(" - data/train/fire/ (training fire images)")
375
+ print(" - data/train/no_fire/ (training no-fire images)")
376
+ print(" - data/val/fire/ (validation fire images)")
377
+ print(" - data/val/no_fire/ (validation no-fire images)")
model_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model utilities for fire detection classification
3
+ Handles ConvNeXt model loading and adaptation for transfer learning
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import timm
9
+ import os
10
+ from typing import Dict, Any, Optional, Tuple
11
+
12
+ class FireDetectionClassifier(nn.Module):
13
+ """
14
+ ConvNeXt-based fire detection classifier
15
+ Uses transfer learning from ImageNet pretrained model
16
+ """
17
+
18
+ def __init__(self, num_classes: int = 2, pretrained: bool = True):
19
+ super(FireDetectionClassifier, self).__init__()
20
+
21
+ # Load ConvNeXt Large model
22
+ self.backbone = timm.create_model(
23
+ 'convnext_large.fb_in22k_ft_in1k',
24
+ pretrained=pretrained,
25
+ num_classes=0 # Remove classification head
26
+ )
27
+
28
+ # Get feature dimensions
29
+ self.feature_dim = self.backbone.num_features
30
+
31
+ # Custom classification head for fire detection
32
+ self.classifier = nn.Sequential(
33
+ nn.LayerNorm(self.feature_dim),
34
+ nn.Linear(self.feature_dim, 512),
35
+ nn.ReLU(inplace=True),
36
+ nn.Dropout(0.3),
37
+ nn.Linear(512, 128),
38
+ nn.ReLU(inplace=True),
39
+ nn.Dropout(0.2),
40
+ nn.Linear(128, num_classes)
41
+ )
42
+
43
+ # Initialize classifier weights
44
+ self._init_classifier_weights()
45
+
46
+ def _init_classifier_weights(self):
47
+ """Initialize classifier weights using Xavier initialization"""
48
+ for module in self.classifier.modules():
49
+ if isinstance(module, nn.Linear):
50
+ nn.init.xavier_uniform_(module.weight)
51
+ nn.init.constant_(module.bias, 0)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ """Forward pass through the model"""
55
+ # Extract features using ConvNeXt backbone
56
+ features = self.backbone(x)
57
+
58
+ # Classify using custom head
59
+ output = self.classifier(features)
60
+
61
+ return output
62
+
63
+ def freeze_backbone(self):
64
+ """Freeze backbone parameters for transfer learning"""
65
+ for param in self.backbone.parameters():
66
+ param.requires_grad = False
67
+ print("πŸ”’ Backbone frozen for transfer learning")
68
+
69
+ def unfreeze_backbone(self):
70
+ """Unfreeze backbone parameters for fine-tuning"""
71
+ for param in self.backbone.parameters():
72
+ param.requires_grad = True
73
+ print("πŸ”“ Backbone unfrozen for fine-tuning")
74
+
75
+ def get_parameter_count(self) -> Dict[str, int]:
76
+ """Get parameter counts for different parts of the model"""
77
+ backbone_params = sum(p.numel() for p in self.backbone.parameters())
78
+ classifier_params = sum(p.numel() for p in self.classifier.parameters())
79
+ total_params = backbone_params + classifier_params
80
+
81
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
82
+
83
+ return {
84
+ 'backbone': backbone_params,
85
+ 'classifier': classifier_params,
86
+ 'total': total_params,
87
+ 'trainable': trainable_params
88
+ }
89
+
90
+ def create_fire_detection_model(
91
+ num_classes: int = 2,
92
+ freeze_backbone: bool = True
93
+ ) -> FireDetectionClassifier:
94
+ """
95
+ Create fire detection classifier model with transfer learning
96
+
97
+ Args:
98
+ num_classes: Number of output classes (2 for fire/no_fire)
99
+ freeze_backbone: Whether to freeze backbone for transfer learning
100
+
101
+ Returns:
102
+ FireDetectionClassifier model ready for training
103
+ """
104
+ print("πŸ”₯ Creating fire detection classifier...")
105
+
106
+ # Create the model
107
+ model = FireDetectionClassifier(num_classes=num_classes, pretrained=True)
108
+
109
+ # Freeze backbone if requested
110
+ if freeze_backbone:
111
+ model.freeze_backbone()
112
+
113
+ # Print model information
114
+ param_counts = model.get_parameter_count()
115
+ print(f"πŸ“Š Model Statistics:")
116
+ print(f" Backbone parameters: {param_counts['backbone']:,}")
117
+ print(f" Classifier parameters: {param_counts['classifier']:,}")
118
+ print(f" Total parameters: {param_counts['total']:,}")
119
+ print(f" Trainable parameters: {param_counts['trainable']:,}")
120
+ print(f" Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB")
121
+
122
+ return model
123
+
124
+ def save_model(
125
+ model: FireDetectionClassifier,
126
+ save_path: str,
127
+ epoch: int,
128
+ best_acc: float,
129
+ optimizer_state: Optional[Dict] = None,
130
+ additional_info: Optional[Dict] = None
131
+ ) -> None:
132
+ """
133
+ Save model checkpoint with training information
134
+
135
+ Args:
136
+ model: The model to save
137
+ save_path: Path to save the model
138
+ epoch: Current epoch number
139
+ best_acc: Best accuracy achieved
140
+ optimizer_state: Optimizer state dict
141
+ additional_info: Additional information to save
142
+ """
143
+ # Create directory if it doesn't exist
144
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
145
+
146
+ # Prepare checkpoint
147
+ checkpoint = {
148
+ 'model_state_dict': model.state_dict(),
149
+ 'epoch': epoch,
150
+ 'best_acc': best_acc,
151
+ 'model_info': {
152
+ 'num_classes': 2,
153
+ 'class_names': ['fire', 'no_fire'],
154
+ 'parameter_count': model.get_parameter_count()
155
+ }
156
+ }
157
+
158
+ # Add optional information
159
+ if optimizer_state:
160
+ checkpoint['optimizer_state_dict'] = optimizer_state
161
+
162
+ if additional_info:
163
+ checkpoint.update(additional_info)
164
+
165
+ # Save checkpoint
166
+ torch.save(checkpoint, save_path)
167
+ print(f"πŸ’Ύ Model saved to: {save_path}")
168
+ print(f"πŸ“ˆ Best accuracy: {best_acc:.4f}")
169
+
170
+ def load_model(
171
+ model_path: str,
172
+ num_classes: int = 2,
173
+ device: str = 'cpu'
174
+ ) -> Tuple[FireDetectionClassifier, Dict[str, Any]]:
175
+ """
176
+ Load a trained fire detection model
177
+
178
+ Args:
179
+ model_path: Path to the saved model
180
+ num_classes: Number of classes (should be 2)
181
+ device: Device to load model on
182
+
183
+ Returns:
184
+ Tuple of (model, model_info)
185
+ """
186
+ if not os.path.exists(model_path):
187
+ raise FileNotFoundError(f"Model not found at: {model_path}")
188
+
189
+ # Load checkpoint
190
+ checkpoint = torch.load(model_path, map_location=device)
191
+
192
+ # Create model
193
+ model = FireDetectionClassifier(num_classes=num_classes, pretrained=False)
194
+
195
+ # Load state dict
196
+ model.load_state_dict(checkpoint['model_state_dict'])
197
+
198
+ # Move to device
199
+ model = model.to(device)
200
+
201
+ # Extract model info
202
+ model_info = checkpoint.get('model_info', {})
203
+ model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
204
+ model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
205
+
206
+ print(f"βœ… Model loaded from: {model_path}")
207
+ print(f"πŸ“Š Model accuracy: {model_info.get('best_acc', 'Unknown')}")
208
+
209
+ return model, model_info
210
+
211
+ def get_model_summary(model: FireDetectionClassifier) -> str:
212
+ """
213
+ Get a summary of the model architecture
214
+
215
+ Args:
216
+ model: The model to summarize
217
+
218
+ Returns:
219
+ String summary of the model
220
+ """
221
+ param_counts = model.get_parameter_count()
222
+
223
+ summary = f"""
224
+ πŸ”₯ Fire Detection Model Summary
225
+ {'='*50}
226
+ Architecture: ConvNeXt Large + Custom Classifier
227
+ Classes: fire, no_fire
228
+
229
+ Parameters:
230
+ Backbone: {param_counts['backbone']:,}
231
+ Classifier: {param_counts['classifier']:,}
232
+ Total: {param_counts['total']:,}
233
+ Trainable: {param_counts['trainable']:,}
234
+
235
+ Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB
236
+ {'='*50}
237
+ """
238
+
239
+ return summary