""" ResNet50 model architecture for crop disease detection """ import torch import torch.nn as nn import torchvision.models as models from torchvision.models import ResNet50_Weights # Import the lite version from .model_lite import CropDiseaseResNet50Lite, TinyDiseaseClassifier, create_memory_optimized_model class CropDiseaseResNet50(nn.Module): """ResNet50 model for crop disease classification""" def __init__(self, num_classes, pretrained=True, freeze_features=True): """ Args: num_classes: Number of disease classes pretrained: Use ImageNet pretrained weights freeze_features: Freeze feature extraction layers initially """ super(CropDiseaseResNet50, self).__init__() # Load pretrained ResNet50 if pretrained: self.resnet = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) else: self.resnet = models.resnet50(weights=None) # Freeze feature extraction layers if specified if freeze_features: for param in self.resnet.parameters(): param.requires_grad = False # Replace the final fully connected layer to match saved v2 model architecture num_features = self.resnet.fc.in_features self.resnet.fc = nn.Sequential( nn.Dropout(0.5), # 0 nn.Linear(num_features, 1024), # 1 nn.BatchNorm1d(1024), # 2 nn.ReLU(inplace=True), # 3 nn.Dropout(0.3), # 4 nn.Linear(1024, 512), # 5 nn.BatchNorm1d(512), # 6 nn.ReLU(inplace=True), # 7 nn.Dropout(0.2), # 8 nn.Linear(512, num_classes) # 9 ) # Store number of classes self.num_classes = num_classes def forward(self, x): """Forward pass""" return self.resnet(x) def unfreeze_features(self): """Unfreeze all layers for fine-tuning""" for param in self.resnet.parameters(): param.requires_grad = True def freeze_features(self): """Freeze feature extraction layers""" for name, param in self.resnet.named_parameters(): if 'fc' not in name: # Don't freeze the classifier param.requires_grad = False def get_feature_extractor(self): """Get feature extractor (without final FC layer) for Grad-CAM""" return nn.Sequential(*list(self.resnet.children())[:-1]) def get_classifier(self): """Get classifier layer for Grad-CAM""" return self.resnet.fc def create_model(num_classes, pretrained=True, device='cpu'): """Create and initialize the model""" model = CropDiseaseResNet50( num_classes=num_classes, pretrained=pretrained, freeze_features=True ) # Move to device model = model.to(device) return model def get_model_summary(model, input_size=(3, 224, 224)): """Print model summary""" total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print("=" * 60) print("MODEL SUMMARY") print("=" * 60) print(f"Model: ResNet50 for Crop Disease Detection") print(f"Input size: {input_size}") print(f"Number of classes: {model.num_classes}") print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") print(f"Non-trainable parameters: {total_params - trainable_params:,}") print("=" * 60) return { 'total_params': total_params, 'trainable_params': trainable_params, 'non_trainable_params': total_params - trainable_params } class ModelCheckpoint: """Save best model checkpoints during training""" def __init__(self, filepath, monitor='val_accuracy', mode='max', save_best_only=True): self.filepath = filepath self.monitor = monitor self.mode = mode self.save_best_only = save_best_only self.best_score = float('-inf') if mode == 'max' else float('inf') def __call__(self, model, optimizer, epoch, metrics): """Save checkpoint if current score is better""" current_score = metrics.get(self.monitor, 0) is_better = False if self.mode == 'max': is_better = current_score > self.best_score else: is_better = current_score < self.best_score if not self.save_best_only or is_better: if is_better: self.best_score = current_score checkpoint = { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'metrics': metrics, 'best_score': self.best_score } torch.save(checkpoint, self.filepath) if is_better: print(f"Saved new best model with {self.monitor}: {current_score:.4f}") return True return False def load_checkpoint(filepath, model, optimizer=None, device='cpu'): """Load model checkpoint""" checkpoint = torch.load(filepath, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if optimizer is not None: optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint.get('epoch', 0) metrics = checkpoint.get('metrics', {}) best_score = checkpoint.get('best_score', 0) print(f"Loaded checkpoint from epoch {epoch}") print(f"Best score: {best_score:.4f}") return model, optimizer, epoch, metrics if __name__ == "__main__": # Test model creation device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Create model for 17 classes (as per our dataset) model = create_model(num_classes=17, device=device) # Print model summary get_model_summary(model) # Test forward pass dummy_input = torch.randn(1, 3, 224, 224).to(device) output = model(dummy_input) print(f"\nTest forward pass:") print(f"Input shape: {dummy_input.shape}") print(f"Output shape: {output.shape}") print(f"Output probabilities sum: {torch.softmax(output, dim=1).sum():.4f}")