mgbam commited on
Commit
e97ac39
·
verified ·
1 Parent(s): ce9678d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -44
app.py CHANGED
@@ -1,54 +1,229 @@
 
 
 
 
 
1
  import torch
2
- import torch.nn as nn
3
- from torchvision import models, transforms
4
- from torch.utils.data import DataLoader
5
  from PIL import Image
6
- import gradio as gr
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
- # Load class names (make sure this file is in the Space)
11
- with open("cifar10_classes.txt") as f:
12
- CLASSES = [line.strip() for line in f.readlines()]
13
-
14
- def build_model(num_classes: int, device: str = "cpu"):
15
- try:
16
- weights = models.ResNet18_Weights.DEFAULT
17
- model = models.resnet18(weights=weights)
18
- except AttributeError:
19
- model = models.resnet18(weights="IMAGENET1K_V1")
20
- model.fc = nn.Linear(model.fc.in_features, num_classes)
21
- model = model.to(device)
22
- return model
23
-
24
- num_classes = len(CLASSES)
25
- model = build_model(num_classes, device=DEVICE)
26
-
27
- state_dict = torch.load("ast_cifar10_resnet18.pth", map_location=DEVICE)
28
- model.load_state_dict(state_dict)
29
- model.eval()
30
-
31
- preprocess = transforms.Compose([
32
- transforms.Resize((224, 224)),
33
- transforms.ToTensor(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ])
35
 
36
- def predict(image: Image.Image):
37
- if image is None:
38
- return {}
39
- x = preprocess(image).unsqueeze(0).to(DEVICE)
40
- with torch.no_grad():
41
- logits = model(x)
42
- probs = torch.softmax(logits, dim=1)[0]
43
- return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}
44
-
45
- demo = gr.Interface(
46
- fn=predict,
47
- inputs=gr.Image(type="pil", label="Upload CIFAR-like image"),
48
- outputs=gr.Label(num_top_classes=3, label="Top-3 Predictions"),
49
- title="AST CIFAR-10 Classifier",
50
- description="ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.",
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
53
  if __name__ == "__main__":
54
  demo.launch()
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Optional, Tuple, Dict, List
4
+
5
+ import gradio as gr
6
  import torch
7
+ import torch.nn.functional as F
 
 
8
  from PIL import Image
 
9
 
10
+ import torchvision.transforms as T
11
+ from torchvision.models import resnet18
12
+
13
+ # -----------------------------
14
+ # Config
15
+ # -----------------------------
16
+ CIFAR10_CLASSES = [
17
+ "airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"
18
+ ]
19
+
20
+ # CIFAR-10 normalization (standard)
21
+ CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
22
+ CIFAR10_STD = (0.2470, 0.2435, 0.2616)
23
+
24
+ EXAMPLES_DIR = Path("Examples") # you uploaded to this folder
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
+ # -----------------------------
28
+ # Utilities
29
+ # -----------------------------
30
+ def find_checkpoint(repo_root: Path) -> Optional[Path]:
31
+ """
32
+ Auto-find a checkpoint file in the repo.
33
+ Add your own filename here if you know it.
34
+ """
35
+ candidates = [
36
+ "model.pth", "model.pt",
37
+ "checkpoint.pth", "checkpoint.pt",
38
+ "best.pth", "best.pt",
39
+ "resnet18.pth", "resnet18.pt",
40
+ "weights.pth", "weights.pt",
41
+ ]
42
+
43
+ for name in candidates:
44
+ p = repo_root / name
45
+ if p.exists() and p.is_file():
46
+ return p
47
+
48
+ # Try pattern search
49
+ patterns = ["*.pth", "*.pt"]
50
+ for pat in patterns:
51
+ hits = sorted(repo_root.glob(pat))
52
+ # Prefer anything that looks like resnet/cifar/ast
53
+ preferred = [h for h in hits if any(k in h.name.lower() for k in ["resnet", "cifar", "ast", "sparse", "best"])]
54
+ if preferred:
55
+ return preferred[0]
56
+ if hits:
57
+ return hits[0]
58
+
59
+ return None
60
+
61
+
62
+ def build_model(num_classes: int = 10) -> torch.nn.Module:
63
+ m = resnet18(weights=None)
64
+ m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
65
+ return m
66
+
67
+
68
+ def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None:
69
+ """
70
+ Loads common checkpoint formats:
71
+ - plain state_dict
72
+ - dict with 'state_dict' or 'model' keys
73
+ """
74
+ ckpt = torch.load(ckpt_path, map_location="cpu")
75
+
76
+ if isinstance(ckpt, dict):
77
+ if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
78
+ state = ckpt["state_dict"]
79
+ elif "model" in ckpt and isinstance(ckpt["model"], dict):
80
+ state = ckpt["model"]
81
+ else:
82
+ # might already be a state_dict-like dict
83
+ state = ckpt
84
+ else:
85
+ raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
86
+
87
+ # Strip possible 'module.' prefix if trained with DDP/DataParallel
88
+ new_state = {}
89
+ for k, v in state.items():
90
+ nk = k.replace("module.", "")
91
+ new_state[nk] = v
92
+
93
+ missing, unexpected = model.load_state_dict(new_state, strict=False)
94
+ # Strict=False to be robust; you can change to strict=True if you prefer.
95
+ if missing or unexpected:
96
+ print("[load_weights] Missing keys:", missing)
97
+ print("[load_weights] Unexpected keys:", unexpected)
98
+
99
+
100
+ # -----------------------------
101
+ # Preprocess + Predict
102
+ # -----------------------------
103
+ preprocess = T.Compose([
104
+ T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR),
105
+ T.ToTensor(),
106
+ T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
107
  ])
