|
|
import os |
|
|
import json |
|
|
os.environ['HF_HOME'] = './hf_cache' |
|
|
os.environ['MPLCONFIGDIR'] = './mpl_cache' |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import yaml |
|
|
from torchvision import models, transforms |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import base64 |
|
|
import io |
|
|
import time |
|
|
import threading |
|
|
from typing import List, Dict, Union, Tuple, Optional |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
from transformers import ConvNextV2Config, ConvNextV2ForImageClassification |
|
|
|
|
|
MODEL_CHECKPOINTS = { |
|
|
"ConvNeXt tiny (Best)": "checkpoints/convnext_v2_tiny_best.pth", |
|
|
"EfficientNet-B0": "checkpoints/effnet_b0_best.pth", |
|
|
"EfficientNet-B3": "checkpoints/effnet_b3_best.pth", |
|
|
"Vision Transformer B-16": "checkpoints/vit_b_16_best.pth" |
|
|
} |
|
|
DEFAULT_MODEL_NAME = "ConvNeXt tiny (Best)" |
|
|
|
|
|
CONVNEXT_CONFIG_PATH = "convnext_config.json" |
|
|
|
|
|
GPU_MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {} |
|
|
CPU_MODELS: Dict[str, Tuple[nn.Module, Dict[int, str]]] = {} |
|
|
CONFIG_PATH: str = os.getenv('CONFIG_PATH', 'staging_config.yaml') |
|
|
model_lock: threading.Lock = threading.Lock() |
|
|
|
|
|
def get_model(model_name: str, num_classes: int) -> nn.Module: |
|
|
model: Optional[nn.Module] = None |
|
|
if model_name == "efficientnet_b0": |
|
|
model = models.efficientnet_b0(weights=None) |
|
|
num_ftrs = model.classifier[1].in_features |
|
|
model.classifier[1] = nn.Linear(num_ftrs, num_classes) |
|
|
elif model_name == "efficientnet_b3": |
|
|
model = models.efficientnet_b3(weights=None) |
|
|
num_ftrs = model.classifier[1].in_features |
|
|
model.classifier[1] = nn.Linear(num_ftrs, num_classes) |
|
|
elif model_name == "vit_b_16": |
|
|
model = models.vit_b_16(weights=None) |
|
|
num_ftrs = model.heads.head.in_features |
|
|
model.heads.head = nn.Linear(num_ftrs, num_classes) |
|
|
elif "convnextv2" in model_name: |
|
|
config = ConvNextV2Config.from_json_file(CONVNEXT_CONFIG_PATH) |
|
|
config.num_labels = num_classes |
|
|
model = ConvNextV2ForImageClassification(config) |
|
|
else: |
|
|
raise ValueError(f"Model '{model_name}' not supported.") |
|
|
return model |
|
|
|
|
|
def load_checkpoint(checkpoint_path: str, device: torch.device) -> Tuple[nn.Module, Dict[int, str]]: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") |
|
|
checkpoint: dict = torch.load(checkpoint_path, map_location=device) |
|
|
model_name_from_ckpt: str = checkpoint['model_name'] |
|
|
state_dict = checkpoint['state_dict'] |
|
|
|
|
|
if any(key.startswith("model.") for key in state_dict.keys()): |
|
|
print(f" > Unwrapping state_dict for {model_name_from_ckpt}...") |
|
|
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} |
|
|
|
|
|
class_to_idx: Dict[str, int] = checkpoint['class_to_idx'] |
|
|
model: nn.Module = get_model(model_name_from_ckpt, num_classes=1) |
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
idx_to_class: Dict[int, str] = {v: k for k, v in class_to_idx.items()} |
|
|
return model, idx_to_class |
|
|
|
|
|
print("--- Loading all models into memory ---") |
|
|
cpu_device = torch.device("cpu") |
|
|
gpu_device = torch.device("cuda") if torch.cuda.is_available() else None |
|
|
|
|
|
for display_name, ckpt_path in MODEL_CHECKPOINTS.items(): |
|
|
if os.path.exists(ckpt_path): |
|
|
print(f"Loading '{display_name}'...") |
|
|
try: |
|
|
cpu_model, idx_to_class = load_checkpoint(ckpt_path, cpu_device) |
|
|
CPU_MODELS[display_name] = (cpu_model, idx_to_class) |
|
|
print(f" > Loaded '{display_name}' for CPU.") |
|
|
if gpu_device: |
|
|
gpu_model, _ = load_checkpoint(ckpt_path, gpu_device) |
|
|
GPU_MODELS[display_name] = (gpu_model, idx_to_class) |
|
|
print(f" > Loaded '{display_name}' for GPU.") |
|
|
except Exception as e: |
|
|
print(f" > FAILED to load '{display_name}'. Error: {e}") |
|
|
else: |
|
|
print(f"WARNING: Checkpoint for '{display_name}' not found at {ckpt_path}. It will not be available.") |
|
|
|
|
|
if not CPU_MODELS: |
|
|
raise RuntimeError("No models were loaded. Please check the `checkpoints` directory.") |
|
|
|
|
|
try: |
|
|
with open(CONFIG_PATH, 'r') as f: config: dict = yaml.safe_load(f) |
|
|
except FileNotFoundError: |
|
|
raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'.") |
|
|
|
|
|
IMG_SIZE: int = config['data_params']['image_size'] |
|
|
inference_transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def base64_to_pil(base64_str: str) -> Image.Image: |
|
|
try: |
|
|
if "base64," in base64_str: base64_str = base64_str.split("base64,")[1] |
|
|
image_data: bytes = base64.b64decode(base64_str) |
|
|
return Image.open(io.BytesIO(image_data)) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Invalid base64 string: {e}") |
|
|
|
|
|
class Base64Image(BaseModel): image_data: str |
|
|
class BatchBase64Images(BaseModel): |
|
|
image_data_list: List[str] |
|
|
model_name: str = DEFAULT_MODEL_NAME |
|
|
use_gpu: bool = True |
|
|
|
|
|
def predict_batch(pil_images: List[Image.Image], use_gpu: bool, model_name: str) -> List[Dict[str, Union[dict, float]]]: |
|
|
model_dict = GPU_MODELS if use_gpu and gpu_device else CPU_MODELS |
|
|
if model_name not in model_dict: |
|
|
raise ValueError(f"Model '{model_name}' not loaded or not available.") |
|
|
|
|
|
model, idx_to_class = model_dict[model_name] |
|
|
device = gpu_device if use_gpu and gpu_device else cpu_device |
|
|
|
|
|
image_tensors = [inference_transform(img.convert("RGB")) for img in pil_images] |
|
|
batch_tensor = torch.stack(image_tensors).to(device) |
|
|
|
|
|
with model_lock, torch.no_grad(): |
|
|
start_time = time.time() |
|
|
output_obj = model(batch_tensor) |
|
|
batch_time = time.time() - start_time |
|
|
|
|
|
if hasattr(output_obj, 'logits'): |
|
|
logits = output_obj.logits |
|
|
else: |
|
|
logits = output_obj |
|
|
|
|
|
results = [] |
|
|
probs = torch.sigmoid(logits).squeeze().tolist() |
|
|
if not isinstance(probs, list): probs = [probs] |
|
|
|
|
|
class_0_name = idx_to_class.get(0, "Class 0") |
|
|
class_1_name = idx_to_class.get(1, "Class 1") |
|
|
|
|
|
for i, prob in enumerate(probs): |
|
|
results.append({ |
|
|
"prediction": {class_0_name: 1 - prob, class_1_name: prob}, |
|
|
"metadata": {"device": str(device), "inference_ms": (batch_time * 1000) / len(pil_images)} |
|
|
}) |
|
|
return results |
|
|
|
|
|
app = FastAPI(title="Multi-Model Image Classifier API") |
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) |
|
|
|
|
|
@app.post("/predict", response_model=dict) |
|
|
async def predict_api(request: Base64Image, model_name: str = DEFAULT_MODEL_NAME, use_gpu: bool = True): |
|
|
try: |
|
|
pil_image = base64_to_pil(request.image_data) |
|
|
return predict_batch([pil_image], use_gpu, model_name)[0] |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
@app.post("/batch_predict", response_model=List[dict]) |
|
|
async def batch_predict_api(request: BatchBase64Images): |
|
|
try: |
|
|
pil_images = [base64_to_pil(b64) for b64 in request.image_data_list] |
|
|
return predict_batch(pil_images, request.use_gpu, request.model_name) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
@app.get("/models", response_model=List[str]) |
|
|
async def get_available_models(): |
|
|
return list(CPU_MODELS.keys()) |
|
|
|
|
|
def predict_gradio(pil_image: Image.Image, model_name: str) -> Optional[dict]: |
|
|
if pil_image is None: return None |
|
|
result = predict_batch([pil_image], use_gpu=True, model_name=model_name)[0] |
|
|
return result["prediction"] |
|
|
|
|
|
gradio_iface = gr.Interface( |
|
|
fn=predict_gradio, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Input Image", sources=["upload", "webcam", "clipboard"]), |
|
|
gr.Dropdown( |
|
|
choices=list(CPU_MODELS.keys()), |
|
|
value=DEFAULT_MODEL_NAME, |
|
|
label="Select Model" |
|
|
) |
|
|
], |
|
|
outputs=gr.Label(num_top_classes=2, label="Predictions"), |
|
|
title="Multi-Model Image Classifier", |
|
|
description="Upload an image and select a model to see its classification. The API is available at the /docs endpoint.", |
|
|
allow_flagging="never" |
|
|
) |
|
|
|
|
|
app = gr.mount_gradio_app(app, gradio_iface, path="/") |