|
|
import torch |
|
|
import torch.nn as nn |
|
|
import yaml |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import os |
|
|
import base64 |
|
|
import io |
|
|
import time |
|
|
import threading |
|
|
from typing import List, Dict, Union, Tuple, Optional |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
|
|
|
class Base64Image(BaseModel): |
|
|
image_data: str |
|
|
|
|
|
class BatchBase64Images(BaseModel): |
|
|
image_data_list: List[str] |
|
|
use_gpu: bool = True |
|
|
|
|
|
CONFIG_PATH: str = os.getenv('CONFIG_PATH', 'staging_config.yaml') |
|
|
CHECKPOINT_FILENAME: str = os.getenv('CHECKPOINT_PATH', 'model.pt') |
|
|
|
|
|
model_lock: threading.Lock = threading.Lock() |
|
|
|
|
|
def get_model(model_name: str, num_classes: int) -> nn.Module: |
|
|
model: Optional[nn.Module] = None |
|
|
if model_name == "efficientnet_b0": |
|
|
model = models.efficientnet_b0(weights=None) |
|
|
num_ftrs: int = model.classifier[1].in_features |
|
|
model.classifier[1] = nn.Linear(num_ftrs, num_classes) |
|
|
else: |
|
|
raise ValueError(f"Model '{model_name}' not supported.") |
|
|
return model |
|
|
|
|
|
def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") |
|
|
checkpoint: dict = torch.load(checkpoint_path, map_location=device) |
|
|
model_name: str = checkpoint['model_name'] |
|
|
class_to_idx: Dict[str, int] = checkpoint['class_to_idx'] |
|
|
model_output_size: int = 1 if len(class_to_idx) == 2 else len(class_to_idx) |
|
|
model: nn.Module = get_model(model_name, num_classes=model_output_size) |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
idx_to_class: Dict[int, str] = {v: k for k, v in class_to_idx.items()} |
|
|
return model, idx_to_class |
|
|
|
|
|
try: |
|
|
with open(CONFIG_PATH, 'r') as f: |
|
|
config: dict = yaml.safe_load(f) |
|
|
except FileNotFoundError: |
|
|
raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'. Make sure it's uploaded to the Space.") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
gpu_device: torch.device = torch.device("cuda") |
|
|
gpu_model: nn.Module |
|
|
IDX_TO_CLASS: Dict[int, str] |
|
|
gpu_model, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, gpu_device) |
|
|
print(f"GPU model loaded successfully on {gpu_device}") |
|
|
else: |
|
|
gpu_device: Optional[torch.device] = None |
|
|
gpu_model: Optional[nn.Module] = None |
|
|
print("No GPU available") |
|
|
|
|
|
cpu_device: torch.device = torch.device("cpu") |
|
|
cpu_model: nn.Module |
|
|
IDX_TO_CLASS: Dict[int, str] |
|
|
cpu_model, IDX_TO_CLASS = load_checkpoint(CHECKPOINT_FILENAME, cpu_device) |
|
|
print(f"CPU model loaded successfully") |
|
|
print(f"Class mapping: {IDX_TO_CLASS}") |
|
|
|
|
|
IMG_SIZE: int = config['data_params']['image_size'] |
|
|
inference_transform: transforms.Compose = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def base64_to_pil(base64_str: str) -> Image.Image: |
|
|
try: |
|
|
if "base64," in base64_str: |
|
|
base64_str = base64_str.split("base64,")[1] |
|
|
image_data: bytes = base64.b64decode(base64_str) |
|
|
return Image.open(io.BytesIO(image_data)) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Invalid base64 string: {e}") |
|
|
|
|
|
def predict_batch(pil_images: List[Image.Image], use_gpu: bool) -> List[Dict[str, Union[dict, float]]]: |
|
|
device: torch.device = gpu_device if (use_gpu and gpu_device) else cpu_device |
|
|
model: nn.Module = gpu_model if (use_gpu and gpu_model) else cpu_model |
|
|
image_tensors: List[torch.Tensor] = [] |
|
|
for img in pil_images: |
|
|
if img.mode != "RGB": |
|
|
img = img.convert("RGB") |
|
|
image_tensors.append(inference_transform(img)) |
|
|
batch_tensor: torch.Tensor = torch.stack(image_tensors).to(device) |
|
|
with model_lock, torch.no_grad(): |
|
|
start_time: float = time.time() |
|
|
output: torch.Tensor = model(batch_tensor) |
|
|
batch_time: float = time.time() - start_time |
|
|
results: List[Dict[str, Union[dict, float]]] = [] |
|
|
probs: Union[List[float], float] = torch.sigmoid(output).squeeze().tolist() |
|
|
class_0_name: str = IDX_TO_CLASS.get(0, "Class 0") |
|
|
class_1_name: str = IDX_TO_CLASS.get(1, "Class 1") |
|
|
if isinstance(probs, float): |
|
|
probs = [probs] |
|
|
for i, prob in enumerate(probs): |
|
|
prediction: Dict[str, float] = { |
|
|
class_0_name: 1 - prob, |
|
|
class_1_name: prob |
|
|
} |
|
|
metadata: Dict[str, Union[str, float]] = { |
|
|
"device": "gpu" if use_gpu and gpu_device else "cpu", |
|
|
"inference_ms": batch_time * 1000 / len(pil_images), |
|
|
"image_size": f"{pil_images[i].width}x{pil_images[i].height}" |
|
|
} |
|
|
results.append({ |
|
|
"prediction": prediction, |
|
|
"metadata": metadata |
|
|
}) |
|
|
return results |
|
|
|
|
|
app: FastAPI = FastAPI( |
|
|
title="Image Classifier API", |
|
|
description="A FastAPI server with a Gradio UI for image classification. Supports batch processing", |
|
|
) |
|
|
|
|
|
app.add_middleware(CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"]) |
|
|
|
|
|
@app.post("/predict", response_model=dict) |
|
|
async def predict_api(request: Base64Image, use_gpu: bool = True) -> dict: |
|
|
try: |
|
|
pil_image: Image.Image = base64_to_pil(request.image_data) |
|
|
result: Dict[str, Union[dict, float]] = predict_batch([pil_image], use_gpu)[0] |
|
|
return result |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
@app.post("/batch_predict", response_model=List[dict]) |
|
|
async def batch_predict_api(request: BatchBase64Images) -> List[dict]: |
|
|
try: |
|
|
pil_images: List[Image.Image] = [] |
|
|
for base64_str in request.image_data_list: |
|
|
pil_images.append(base64_to_pil(base64_str)) |
|
|
results: List[Dict[str, Union[dict, float]]] = predict_batch(pil_images, request.use_gpu) |
|
|
return results |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
def predict_from_pil(pil_image: Image.Image) -> Optional[dict]: |
|
|
if pil_image is None: |
|
|
return None |
|
|
result: Dict[str, Union[dict, float]] = predict_batch([pil_image], use_gpu=True)[0] |
|
|
return result["prediction"] |
|
|
|
|
|
gradio_iface: gr.Interface = gr.Interface( |
|
|
fn=predict_from_pil, |
|
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
|
outputs=gr.Label(num_top_classes=2, label="Predictions"), |
|
|
title="Image Classifier", |
|
|
description="Upload an image to see its classification. The API is available at the /docs endpoint.", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, gradio_iface, path="/") |