""" Model utilities for fire detection classification Handles ConvNeXt model loading and adaptation for transfer learning """ import torch import torch.nn as nn import timm import os from typing import Dict, Any, Optional, Tuple class FireDetectionClassifier(nn.Module): """ ConvNeXt-based fire detection classifier Uses transfer learning from ImageNet pretrained model """ def __init__(self, num_classes: int = 2, pretrained: bool = True): super(FireDetectionClassifier, self).__init__() # Load ConvNeXt Large model self.backbone = timm.create_model( 'convnext_large.fb_in22k_ft_in1k', pretrained=pretrained, num_classes=0 # Remove classification head ) # Get feature dimensions self.feature_dim = self.backbone.num_features # Custom classification head for fire detection self.classifier = nn.Sequential( nn.LayerNorm(self.feature_dim), nn.Linear(self.feature_dim, 512), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(inplace=True), nn.Dropout(0.2), nn.Linear(128, num_classes) ) # Initialize classifier weights self._init_classifier_weights() def _init_classifier_weights(self): """Initialize classifier weights using Xavier initialization""" for module in self.classifier.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) nn.init.constant_(module.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the model""" # Extract features using ConvNeXt backbone features = self.backbone(x) # Classify using custom head output = self.classifier(features) return output def freeze_backbone(self): """Freeze backbone parameters for transfer learning""" for param in self.backbone.parameters(): param.requires_grad = False print("🔒 Backbone frozen for transfer learning") def unfreeze_backbone(self): """Unfreeze backbone parameters for fine-tuning""" for param in self.backbone.parameters(): param.requires_grad = True print("🔓 Backbone unfrozen for fine-tuning") def get_parameter_count(self) -> Dict[str, int]: """Get parameter counts for different parts of the model""" backbone_params = sum(p.numel() for p in self.backbone.parameters()) classifier_params = sum(p.numel() for p in self.classifier.parameters()) total_params = backbone_params + classifier_params trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) return { 'backbone': backbone_params, 'classifier': classifier_params, 'total': total_params, 'trainable': trainable_params } def create_fire_detection_model( num_classes: int = 2, freeze_backbone: bool = True ) -> FireDetectionClassifier: """ Create fire detection classifier model with transfer learning Args: num_classes: Number of output classes (2 for fire/no_fire) freeze_backbone: Whether to freeze backbone for transfer learning Returns: FireDetectionClassifier model ready for training """ print("🔥 Creating fire detection classifier...") # Create the model model = FireDetectionClassifier(num_classes=num_classes, pretrained=True) # Freeze backbone if requested if freeze_backbone: model.freeze_backbone() # Print model information param_counts = model.get_parameter_count() print(f"📊 Model Statistics:") print(f" Backbone parameters: {param_counts['backbone']:,}") print(f" Classifier parameters: {param_counts['classifier']:,}") print(f" Total parameters: {param_counts['total']:,}") print(f" Trainable parameters: {param_counts['trainable']:,}") print(f" Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB") return model def save_model( model: FireDetectionClassifier, save_path: str, epoch: int, best_acc: float, optimizer_state: Optional[Dict] = None, additional_info: Optional[Dict] = None ) -> None: """ Save model checkpoint with training information Args: model: The model to save save_path: Path to save the model epoch: Current epoch number best_acc: Best accuracy achieved optimizer_state: Optimizer state dict additional_info: Additional information to save """ # Create directory if it doesn't exist os.makedirs(os.path.dirname(save_path), exist_ok=True) # Prepare checkpoint checkpoint = { 'model_state_dict': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc, 'model_info': { 'num_classes': 2, 'class_names': ['fire', 'no_fire'], 'parameter_count': model.get_parameter_count() } } # Add optional information if optimizer_state: checkpoint['optimizer_state_dict'] = optimizer_state if additional_info: checkpoint.update(additional_info) # Save checkpoint torch.save(checkpoint, save_path) print(f"💾 Model saved to: {save_path}") print(f"📈 Best accuracy: {best_acc:.4f}") def load_model( model_path: str, num_classes: int = 2, device: str = 'cpu' ) -> Tuple[FireDetectionClassifier, Dict[str, Any]]: """ Load a trained fire detection model Args: model_path: Path to the saved model num_classes: Number of classes (should be 2) device: Device to load model on Returns: Tuple of (model, model_info) """ if not os.path.exists(model_path): raise FileNotFoundError(f"Model not found at: {model_path}") # Load checkpoint checkpoint = torch.load(model_path, map_location=device) # Create model model = FireDetectionClassifier(num_classes=num_classes, pretrained=False) # Load state dict model.load_state_dict(checkpoint['model_state_dict']) # Move to device model = model.to(device) # Extract model info model_info = checkpoint.get('model_info', {}) model_info['epoch'] = checkpoint.get('epoch', 'Unknown') model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown') print(f"✅ Model loaded from: {model_path}") print(f"📊 Model accuracy: {model_info.get('best_acc', 'Unknown')}") return model, model_info def get_model_summary(model: FireDetectionClassifier) -> str: """ Get a summary of the model architecture Args: model: The model to summarize Returns: String summary of the model """ param_counts = model.get_parameter_count() summary = f""" 🔥 Fire Detection Model Summary {'='*50} Architecture: ConvNeXt Large + Custom Classifier Classes: fire, no_fire Parameters: Backbone: {param_counts['backbone']:,} Classifier: {param_counts['classifier']:,} Total: {param_counts['total']:,} Trainable: {param_counts['trainable']:,} Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB {'='*50} """ return summary