mgbam commited on
Commit
e81731d
·
verified ·
1 Parent(s): e97ac39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -107
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import os
2
  from pathlib import Path
3
- from typing import Optional, Tuple, Dict, List
4
 
5
  import gradio as gr
6
  import torch
@@ -14,63 +13,28 @@ from torchvision.models import resnet18
14
  # Config
15
  # -----------------------------
16
  CIFAR10_CLASSES = [
17
- "airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"
 
18
  ]
19
 
20
- # CIFAR-10 normalization (standard)
21
  CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
22
  CIFAR10_STD = (0.2470, 0.2435, 0.2616)
23
 
24
- EXAMPLES_DIR = Path("Examples") # you uploaded to this folder
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
 
 
 
27
  # -----------------------------
28
- # Utilities
29
  # -----------------------------
30
- def find_checkpoint(repo_root: Path) -> Optional[Path]:
31
- """
32
- Auto-find a checkpoint file in the repo.
33
- Add your own filename here if you know it.
34
- """
35
- candidates = [
36
- "model.pth", "model.pt",
37
- "checkpoint.pth", "checkpoint.pt",
38
- "best.pth", "best.pt",
39
- "resnet18.pth", "resnet18.pt",
40
- "weights.pth", "weights.pt",
41
- ]
42
-
43
- for name in candidates:
44
- p = repo_root / name
45
- if p.exists() and p.is_file():
46
- return p
47
-
48
- # Try pattern search
49
- patterns = ["*.pth", "*.pt"]
50
- for pat in patterns:
51
- hits = sorted(repo_root.glob(pat))
52
- # Prefer anything that looks like resnet/cifar/ast
53
- preferred = [h for h in hits if any(k in h.name.lower() for k in ["resnet", "cifar", "ast", "sparse", "best"])]
54
- if preferred:
55
- return preferred[0]
56
- if hits:
57
- return hits[0]
58
-
59
- return None
60
-
61
-
62
  def build_model(num_classes: int = 10) -> torch.nn.Module:
63
  m = resnet18(weights=None)
64
  m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
65
  return m
66
 
67
-
68
  def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None:
69
- """
70
- Loads common checkpoint formats:
71
- - plain state_dict
72
- - dict with 'state_dict' or 'model' keys
73
- """
74
  ckpt = torch.load(ckpt_path, map_location="cpu")
75
 
76
  if isinstance(ckpt, dict):
@@ -79,26 +43,19 @@ def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None:
79
  elif "model" in ckpt and isinstance(ckpt["model"], dict):
80
  state = ckpt["model"]
81
  else:
82
- # might already be a state_dict-like dict
83
  state = ckpt
84
  else:
85
  raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
86
 
87
- # Strip possible 'module.' prefix if trained with DDP/DataParallel
88
- new_state = {}
89
- for k, v in state.items():
90
- nk = k.replace("module.", "")
91
- new_state[nk] = v
92
-
93
- missing, unexpected = model.load_state_dict(new_state, strict=False)
94
- # Strict=False to be robust; you can change to strict=True if you prefer.
95
  if missing or unexpected:
96
  print("[load_weights] Missing keys:", missing)
97
  print("[load_weights] Unexpected keys:", unexpected)
98
 
99
-
100
  # -----------------------------
101
- # Preprocess + Predict
102
  # -----------------------------
103
  preprocess = T.Compose([
104
  T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR),
@@ -106,75 +63,68 @@ preprocess = T.Compose([
106
  T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
107
  ])
108
 
109
- def pil_to_model_tensor(img: Image.Image) -> torch.Tensor:
110
- img = img.convert("RGB")
111
- x = preprocess(img).unsqueeze(0) # [1,3,32,32]
112
- return x
 
 
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def predict(img: Image.Image):
115
  if img is None:
116
- return None, None, None
117
 
118
  if STATE["model"] is None:
119
- raise gr.Error("Model is not loaded. Check that your checkpoint exists in the Space repo.")
120
 
121
- # Show exactly what goes into the model (32x32)
122
  img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR)
123
 
124
- x = pil_to_model_tensor(img).to(DEVICE)
125
  with torch.inference_mode():
126
  logits = STATE["model"](x)
127
  probs = F.softmax(logits, dim=1).squeeze(0) # [10]
128
 
129
- # Top-3
130
- topk = torch.topk(probs, k=3)
131
- top3 = [(CIFAR10_CLASSES[i], float(topk.values[j])) for j, i in enumerate(topk.indices.tolist())]
132
-
133
- # Gradio Label expects dict label->confidence
134
  label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)}
