|
|
"""
|
|
|
Model utilities for fire detection classification
|
|
|
Handles ConvNeXt model loading and adaptation for transfer learning
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import timm
|
|
|
import os
|
|
|
from typing import Dict, Any, Optional, Tuple
|
|
|
|
|
|
class FireDetectionClassifier(nn.Module):
|
|
|
"""
|
|
|
ConvNeXt-based fire detection classifier
|
|
|
Uses transfer learning from ImageNet pretrained model
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes: int = 2, pretrained: bool = True):
|
|
|
super(FireDetectionClassifier, self).__init__()
|
|
|
|
|
|
|
|
|
self.backbone = timm.create_model(
|
|
|
'convnext_large.fb_in22k_ft_in1k',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=0
|
|
|
)
|
|
|
|
|
|
|
|
|
self.feature_dim = self.backbone.num_features
|
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential(
|
|
|
nn.LayerNorm(self.feature_dim),
|
|
|
nn.Linear(self.feature_dim, 512),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Dropout(0.3),
|
|
|
nn.Linear(512, 128),
|
|
|
nn.ReLU(inplace=True),
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(128, num_classes)
|
|
|
)
|
|
|
|
|
|
|
|
|
self._init_classifier_weights()
|
|
|
|
|
|
def _init_classifier_weights(self):
|
|
|
"""Initialize classifier weights using Xavier initialization"""
|
|
|
for module in self.classifier.modules():
|
|
|
if isinstance(module, nn.Linear):
|
|
|
nn.init.xavier_uniform_(module.weight)
|
|
|
nn.init.constant_(module.bias, 0)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""Forward pass through the model"""
|
|
|
|
|
|
features = self.backbone(x)
|
|
|
|
|
|
|
|
|
output = self.classifier(features)
|
|
|
|
|
|
return output
|
|
|
|
|
|
def freeze_backbone(self):
|
|
|
"""Freeze backbone parameters for transfer learning"""
|
|
|
for param in self.backbone.parameters():
|
|
|
param.requires_grad = False
|
|
|
print("π Backbone frozen for transfer learning")
|
|
|
|
|
|
def unfreeze_backbone(self):
|
|
|
"""Unfreeze backbone parameters for fine-tuning"""
|
|
|
for param in self.backbone.parameters():
|
|
|
param.requires_grad = True
|
|
|
print("π Backbone unfrozen for fine-tuning")
|
|
|
|
|
|
def get_parameter_count(self) -> Dict[str, int]:
|
|
|
"""Get parameter counts for different parts of the model"""
|
|
|
backbone_params = sum(p.numel() for p in self.backbone.parameters())
|
|
|
classifier_params = sum(p.numel() for p in self.classifier.parameters())
|
|
|
total_params = backbone_params + classifier_params
|
|
|
|
|
|
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
|
|
|
|
|
return {
|
|
|
'backbone': backbone_params,
|
|
|
'classifier': classifier_params,
|
|
|
'total': total_params,
|
|
|
'trainable': trainable_params
|
|
|
}
|
|
|
|
|
|
def create_fire_detection_model(
|
|
|
num_classes: int = 2,
|
|
|
freeze_backbone: bool = True
|
|
|
) -> FireDetectionClassifier:
|
|
|
"""
|
|
|
Create fire detection classifier model with transfer learning
|
|
|
|
|
|
Args:
|
|
|
num_classes: Number of output classes (2 for fire/no_fire)
|
|
|
freeze_backbone: Whether to freeze backbone for transfer learning
|
|
|
|
|
|
Returns:
|
|
|
FireDetectionClassifier model ready for training
|
|
|
"""
|
|
|
print("π₯ Creating fire detection classifier...")
|
|
|
|
|
|
|
|
|
model = FireDetectionClassifier(num_classes=num_classes, pretrained=True)
|
|
|
|
|
|
|
|
|
if freeze_backbone:
|
|
|
model.freeze_backbone()
|
|
|
|
|
|
|
|
|
param_counts = model.get_parameter_count()
|
|
|
print(f"π Model Statistics:")
|
|
|
print(f" Backbone parameters: {param_counts['backbone']:,}")
|
|
|
print(f" Classifier parameters: {param_counts['classifier']:,}")
|
|
|
print(f" Total parameters: {param_counts['total']:,}")
|
|
|
print(f" Trainable parameters: {param_counts['trainable']:,}")
|
|
|
print(f" Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB")
|
|
|
|
|
|
return model
|
|
|
|
|
|
def save_model(
|
|
|
model: FireDetectionClassifier,
|
|
|
save_path: str,
|
|
|
epoch: int,
|
|
|
best_acc: float,
|
|
|
optimizer_state: Optional[Dict] = None,
|
|
|
additional_info: Optional[Dict] = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
Save model checkpoint with training information
|
|
|
|
|
|
Args:
|
|
|
model: The model to save
|
|
|
save_path: Path to save the model
|
|
|
epoch: Current epoch number
|
|
|
best_acc: Best accuracy achieved
|
|
|
optimizer_state: Optimizer state dict
|
|
|
additional_info: Additional information to save
|
|
|
"""
|
|
|
|
|
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
|
|
|
|
|
|
|
checkpoint = {
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
'epoch': epoch,
|
|
|
'best_acc': best_acc,
|
|
|
'model_info': {
|
|
|
'num_classes': 2,
|
|
|
'class_names': ['fire', 'no_fire'],
|
|
|
'parameter_count': model.get_parameter_count()
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
if optimizer_state:
|
|
|
checkpoint['optimizer_state_dict'] = optimizer_state
|
|
|
|
|
|
if additional_info:
|
|
|
checkpoint.update(additional_info)
|
|
|
|
|
|
|
|
|
torch.save(checkpoint, save_path)
|
|
|
print(f"πΎ Model saved to: {save_path}")
|
|
|
print(f"π Best accuracy: {best_acc:.4f}")
|
|
|
|
|
|
def load_model(
|
|
|
model_path: str,
|
|
|
num_classes: int = 2,
|
|
|
device: str = 'cpu'
|
|
|
) -> Tuple[FireDetectionClassifier, Dict[str, Any]]:
|
|
|
"""
|
|
|
Load a trained fire detection model
|
|
|
|
|
|
Args:
|
|
|
model_path: Path to the saved model
|
|
|
num_classes: Number of classes (should be 2)
|
|
|
device: Device to load model on
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (model, model_info)
|
|
|
"""
|
|
|
if not os.path.exists(model_path):
|
|
|
raise FileNotFoundError(f"Model not found at: {model_path}")
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device)
|
|
|
|
|
|
|
|
|
model = FireDetectionClassifier(num_classes=num_classes, pretrained=False)
|
|
|
|
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
|
|
|
model_info = checkpoint.get('model_info', {})
|
|
|
model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
|
|
|
model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
|
|
|
|
|
|
print(f"β
Model loaded from: {model_path}")
|
|
|
print(f"π Model accuracy: {model_info.get('best_acc', 'Unknown')}")
|
|
|
|
|
|
return model, model_info
|
|
|
|
|
|
def get_model_summary(model: FireDetectionClassifier) -> str:
|
|
|
"""
|
|
|
Get a summary of the model architecture
|
|
|
|
|
|
Args:
|
|
|
model: The model to summarize
|
|
|
|
|
|
Returns:
|
|
|
String summary of the model
|
|
|
"""
|
|
|
param_counts = model.get_parameter_count()
|
|
|
|
|
|
summary = f"""
|
|
|
π₯ Fire Detection Model Summary
|
|
|
{'='*50}
|
|
|
Architecture: ConvNeXt Large + Custom Classifier
|
|
|
Classes: fire, no_fire
|
|
|
|
|
|
Parameters:
|
|
|
Backbone: {param_counts['backbone']:,}
|
|
|
Classifier: {param_counts['classifier']:,}
|
|
|
Total: {param_counts['total']:,}
|
|
|
Trainable: {param_counts['trainable']:,}
|
|
|
|
|
|
Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB
|
|
|
{'='*50}
|
|
|
"""
|
|
|
|
|
|
return summary |