Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Optional, Dict, List | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import torchvision.transforms as T | |
| from torchvision.models import resnet18 | |
| # ----------------------------- | |
| # Config | |
| # ----------------------------- | |
| CIFAR10_CLASSES = [ | |
| "airplane", "automobile", "bird", "cat", "deer", | |
| "dog", "frog", "horse", "ship", "truck" | |
| ] | |
| CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) | |
| CIFAR10_STD = (0.2470, 0.2435, 0.2616) | |
| EXAMPLES_DIR = Path("Examples") | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # If you know the exact checkpoint name, lock it here: | |
| CKPT_PATH = Path("ast_cifar10_resnet18.pth") | |
| # ----------------------------- | |
| # Model helpers | |
| # ----------------------------- | |
| def build_model(num_classes: int = 10) -> torch.nn.Module: | |
| m = resnet18(weights=None) | |
| m.fc = torch.nn.Linear(m.fc.in_features, num_classes) | |
| return m | |
| def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None: | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| if isinstance(ckpt, dict): | |
| if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): | |
| state = ckpt["state_dict"] | |
| elif "model" in ckpt and isinstance(ckpt["model"], dict): | |
| state = ckpt["model"] | |
| else: | |
| state = ckpt | |
| else: | |
| raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}") | |
| # Remove "module." if saved from DDP | |
| cleaned = {k.replace("module.", ""): v for k, v in state.items()} | |
| missing, unexpected = model.load_state_dict(cleaned, strict=False) | |
| if missing or unexpected: | |
| print("[load_weights] Missing keys:", missing) | |
| print("[load_weights] Unexpected keys:", unexpected) | |
| # ----------------------------- | |
| # Preprocess | |
| # ----------------------------- | |
| preprocess = T.Compose([ | |
| T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR), | |
| T.ToTensor(), | |
| T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD), | |
| ]) | |
| STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None} | |
| def init(): | |
| if not CKPT_PATH.exists(): | |
| print(f"[init] Checkpoint not found: {CKPT_PATH}") | |
| STATE["model"] = None | |
| return | |
| print(f"[init] Loading checkpoint: {CKPT_PATH}") | |
| model = build_model(num_classes=len(CIFAR10_CLASSES)) | |
| load_weights(model, CKPT_PATH) | |
| model.to(DEVICE).eval() | |
| STATE["model"] = model | |
| def get_examples() -> List[List[str]]: | |
| if not EXAMPLES_DIR.exists(): | |
| return [] | |
| imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]]) | |
| return [[str(p)] for p in imgs] | |
| # ----------------------------- | |
| # Predict | |
| # ----------------------------- | |
| def predict(img: Image.Image): | |
| if img is None: | |
| return None, {}, [["", ""], ["", ""], ["", ""]], "" | |
| if STATE["model"] is None: | |
| raise gr.Error("Model is not loaded. Ensure ast_cifar10_resnet18.pth exists in the repo root.") | |
| # show the actual 32x32 that goes into model | |
| img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR) | |
| x = preprocess(img.convert("RGB")).unsqueeze(0).to(DEVICE) # [1,3,32,32] | |
| with torch.inference_mode(): | |
| logits = STATE["model"](x) | |
| probs = F.softmax(logits, dim=1).squeeze(0) # [10] | |
| # label dict for gr.Label | |
| label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)} | |
| # top-3 table | |
| topk = torch.topk(probs, k=3) | |
| top3_rows = [] | |
| for j, idx in enumerate(topk.indices.tolist()): | |
| top3_rows.append([CIFAR10_CLASSES[idx], f"{float(topk.values[j]) * 100:.2f}%"]) | |
| pred_name = CIFAR10_CLASSES[int(topk.indices[0])] | |
| pred_conf = float(topk.values[0]) * 100.0 | |
| pred_text = f"**{pred_name}** ({pred_conf:.2f}%)" | |
| return img32, label_dict, top3_rows, pred_text | |
| def clear_all(): | |
| return None, None, {}, [["", ""], ["", ""], ["", ""]], "" | |
| # ----------------------------- | |
| # App | |
| # ----------------------------- | |
| init() | |
| EXAMPLES = get_examples() | |
| with gr.Blocks(title="AST CIFAR-10 Classifier") as demo: | |
| gr.Markdown( | |
| "# AST CIFAR-10 Classifier\n" | |
| "ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.\n\n" | |
| f"**Device:** `{DEVICE}`" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| img_in = gr.Image(type="pil", label="Upload CIFAR-like image") | |
| img_32 = gr.Image(type="pil", label="Model input (32×32)") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Top-3 Predictions") | |
| pred_label = gr.Label(num_top_classes=3, label="Probabilities") | |
| top3_table = gr.Dataframe( | |
| headers=["class", "confidence"], | |
| datatype=["str", "str"], | |
| row_count=3, | |
| column_count=2, # <-- fixed (no deprecated col_count) | |
| interactive=False, | |
| label="Top-3" | |
| ) | |
| pred_text = gr.Markdown() | |
| with gr.Row(): | |
| submit = gr.Button("Submit", variant="primary") | |
| clear = gr.Button("Clear") | |
| # ✅ FIX: if cache_examples=True, you MUST provide fn and outputs | |
| if EXAMPLES: | |
| gr.Markdown("### Examples (from `Examples/` folder)") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[img_in], | |
| outputs=[img_32, pred_label, top3_table, pred_text], | |
| fn=predict, | |
| cache_examples=True | |
| ) | |
| submit.click( | |
| fn=predict, | |
| inputs=[img_in], | |
| outputs=[img_32, pred_label, top3_table, pred_text] | |
| ) | |
| clear.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[img_in, img_32, pred_label, top3_table, pred_text] | |
| ) | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch() | |