135
 
136
- # Table for top-3
137
- top3_table = [[name, f"{p*100:.2f}%"] for name, p in top3]
 
 
 
138
 
139
- # Main prediction text
140
- pred_name, pred_p = top3[0]
141
- pred_text = f"**{pred_name}** ({pred_p*100:.2f}%)"
142
 
143
- return img32, label_dict, top3_table, pred_text
144
 
 
 
145
 
146
  # -----------------------------
147
- # App state
148
  # -----------------------------
149
- STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None}
150
-
151
- def init():
152
- repo_root = Path(".")
153
- ckpt = find_checkpoint(repo_root)
154
- if ckpt is None:
155
- print("[init] No checkpoint found in repo root.")
156
- STATE["model"] = None
157
- return
158
-
159
- print(f"[init] Loading checkpoint: {ckpt}")
160
- model = build_model(num_classes=len(CIFAR10_CLASSES))
161
- load_weights(model, ckpt)
162
- model.to(DEVICE).eval()
163
- STATE["model"] = model
164
-
165
- def get_examples() -> List[List[str]]:
166
- if not EXAMPLES_DIR.exists():
167
- return []
168
- imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]])
169
- # Gradio expects list of lists, each inner list corresponds to inputs
170
- return [[str(p)] for p in imgs]
171
-
172
  init()
173
  EXAMPLES = get_examples()
174
 
175
- # -----------------------------
176
- # UI
177
- # -----------------------------
178
  with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
179
  gr.Markdown(
180
  "# AST CIFAR-10 Classifier\n"
@@ -185,8 +135,6 @@ with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
185
  with gr.Row():
186
  with gr.Column(scale=1):
187
  img_in = gr.Image(type="pil", label="Upload CIFAR-like image")
188
-
189
- # Show the exact 32×32 fed to model (useful for debugging)
190
  img_32 = gr.Image(type="pil", label="Model input (32×32)")
191
 
192
  with gr.Column(scale=1):
@@ -196,7 +144,7 @@ with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
196
  headers=["class", "confidence"],
197
  datatype=["str", "str"],
198
  row_count=3,
199
- col_count=(2, "fixed"),
200
  interactive=False,
201
  label="Top-3"
202
  )
@@ -206,11 +154,14 @@ with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
206
  submit = gr.Button("Submit", variant="primary")
207
  clear = gr.Button("Clear")
208
 
 
209
  if EXAMPLES:
210
  gr.Markdown("### Examples (from `Examples/` folder)")
211
  gr.Examples(
212
  examples=EXAMPLES,
213
  inputs=[img_in],
 
 
214
  cache_examples=True
215
  )
216
 
@@ -220,9 +171,11 @@ with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
220
  outputs=[img_32, pred_label, top3_table, pred_text]
221
  )
222
 
223
- def _clear():
224
- return None, None, None, ""
225
- clear.click(fn=_clear, inputs=[], outputs=[img_in, img_32, top3_table, pred_text])
 
 
226
 
227
  demo.queue()
228
  if __name__ == "__main__":
 
 
1
  from pathlib import Path
2
+ from typing import Optional, Dict, List
3
 
4
  import gradio as gr
5
  import torch
 
13
  # Config
14
  # -----------------------------
15
  CIFAR10_CLASSES = [
16
+ "airplane", "automobile", "bird", "cat", "deer",
17
+ "dog", "frog", "horse", "ship", "truck"
18
  ]
19
 
 
20
  CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
21
  CIFAR10_STD = (0.2470, 0.2435, 0.2616)
22
 
23
+ EXAMPLES_DIR = Path("Examples")
24
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
+ # If you know the exact checkpoint name, lock it here:
27
+ CKPT_PATH = Path("ast_cifar10_resnet18.pth")
28
+
29
  # -----------------------------
30
+ # Model helpers
31
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def build_model(num_classes: int = 10) -> torch.nn.Module:
33
  m = resnet18(weights=None)
34
  m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
35
  return m
36
 
 
37
  def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None:
 
 
 
 
 
38
  ckpt = torch.load(ckpt_path, map_location="cpu")
39
 
40
  if isinstance(ckpt, dict):
 
43
  elif "model" in ckpt and isinstance(ckpt["model"], dict):
44
  state = ckpt["model"]
45
  else:
 
46
  state = ckpt
47
  else:
48
  raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
49
 
50
+ # Remove "module." if saved from DDP
51
+ cleaned = {k.replace("module.", ""): v for k, v in state.items()}
52
+ missing, unexpected = model.load_state_dict(cleaned, strict=False)
 
 
 
 
 
