File size: 6,651 Bytes
c8df794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
ResNet50 model architecture for crop disease detection
"""

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights

# Import the lite version
from .model_lite import CropDiseaseResNet50Lite, TinyDiseaseClassifier, create_memory_optimized_model

class CropDiseaseResNet50(nn.Module):
    """ResNet50 model for crop disease classification"""
    
    def __init__(self, num_classes, pretrained=True, freeze_features=True):
        """
        Args:
            num_classes: Number of disease classes
            pretrained: Use ImageNet pretrained weights
            freeze_features: Freeze feature extraction layers initially
        """
        super(CropDiseaseResNet50, self).__init__()
        
        # Load pretrained ResNet50
        if pretrained:
            self.resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        else:
            self.resnet = models.resnet50(weights=None)
        
        # Freeze feature extraction layers if specified
        if freeze_features:
            for param in self.resnet.parameters():
                param.requires_grad = False
        
        # Replace the final fully connected layer to match saved v2 model architecture
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(0.5),                    # 0
            nn.Linear(num_features, 1024),      # 1
            nn.BatchNorm1d(1024),               # 2
            nn.ReLU(inplace=True),              # 3
            nn.Dropout(0.3),                    # 4
            nn.Linear(1024, 512),               # 5
            nn.BatchNorm1d(512),                # 6
            nn.ReLU(inplace=True),              # 7
            nn.Dropout(0.2),                    # 8
            nn.Linear(512, num_classes)         # 9
        )
        
        # Store number of classes
        self.num_classes = num_classes
        
    def forward(self, x):
        """Forward pass"""
        return self.resnet(x)
    
    def unfreeze_features(self):
        """Unfreeze all layers for fine-tuning"""
        for param in self.resnet.parameters():
            param.requires_grad = True
    
    def freeze_features(self):
        """Freeze feature extraction layers"""
        for name, param in self.resnet.named_parameters():
            if 'fc' not in name:  # Don't freeze the classifier
                param.requires_grad = False
    
    def get_feature_extractor(self):
        """Get feature extractor (without final FC layer) for Grad-CAM"""
        return nn.Sequential(*list(self.resnet.children())[:-1])
    
    def get_classifier(self):
        """Get classifier layer for Grad-CAM"""
        return self.resnet.fc

def create_model(num_classes, pretrained=True, device='cpu'):
    """Create and initialize the model"""
    
    model = CropDiseaseResNet50(
        num_classes=num_classes,
        pretrained=pretrained,
        freeze_features=True
    )
    
    # Move to device
    model = model.to(device)
    
    return model

def get_model_summary(model, input_size=(3, 224, 224)):
    """Print model summary"""
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("=" * 60)
    print("MODEL SUMMARY")
    print("=" * 60)
    print(f"Model: ResNet50 for Crop Disease Detection")
    print(f"Input size: {input_size}")
    print(f"Number of classes: {model.num_classes}")
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Non-trainable parameters: {total_params - trainable_params:,}")
    print("=" * 60)
    
    return {
        'total_params': total_params,
        'trainable_params': trainable_params,
        'non_trainable_params': total_params - trainable_params
    }

class ModelCheckpoint:
    """Save best model checkpoints during training"""
    
    def __init__(self, filepath, monitor='val_accuracy', mode='max', save_best_only=True):
        self.filepath = filepath
        self.monitor = monitor
        self.mode = mode
        self.save_best_only = save_best_only
        self.best_score = float('-inf') if mode == 'max' else float('inf')
        
    def __call__(self, model, optimizer, epoch, metrics):
        """Save checkpoint if current score is better"""
        
        current_score = metrics.get(self.monitor, 0)
        
        is_better = False
        if self.mode == 'max':
            is_better = current_score > self.best_score
        else:
            is_better = current_score < self.best_score
        
        if not self.save_best_only or is_better:
            if is_better:
                self.best_score = current_score
                
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'metrics': metrics,
                'best_score': self.best_score
            }
            
            torch.save(checkpoint, self.filepath)
            
            if is_better:
                print(f"Saved new best model with {self.monitor}: {current_score:.4f}")
            
            return True
        
        return False

def load_checkpoint(filepath, model, optimizer=None, device='cpu'):
    """Load model checkpoint"""
    
    checkpoint = torch.load(filepath, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    epoch = checkpoint.get('epoch', 0)
    metrics = checkpoint.get('metrics', {})
    best_score = checkpoint.get('best_score', 0)
    
    print(f"Loaded checkpoint from epoch {epoch}")
    print(f"Best score: {best_score:.4f}")
    
    return model, optimizer, epoch, metrics

if __name__ == "__main__":
    # Test model creation
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create model for 17 classes (as per our dataset)
    model = create_model(num_classes=17, device=device)
    
    # Print model summary
    get_model_summary(model)
    
    # Test forward pass
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    output = model(dummy_input)
    print(f"\nTest forward pass:")
    print(f"Input shape: {dummy_input.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Output probabilities sum: {torch.softmax(output, dim=1).sum():.4f}")