|
|
import gradio as gr |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from PIL import Image |
|
|
from model import ResNet18 |
|
|
|
|
|
|
|
|
cifar100_classes = [ |
|
|
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', |
|
|
'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', |
|
|
'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', |
|
|
'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', |
|
|
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', |
|
|
'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', |
|
|
'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', |
|
|
'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', |
|
|
'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', |
|
|
'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', |
|
|
'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', |
|
|
'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', |
|
|
'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', |
|
|
'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', |
|
|
'worm' |
|
|
] |
|
|
|
|
|
|
|
|
device = torch.device('cpu') |
|
|
net = ResNet18().to(device) |
|
|
|
|
|
|
|
|
def _extract_state_dict(state: dict) -> dict: |
|
|
if isinstance(state, dict) and 'state_dict' in state and isinstance(state['state_dict'], dict): |
|
|
return state['state_dict'] |
|
|
if isinstance(state, dict) and 'net' in state and isinstance(state['net'], dict): |
|
|
return state['net'] |
|
|
if isinstance(state, dict): |
|
|
return state |
|
|
return {} |
|
|
|
|
|
def _strip_module_prefix(sd: dict) -> dict: |
|
|
if any(k.startswith('module.') for k in sd.keys()): |
|
|
return {k.replace('module.', '', 1): v for k, v in sd.items()} |
|
|
return sd |
|
|
|
|
|
ckpt_paths = ['ckpt.pth', 'checkpoint/resnet18_cifar100.pth', 'resnet18_cifar100.pth'] |
|
|
loaded = False |
|
|
for path in ckpt_paths: |
|
|
try: |
|
|
raw = torch.load(path, map_location=device) |
|
|
sd = _extract_state_dict(raw) |
|
|
sd = _strip_module_prefix(sd) |
|
|
missing, unexpected = net.load_state_dict(sd, strict=True) |
|
|
loaded = True |
|
|
print(f"Loaded weights from {path}") |
|
|
break |
|
|
except Exception as e: |
|
|
print(f"Failed to load {path}: {e}") |
|
|
continue |
|
|
net.eval() |
|
|
|
|
|
|
|
|
_norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) |
|
|
_to_tensor_norm = transforms.Compose([transforms.ToTensor(), _norm]) |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(40), |
|
|
transforms.TenCrop(32), |
|
|
transforms.Lambda(lambda crops: torch.stack([_to_tensor_norm(c) for c in crops])) |
|
|
]) |
|
|
|
|
|
def predict(image): |
|
|
batch = transform(image) |
|
|
with torch.no_grad(): |
|
|
outputs = net(batch) |
|
|
probs = torch.softmax(outputs, dim=1).mean(0) |
|
|
topk_probs, topk_idx = torch.topk(probs, k=5) |
|
|
results = [] |
|
|
for p, idx in zip(topk_probs.tolist(), topk_idx.tolist()): |
|
|
results.append({"label": cifar100_classes[idx], "score": round(p, 4)}) |
|
|
return results[0]['label'], results |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Image(type="pil"), |
|
|
outputs=[gr.Textbox(label="Prediction"), gr.JSON(label="Top-5")], |
|
|
title="CIFAR-100 Image Classification", |
|
|
description="Upload an image (any size). The model resizes to 32×32 and predicts top-5 classes." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|