godsofheaven's picture
Update app.py
07376d6 verified
import os
from typing import List, Dict, Any, Union
import torch
from torch import nn
import torchvision.models as tvm
from torchvision.transforms import functional as F
from torchvision import transforms as T
from PIL import Image
import gradio as gr
CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "best.pth")
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def build_model(num_classes: int = 1000) -> nn.Module:
model = tvm.resnet50(weights=None)
model.fc = nn.Linear(model.fc.in_features, num_classes)
return model
def get_preprocess_and_labels():
# Use torchvision's ImageNet-1k metadata for categories and canonical transforms
try:
weights = tvm.ResNet50_Weights.IMAGENET1K_V2
except Exception:
# Fallback if weights enum not available
weights = None
if weights is not None:
preprocess = weights.transforms()
labels = weights.meta.get("categories", [str(i) for i in range(1000)])
else:
preprocess = T.Compose(
[
T.Resize(256, interpolation=T.InterpolationMode.BILINEAR),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
)
labels = [str(i) for i in range(1000)]
return preprocess, labels
def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Checkpoint not found at '{checkpoint_path}'. "
f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var."
)
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Support either a full training checkpoint dict or a raw state_dict
state_dict = checkpoint.get("model", checkpoint)
model.load_state_dict(state_dict, strict=False)
model.eval()
device = get_device()
model = build_model(num_classes=1000).to(device)
preprocess, imagenet_labels = get_preprocess_and_labels()
load_checkpoint_into_model(model, CHECKPOINT_PATH)
def predict_images(
images: Union[Image.Image, List[Image.Image]],
top_k: int = 5,
) -> List[List[Dict[str, Any]]]:
if images is None:
return []
if not isinstance(images, list):
images = [images]
results: List[List[Dict[str, Any]]] = []
with torch.no_grad():
for image in images:
if not isinstance(image, Image.Image):
# Some gradio versions may return dicts; handle defensively
image = Image.fromarray(image)
tensor = preprocess(image).unsqueeze(0).to(device)
logits = model(tensor)
probs = torch.softmax(logits, dim=1)[0]
topk = torch.topk(probs, k=top_k)
sample_result: List[Dict[str, Any]] = []
for score, idx in zip(topk.values.tolist(), topk.indices.tolist()):
label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx)
sample_result.append({"label": label, "probability": float(score)})
results.append(sample_result)
return results
with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier") as demo:
gr.Markdown(
"""
**ResNet-50 ImageNet-1k Classifier**
- Upload one or more images and get top-5 predictions.
- Model weights loaded from `runs/exp1/best.pth`.
"""
)
with gr.Row():
with gr.Column():
input_images = gr.Image(
label="Upload images",
type="pil",
sources=["upload", "clipboard"],
)
gr.Examples(
examples=[
"goldfish.png",
"tiger-shark.png",
"toilet-tissue.png",
],
inputs=input_images,
label="Example images",
)
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K")
run_btn = gr.Button("Predict")
with gr.Column():
output = gr.JSON(label="Predictions (per-image top-K)")
run_btn.click(fn=predict_images, inputs=[input_images, topk], outputs=output)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))