mgbam's picture
Update app.py
e81731d verified
raw
history blame
5.75 kB
from pathlib import Path
from typing import Optional, Dict, List
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
import torchvision.transforms as T
from torchvision.models import resnet18
# -----------------------------
# Config
# -----------------------------
CIFAR10_CLASSES = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
EXAMPLES_DIR = Path("Examples")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# If you know the exact checkpoint name, lock it here:
CKPT_PATH = Path("ast_cifar10_resnet18.pth")
# -----------------------------
# Model helpers
# -----------------------------
def build_model(num_classes: int = 10) -> torch.nn.Module:
m = resnet18(weights=None)
m.fc = torch.nn.Linear(m.fc.in_features, num_classes)
return m
def load_weights(model: torch.nn.Module, ckpt_path: Path) -> None:
ckpt = torch.load(ckpt_path, map_location="cpu")
if isinstance(ckpt, dict):
if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
state = ckpt["state_dict"]
elif "model" in ckpt and isinstance(ckpt["model"], dict):
state = ckpt["model"]
else:
state = ckpt
else:
raise ValueError(f"Unsupported checkpoint format: {type(ckpt)}")
# Remove "module." if saved from DDP
cleaned = {k.replace("module.", ""): v for k, v in state.items()}
missing, unexpected = model.load_state_dict(cleaned, strict=False)
if missing or unexpected:
print("[load_weights] Missing keys:", missing)
print("[load_weights] Unexpected keys:", unexpected)
# -----------------------------
# Preprocess
# -----------------------------
preprocess = T.Compose([
T.Resize((32, 32), interpolation=T.InterpolationMode.BILINEAR),
T.ToTensor(),
T.Normalize(mean=CIFAR10_MEAN, std=CIFAR10_STD),
])
STATE: Dict[str, Optional[torch.nn.Module]] = {"model": None}
def init():
if not CKPT_PATH.exists():
print(f"[init] Checkpoint not found: {CKPT_PATH}")
STATE["model"] = None
return
print(f"[init] Loading checkpoint: {CKPT_PATH}")
model = build_model(num_classes=len(CIFAR10_CLASSES))
load_weights(model, CKPT_PATH)
model.to(DEVICE).eval()
STATE["model"] = model
def get_examples() -> List[List[str]]:
if not EXAMPLES_DIR.exists():
return []
imgs = sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in [".png", ".jpg", ".jpeg"]])
return [[str(p)] for p in imgs]
# -----------------------------
# Predict
# -----------------------------
def predict(img: Image.Image):
if img is None:
return None, {}, [["", ""], ["", ""], ["", ""]], ""
if STATE["model"] is None:
raise gr.Error("Model is not loaded. Ensure ast_cifar10_resnet18.pth exists in the repo root.")
# show the actual 32x32 that goes into model
img32 = img.convert("RGB").resize((32, 32), resample=Image.BILINEAR)
x = preprocess(img.convert("RGB")).unsqueeze(0).to(DEVICE) # [1,3,32,32]
with torch.inference_mode():
logits = STATE["model"](x)
probs = F.softmax(logits, dim=1).squeeze(0) # [10]
# label dict for gr.Label
label_dict = {cls: float(probs[i]) for i, cls in enumerate(CIFAR10_CLASSES)}
# top-3 table
topk = torch.topk(probs, k=3)
top3_rows = []
for j, idx in enumerate(topk.indices.tolist()):
top3_rows.append([CIFAR10_CLASSES[idx], f"{float(topk.values[j]) * 100:.2f}%"])
pred_name = CIFAR10_CLASSES[int(topk.indices[0])]
pred_conf = float(topk.values[0]) * 100.0
pred_text = f"**{pred_name}** ({pred_conf:.2f}%)"
return img32, label_dict, top3_rows, pred_text
def clear_all():
return None, None, {}, [["", ""], ["", ""], ["", ""]], ""
# -----------------------------
# App
# -----------------------------
init()
EXAMPLES = get_examples()
with gr.Blocks(title="AST CIFAR-10 Classifier") as demo:
gr.Markdown(
"# AST CIFAR-10 Classifier\n"
"ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.\n\n"
f"**Device:** `{DEVICE}`"
)
with gr.Row():
with gr.Column(scale=1):
img_in = gr.Image(type="pil", label="Upload CIFAR-like image")
img_32 = gr.Image(type="pil", label="Model input (32×32)")
with gr.Column(scale=1):
gr.Markdown("### Top-3 Predictions")
pred_label = gr.Label(num_top_classes=3, label="Probabilities")
top3_table = gr.Dataframe(
headers=["class", "confidence"],
datatype=["str", "str"],
row_count=3,
column_count=2, # <-- fixed (no deprecated col_count)
interactive=False,
label="Top-3"
)
pred_text = gr.Markdown()
with gr.Row():
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear")
# ✅ FIX: if cache_examples=True, you MUST provide fn and outputs
if EXAMPLES:
gr.Markdown("### Examples (from `Examples/` folder)")
gr.Examples(
examples=EXAMPLES,
inputs=[img_in],
outputs=[img_32, pred_label, top3_table, pred_text],
fn=predict,
cache_examples=True
)
submit.click(
fn=predict,
inputs=[img_in],
outputs=[img_32, pred_label, top3_table, pred_text]
)
clear.click(
fn=clear_all,
inputs=[],
outputs=[img_in, img_32, pred_label, top3_table, pred_text]
)
demo.queue()
if __name__ == "__main__":
demo.launch()