File size: 1,800 Bytes
3a20500 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | """
SimpleNet — lightweight CNN for EuroSAT satellite image classification.
4 convolutional blocks (double-and-halve pattern) + FC classifier.
Input: 3×64×64 RGB | Output: 10 land-use classes | ~850K parameters
"""
import torch
import torch.nn as nn
CLASS_NAMES = [
"AnnualCrop", "Forest", "HerbaceousVegetation", "Highway",
"Industrial", "Pasture", "PermanentCrop", "Residential",
"River", "SeaLake"
]
class SimpleNet(nn.Module):
def __init__(self, num_classes: int = 10):
super().__init__()
# 64×64 → 32×32
self.block1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 32×32 → 16×16
self.block2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 16×16 → 8×8
self.block3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 8×8 → 4×4
self.block4 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 4 * 4, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
return self.classifier(x)
|