Boyun7's picture
upload all files
03d5bce
"""
Classification Models for Pest and Disease Detection
Supports multiple pretrained backbones: ResNet, EfficientNet, MobileNet
"""
import torch
import torch.nn as nn
import torchvision.models as models
class PestDiseaseClassifier(nn.Module):
"""
General classifier with pretrained backbone for transfer learning
"""
def __init__(self, num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
"""
Args:
num_classes (int): Number of output classes
backbone (str): Backbone architecture ('resnet50', 'resnet101', 'efficientnet_b0',
'efficientnet_b3', 'mobilenet_v2')
pretrained (bool): Use pretrained weights
dropout (float): Dropout rate for regularization
"""
super(PestDiseaseClassifier, self).__init__()
self.backbone_name = backbone
self.num_classes = num_classes
# Select backbone
if backbone == 'resnet50':
self.backbone = models.resnet50(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif backbone == 'resnet101':
self.backbone = models.resnet101(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif backbone == 'efficientnet_b0':
self.backbone = models.efficientnet_b0(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
elif backbone == 'efficientnet_b3':
self.backbone = models.efficientnet_b3(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
elif backbone == 'mobilenet_v2':
self.backbone = models.mobilenet_v2(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
else:
raise ValueError(f"Unknown backbone: {backbone}")
# Custom classifier head
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(num_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(512, num_classes)
)
print(f"Model created: {backbone}")
print(f" Features: {num_features}")
print(f" Classes: {num_classes}")
print(f" Pretrained: {pretrained}")
def forward(self, x):
"""
Forward pass
Args:
x: Input tensor [batch_size, 3, H, W]
Returns:
logits: Output tensor [batch_size, num_classes]
"""
features = self.backbone(x)
logits = self.classifier(features)
return logits
def freeze_backbone(self):
"""Freeze backbone parameters for fine-tuning"""
for param in self.backbone.parameters():
param.requires_grad = False
print("Backbone frozen")
def unfreeze_backbone(self):
"""Unfreeze backbone parameters"""
for param in self.backbone.parameters():
param.requires_grad = True
print("Backbone unfrozen")
def create_model(num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
"""
Factory function to create model
Args:
num_classes (int): Number of classes
backbone (str): Model architecture
pretrained (bool): Use pretrained weights
dropout (float): Dropout rate
Returns:
model: PestDiseaseClassifier instance
"""
model = PestDiseaseClassifier(
num_classes=num_classes,
backbone=backbone,
pretrained=pretrained,
dropout=dropout
)
return model
def count_parameters(model):
"""Count total and trainable parameters"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters:")
print(f" Total: {total_params:,}")
print(f" Trainable: {trainable_params:,}")
print(f" Non-trainable: {total_params - trainable_params:,}")
return total_params, trainable_params
if __name__ == "__main__":
"""Test model creation"""
print("Testing Pest and Disease Classification Models")
print("=" * 60)
# Test different backbones
backbones = ['resnet50', 'efficientnet_b0', 'mobilenet_v2']
for backbone in backbones:
print(f"\nTesting {backbone}...")
print("-" * 60)
model = create_model(num_classes=10, backbone=backbone, pretrained=True)
count_parameters(model)
# Test forward pass
dummy_input = torch.randn(2, 3, 224, 224)
with torch.no_grad():
output = model(dummy_input)
print(f" Input shape: {dummy_input.shape}")
print(f" Output shape: {output.shape}")
print(f" Output range: [{output.min():.3f}, {output.max():.3f}]")
print("\n" + "=" * 60)
print("Model test completed successfully!")
print("=" * 60)