yava-code commited on
Commit
745cefc
Β·
verified Β·
1 Parent(s): 0d7ca47

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EuroSAT Classifier β€” Gradio demo for Hugging Face Spaces.
3
+ Upload a satellite image β†’ get land-use class predictions.
4
+ """
5
+
6
+ import torch
7
+ import gradio as gr
8
+ from torchvision import transforms
9
+ from huggingface_hub import hf_hub_download
10
+ from PIL import Image
11
+
12
+ from model import SimpleNet, CLASS_NAMES
13
+
14
+
15
+ # ── Load model ────────────────────────────────────────────────────────
16
+
17
+ def load_model():
18
+ """Download weights from HF Hub and load into SimpleNet."""
19
+ # TODO: replace with your actual HF repo id after upload
20
+ weights_path = hf_hub_download(
21
+ repo_id="yava-code/eurosat-simplenet",
22
+ filename="simple_net_v1.pth",
23
+ )
24
+ model = SimpleNet(num_classes=10)
25
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
26
+ model.eval()
27
+ return model
28
+
29
+
30
+ model = load_model()
31
+
32
+ preprocess = transforms.Compose([
33
+ transforms.Resize((64, 64)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
36
+ std=[0.229, 0.224, 0.225]),
37
+ ])
38
+
39
+
40
+ # ── Inference ─────────────────────────────────────────────────────────
41
+
42
+ def predict(image: Image.Image) -> dict[str, float]:
43
+ """Return class probabilities for a satellite image."""
44
+ if image is None:
45
+ return {}
46
+
47
+ tensor = preprocess(image).unsqueeze(0) # [1, 3, 64, 64]
48
+
49
+ with torch.no_grad():
50
+ logits = model(tensor)
51
+ probs = torch.nn.functional.softmax(logits, dim=1)[0]
52
+
53
+ return {CLASS_NAMES[i]: float(probs[i]) for i in range(len(CLASS_NAMES))}
54
+
55
+
56
+ # ── Gradio UI ─────────────────────────────────────────────────────────
57
+
58
+ demo = gr.Interface(
59
+ fn=predict,
60
+ inputs=gr.Image(type="pil", label="Satellite Image"),
61
+ outputs=gr.Label(num_top_classes=5, label="Predictions"),
62
+ title="πŸ›°οΈ EuroSAT Land-Use Classifier",
63
+ description=(
64
+ "Upload a Sentinel-2 satellite image to classify its land-use type. "
65
+ "Custom CNN (SimpleNet, ~850K params) trained from scratch on EuroSAT."
66
+ ),
67
+ examples=[], # add example images if you want
68
+ theme=gr.themes.Soft(),
69
+ )
70
+
71
+ if __name__ == "__main__":
72
+ demo.launch()