Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| os.environ['HF_HOME'] = './hf_cache' | |
| os.environ['MPLCONFIGDIR'] = './mpl_cache' | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import base64 | |
| import io | |
| import time | |
| import threading | |
| from typing import List, Dict, Union, Tuple, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import ConvNextV2Config, ConvNextV2ForImageClassification | |
| MODEL_CHECKPOINTS = { | |
| "convnext_tiny_best": "checkpoints/convnext_v2_tiny_best.pth", | |
| "efficientnet_b0": "checkpoints/effnet_b0_best.pth", | |
| "efficientnet_b3": "checkpoints/effnet_b3_best.pth", | |
| "vit_b_16": "checkpoints/vit_b_16_best.pth" | |
| } | |
| DEFAULT_MODEL_NAME = "vit_b_16" | |
| CONVNEXT_CONFIG_PATH = "convnext_config.json" | |
| GPU_MODELS: Dict[str, nn.Module] = {} | |
| CPU_MODELS: Dict[str, nn.Module] = {} | |
| CONFIG_PATH: str = os.getenv('CONFIG_PATH', 'cm_config.yaml') | |
| model_lock: threading.Lock = threading.Lock() | |
| def get_model(model_name: str, num_classes: int) -> nn.Module: | |
| model: Optional[nn.Module] = None | |
| if model_name == "efficientnet_b0": | |
| model = models.efficientnet_b0(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif model_name == "efficientnet_b3": | |
| model = models.efficientnet_b3(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif model_name == "vit_b_16": | |
| model = models.vit_b_16(weights=None) | |
| num_ftrs = model.heads.head.in_features | |
| model.heads.head = nn.Linear(num_ftrs, num_classes) | |
| elif "convnextv2" in model_name: | |
| config = ConvNextV2Config.from_json_file(CONVNEXT_CONFIG_PATH) | |
| config.num_labels = num_classes | |
| model = ConvNextV2ForImageClassification(config) | |
| else: | |
| raise ValueError(f"Model '{model_name}' not supported.") | |
| return model | |
| def load_checkpoint(checkpoint_path: str, device: torch.device) -> nn.Module: | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint file not found at: {checkpoint_path}") | |
| checkpoint: dict = torch.load(checkpoint_path, map_location=device) | |
| model_name_from_ckpt: str = checkpoint['model_name'] | |
| state_dict = checkpoint['state_dict'] | |
| if any(key.startswith("model.") for key in state_dict.keys()): | |
| print(f" > Unwrapping state_dict for {model_name_from_ckpt}...") | |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| model: nn.Module = get_model(model_name_from_ckpt, num_classes=1) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |
| print("--- Loading all models into memory ---") | |
| cpu_device = torch.device("cpu") | |
| gpu_device = torch.device("cuda") if torch.cuda.is_available() else None | |
| for display_name, ckpt_path in MODEL_CHECKPOINTS.items(): | |
| if os.path.exists(ckpt_path): | |
| print(f"Loading '{display_name}'...") | |
| try: | |
| cpu_model = load_checkpoint(ckpt_path, cpu_device) | |
| CPU_MODELS[display_name] = cpu_model | |
| print(f" > Loaded '{display_name}' for CPU.") | |
| if gpu_device: | |
| gpu_model = load_checkpoint(ckpt_path, gpu_device) | |
| GPU_MODELS[display_name] = gpu_model | |
| print(f" > Loaded '{display_name}' for GPU.") | |
| except Exception as e: | |
| print(f" > FAILED to load '{display_name}'. Error: {e}") | |
| else: | |
| print(f"WARNING: Checkpoint for '{display_name}' not found at {ckpt_path}. It will not be available.") | |
| if not CPU_MODELS: | |
| raise RuntimeError("No models were loaded. Please check the `checkpoints` directory.") | |
| try: | |
| with open(CONFIG_PATH, 'r') as f: config: dict = yaml.safe_load(f) | |
| except FileNotFoundError: | |
| raise RuntimeError(f"ERROR: Config file not found at '{CONFIG_PATH}'.") | |
| IMG_SIZE: int = config['data_params']['image_size'] | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def base64_to_pil(base64_str: str) -> Image.Image: | |
| try: | |
| if "base64," in base64_str: base64_str = base64_str.split("base64,")[1] | |
| image_data: bytes = base64.b64decode(base64_str) | |
| return Image.open(io.BytesIO(image_data)) | |
| except Exception as e: | |
| raise ValueError(f"Invalid base64 string: {e}") | |
| class Base64Image(BaseModel): image_data: str | |
| class BatchBase64Images(BaseModel): | |
| image_data_list: List[str] | |
| model_name: str = DEFAULT_MODEL_NAME | |
| use_gpu: bool = True | |
| def predict_batch(pil_images: List[Image.Image], use_gpu: bool, model_name: str) -> List[Dict[str, Union[dict, float]]]: | |
| model_dict = GPU_MODELS if use_gpu and gpu_device else CPU_MODELS | |
| if model_name not in model_dict: | |
| raise ValueError(f"Model '{model_name}' not loaded or not available. Available: {list(model_dict.keys())}") | |
| model = model_dict[model_name] | |
| device = gpu_device if use_gpu and gpu_device else cpu_device | |
| image_tensors = [inference_transform(img.convert("RGB")) for img in pil_images] | |
| batch_tensor = torch.stack(image_tensors).to(device) | |
| with model_lock, torch.no_grad(): | |
| start_time = time.time() | |
| output_obj = model(batch_tensor) | |
| batch_time = time.time() - start_time | |
| if hasattr(output_obj, 'logits'): | |
| logits = output_obj.logits | |
| else: | |
| logits = output_obj | |
| results = [] | |
| probs = torch.sigmoid(logits).squeeze().tolist() | |
| if not isinstance(probs, list): probs = [probs] | |
| for prob in probs: | |
| results.append({ | |
| "prediction": {"clean": 1 - prob, "messy": prob}, | |
| "metadata": {"device": str(device), "inference_ms": (batch_time * 1000) / len(pil_images)} | |
| }) | |
| return results | |
| app = FastAPI(title="Messy vs Clean Image Classifier API") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]) | |
| async def predict_api(request: Base64Image, model_name: str = DEFAULT_MODEL_NAME, use_gpu: bool = True): | |
| try: | |
| pil_image = base64_to_pil(request.image_data) | |
| return predict_batch([pil_image], use_gpu, model_name)[0] | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def batch_predict_api(request: BatchBase64Images): | |
| try: | |
| pil_images = [base64_to_pil(b64) for b64 in request.image_data_list] | |
| return predict_batch(pil_images, request.use_gpu, request.model_name) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| async def get_available_models(): | |
| return list(CPU_MODELS.keys()) | |
| def predict_gradio(pil_image: Image.Image, model_name: str) -> Optional[dict]: | |
| if pil_image is None: return None | |
| result = predict_batch([pil_image], use_gpu=True, model_name=model_name)[0] | |
| return result["prediction"] | |
| gradio_iface = gr.Interface( | |
| fn=predict_gradio, | |
| inputs=[ | |
| gr.Image(type="pil", label="Input Image", sources=["upload", "webcam", "clipboard"]), | |
| gr.Dropdown( | |
| choices=list(CPU_MODELS.keys()), | |
| value=DEFAULT_MODEL_NAME, | |
| label="Select Model" | |
| ) | |
| ], | |
| outputs=gr.Label(num_top_classes=2, label="Predictions"), | |
| title="Messy vs Clean Image Classifier", | |
| description="Upload an image and select a model to see its classification for 'messy' vs 'clean'. The API is available at the /docs endpoint.", | |
| allow_flagging="never" | |
| ) | |
| app = gr.mount_gradio_app(app, gradio_iface, path="/") |