|
|
import os |
|
|
from typing import List, Dict, Any, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
import torchvision.models as tvm |
|
|
from torchvision.transforms import functional as F |
|
|
from torchvision import transforms as T |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
CHECKPOINT_PATH = os.environ.get("CKPT_PATH", "best.pth") |
|
|
|
|
|
|
|
|
def get_device() -> torch.device: |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
return torch.device("cpu") |
|
|
|
|
|
|
|
|
def build_model(num_classes: int = 1000) -> nn.Module: |
|
|
model = tvm.resnet50(weights=None) |
|
|
model.fc = nn.Linear(model.fc.in_features, num_classes) |
|
|
return model |
|
|
|
|
|
|
|
|
def get_preprocess_and_labels(): |
|
|
|
|
|
try: |
|
|
weights = tvm.ResNet50_Weights.IMAGENET1K_V2 |
|
|
except Exception: |
|
|
|
|
|
weights = None |
|
|
if weights is not None: |
|
|
preprocess = weights.transforms() |
|
|
labels = weights.meta.get("categories", [str(i) for i in range(1000)]) |
|
|
else: |
|
|
preprocess = T.Compose( |
|
|
[ |
|
|
T.Resize(256, interpolation=T.InterpolationMode.BILINEAR), |
|
|
T.CenterCrop(224), |
|
|
T.ToTensor(), |
|
|
T.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225], |
|
|
), |
|
|
] |
|
|
) |
|
|
labels = [str(i) for i in range(1000)] |
|
|
return preprocess, labels |
|
|
|
|
|
|
|
|
def load_checkpoint_into_model(model: nn.Module, checkpoint_path: str) -> None: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError( |
|
|
f"Checkpoint not found at '{checkpoint_path}'. " |
|
|
f"Place your file at runs/exp1/best.pth or set CKPT_PATH env var." |
|
|
) |
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
|
|
|
state_dict = checkpoint.get("model", checkpoint) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
device = get_device() |
|
|
model = build_model(num_classes=1000).to(device) |
|
|
preprocess, imagenet_labels = get_preprocess_and_labels() |
|
|
load_checkpoint_into_model(model, CHECKPOINT_PATH) |
|
|
|
|
|
|
|
|
def predict_images( |
|
|
images: Union[Image.Image, List[Image.Image]], |
|
|
top_k: int = 5, |
|
|
) -> List[List[Dict[str, Any]]]: |
|
|
if images is None: |
|
|
return [] |
|
|
if not isinstance(images, list): |
|
|
images = [images] |
|
|
|
|
|
results: List[List[Dict[str, Any]]] = [] |
|
|
with torch.no_grad(): |
|
|
for image in images: |
|
|
if not isinstance(image, Image.Image): |
|
|
|
|
|
image = Image.fromarray(image) |
|
|
tensor = preprocess(image).unsqueeze(0).to(device) |
|
|
logits = model(tensor) |
|
|
probs = torch.softmax(logits, dim=1)[0] |
|
|
topk = torch.topk(probs, k=top_k) |
|
|
sample_result: List[Dict[str, Any]] = [] |
|
|
for score, idx in zip(topk.values.tolist(), topk.indices.tolist()): |
|
|
label = imagenet_labels[idx] if 0 <= idx < len(imagenet_labels) else str(idx) |
|
|
sample_result.append({"label": label, "probability": float(score)}) |
|
|
results.append(sample_result) |
|
|
return results |
|
|
|
|
|
|
|
|
with gr.Blocks(title="ResNet-50 ImageNet-1k Classifier") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
**ResNet-50 ImageNet-1k Classifier** |
|
|
|
|
|
- Upload one or more images and get top-5 predictions. |
|
|
- Model weights loaded from `runs/exp1/best.pth`. |
|
|
""" |
|
|
) |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_images = gr.Image( |
|
|
label="Upload images", |
|
|
type="pil", |
|
|
sources=["upload", "clipboard"], |
|
|
) |
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"goldfish.png", |
|
|
"tiger-shark.png", |
|
|
"toilet-tissue.png", |
|
|
], |
|
|
inputs=input_images, |
|
|
label="Example images", |
|
|
) |
|
|
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K") |
|
|
run_btn = gr.Button("Predict") |
|
|
with gr.Column(): |
|
|
output = gr.JSON(label="Predictions (per-image top-K)") |
|
|
|
|
|
run_btn.click(fn=predict_images, inputs=[input_images, topk], outputs=output) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) |
|
|
|
|
|
|
|
|
|