File size: 2,458 Bytes
745cefc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
EuroSAT Classifier β€” Gradio demo for Hugging Face Spaces.
Upload a satellite image β†’ get land-use class predictions.
"""

import torch
import gradio as gr
from torchvision import transforms
from huggingface_hub import hf_hub_download
from PIL import Image

from model import SimpleNet, CLASS_NAMES


# ── Load model ────────────────────────────────────────────────────────

def load_model():
    """Download weights from HF Hub and load into SimpleNet."""
    # TODO: replace with your actual HF repo id after upload
    weights_path = hf_hub_download(
        repo_id="yava-code/eurosat-simplenet",
        filename="simple_net_v1.pth",
    )
    model = SimpleNet(num_classes=10)
    model.load_state_dict(torch.load(weights_path, map_location="cpu"))
    model.eval()
    return model


model = load_model()

preprocess = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])


# ── Inference ─────────────────────────────────────────────────────────

def predict(image: Image.Image) -> dict[str, float]:
    """Return class probabilities for a satellite image."""
    if image is None:
        return {}

    tensor = preprocess(image).unsqueeze(0)  # [1, 3, 64, 64]

    with torch.no_grad():
        logits = model(tensor)
        probs = torch.nn.functional.softmax(logits, dim=1)[0]

    return {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}


# ── Gradio UI ─────────────────────────────────────────────────────────

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Satellite Image"),
    outputs=gr.Label(num_top_classes=5, label="Predictions"),
    title="πŸ›°οΈ EuroSAT Land-Use Classifier",
    description=(
        "Upload a Sentinel-2 satellite image to classify its land-use type. "
        "Custom CNN (SimpleNet, ~850K params) trained from scratch on EuroSAT."
    ),
    examples=[],  # add example images if you want
    theme=gr.themes.Soft(),
)

if __name__ == "__main__":
    demo.launch()