File size: 4,489 Bytes
793b943
 
 
 
 
 
 
 
 
 
 
 
ade401c
793b943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
07376d6
ade401c
 
793b943
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
from typing import List, Dict, Any, Union

import torch
from torch import nn
import torchvision.models as tvm
from torchvision.transforms import functional as F
from torchvision import transforms as T
from PIL import Image
import gradio as gr


CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "best.pth")


def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def build_model(num_classes: int = 1000) -> nn.Module:
    model = tvm.resnet50(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model


def get_preprocess_and_labels():
    # Use torchvision's ImageNet-1k metadata for categories and canonical transforms
    try:
        weights = tvm.ResNet50_Weights.IMAGENET1K_V2
    except Exception:
        # Fallback if weights enum not available
        weights = None
    if weights is not None:
        preprocess = weights.transforms()
        labels = weights.meta.get("categories", [str(i) for i in range(1000)])
    else:
        preprocess = T.Compose(
            [
                T.Resize(256, interpolation=T.InterpolationMode.BILINEAR),
                T.CenterCrop(224),
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )
        labels = [str(i) for i in range(1000)]
    return preprocess, labels


def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None:
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(
            f"Checkpoint not found at '{checkpoint_path}'. "
            f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var."
        )
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    # Support either a full training checkpoint dict or a raw state_dict
    state_dict = checkpoint.get("model", checkpoint)
    model.load_state_dict(state_dict, strict=False)
    model.eval()


device = get_device()
model = build_model(num_classes=1000).to(device)
preprocess, imagenet_labels = get_preprocess_and_labels()
load_checkpoint_into_model(model, CHECKPOINT_PATH)


def predict_images(
    images: Union[Image.Image, List[Image.Image]],
    top_k: int = 5,
) -> List[List[Dict[str, Any]]]:
    if images is None:
        return []
    if not isinstance(images, list):
        images = [images]

    results: List[List[Dict[str, Any]]] = []
    with torch.no_grad():
        for image in images:
            if not isinstance(image, Image.Image):
                # Some gradio versions may return dicts; handle defensively
                image = Image.fromarray(image)
            tensor = preprocess(image).unsqueeze(0).to(device)
            logits = model(tensor)
            probs = torch.softmax(logits, dim=1)[0]
            topk = torch.topk(probs, k=top_k)
            sample_result: List[Dict[str, Any]] = []
            for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
                label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx)
                sample_result.append({"label": label, "probability": float(score)})
            results.append(sample_result)
    return results


with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier") as demo:
    gr.Markdown(
        """
        **ResNet-50 ImageNet-1k Classifier**

        - Upload one or more images and get top-5 predictions.
        - Model weights loaded from `runs/exp1/best.pth`.
        """
    )
    with gr.Row():
        with gr.Column():
            input_images = gr.Image(
                label="Upload images",
                type="pil",
                sources=["upload", "clipboard"],
            )
            gr.Examples(
                examples=[
                    "goldfish.png",
                    "tiger-shark.png",
                    "toilet-tissue.png",
                ],
                inputs=input_images,
                label="Example images",
            )
            topk = gr.Slider(1, 10, value=5, step=1, label="Top-K")
            run_btn = gr.Button("Predict")
        with gr.Column():
            output = gr.JSON(label="Predictions (per-image top-K)")

    run_btn.click(fn=predict_images, inputs=[input_images, topk], outputs=output)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))