import gradio as gr import torch import torchvision.transforms as transforms from PIL import Image from model import ResNet18 # CIFAR-100 class names cifar100_classes = [ 'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm' ] # Load model (CPU for HF Spaces) device = torch.device('cpu') net = ResNet18().to(device) # Load checkpoint if present (strict, prefix-robust) def _extract_state_dict(state: dict) -> dict: if isinstance(state, dict) and 'state_dict' in state and isinstance(state['state_dict'], dict): return state['state_dict'] if isinstance(state, dict) and 'net' in state and isinstance(state['net'], dict): return state['net'] if isinstance(state, dict): return state return {} def _strip_module_prefix(sd: dict) -> dict: if any(k.startswith('module.') for k in sd.keys()): return {k.replace('module.', '', 1): v for k, v in sd.items()} return sd ckpt_paths = ['ckpt.pth', 'checkpoint/resnet18_cifar100.pth', 'resnet18_cifar100.pth'] loaded = False for path in ckpt_paths: try: raw = torch.load(path, map_location=device) sd = _extract_state_dict(raw) sd = _strip_module_prefix(sd) missing, unexpected = net.load_state_dict(sd, strict=True) loaded = True print(f"Loaded weights from {path}") break except Exception as e: print(f"Failed to load {path}: {e}") continue net.eval() # Aspect-ratio preserving resize + TenCrop(32) TTA, then normalize _norm = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) _to_tensor_norm = transforms.Compose([transforms.ToTensor(), _norm]) transform = transforms.Compose([ transforms.Resize(40), # keep aspect ratio, shorter side=40 transforms.TenCrop(32), # 5 crops + horizontal flips transforms.Lambda(lambda crops: torch.stack([_to_tensor_norm(c) for c in crops])) ]) def predict(image): batch = transform(image) # shape [10, 3, 32, 32] with torch.no_grad(): outputs = net(batch) probs = torch.softmax(outputs, dim=1).mean(0) # average TTA topk_probs, topk_idx = torch.topk(probs, k=5) results = [] for p, idx in zip(topk_probs.tolist(), topk_idx.tolist()): results.append({"label": cifar100_classes[idx], "score": round(p, 4)}) return results[0]['label'], results # Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Textbox(label="Prediction"), gr.JSON(label="Top-5")], title="CIFAR-100 Image Classification", description="Upload an image (any size). The model resizes to 32×32 and predicts top-5 classes." ) if __name__ == "__main__": iface.launch()