Spaces:
Sleeping
Sleeping
| """ | |
| Classification Models for Pest and Disease Detection | |
| Supports multiple pretrained backbones: ResNet, EfficientNet, MobileNet | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| class PestDiseaseClassifier(nn.Module): | |
| """ | |
| General classifier with pretrained backbone for transfer learning | |
| """ | |
| def __init__(self, num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3): | |
| """ | |
| Args: | |
| num_classes (int): Number of output classes | |
| backbone (str): Backbone architecture ('resnet50', 'resnet101', 'efficientnet_b0', | |
| 'efficientnet_b3', 'mobilenet_v2') | |
| pretrained (bool): Use pretrained weights | |
| dropout (float): Dropout rate for regularization | |
| """ | |
| super(PestDiseaseClassifier, self).__init__() | |
| self.backbone_name = backbone | |
| self.num_classes = num_classes | |
| # Select backbone | |
| if backbone == 'resnet50': | |
| self.backbone = models.resnet50(pretrained=pretrained) | |
| num_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() | |
| elif backbone == 'resnet101': | |
| self.backbone = models.resnet101(pretrained=pretrained) | |
| num_features = self.backbone.fc.in_features | |
| self.backbone.fc = nn.Identity() | |
| elif backbone == 'efficientnet_b0': | |
| self.backbone = models.efficientnet_b0(pretrained=pretrained) | |
| num_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Identity() | |
| elif backbone == 'efficientnet_b3': | |
| self.backbone = models.efficientnet_b3(pretrained=pretrained) | |
| num_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Identity() | |
| elif backbone == 'mobilenet_v2': | |
| self.backbone = models.mobilenet_v2(pretrained=pretrained) | |
| num_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Identity() | |
| else: | |
| raise ValueError(f"Unknown backbone: {backbone}") | |
| # Custom classifier head | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(dropout), | |
| nn.Linear(num_features, 512), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, num_classes) | |
| ) | |
| print(f"Model created: {backbone}") | |
| print(f" Features: {num_features}") | |
| print(f" Classes: {num_classes}") | |
| print(f" Pretrained: {pretrained}") | |
| def forward(self, x): | |
| """ | |
| Forward pass | |
| Args: | |
| x: Input tensor [batch_size, 3, H, W] | |
| Returns: | |
| logits: Output tensor [batch_size, num_classes] | |
| """ | |
| features = self.backbone(x) | |
| logits = self.classifier(features) | |
| return logits | |
| def freeze_backbone(self): | |
| """Freeze backbone parameters for fine-tuning""" | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| print("Backbone frozen") | |
| def unfreeze_backbone(self): | |
| """Unfreeze backbone parameters""" | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = True | |
| print("Backbone unfrozen") | |
| def create_model(num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3): | |
| """ | |
| Factory function to create model | |
| Args: | |
| num_classes (int): Number of classes | |
| backbone (str): Model architecture | |
| pretrained (bool): Use pretrained weights | |
| dropout (float): Dropout rate | |
| Returns: | |
| model: PestDiseaseClassifier instance | |
| """ | |
| model = PestDiseaseClassifier( | |
| num_classes=num_classes, | |
| backbone=backbone, | |
| pretrained=pretrained, | |
| dropout=dropout | |
| ) | |
| return model | |
| def count_parameters(model): | |
| """Count total and trainable parameters""" | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print(f"\nModel Parameters:") | |
| print(f" Total: {total_params:,}") | |
| print(f" Trainable: {trainable_params:,}") | |
| print(f" Non-trainable: {total_params - trainable_params:,}") | |
| return total_params, trainable_params | |
| if __name__ == "__main__": | |
| """Test model creation""" | |
| print("Testing Pest and Disease Classification Models") | |
| print("=" * 60) | |
| # Test different backbones | |
| backbones = ['resnet50', 'efficientnet_b0', 'mobilenet_v2'] | |
| for backbone in backbones: | |
| print(f"\nTesting {backbone}...") | |
| print("-" * 60) | |
| model = create_model(num_classes=10, backbone=backbone, pretrained=True) | |
| count_parameters(model) | |
| # Test forward pass | |
| dummy_input = torch.randn(2, 3, 224, 224) | |
| with torch.no_grad(): | |
| output = model(dummy_input) | |
| print(f" Input shape: {dummy_input.shape}") | |
| print(f" Output shape: {output.shape}") | |
| print(f" Output range: [{output.min():.3f}, {output.max():.3f}]") | |
| print("\n" + "=" * 60) | |
| print("Model test completed successfully!") | |
| print("=" * 60) | |