File size: 4,489 Bytes
793b943 ade401c 793b943 07376d6 ade401c 793b943 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
from typing import List, Dict, Any, Union
import torch
from torch import nn
import torchvision.models as tvm
from torchvision.transforms import functional as F
from torchvision import transforms as T
from PIL import Image
import gradio as gr
CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "best.pth")
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def build_model(num_classes: int = 1000) -> nn.Module:
model = tvm.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def get_preprocess_and_labels():
# Use torchvision's ImageNet-1k metadata for categories and canonical transforms
try:
weights = tvm.ResNet50_Weights.IMAGENET1K_V2
except Exception:
# Fallback if weights enum not available
weights = None
if weights is not None:
preprocess = weights.transforms()
labels = weights.meta.get("categories", [str(i) for i in range(1000)])
else:
preprocess = T.Compose(
[
T.Resize(256, interpolation=T.InterpolationMode.BILINEAR),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
labels = [str(i) for i in range(1000)]
return preprocess, labels
def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found at '{checkpoint_path}'. "
f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var."
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Support either a full training checkpoint dict or a raw state_dict
state_dict = checkpoint.get("model", checkpoint)
model.load_state_dict(state_dict, strict=False)
model.eval()
device = get_device()
model = build_model(num_classes=1000).to(device)
preprocess, imagenet_labels = get_preprocess_and_labels()
load_checkpoint_into_model(model, CHECKPOINT_PATH)
def predict_images(
images: Union[Image.Image, List[Image.Image]],
top_k: int = 5,
) -> List[List[Dict[str, Any]]]:
if images is None:
return []
if not isinstance(images, list):
images = [images]
results: List[List[Dict[str, Any]]] = []
with torch.no_grad():
for image in images:
if not isinstance(image, Image.Image):
# Some gradio versions may return dicts; handle defensively
image = Image.fromarray(image)
tensor = preprocess(image).unsqueeze(0).to(device)
logits = model(tensor)
probs = torch.softmax(logits, dim=1)[0]
topk = torch.topk(probs, k=top_k)
sample_result: List[Dict[str, Any]] = []
for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx)
sample_result.append({"label": label, "probability": float(score)})
results.append(sample_result)
return results
with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier") as demo:
gr.Markdown(
"""
**ResNet-50 ImageNet-1k Classifier**
- Upload one or more images and get top-5 predictions.
- Model weights loaded from `runs/exp1/best.pth`.
"""
)
with gr.Row():
with gr.Column():
input_images = gr.Image(
label="Upload images",
type="pil",
sources=["upload", "clipboard"],
)
gr.Examples(
examples=[
"goldfish.png",
"tiger-shark.png",
"toilet-tissue.png",
],
inputs=input_images,
label="Example images",
)
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K")
run_btn = gr.Button("Predict")
with gr.Column():
output = gr.JSON(label="Predictions (per-image top-K)")
run_btn.click(fn=predict_images, inputs=[input_images, topk], outputs=output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|