| | """
|
| | Model utilities for fire detection classification
|
| | Handles ConvNeXt model loading and adaptation for transfer learning
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import timm
|
| | import os
|
| | from typing import Dict, Any, Optional, Tuple
|
| |
|
| | class FireDetectionClassifier(nn.Module):
|
| | """
|
| | ConvNeXt-based fire detection classifier
|
| | Uses transfer learning from ImageNet pretrained model
|
| | """
|
| |
|
| | def __init__(self, num_classes: int = 2, pretrained: bool = True):
|
| | super(FireDetectionClassifier, self).__init__()
|
| |
|
| |
|
| | self.backbone = timm.create_model(
|
| | 'convnext_large.fb_in22k_ft_in1k',
|
| | pretrained=pretrained,
|
| | num_classes=0
|
| | )
|
| |
|
| |
|
| | self.feature_dim = self.backbone.num_features
|
| |
|
| |
|
| | self.classifier = nn.Sequential(
|
| | nn.LayerNorm(self.feature_dim),
|
| | nn.Linear(self.feature_dim, 512),
|
| | nn.ReLU(inplace=True),
|
| | nn.Dropout(0.3),
|
| | nn.Linear(512, 128),
|
| | nn.ReLU(inplace=True),
|
| | nn.Dropout(0.2),
|
| | nn.Linear(128, num_classes)
|
| | )
|
| |
|
| |
|
| | self._init_classifier_weights()
|
| |
|
| | def _init_classifier_weights(self):
|
| | """Initialize classifier weights using Xavier initialization"""
|
| | for module in self.classifier.modules():
|
| | if isinstance(module, nn.Linear):
|
| | nn.init.xavier_uniform_(module.weight)
|
| | nn.init.constant_(module.bias, 0)
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """Forward pass through the model"""
|
| |
|
| | features = self.backbone(x)
|
| |
|
| |
|
| | output = self.classifier(features)
|
| |
|
| | return output
|
| |
|
| | def freeze_backbone(self):
|
| | """Freeze backbone parameters for transfer learning"""
|
| | for param in self.backbone.parameters():
|
| | param.requires_grad = False
|
| | print("π Backbone frozen for transfer learning")
|
| |
|
| | def unfreeze_backbone(self):
|
| | """Unfreeze backbone parameters for fine-tuning"""
|
| | for param in self.backbone.parameters():
|
| | param.requires_grad = True
|
| | print("π Backbone unfrozen for fine-tuning")
|
| |
|
| | def get_parameter_count(self) -> Dict[str, int]:
|
| | """Get parameter counts for different parts of the model"""
|
| | backbone_params = sum(p.numel() for p in self.backbone.parameters())
|
| | classifier_params = sum(p.numel() for p in self.classifier.parameters())
|
| | total_params = backbone_params + classifier_params
|
| |
|
| | trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| |
|
| | return {
|
| | 'backbone': backbone_params,
|
| | 'classifier': classifier_params,
|
| | 'total': total_params,
|
| | 'trainable': trainable_params
|
| | }
|
| |
|
| | def create_fire_detection_model(
|
| | num_classes: int = 2,
|
| | freeze_backbone: bool = True
|
| | ) -> FireDetectionClassifier:
|
| | """
|
| | Create fire detection classifier model with transfer learning
|
| |
|
| | Args:
|
| | num_classes: Number of output classes (2 for fire/no_fire)
|
| | freeze_backbone: Whether to freeze backbone for transfer learning
|
| |
|
| | Returns:
|
| | FireDetectionClassifier model ready for training
|
| | """
|
| | print("π₯ Creating fire detection classifier...")
|
| |
|
| |
|
| | model = FireDetectionClassifier(num_classes=num_classes, pretrained=True)
|
| |
|
| |
|
| | if freeze_backbone:
|
| | model.freeze_backbone()
|
| |
|
| |
|
| | param_counts = model.get_parameter_count()
|
| | print(f"π Model Statistics:")
|
| | print(f" Backbone parameters: {param_counts['backbone']:,}")
|
| | print(f" Classifier parameters: {param_counts['classifier']:,}")
|
| | print(f" Total parameters: {param_counts['total']:,}")
|
| | print(f" Trainable parameters: {param_counts['trainable']:,}")
|
| | print(f" Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB")
|
| |
|
| | return model
|
| |
|
| | def save_model(
|
| | model: FireDetectionClassifier,
|
| | save_path: str,
|
| | epoch: int,
|
| | best_acc: float,
|
| | optimizer_state: Optional[Dict] = None,
|
| | additional_info: Optional[Dict] = None
|
| | ) -> None:
|
| | """
|
| | Save model checkpoint with training information
|
| |
|
| | Args:
|
| | model: The model to save
|
| | save_path: Path to save the model
|
| | epoch: Current epoch number
|
| | best_acc: Best accuracy achieved
|
| | optimizer_state: Optimizer state dict
|
| | additional_info: Additional information to save
|
| | """
|
| |
|
| | os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| |
|
| |
|
| | checkpoint = {
|
| | 'model_state_dict': model.state_dict(),
|
| | 'epoch': epoch,
|
| | 'best_acc': best_acc,
|
| | 'model_info': {
|
| | 'num_classes': 2,
|
| | 'class_names': ['fire', 'no_fire'],
|
| | 'parameter_count': model.get_parameter_count()
|
| | }
|
| | }
|
| |
|
| |
|
| | if optimizer_state:
|
| | checkpoint['optimizer_state_dict'] = optimizer_state
|
| |
|
| | if additional_info:
|
| | checkpoint.update(additional_info)
|
| |
|
| |
|
| | torch.save(checkpoint, save_path)
|
| | print(f"πΎ Model saved to: {save_path}")
|
| | print(f"π Best accuracy: {best_acc:.4f}")
|
| |
|
| | def load_model(
|
| | model_path: str,
|
| | num_classes: int = 2,
|
| | device: str = 'cpu'
|
| | ) -> Tuple[FireDetectionClassifier, Dict[str, Any]]:
|
| | """
|
| | Load a trained fire detection model
|
| |
|
| | Args:
|
| | model_path: Path to the saved model
|
| | num_classes: Number of classes (should be 2)
|
| | device: Device to load model on
|
| |
|
| | Returns:
|
| | Tuple of (model, model_info)
|
| | """
|
| | if not os.path.exists(model_path):
|
| | raise FileNotFoundError(f"Model not found at: {model_path}")
|
| |
|
| |
|
| | checkpoint = torch.load(model_path, map_location=device)
|
| |
|
| |
|
| | model = FireDetectionClassifier(num_classes=num_classes, pretrained=False)
|
| |
|
| |
|
| | model.load_state_dict(checkpoint['model_state_dict'])
|
| |
|
| |
|
| | model = model.to(device)
|
| |
|
| |
|
| | model_info = checkpoint.get('model_info', {})
|
| | model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
|
| | model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
|
| |
|
| | print(f"β
Model loaded from: {model_path}")
|
| | print(f"π Model accuracy: {model_info.get('best_acc', 'Unknown')}")
|
| |
|
| | return model, model_info
|
| |
|
| | def get_model_summary(model: FireDetectionClassifier) -> str:
|
| | """
|
| | Get a summary of the model architecture
|
| |
|
| | Args:
|
| | model: The model to summarize
|
| |
|
| | Returns:
|
| | String summary of the model
|
| | """
|
| | param_counts = model.get_parameter_count()
|
| |
|
| | summary = f"""
|
| | π₯ Fire Detection Model Summary
|
| | {'='*50}
|
| | Architecture: ConvNeXt Large + Custom Classifier
|
| | Classes: fire, no_fire
|
| |
|
| | Parameters:
|
| | Backbone: {param_counts['backbone']:,}
|
| | Classifier: {param_counts['classifier']:,}
|
| | Total: {param_counts['total']:,}
|
| | Trainable: {param_counts['trainable']:,}
|
| |
|
| | Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB
|
| | {'='*50}
|
| | """
|
| |
|
| | return summary |