room_classifier / main.py
Nightfury16's picture
Initial commit
8317439
import os
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import base64
import io
import time
import threading
from typing import List, Dict, Union, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import ConvNextV2ForImageClassification
CHECKPOINT_DIR = "checkpoints"
CONFIG_PATH = "cm_config.yaml"
MODELS = {}
LABELS = {}
class HFConvNeXtWrapper(nn.Module):
def __init__(self, model_name, num_labels):
super(HFConvNeXtWrapper, self).__init__()
self.model = ConvNextV2ForImageClassification.from_pretrained(
model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
def forward(self, x):
return self.model(x).logits
def get_model(model_name, num_classes):
if model_name.startswith("efficientnet"):
model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif "convnextv2" in model_name:
model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
else:
raise ValueError(f"Unknown model: {model_name}")
return model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.exists(CHECKPOINT_DIR):
os.makedirs(CHECKPOINT_DIR)
model_files = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')]
default_model_name = None
print(f"--- Loading models from {CHECKPOINT_DIR} ---")
for filename in model_files:
path = os.path.join(CHECKPOINT_DIR, filename)
try:
ckpt = torch.load(path, map_location=device)
m_name = ckpt.get('model_name', 'efficientnet_b0')
n_classes = ckpt.get('num_classes', 5)
model = get_model(m_name, n_classes)
model.load_state_dict(ckpt['state_dict'])
model.to(device)
model.eval()
display_name = filename.replace('.pth', '')
MODELS[display_name] = model
if 'class_to_idx' in ckpt:
LABELS[display_name] = {v: k for k, v in ckpt['class_to_idx'].items()}
else:
LABELS[display_name] = {0:'Bat', 1:'Bed', 2:'Din', 3:'Kit', 4:'Liv'}
if default_model_name is None: default_model_name = display_name
print(f"Loaded: {display_name}")
except Exception as e:
print(f"Failed to load {filename}: {e}")
if not MODELS:
print("WARNING: No models loaded. Using Dummy for build.")
default_model_name = "dummy"
inference_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
class Base64Image(BaseModel):
image_data: str
model_name: Optional[str] = default_model_name
def base64_to_pil(base64_str: str) -> Image.Image:
if "base64," in base64_str: base64_str = base64_str.split("base64,")[1]
return Image.open(io.BytesIO(base64.b64decode(base64_str)))
def run_inference(pil_image, model_key):
if model_key not in MODELS:
raise ValueError("Model not found")
model = MODELS[model_key]
idx_map = LABELS[model_key]
img_tensor = inference_transform(pil_image.convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.softmax(logits, dim=1).squeeze().tolist()
return {idx_map[i]: float(probs[i]) for i in range(len(probs))}
app = FastAPI(title="Room Type Classifier API")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.get("/")
def home():
return {"message": "Room Classifier API is running", "models": list(MODELS.keys())}
@app.post("/predict")
def predict_api(payload: Base64Image):
m_name = payload.model_name if payload.model_name else default_model_name
try:
img = base64_to_pil(payload.image_data)
result = run_inference(img, m_name)
return {"model": m_name, "predictions": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def predict_gradio(img, model_choice):
if img is None: return None
return run_inference(img, model_choice)
if MODELS:
gradio_iface = gr.Interface(
fn=predict_gradio,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Dropdown(choices=list(MODELS.keys()), value=default_model_name, label="Model")
],
outputs=gr.Label(num_top_classes=5),
title="Room Type Classifier",
description="Detects: Bathroom, Bedroom, Dining, Kitchen, Living",
allow_flagging="never"
)
app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")