TSAI-S8 / app.py
shwethd's picture
Update app.py
74a5188 verified
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from model import ResNet18
# CIFAR-100 class names
cifar100_classes = [
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle',
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel',
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock',
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster',
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion',
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse',
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear',
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine',
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose',
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake',
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table',
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout',
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman',
'worm'
]
# Load model (CPU for HF Spaces)
device = torch.device('cpu')
net = ResNet18().to(device)
# Load checkpoint if present (strict, prefix-robust)
def _extract_state_dict(state: dict) -> dict:
if isinstance(state, dict) and 'state_dict' in state and isinstance(state['state_dict'], dict):
return state['state_dict']
if isinstance(state, dict) and 'net' in state and isinstance(state['net'], dict):
return state['net']
if isinstance(state, dict):
return state
return {}
def _strip_module_prefix(sd: dict) -> dict:
if any(k.startswith('module.') for k in sd.keys()):
return {k.replace('module.', '', 1): v for k, v in sd.items()}
return sd
ckpt_paths = ['ckpt.pth', 'checkpoint/resnet18_cifar100.pth', 'resnet18_cifar100.pth']
loaded = False
for path in ckpt_paths:
try:
raw = torch.load(path, map_location=device)
sd = _extract_state_dict(raw)
sd = _strip_module_prefix(sd)
missing, unexpected = net.load_state_dict(sd, strict=True)
loaded = True
print(f"Loaded weights from {path}")
break
except Exception as e:
print(f"Failed to load {path}: {e}")
continue
net.eval()
# Aspect-ratio preserving resize + TenCrop(32) TTA, then normalize
_norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
_to_tensor_norm = transforms.Compose([transforms.ToTensor(), _norm])
transform = transforms.Compose([
transforms.Resize(40), # keep aspect ratio, shorter side=40
transforms.TenCrop(32), # 5 crops + horizontal flips
transforms.Lambda(lambda crops: torch.stack([_to_tensor_norm(c) for c in crops]))
])
def predict(image):
batch = transform(image) # shape [10, 3, 32, 32]
with torch.no_grad():
outputs = net(batch)
probs = torch.softmax(outputs, dim=1).mean(0) # average TTA
topk_probs, topk_idx = torch.topk(probs, k=5)
results = []
for p, idx in zip(topk_probs.tolist(), topk_idx.tolist()):
results.append({"label": cifar100_classes[idx], "score": round(p, 4)})
return results[0]['label'], results
# Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=[gr.Textbox(label="Prediction"), gr.JSON(label="Top-5")],
title="CIFAR-100 Image Classification",
description="Upload an image (any size). The model resizes to 32×32 and predicts top-5 classes."
)
if __name__ == "__main__":
iface.launch()