Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import base64 | |
| import io | |
| import time | |
| import threading | |
| from typing import List, Dict, Union, Optional | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import ConvNextV2ForImageClassification | |
| CHECKPOINT_DIR = "checkpoints" | |
| CONFIG_PATH = "cm_config.yaml" | |
| MODELS = {} | |
| LABELS = {} | |
| class HFConvNeXtWrapper(nn.Module): | |
| def __init__(self, model_name, num_labels): | |
| super(HFConvNeXtWrapper, self).__init__() | |
| self.model = ConvNextV2ForImageClassification.from_pretrained( | |
| model_name, num_labels=num_labels, ignore_mismatched_sizes=True) | |
| def forward(self, x): | |
| return self.model(x).logits | |
| def get_model(model_name, num_classes): | |
| if model_name.startswith("efficientnet"): | |
| model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif "convnextv2" in model_name: | |
| model = HFConvNeXtWrapper(model_name, num_labels=num_classes) | |
| elif model_name == "vit_b_16": | |
| model = models.vit_b_16(weights=None) | |
| model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| return model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if not os.path.exists(CHECKPOINT_DIR): | |
| os.makedirs(CHECKPOINT_DIR) | |
| model_files = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')] | |
| default_model_name = None | |
| print(f"--- Loading models from {CHECKPOINT_DIR} ---") | |
| for filename in model_files: | |
| path = os.path.join(CHECKPOINT_DIR, filename) | |
| try: | |
| ckpt = torch.load(path, map_location=device) | |
| m_name = ckpt.get('model_name', 'efficientnet_b0') | |
| n_classes = ckpt.get('num_classes', 5) | |
| model = get_model(m_name, n_classes) | |
| model.load_state_dict(ckpt['state_dict']) | |
| model.to(device) | |
| model.eval() | |
| display_name = filename.replace('.pth', '') | |
| MODELS[display_name] = model | |
| if 'class_to_idx' in ckpt: | |
| LABELS[display_name] = {v: k for k, v in ckpt['class_to_idx'].items()} | |
| else: | |
| LABELS[display_name] = {0:'Bat', 1:'Bed', 2:'Din', 3:'Kit', 4:'Liv'} | |
| if default_model_name is None: default_model_name = display_name | |
| print(f"Loaded: {display_name}") | |
| except Exception as e: | |
| print(f"Failed to load {filename}: {e}") | |
| if not MODELS: | |
| print("WARNING: No models loaded. Using Dummy for build.") | |
| default_model_name = "dummy" | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| class Base64Image(BaseModel): | |
| image_data: str | |
| model_name: Optional[str] = default_model_name | |
| def base64_to_pil(base64_str: str) -> Image.Image: | |
| if "base64," in base64_str: base64_str = base64_str.split("base64,")[1] | |
| return Image.open(io.BytesIO(base64.b64decode(base64_str))) | |
| def run_inference(pil_image, model_key): | |
| if model_key not in MODELS: | |
| raise ValueError("Model not found") | |
| model = MODELS[model_key] | |
| idx_map = LABELS[model_key] | |
| img_tensor = inference_transform(pil_image.convert("RGB")).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| logits = model(img_tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze().tolist() | |
| return {idx_map[i]: float(probs[i]) for i in range(len(probs))} | |
| app = FastAPI(title="Room Type Classifier API") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| def home(): | |
| return {"message": "Room Classifier API is running", "models": list(MODELS.keys())} | |
| def predict_api(payload: Base64Image): | |
| m_name = payload.model_name if payload.model_name else default_model_name | |
| try: | |
| img = base64_to_pil(payload.image_data) | |
| result = run_inference(img, m_name) | |
| return {"model": m_name, "predictions": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def predict_gradio(img, model_choice): | |
| if img is None: return None | |
| return run_inference(img, model_choice) | |
| if MODELS: | |
| gradio_iface = gr.Interface( | |
| fn=predict_gradio, | |
| inputs=[ | |
| gr.Image(type="pil", label="Image"), | |
| gr.Dropdown(choices=list(MODELS.keys()), value=default_model_name, label="Model") | |
| ], | |
| outputs=gr.Label(num_top_classes=5), | |
| title="Room Type Classifier", | |
| description="Detects: Bathroom, Bedroom, Dining, Kitchen, Living", | |
| allow_flagging="never" | |
| ) | |
| app = gr.mount_gradio_app(app, gradio_iface, path="/gradio") | |