Islam Mamedov commited on
Commit
b7265bc
·
1 Parent(s): 0fba6d4

Minimal Space: requirements, model, app, placeholder classes

Browse files
Files changed (4) hide show
  1. app.py +65 -0
  2. classes.txt +2 -0
  3. model.py +9 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from pathlib import Path
3
+ import gradio as gr
4
+ import torch, torch.nn.functional as F
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+ from model import build_model
8
+
9
+ ROOT = Path(__file__).resolve().parent
10
+ CKPT = ROOT / "ckpt_final320" / "best.pt"
11
+ CLASSES_TXT = ROOT / "classes.txt"
12
+
13
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
14
+ IMAGENET_STD = (0.229, 0.224, 0.225)
15
+ RES = 320
16
+
17
+ def _load_classes(p: Path):
18
+ return [ln.strip() for ln in p.read_text().splitlines() if ln.strip()] if p.exists() else []
19
+
20
+ CLASSES = _load_classes(CLASSES_TXT)
21
+ MODEL = None
22
+
23
+ def _maybe_load_model():
24
+ global MODEL, CLASSES
25
+ if MODEL is not None:
26
+ return
27
+ CLASSES = _load_classes(CLASSES_TXT) or ["class_0","class_1"]
28
+ m = build_model(len(CLASSES), pretrained=False)
29
+ if CKPT.exists():
30
+ sd = torch.load(CKPT, map_location="cpu")
31
+ sd = sd["model"] if isinstance(sd, dict) and "model" in sd else sd
32
+ m.load_state_dict(sd, strict=False)
33
+ m.eval()
34
+ MODEL = m
35
+
36
+ TFM = transforms.Compose([
37
+ transforms.Resize(int(RES*256/224)),
38
+ transforms.CenterCrop(RES),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
41
+ ])
42
+
43
+ def predict(img: Image.Image):
44
+ _maybe_load_model()
45
+ if img is None:
46
+ return "Please upload an image."
47
+ x = TFM(img.convert("RGB")).unsqueeze(0)
48
+ with torch.inference_mode():
49
+ logits = MODEL(x)
50
+ probs = F.softmax(logits, dim=1)[0]
51
+ i = int(torch.argmax(probs).item())
52
+ p = float(probs[i].item())
53
+ name = CLASSES[i] if i < len(CLASSES) else f"class_{i}"
54
+ note = "" if CKPT.exists() else "_Running without checkpoint (demo mode). Upload ckpt_final320/best.pt to get real predictions._"
55
+ return f"**Prediction:** {name} — **{p*100:.2f}%**\n\n{note}"
56
+
57
+ with gr.Blocks(title="Bird Classifier (Minimal)") as demo:
58
+ gr.Markdown("### Bird Classifier — Minimal Space\nUpload an image and predict. (Works even without checkpoint.)")
59
+ img = gr.Image(type="pil", label="Image")
60
+ btn = gr.Button("Predict", variant="primary")
61
+ out = gr.Markdown()
62
+ btn.click(predict, inputs=[img], outputs=[out], show_progress="full")
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch()
classes.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ bird_a
2
+ bird_b
model.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet50
4
+
5
+ def build_model(num_classes: int, pretrained: bool = False) -> nn.Module:
6
+ m = resnet50(pretrained=pretrained)
7
+ in_features = m.fc.in_features
8
+ m.fc = nn.Linear(in_features, num_classes)
9
+ return m
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==4.44.0
2
+ torch
3
+ torchvision
4
+ numpy
5
+ pillow
6
+ matplotlib
7
+ scikit-learn