File size: 2,458 Bytes
745cefc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """
EuroSAT Classifier β Gradio demo for Hugging Face Spaces.
Upload a satellite image β get land-use class predictions.
"""
import torch
import gradio as gr
from torchvision import transforms
from huggingface_hub import hf_hub_download
from PIL import Image
from model import SimpleNet, CLASS_NAMES
# ββ Load model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_model():
"""Download weights from HF Hub and load into SimpleNet."""
# TODO: replace with your actual HF repo id after upload
weights_path = hf_hub_download(
repo_id="yava-code/eurosat-simplenet",
filename="simple_net_v1.pth",
)
model = SimpleNet(num_classes=10)
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
model.eval()
return model
model = load_model()
preprocess = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ββ Inference βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict(image: Image.Image) -> dict[str, float]:
"""Return class probabilities for a satellite image."""
if image is None:
return {}
tensor = preprocess(image).unsqueeze(0) # [1, 3, 64, 64]
with torch.no_grad():
logits = model(tensor)
probs = torch.nn.functional.softmax(logits, dim=1)[0]
return {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
# ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Satellite Image"),
outputs=gr.Label(num_top_classes=5, label="Predictions"),
title="π°οΈ EuroSAT Land-Use Classifier",
description=(
"Upload a Sentinel-2 satellite image to classify its land-use type. "
"Custom CNN (SimpleNet, ~850K params) trained from scratch on EuroSAT."
),
examples=[], # add example images if you want
theme=gr.themes.Soft(),
)
if __name__ == "__main__":
demo.launch()
|