108
 
109
+ def pil_to_model_tensor(img: Image.Image) -> torch.Tensor:
110
+ img = img.convert("RGB")
111
+ x = preprocess(img).unsqueeze(0) # [1,3,32,32]
112
+ return x
113
+
114
+ def predict(img: Image.Image):
115
+ if img is None:
116
+ return None, None, None
117
+
118
+ if STATE["model"] is None:
119
+ raise gr.Error("Model is not loaded. Check that your checkpoint exists in the Space repo.")
120
+
121
+ # Show exactly what goes into the model (32x32)
122
+ img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR)
123
+
124
+ x = pil_to_model_tensor(img).to(DEVICE)
125
+ with torch.inference_mode():
126
+ logits = STATE["model"](x)
127
+ probs = F.softmax(logits, dim=1).squeeze(0) # [10]
128
+
129
+ # Top-3
130
+ topk = torch.topk(probs, k=3)
131
+ top3 = [(CIFAR10_CLASSES[i], float(topk.values[j])) for j, i in enumerate(topk.indices.tolist())]
132
+
133
+ # Gradio Label expects dict label->confidence
134
+ label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)}
135
+
136
+ # Table for top-3
137
+ top3_table = [[name, f"{p*100:.2f}%"] for name, p in top3]
138
+
139
+ # Main prediction text
140
+ pred_name, pred_p = top3[0]
141
+ pred_text = f"**{pred_name}** ({pred_p*100:.2f}%)"
142
+
143
+ return img32, label_dict, top3_table, pred_text
144
+
145
+
146
+ # -----------------------------
147
+ # App state
148
+ # -----------------------------
149
+ STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None}
150
+
151
+ def init():
152
+ repo_root = Path(".")
153
+ ckpt = find_checkpoint(repo_root)
154
+ if ckpt is None:
155
+ print("[init] No checkpoint found in repo root.")
156
+ STATE["model"] = None
157
+ return
158
+
159
+ print(f"[init] Loading checkpoint: {ckpt}")
160
+ model = build_model(num_classes=len(CIFAR10_CLASSES))
161
+ load_weights(model, ckpt)
162
+ model.to(DEVICE).eval()
163
+ STATE["model"] = model
164
+
165
+ def get_examples() -> List[List[str]]:
166
+ if not EXAMPLES_DIR.exists():
167
+ return []
168
+ imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]])
169
+ # Gradio expects list of lists, each inner list corresponds to inputs
170
+ return [[str(p)] for p in imgs]
171
+
172
+ init()
173
+ EXAMPLES = get_examples()
174
+
175
+ # -----------------------------
176
+ # UI
177
+ # -----------------------------
178
+ with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
179
+ gr.Markdown(
180
+ "# AST CIFAR-10 Classifier\n"
181
+ "ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.\n\n"
182
+ f"**Device:** `{DEVICE}`"
183
+ )
184
+
185
+ with gr.Row():
186
+ with gr.Column(scale=1):
187
+ img_in = gr.Image(type="pil", label="Upload CIFAR-like image")
188
+
189
+ # Show the exact 32×32 fed to model (useful for debugging)
190
+ img_32 = gr.Image(type="pil", label="Model input (32×32)")
191
+
192
+ with gr.Column(scale=1):
193
+ gr.Markdown("### Top-3 Predictions")
194
+ pred_label = gr.Label(num_top_classes=3, label="Probabilities")
195
+ top3_table = gr.Dataframe(
196
+ headers=["class", "confidence"],
197
+ datatype=["str", "str"],
198
+ row_count=3,
199
+ col_count=(2, "fixed"),
200
+ interactive=False,
201
+ label="Top-3"
202
+ )
203
+ pred_text = gr.Markdown()
204
+
205
+ with gr.Row():
206
+ submit = gr.Button("Submit", variant="primary")
207
+ clear = gr.Button("Clear")
208
+
209
+ if EXAMPLES:
210
+ gr.Markdown("### Examples (from `Examples/` folder)")
211
+ gr.Examples(
212
+ examples=EXAMPLES,
213
+ inputs=[img_in],
214
+ cache_examples=True
215
+ )
216
+
217
+ submit.click(
218
+ fn=predict,
219
+ inputs=[img_in],
220
+ outputs=[img_32, pred_label, top3_table, pred_text]
221
+ )
222
+
223
+ def _clear():
224
+ return None, None, None, ""
225
+ clear.click(fn=_clear, inputs=[], outputs=[img_in, img_32, top3_table, pred_text])
226
 
227
+ demo.queue()
228
  if __name__ == "__main__":
229
  demo.launch()