File size: 7,761 Bytes
66a8df2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""

Model utilities for fire detection classification

Handles ConvNeXt model loading and adaptation for transfer learning

"""

import torch
import torch.nn as nn
import timm
import os
from typing import Dict, Any, Optional, Tuple

class FireDetectionClassifier(nn.Module):
    """

    ConvNeXt-based fire detection classifier

    Uses transfer learning from ImageNet pretrained model

    """
    
    def __init__(self, num_classes: int = 2, pretrained: bool = True):
        super(FireDetectionClassifier, self).__init__()
        
        # Load ConvNeXt Large model
        self.backbone = timm.create_model(
            'convnext_large.fb_in22k_ft_in1k', 
            pretrained=pretrained, 
            num_classes=0  # Remove classification head
        )
        
        # Get feature dimensions
        self.feature_dim = self.backbone.num_features
        
        # Custom classification head for fire detection
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes)
        )
        
        # Initialize classifier weights
        self._init_classifier_weights()
    
    def _init_classifier_weights(self):
        """Initialize classifier weights using Xavier initialization"""
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.constant_(module.bias, 0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the model"""
        # Extract features using ConvNeXt backbone
        features = self.backbone(x)
        
        # Classify using custom head
        output = self.classifier(features)
        
        return output
    
    def freeze_backbone(self):
        """Freeze backbone parameters for transfer learning"""
        for param in self.backbone.parameters():
            param.requires_grad = False
        print("πŸ”’ Backbone frozen for transfer learning")
    
    def unfreeze_backbone(self):
        """Unfreeze backbone parameters for fine-tuning"""
        for param in self.backbone.parameters():
            param.requires_grad = True
        print("πŸ”“ Backbone unfrozen for fine-tuning")
    
    def get_parameter_count(self) -> Dict[str, int]:
        """Get parameter counts for different parts of the model"""
        backbone_params = sum(p.numel() for p in self.backbone.parameters())
        classifier_params = sum(p.numel() for p in self.classifier.parameters())
        total_params = backbone_params + classifier_params
        
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        
        return {
            'backbone': backbone_params,
            'classifier': classifier_params,
            'total': total_params,
            'trainable': trainable_params
        }

def create_fire_detection_model(

    num_classes: int = 2,

    freeze_backbone: bool = True

) -> FireDetectionClassifier:
    """

    Create fire detection classifier model with transfer learning

    

    Args:

        num_classes: Number of output classes (2 for fire/no_fire)

        freeze_backbone: Whether to freeze backbone for transfer learning

    

    Returns:

        FireDetectionClassifier model ready for training

    """
    print("πŸ”₯ Creating fire detection classifier...")
    
    # Create the model
    model = FireDetectionClassifier(num_classes=num_classes, pretrained=True)
    
    # Freeze backbone if requested
    if freeze_backbone:
        model.freeze_backbone()
    
    # Print model information
    param_counts = model.get_parameter_count()
    print(f"πŸ“Š Model Statistics:")
    print(f"   Backbone parameters: {param_counts['backbone']:,}")
    print(f"   Classifier parameters: {param_counts['classifier']:,}")
    print(f"   Total parameters: {param_counts['total']:,}")
    print(f"   Trainable parameters: {param_counts['trainable']:,}")
    print(f"   Model size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB")
    
    return model

def save_model(

    model: FireDetectionClassifier,

    save_path: str,

    epoch: int,

    best_acc: float,

    optimizer_state: Optional[Dict] = None,

    additional_info: Optional[Dict] = None

) -> None:
    """

    Save model checkpoint with training information

    

    Args:

        model: The model to save

        save_path: Path to save the model

        epoch: Current epoch number

        best_acc: Best accuracy achieved

        optimizer_state: Optimizer state dict

        additional_info: Additional information to save

    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Prepare checkpoint
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
        'best_acc': best_acc,
        'model_info': {
            'num_classes': 2,
            'class_names': ['fire', 'no_fire'],
            'parameter_count': model.get_parameter_count()
        }
    }
    
    # Add optional information
    if optimizer_state:
        checkpoint['optimizer_state_dict'] = optimizer_state
    
    if additional_info:
        checkpoint.update(additional_info)
    
    # Save checkpoint
    torch.save(checkpoint, save_path)
    print(f"πŸ’Ύ Model saved to: {save_path}")
    print(f"πŸ“ˆ Best accuracy: {best_acc:.4f}")

def load_model(

    model_path: str,

    num_classes: int = 2,

    device: str = 'cpu'

) -> Tuple[FireDetectionClassifier, Dict[str, Any]]:
    """

    Load a trained fire detection model

    

    Args:

        model_path: Path to the saved model

        num_classes: Number of classes (should be 2)

        device: Device to load model on

    

    Returns:

        Tuple of (model, model_info)

    """
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found at: {model_path}")
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Create model
    model = FireDetectionClassifier(num_classes=num_classes, pretrained=False)
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Move to device
    model = model.to(device)
    
    # Extract model info
    model_info = checkpoint.get('model_info', {})
    model_info['epoch'] = checkpoint.get('epoch', 'Unknown')
    model_info['best_acc'] = checkpoint.get('best_acc', 'Unknown')
    
    print(f"βœ… Model loaded from: {model_path}")
    print(f"πŸ“Š Model accuracy: {model_info.get('best_acc', 'Unknown')}")
    
    return model, model_info

def get_model_summary(model: FireDetectionClassifier) -> str:
    """

    Get a summary of the model architecture

    

    Args:

        model: The model to summarize

    

    Returns:

        String summary of the model

    """
    param_counts = model.get_parameter_count()
    
    summary = f"""

πŸ”₯ Fire Detection Model Summary

{'='*50}

Architecture: ConvNeXt Large + Custom Classifier

Classes: fire, no_fire



Parameters:

  Backbone: {param_counts['backbone']:,}

  Classifier: {param_counts['classifier']:,}

  Total: {param_counts['total']:,}

  Trainable: {param_counts['trainable']:,}



Model Size: ~{param_counts['total'] * 4 / 1024**2:.1f} MB

{'='*50}

"""
    
    return summary