53
  if missing or unexpected:
54
  print("[load_weights] Missing keys:", missing)
55
  print("[load_weights] Unexpected keys:", unexpected)
56
 
 
57
  # -----------------------------
58
+ # Preprocess
59
  # -----------------------------
60
  preprocess = T.Compose([
61
  T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR),
 
63
  T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
64
  ])
65
 
66
+ STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None}
67
+
68
+ def init():
69
+ if not CKPT_PATH.exists():
70
+ print(f"[init] Checkpoint not found: {CKPT_PATH}")
71
+ STATE["model"] = None
72
+ return
73
 
74
+ print(f"[init] Loading checkpoint: {CKPT_PATH}")
75
+ model = build_model(num_classes=len(CIFAR10_CLASSES))
76
+ load_weights(model, CKPT_PATH)
77
+ model.to(DEVICE).eval()
78
+ STATE["model"] = model
79
+
80
+ def get_examples() -> List[List[str]]:
81
+ if not EXAMPLES_DIR.exists():
82
+ return []
83
+ imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]])
84
+ return [[str(p)] for p in imgs]
85
+
86
+ # -----------------------------
87
+ # Predict
88
+ # -----------------------------
89
  def predict(img: Image.Image):
90
  if img is None:
91
+ return None, {}, [["", ""], ["", ""], ["", ""]], ""
92
 
93
  if STATE["model"] is None:
94
+ raise gr.Error("Model is not loaded. Ensure ast_cifar10_resnet18.pth exists in the repo root.")
95
 
96
+ # show the actual 32x32 that goes into model
97
  img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR)
98
 
99
+ x = preprocess(img.convert("RGB")).unsqueeze(0).to(DEVICE) # [1,3,32,32]
100
  with torch.inference_mode():
101
  logits = STATE["model"](x)
102
  probs = F.softmax(logits, dim=1).squeeze(0) # [10]
103
 
104
+ # label dict for gr.Label
 
 
 
 
105
  label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)}
106
 
107
+ # top-3 table
108
+ topk = torch.topk(probs, k=3)
109
+ top3_rows = []
110
+ for j, idx in enumerate(topk.indices.tolist()):
111
+ top3_rows.append([CIFAR10_CLASSES[idx], f"{float(topk.values[j]) * 100:.2f}%"])
112
 
113
+ pred_name = CIFAR10_CLASSES[int(topk.indices[0])]
114
+ pred_conf = float(topk.values[0]) * 100.0
115
+ pred_text = f"**{pred_name}** ({pred_conf:.2f}%)"
116
 
117
+ return img32, label_dict, top3_rows, pred_text
118
 
119
+ def clear_all():
120
+ return None, None, {}, [["", ""], ["", ""], ["", ""]], ""
121
 
122
  # -----------------------------
123
+ # App
124
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  init()
126
  EXAMPLES = get_examples()
127
 
 
 
 
128
  with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
129
  gr.Markdown(
130
  "# AST CIFAR-10 Classifier\n"
 
135
  with gr.Row():
136
  with gr.Column(scale=1):
137
  img_in = gr.Image(type="pil", label="Upload CIFAR-like image")
 
 
138
  img_32 = gr.Image(type="pil", label="Model input (32×32)")
139
 
140
  with gr.Column(scale=1):
 
144
  headers=["class", "confidence"],
145
  datatype=["str", "str"],
146
  row_count=3,
147
+ column_count=2, # <-- fixed (no deprecated col_count)
148
  interactive=False,
149
  label="Top-3"
150
  )
 
154
  submit = gr.Button("Submit", variant="primary")
155
  clear = gr.Button("Clear")
156
 
157
+ # ✅ FIX: if cache_examples=True, you MUST provide fn and outputs
158
  if EXAMPLES:
159
  gr.Markdown("### Examples (from `Examples/` folder)")
160
  gr.Examples(
161
  examples=EXAMPLES,
162
  inputs=[img_in],
163
+ outputs=[img_32, pred_label, top3_table, pred_text],
164
+ fn=predict,
165
  cache_examples=True
166
  )
167
 
 
171
  outputs=[img_32, pred_label, top3_table, pred_text]
172
  )
173
 
174
+ clear.click(
175
+ fn=clear_all,
176
+ inputs=[],
177
+ outputs=[img_in, img_32, pred_label, top3_table, pred_text]
178
+ )
179
 
180
  demo.queue()
181
  if __name__ == "__main__":