Spaces:
Sleeping
Sleeping
File size: 5,208 Bytes
03d5bce | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | """
Classification Models for Pest and Disease Detection
Supports multiple pretrained backbones: ResNet, EfficientNet, MobileNet
"""
import torch
import torch.nn as nn
import torchvision.models as models
class PestDiseaseClassifier(nn.Module):
"""
General classifier with pretrained backbone for transfer learning
"""
def __init__(self, num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
"""
Args:
num_classes (int): Number of output classes
backbone (str): Backbone architecture ('resnet50', 'resnet101', 'efficientnet_b0',
'efficientnet_b3', 'mobilenet_v2')
pretrained (bool): Use pretrained weights
dropout (float): Dropout rate for regularization
"""
super(PestDiseaseClassifier, self).__init__()
self.backbone_name = backbone
self.num_classes = num_classes
# Select backbone
if backbone == 'resnet50':
self.backbone = models.resnet50(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif backbone == 'resnet101':
self.backbone = models.resnet101(pretrained=pretrained)
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity()
elif backbone == 'efficientnet_b0':
self.backbone = models.efficientnet_b0(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
elif backbone == 'efficientnet_b3':
self.backbone = models.efficientnet_b3(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
elif backbone == 'mobilenet_v2':
self.backbone = models.mobilenet_v2(pretrained=pretrained)
num_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Identity()
else:
raise ValueError(f"Unknown backbone: {backbone}")
# Custom classifier head
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(num_features, 512),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(512, num_classes)
)
print(f"Model created: {backbone}")
print(f" Features: {num_features}")
print(f" Classes: {num_classes}")
print(f" Pretrained: {pretrained}")
def forward(self, x):
"""
Forward pass
Args:
x: Input tensor [batch_size, 3, H, W]
Returns:
logits: Output tensor [batch_size, num_classes]
"""
features = self.backbone(x)
logits = self.classifier(features)
return logits
def freeze_backbone(self):
"""Freeze backbone parameters for fine-tuning"""
for param in self.backbone.parameters():
param.requires_grad = False
print("Backbone frozen")
def unfreeze_backbone(self):
"""Unfreeze backbone parameters"""
for param in self.backbone.parameters():
param.requires_grad = True
print("Backbone unfrozen")
def create_model(num_classes=10, backbone='resnet50', pretrained=True, dropout=0.3):
"""
Factory function to create model
Args:
num_classes (int): Number of classes
backbone (str): Model architecture
pretrained (bool): Use pretrained weights
dropout (float): Dropout rate
Returns:
model: PestDiseaseClassifier instance
"""
model = PestDiseaseClassifier(
num_classes=num_classes,
backbone=backbone,
pretrained=pretrained,
dropout=dropout
)
return model
def count_parameters(model):
"""Count total and trainable parameters"""
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel Parameters:")
print(f" Total: {total_params:,}")
print(f" Trainable: {trainable_params:,}")
print(f" Non-trainable: {total_params - trainable_params:,}")
return total_params, trainable_params
if __name__ == "__main__":
"""Test model creation"""
print("Testing Pest and Disease Classification Models")
print("=" * 60)
# Test different backbones
backbones = ['resnet50', 'efficientnet_b0', 'mobilenet_v2']
for backbone in backbones:
print(f"\nTesting {backbone}...")
print("-" * 60)
model = create_model(num_classes=10, backbone=backbone, pretrained=True)
count_parameters(model)
# Test forward pass
dummy_input = torch.randn(2, 3, 224, 224)
with torch.no_grad():
output = model(dummy_input)
print(f" Input shape: {dummy_input.shape}")
print(f" Output shape: {output.shape}")
print(f" Output range: [{output.min():.3f}, {output.max():.3f}]")
print("\n" + "=" * 60)
print("Model test completed successfully!")
print("=" * 60)
|