import os from typing import List, Dict, Any, Union import torch from torch import nn import torchvision.models as tvm from torchvision.transforms import functional as F from torchvision import transforms as T from PIL import Image import gradio as gr CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "best.pth") def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def build_model(num_classes: int = 1000) -> nn.Module: model = tvm.resnet50(weights=None) model.fc = nn.Linear(model.fc.in_features, num_classes) return model def get_preprocess_and_labels(): # Use torchvision's ImageNet-1k metadata for categories and canonical transforms try: weights = tvm.ResNet50_Weights.IMAGENET1K_V2 except Exception: # Fallback if weights enum not available weights = None if weights is not None: preprocess = weights.transforms() labels = weights.meta.get("categories", [str(i) for i in range(1000)]) else: preprocess = T.Compose( [ T.Resize(256, interpolation=T.InterpolationMode.BILINEAR), T.CenterCrop(224), T.ToTensor(), T.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ] ) labels = [str(i) for i in range(1000)] return preprocess, labels def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None: if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Checkpoint not found at '{checkpoint_path}'. " f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var." ) checkpoint = torch.load(checkpoint_path, map_location="cpu") # Support either a full training checkpoint dict or a raw state_dict state_dict = checkpoint.get("model", checkpoint) model.load_state_dict(state_dict, strict=False) model.eval() device = get_device() model = build_model(num_classes=1000).to(device) preprocess, imagenet_labels = get_preprocess_and_labels() load_checkpoint_into_model(model, CHECKPOINT_PATH) def predict_images( images: Union[Image.Image, List[Image.Image]], top_k: int = 5, ) -> List[List[Dict[str, Any]]]: if images is None: return [] if not isinstance(images, list): images = [images] results: List[List[Dict[str, Any]]] = [] with torch.no_grad(): for image in images: if not isinstance(image, Image.Image): # Some gradio versions may return dicts; handle defensively image = Image.fromarray(image) tensor = preprocess(image).unsqueeze(0).to(device) logits = model(tensor) probs = torch.softmax(logits, dim=1)[0] topk = torch.topk(probs, k=top_k) sample_result: List[Dict[str, Any]] = [] for score, idx in zip(topk.values.tolist(), topk.indices.tolist()): label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx) sample_result.append({"label": label, "probability": float(score)}) results.append(sample_result) return results with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier") as demo: gr.Markdown( """ **ResNet-50 ImageNet-1k Classifier** - Upload one or more images and get top-5 predictions. - Model weights loaded from `runs/exp1/best.pth`. """ ) with gr.Row(): with gr.Column(): input_images = gr.Image( label="Upload images", type="pil", sources=["upload", "clipboard"], ) gr.Examples( examples=[ "goldfish.png", "tiger-shark.png", "toilet-tissue.png", ], inputs=input_images, label="Example images", ) topk = gr.Slider(1, 10, value=5, step=1, label="Top-K") run_btn = gr.Button("Predict") with gr.Column(): output = gr.JSON(label="Predictions (per-image top-K)") run_btn.click(fn=predict_images, inputs=[input_images, topk], outputs=output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))