import os import torch import torch.nn as nn import yaml from torchvision import models, transforms from PIL import Image import gradio as gr import base64 import io import time import threading from typing import List, Dict, Union, Optional from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import ConvNextV2ForImageClassification CHECKPOINT_DIR = "checkpoints" CONFIG_PATH = "cm_config.yaml" MODELS = {} LABELS = {} class HFConvNeXtWrapper(nn.Module): def __init__(self, model_name, num_labels): super(HFConvNeXtWrapper, self).__init__() self.model = ConvNextV2ForImageClassification.from_pretrained( model_name, num_labels=num_labels, ignore_mismatched_sizes=True) def forward(self, x): return self.model(x).logits def get_model(model_name, num_classes): if model_name.startswith("efficientnet"): model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif "convnextv2" in model_name: model = HFConvNeXtWrapper(model_name, num_labels=num_classes) elif model_name == "vit_b_16": model = models.vit_b_16(weights=None) model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) else: raise ValueError(f"Unknown model: {model_name}") return model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not os.path.exists(CHECKPOINT_DIR): os.makedirs(CHECKPOINT_DIR) model_files = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')] default_model_name = None print(f"--- Loading models from {CHECKPOINT_DIR} ---") for filename in model_files: path = os.path.join(CHECKPOINT_DIR, filename) try: ckpt = torch.load(path, map_location=device) m_name = ckpt.get('model_name', 'efficientnet_b0') n_classes = ckpt.get('num_classes', 5) model = get_model(m_name, n_classes) model.load_state_dict(ckpt['state_dict']) model.to(device) model.eval() display_name = filename.replace('.pth', '') MODELS[display_name] = model if 'class_to_idx' in ckpt: LABELS[display_name] = {v: k for k, v in ckpt['class_to_idx'].items()} else: LABELS[display_name] = {0:'Bat', 1:'Bed', 2:'Din', 3:'Kit', 4:'Liv'} if default_model_name is None: default_model_name = display_name print(f"Loaded: {display_name}") except Exception as e: print(f"Failed to load {filename}: {e}") if not MODELS: print("WARNING: No models loaded. Using Dummy for build.") default_model_name = "dummy" inference_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) class Base64Image(BaseModel): image_data: str model_name: Optional[str] = default_model_name def base64_to_pil(base64_str: str) -> Image.Image: if "base64," in base64_str: base64_str = base64_str.split("base64,")[1] return Image.open(io.BytesIO(base64.b64decode(base64_str))) def run_inference(pil_image, model_key): if model_key not in MODELS: raise ValueError("Model not found") model = MODELS[model_key] idx_map = LABELS[model_key] img_tensor = inference_transform(pil_image.convert("RGB")).unsqueeze(0).to(device) with torch.no_grad(): logits = model(img_tensor) probs = torch.softmax(logits, dim=1).squeeze().tolist() return {idx_map[i]: float(probs[i]) for i in range(len(probs))} app = FastAPI(title="Room Type Classifier API") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @app.get("/") def home(): return {"message": "Room Classifier API is running", "models": list(MODELS.keys())} @app.post("/predict") def predict_api(payload: Base64Image): m_name = payload.model_name if payload.model_name else default_model_name try: img = base64_to_pil(payload.image_data) result = run_inference(img, m_name) return {"model": m_name, "predictions": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def predict_gradio(img, model_choice): if img is None: return None return run_inference(img, model_choice) if MODELS: gradio_iface = gr.Interface( fn=predict_gradio, inputs=[ gr.Image(type="pil", label="Image"), gr.Dropdown(choices=list(MODELS.keys()), value=default_model_name, label="Model") ], outputs=gr.Label(num_top_classes=5), title="Room Type Classifier", description="Detects: Bathroom, Bedroom, Dining, Kitchen, Living", allow_flagging="never" ) app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")