""" Classification Models for Pest and Disease Detection Supports multiple pretrained backbones: ResNet, EfficientNet, MobileNet """ import torch import torch.nn as nn import torchvision.models as models class PestDiseaseClassifier(nn.Module): """ General classifier with pretrained backbone for transfer learning """ def __init__(self, num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3): """ Args: num_classes (int): Number of output classes backbone (str): Backbone architecture ('resnet50', 'resnet101', 'efficientnet_b0', 'efficientnet_b3', 'mobilenet_v2') pretrained (bool): Use pretrained weights dropout (float): Dropout rate for regularization """ super(PestDiseaseClassifier, self).__init__() self.backbone_name = backbone self.num_classes = num_classes # Select backbone if backbone == 'resnet50': self.backbone = models.resnet50(pretrained=pretrained) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() elif backbone == 'resnet101': self.backbone = models.resnet101(pretrained=pretrained) num_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() elif backbone == 'efficientnet_b0': self.backbone = models.efficientnet_b0(pretrained=pretrained) num_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Identity() elif backbone == 'efficientnet_b3': self.backbone = models.efficientnet_b3(pretrained=pretrained) num_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Identity() elif backbone == 'mobilenet_v2': self.backbone = models.mobilenet_v2(pretrained=pretrained) num_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Identity() else: raise ValueError(f"Unknown backbone: {backbone}") # Custom classifier head self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(num_features, 512), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(512, num_classes) ) print(f"Model created: {backbone}") print(f" Features: {num_features}") print(f" Classes: {num_classes}") print(f" Pretrained: {pretrained}") def forward(self, x): """ Forward pass Args: x: Input tensor [batch_size, 3, H, W] Returns: logits: Output tensor [batch_size, num_classes] """ features = self.backbone(x) logits = self.classifier(features) return logits def freeze_backbone(self): """Freeze backbone parameters for fine-tuning""" for param in self.backbone.parameters(): param.requires_grad = False print("Backbone frozen") def unfreeze_backbone(self): """Unfreeze backbone parameters""" for param in self.backbone.parameters(): param.requires_grad = True print("Backbone unfrozen") def create_model(num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3): """ Factory function to create model Args: num_classes (int): Number of classes backbone (str): Model architecture pretrained (bool): Use pretrained weights dropout (float): Dropout rate Returns: model: PestDiseaseClassifier instance """ model = PestDiseaseClassifier( num_classes=num_classes, backbone=backbone, pretrained=pretrained, dropout=dropout ) return model def count_parameters(model): """Count total and trainable parameters""" total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\nModel Parameters:") print(f" Total: {total_params:,}") print(f" Trainable: {trainable_params:,}") print(f" Non-trainable: {total_params - trainable_params:,}") return total_params, trainable_params if __name__ == "__main__": """Test model creation""" print("Testing Pest and Disease Classification Models") print("=" * 60) # Test different backbones backbones = ['resnet50', 'efficientnet_b0', 'mobilenet_v2'] for backbone in backbones: print(f"\nTesting {backbone}...") print("-" * 60) model = create_model(num_classes=10, backbone=backbone, pretrained=True) count_parameters(model) # Test forward pass dummy_input = torch.randn(2, 3, 224, 224) with torch.no_grad(): output = model(dummy_input) print(f" Input shape: {dummy_input.shape}") print(f" Output shape: {output.shape}") print(f" Output range: [{output.min():.3f}, {output.max():.3f}]") print("\n" + "=" * 60) print("Model test completed successfully!") print("=" * 60)