import os import json os.environ['HF_HOME'] = './hf_cache' os.environ['MPLCONFIGDIR'] = './mpl_cache' import torch import torch.nn as nn import yaml from torchvision import models, transforms from PIL import Image import gradio as gr import base64 import io import time import threading from typing import List, Dict, Union, Tuple, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import ConvNextV2Config, ConvNextV2ForImageClassification MODEL_CHECKPOINTS = { "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth", "efficientnet_b0": "checkpoints/effnet_b0_best.pth", "efficientnet_b3": "checkpoints/effnet_b3_best.pth", "vit_b_16": "checkpoints/vit_b_16_best.pth" } DEFAULT_MODEL_NAME = "vit_b_16" CONVNEXT_CONFIG_PATH = "convnext_config.json" GPU_MODELS: Dict[str, nn.Module] = {} CPU_MODELS: Dict[str, nn.Module] = {} CONFIG_PATH: str = os.getenv('CONFIG_PATH', 'cm_config.yaml') model_lock: threading.Lock = threading.Lock() def get_model(model_name: str, num_classes: int) -> nn.Module: model: Optional[nn.Module] = None if model_name == "efficientnet_b0": model = models.efficientnet_b0(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif model_name == "efficientnet_b3": model = models.efficientnet_b3(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif model_name == "vit_b_16": model = models.vit_b_16(weights=None) num_ftrs = model.heads.head.in_features model.heads.head = nn.Linear(num_ftrs, num_classes) elif "convnextv2" in model_name: config = ConvNextV2Config.from_json_file(CONVNEXT_CONFIG_PATH) config.num_labels = num_classes model = ConvNextV2ForImageClassification(config) else: raise ValueError(f"Model '{model_name}' not supported.") return model def load_checkpoint(checkpoint_path: str, device: torch.device) -> nn.Module: if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") checkpoint: dict = torch.load(checkpoint_path, map_location=device) model_name_from_ckpt: str = checkpoint['model_name'] state_dict = checkpoint['state_dict'] if any(key.startswith("model.") for key in state_dict.keys()): print(f" > Unwrapping state_dict for {model_name_from_ckpt}...") state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} model: nn.Module = get_model(model_name_from_ckpt, num_classes=1) model.load_state_dict(state_dict) model.to(device) model.eval() return model print("--- Loading all models into memory ---") cpu_device = torch.device("cpu") gpu_device = torch.device("cuda") if torch.cuda.is_available() else None for display_name, ckpt_path in MODEL_CHECKPOINTS.items(): if os.path.exists(ckpt_path): print(f"Loading '{display_name}'...") try: cpu_model = load_checkpoint(ckpt_path, cpu_device) CPU_MODELS[display_name] = cpu_model print(f" > Loaded '{display_name}' for CPU.") if gpu_device: gpu_model = load_checkpoint(ckpt_path, gpu_device) GPU_MODELS[display_name] = gpu_model print(f" > Loaded '{display_name}' for GPU.") except Exception as e: print(f" > FAILED to load '{display_name}'. Error: {e}") else: print(f"WARNING: Checkpoint for '{display_name}' not found at {ckpt_path}. It will not be available.") if not CPU_MODELS: raise RuntimeError("No models were loaded. Please check the `checkpoints` directory.") try: with open(CONFIG_PATH, 'r') as f: config: dict = yaml.safe_load(f) except FileNotFoundError: raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'.") IMG_SIZE: int = config['data_params']['image_size'] inference_transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def base64_to_pil(base64_str: str) -> Image.Image: try: if "base64," in base64_str: base64_str = base64_str.split("base64,")[1] image_data: bytes = base64.b64decode(base64_str) return Image.open(io.BytesIO(image_data)) except Exception as e: raise ValueError(f"Invalid base64 string: {e}") class Base64Image(BaseModel): image_data: str class BatchBase64Images(BaseModel): image_data_list: List[str] model_name: str = DEFAULT_MODEL_NAME use_gpu: bool = True def predict_batch(pil_images: List[Image.Image], use_gpu: bool, model_name: str) -> List[Dict[str, Union[dict, float]]]: model_dict = GPU_MODELS if use_gpu and gpu_device else CPU_MODELS if model_name not in model_dict: raise ValueError(f"Model '{model_name}' not loaded or not available. Available: {list(model_dict.keys())}") model = model_dict[model_name] device = gpu_device if use_gpu and gpu_device else cpu_device image_tensors = [inference_transform(img.convert("RGB")) for img in pil_images] batch_tensor = torch.stack(image_tensors).to(device) with model_lock, torch.no_grad(): start_time = time.time() output_obj = model(batch_tensor) batch_time = time.time() - start_time if hasattr(output_obj, 'logits'): logits = output_obj.logits else: logits = output_obj results = [] probs = torch.sigmoid(logits).squeeze().tolist() if not isinstance(probs, list): probs = [probs] for prob in probs: results.append({ "prediction": {"clean": 1 - prob, "messy": prob}, "metadata": {"device": str(device), "inference_ms": (batch_time * 1000) / len(pil_images)} }) return results app = FastAPI(title="Messy vs Clean Image Classifier API") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) @app.post("/predict", response_model=dict) async def predict_api(request: Base64Image, model_name: str = DEFAULT_MODEL_NAME, use_gpu: bool = True): try: pil_image = base64_to_pil(request.image_data) return predict_batch([pil_image], use_gpu, model_name)[0] except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/batch_predict", response_model=List[dict]) async def batch_predict_api(request: BatchBase64Images): try: pil_images = [base64_to_pil(b64) for b64 in request.image_data_list] return predict_batch(pil_images, request.use_gpu, request.model_name) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/models", response_model=List[str]) async def get_available_models(): return list(CPU_MODELS.keys()) def predict_gradio(pil_image: Image.Image, model_name: str) -> Optional[dict]: if pil_image is None: return None result = predict_batch([pil_image], use_gpu=True, model_name=model_name)[0] return result["prediction"] gradio_iface = gr.Interface( fn=predict_gradio, inputs=[ gr.Image(type="pil", label="Input Image", sources=["upload", "webcam", "clipboard"]), gr.Dropdown( choices=list(CPU_MODELS.keys()), value=DEFAULT_MODEL_NAME, label="Select Model" ) ], outputs=gr.Label(num_top_classes=2, label="Predictions"), title="Messy vs Clean Image Classifier", description="Upload an image and select a model to see its classification for 'messy' vs 'clean'. The API is available at the /docs endpoint.", allow_flagging="never" ) app = gr.mount_gradio_app(app, gradio_iface, path="/")