from pathlib import Path from typing import Optional, Dict, List import gradio as gr import torch import torch.nn.functional as F from PIL import Image import torchvision.transforms as T from torchvision.models import resnet18 # ----------------------------- # Config # ----------------------------- CIFAR10_CLASSES = [ "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) CIFAR10_STD = (0.2470, 0.2435, 0.2616) EXAMPLES_DIR = Path("Examples") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # If you know the exact checkpoint name, lock it here: CKPT_PATH = Path("ast_cifar10_resnet18.pth") # ----------------------------- # Model helpers # ----------------------------- def build_model(num_classes: int = 10) -> torch.nn.Module: m = resnet18(weights=None) m.fc = torch.nn.Linear(m.fc.in_features, num_classes) return m def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None: ckpt = torch.load(ckpt_path, map_location="cpu") if isinstance(ckpt, dict): if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): state = ckpt["state_dict"] elif "model" in ckpt and isinstance(ckpt["model"], dict): state = ckpt["model"] else: state = ckpt else: raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}") # Remove "module." if saved from DDP cleaned = {k.replace("module.", ""): v for k, v in state.items()} missing, unexpected = model.load_state_dict(cleaned, strict=False) if missing or unexpected: print("[load_weights] Missing keys:", missing) print("[load_weights] Unexpected keys:", unexpected) # ----------------------------- # Preprocess # ----------------------------- preprocess = T.Compose([ T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR), T.ToTensor(), T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD), ]) STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None} def init(): if not CKPT_PATH.exists(): print(f"[init] Checkpoint not found: {CKPT_PATH}") STATE["model"] = None return print(f"[init] Loading checkpoint: {CKPT_PATH}") model = build_model(num_classes=len(CIFAR10_CLASSES)) load_weights(model, CKPT_PATH) model.to(DEVICE).eval() STATE["model"] = model def get_examples() -> List[List[str]]: if not EXAMPLES_DIR.exists(): return [] imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]]) return [[str(p)] for p in imgs] # ----------------------------- # Predict # ----------------------------- def predict(img: Image.Image): if img is None: return None, {}, [["", ""], ["", ""], ["", ""]], "" if STATE["model"] is None: raise gr.Error("Model is not loaded. Ensure ast_cifar10_resnet18.pth exists in the repo root.") # show the actual 32x32 that goes into model img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR) x = preprocess(img.convert("RGB")).unsqueeze(0).to(DEVICE) # [1,3,32,32] with torch.inference_mode(): logits = STATE["model"](x) probs = F.softmax(logits, dim=1).squeeze(0) # [10] # label dict for gr.Label label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)} # top-3 table topk = torch.topk(probs, k=3) top3_rows = [] for j, idx in enumerate(topk.indices.tolist()): top3_rows.append([CIFAR10_CLASSES[idx], f"{float(topk.values[j]) * 100:.2f}%"]) pred_name = CIFAR10_CLASSES[int(topk.indices[0])] pred_conf = float(topk.values[0]) * 100.0 pred_text = f"**{pred_name}** ({pred_conf:.2f}%)" return img32, label_dict, top3_rows, pred_text def clear_all(): return None, None, {}, [["", ""], ["", ""], ["", ""]], "" # ----------------------------- # App # ----------------------------- init() EXAMPLES = get_examples() with gr.Blocks(title="AST CIFAR-10 Classifier") as demo: gr.Markdown( "# AST CIFAR-10 Classifier\n" "ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.\n\n" f"**Device:** `{DEVICE}`" ) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="Upload CIFAR-like image") img_32 = gr.Image(type="pil", label="Model input (32×32)") with gr.Column(scale=1): gr.Markdown("### Top-3 Predictions") pred_label = gr.Label(num_top_classes=3, label="Probabilities") top3_table = gr.Dataframe( headers=["class", "confidence"], datatype=["str", "str"], row_count=3, column_count=2, # <-- fixed (no deprecated col_count) interactive=False, label="Top-3" ) pred_text = gr.Markdown() with gr.Row(): submit = gr.Button("Submit", variant="primary") clear = gr.Button("Clear") # ✅ FIX: if cache_examples=True, you MUST provide fn and outputs if EXAMPLES: gr.Markdown("### Examples (from `Examples/` folder)") gr.Examples( examples=EXAMPLES, inputs=[img_in], outputs=[img_32, pred_label, top3_table, pred_text], fn=predict, cache_examples=True ) submit.click( fn=predict, inputs=[img_in], outputs=[img_32, pred_label, top3_table, pred_text] ) clear.click( fn=clear_all, inputs=[], outputs=[img_in, img_32, pred_label, top3_table, pred_text] ) demo.queue() if __name__ == "__main__": demo.launch()