eurostat / model.py
yava-code's picture
Upload model.py with huggingface_hub
3a20500 verified
"""
SimpleNet β€” lightweight CNN for EuroSAT satellite image classification.
4 convolutional blocks (double-and-halve pattern) + FC classifier.
Input: 3Γ—64Γ—64 RGB | Output: 10 land-use classes | ~850K parameters
"""
import torch
import torch.nn as nn
CLASS_NAMES = [
"AnnualCrop", "Forest", "HerbaceousVegetation", "Highway",
"Industrial", "Pasture", "PermanentCrop", "Residential",
"River", "SeaLake"
]
class SimpleNet(nn.Module):
def __init__(self, num_classes: int = 10):
super().__init__()
# 64Γ—64 β†’ 32Γ—32
self.block1 = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 32Γ—32 β†’ 16Γ—16
self.block2 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 16Γ—16 β†’ 8Γ—8
self.block3 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2),
)
# 8Γ—8 β†’ 4Γ—4
self.block4 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 4 * 4, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
return self.classifier(x)