import os from pathlib import Path from typing import Optional, Tuple, Dict, List import gradio as gr import torch import torch.nn.functional as F from PIL import Image import torchvision.transforms as T from torchvision.models import resnet18 # ----------------------------- # Config # ----------------------------- CIFAR10_CLASSES = [ "airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck" ] # CIFAR-10 normalization (standard) CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) CIFAR10_STD = (0.2470, 0.2435, 0.2616) EXAMPLES_DIR = Path("Examples") # you uploaded to this folder DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ----------------------------- # Utilities # ----------------------------- def find_checkpoint(repo_root: Path) -> Optional[Path]: """ Auto-find a checkpoint file in the repo. Add your own filename here if you know it. """ candidates = [ "model.pth", "model.pt", "checkpoint.pth", "checkpoint.pt", "best.pth", "best.pt", "resnet18.pth", "resnet18.pt", "weights.pth", "weights.pt", ] for name in candidates: p = repo_root / name if p.exists() and p.is_file(): return p # Try pattern search patterns = ["*.pth", "*.pt"] for pat in patterns: hits = sorted(repo_root.glob(pat)) # Prefer anything that looks like resnet/cifar/ast preferred = [h for h in hits if any(k in h.name.lower() for k in ["resnet", "cifar", "ast", "sparse", "best"])] if preferred: return preferred[0] if hits: return hits[0] return None def build_model(num_classes: int = 10) -> torch.nn.Module: m = resnet18(weights=None) m.fc = torch.nn.Linear(m.fc.in_features, num_classes) return m def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None: """ Loads common checkpoint formats: - plain state_dict - dict with 'state_dict' or 'model' keys """ ckpt = torch.load(ckpt_path, map_location="cpu") if isinstance(ckpt, dict): if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict): state = ckpt["state_dict"] elif "model" in ckpt and isinstance(ckpt["model"], dict): state = ckpt["model"] else: # might already be a state_dict-like dict state = ckpt else: raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}") # Strip possible 'module.' prefix if trained with DDP/DataParallel new_state = {} for k, v in state.items(): nk = k.replace("module.", "") new_state[nk] = v missing, unexpected = model.load_state_dict(new_state, strict=False) # Strict=False to be robust; you can change to strict=True if you prefer. if missing or unexpected: print("[load_weights] Missing keys:", missing) print("[load_weights] Unexpected keys:", unexpected) # ----------------------------- # Preprocess + Predict # ----------------------------- preprocess = T.Compose([ T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR), T.ToTensor(), T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD), ]) def pil_to_model_tensor(img: Image.Image) -> torch.Tensor: img = img.convert("RGB") x = preprocess(img).unsqueeze(0) # [1,3,32,32] return x def predict(img: Image.Image): if img is None: return None, None, None if STATE["model"] is None: raise gr.Error("Model is not loaded. Check that your checkpoint exists in the Space repo.") # Show exactly what goes into the model (32x32) img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR) x = pil_to_model_tensor(img).to(DEVICE) with torch.inference_mode(): logits = STATE["model"](x) probs = F.softmax(logits, dim=1).squeeze(0) # [10] # Top-3 topk = torch.topk(probs, k=3) top3 = [(CIFAR10_CLASSES[i], float(topk.values[j])) for j, i in enumerate(topk.indices.tolist())] # Gradio Label expects dict label->confidence label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)} # Table for top-3 top3_table = [[name, f"{p*100:.2f}%"] for name, p in top3] # Main prediction text pred_name, pred_p = top3[0] pred_text = f"**{pred_name}** ({pred_p*100:.2f}%)" return img32, label_dict, top3_table, pred_text # ----------------------------- # App state # ----------------------------- STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None} def init(): repo_root = Path(".") ckpt = find_checkpoint(repo_root) if ckpt is None: print("[init] No checkpoint found in repo root.") STATE["model"] = None return print(f"[init] Loading checkpoint: {ckpt}") model = build_model(num_classes=len(CIFAR10_CLASSES)) load_weights(model, ckpt) model.to(DEVICE).eval() STATE["model"] = model def get_examples() -> List[List[str]]: if not EXAMPLES_DIR.exists(): return [] imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]]) # Gradio expects list of lists, each inner list corresponds to inputs return [[str(p)] for p in imgs] init() EXAMPLES = get_examples() # ----------------------------- # UI # ----------------------------- with gr.Blocks(title="AST CIFAR-10 Classifier") as demo: gr.Markdown( "# AST CIFAR-10 Classifier\n" "ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.\n\n" f"**Device:** `{DEVICE}`" ) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="Upload CIFAR-like image") # Show the exact 32×32 fed to model (useful for debugging) img_32 = gr.Image(type="pil", label="Model input (32×32)") with gr.Column(scale=1): gr.Markdown("### Top-3 Predictions") pred_label = gr.Label(num_top_classes=3, label="Probabilities") top3_table = gr.Dataframe( headers=["class", "confidence"], datatype=["str", "str"], row_count=3, col_count=(2, "fixed"), interactive=False, label="Top-3" ) pred_text = gr.Markdown() with gr.Row(): submit = gr.Button("Submit", variant="primary") clear = gr.Button("Clear") if EXAMPLES: gr.Markdown("### Examples (from `Examples/` folder)") gr.Examples( examples=EXAMPLES, inputs=[img_in], cache_examples=True ) submit.click( fn=predict, inputs=[img_in], outputs=[img_32, pred_label, top3_table, pred_text] ) def _clear(): return None, None, None, "" clear.click(fn=_clear, inputs=[], outputs=[img_in, img_32, top3_table, pred_text]) demo.queue() if __name__ == "__main__": demo.launch()