JAMM032's picture
Upload github repo files
97fcc90 verified
import torch.nn as nn
def conv_block(cin: int, cout: int, kernel_size: int = 3, padding: int = 1, p_drop: float = 0.1) -> nn.Sequential:
"""
A standard convolutional block comprising Conv2d, BatchNorm2d, ReLU, MaxPool2d, and Dropout2d.
This follows the well known best-practice of applying regularisation and downsampling within the feature extractor.
"""
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=kernel_size, padding=padding, bias=False),
nn.BatchNorm2d(cout),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Dropout2d(p_drop)
)
class PlantCNN(nn.Module):
"""
A simple CNN architecture designed for the PlantVillage dataset.
This model is intentionally kept simple as a baseline.
It implements common-sense architectural choices:
1. Progressively increases channel depth (3 -> 32 -> 64 -> 128).
2. Reduces spatial resolution at each block via MaxPooling.
3. Uses a two-layer dense head for improved classification.
The model also includes an adaptive average pool.
"""
def __init__(self, num_classes: int = 39, p_drop: float = 0.5):
super().__init__()
# FE. Progressively increase the channel depth while halving spatial resolution.
self.features = nn.Sequential(
conv_block(3, 32),
conv_block(32, 64),
conv_block(64, 128),
)
# GAP. Creates fixed-size FV for the classifier head.
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# CH. Two-layer super dense head.
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(p_drop),
nn.Linear(64, num_classes),
)
def forward(self, x):
"""Forward pass of the model."""
x = self.features(x)
x = self.avgpool(x)
x = self.classifier(x)
return x