Spaces:
Running
Running
File size: 6,522 Bytes
36dd4e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
"""
ResNet50 model architecture for crop disease detection
"""
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights
class CropDiseaseResNet50(nn.Module):
"""ResNet50 model for crop disease classification"""
def __init__(self, num_classes, pretrained=True, freeze_features=True):
"""
Args:
num_classes: Number of disease classes
pretrained: Use ImageNet pretrained weights
freeze_features: Freeze feature extraction layers initially
"""
super(CropDiseaseResNet50, self).__init__()
# Load pretrained ResNet50
if pretrained:
self.resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
else:
self.resnet = models.resnet50(weights=None)
# Freeze feature extraction layers if specified
if freeze_features:
for param in self.resnet.parameters():
param.requires_grad = False
# Replace the final fully connected layer to match saved v2 model architecture
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Sequential(
nn.Dropout(0.5), # 0
nn.Linear(num_features, 1024), # 1
nn.BatchNorm1d(1024), # 2
nn.ReLU(inplace=True), # 3
nn.Dropout(0.3), # 4
nn.Linear(1024, 512), # 5
nn.BatchNorm1d(512), # 6
nn.ReLU(inplace=True), # 7
nn.Dropout(0.2), # 8
nn.Linear(512, num_classes) # 9
)
# Store number of classes
self.num_classes = num_classes
def forward(self, x):
"""Forward pass"""
return self.resnet(x)
def unfreeze_features(self):
"""Unfreeze all layers for fine-tuning"""
for param in self.resnet.parameters():
param.requires_grad = True
def freeze_features(self):
"""Freeze feature extraction layers"""
for name, param in self.resnet.named_parameters():
if 'fc' not in name: # Don't freeze the classifier
param.requires_grad = False
def get_feature_extractor(self):
"""Get feature extractor (without final FC layer) for Grad-CAM"""
return nn.Sequential(*list(self.resnet.children())[:-1])
def get_classifier(self):
"""Get classifier layer for Grad-CAM"""
return self.resnet.fc
def create_model(num_classes, pretrained=True, device='cpu'):
"""Create and initialize the model"""
model = CropDiseaseResNet50(
num_classes=num_classes,
pretrained=pretrained,
freeze_features=True
)
# Move to device
model = model.to(device)
return model
def get_model_summary(model, input_size=(3, 224, 224)):
"""Print model summary"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("=" * 60)
print("MODEL SUMMARY")
print("=" * 60)
print(f"Model: ResNet50 for Crop Disease Detection")
print(f"Input size: {input_size}")
print(f"Number of classes: {model.num_classes}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {total_params - trainable_params:,}")
print("=" * 60)
return {
'total_params': total_params,
'trainable_params': trainable_params,
'non_trainable_params': total_params - trainable_params
}
class ModelCheckpoint:
"""Save best model checkpoints during training"""
def __init__(self, filepath, monitor='val_accuracy', mode='max', save_best_only=True):
self.filepath = filepath
self.monitor = monitor
self.mode = mode
self.save_best_only = save_best_only
self.best_score = float('-inf') if mode == 'max' else float('inf')
def __call__(self, model, optimizer, epoch, metrics):
"""Save checkpoint if current score is better"""
current_score = metrics.get(self.monitor, 0)
is_better = False
if self.mode == 'max':
is_better = current_score > self.best_score
else:
is_better = current_score < self.best_score
if not self.save_best_only or is_better:
if is_better:
self.best_score = current_score
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'metrics': metrics,
'best_score': self.best_score
}
torch.save(checkpoint, self.filepath)
if is_better:
print(f"Saved new best model with {self.monitor}: {current_score:.4f}")
return True
return False
def load_checkpoint(filepath, model, optimizer=None, device='cpu'):
"""Load model checkpoint"""
checkpoint = torch.load(filepath, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint.get('epoch', 0)
metrics = checkpoint.get('metrics', {})
best_score = checkpoint.get('best_score', 0)
print(f"Loaded checkpoint from epoch {epoch}")
print(f"Best score: {best_score:.4f}")
return model, optimizer, epoch, metrics
if __name__ == "__main__":
# Test model creation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Create model for 17 classes (as per our dataset)
model = create_model(num_classes=17, device=device)
# Print model summary
get_model_summary(model)
# Test forward pass
dummy_input = torch.randn(1, 3, 224, 224).to(device)
output = model(dummy_input)
print(f"\nTest forward pass:")
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Output probabilities sum: {torch.softmax(output, dim=1).sum():.4f}")
|