File size: 5,095 Bytes
8317439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
import base64
import io
import time
import threading
from typing import List, Dict, Union, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import ConvNextV2ForImageClassification

CHECKPOINT_DIR = "checkpoints"
CONFIG_PATH = "cm_config.yaml"

MODELS = {}
LABELS = {}

class HFConvNeXtWrapper(nn.Module):
    def __init__(self, model_name, num_labels):
        super(HFConvNeXtWrapper, self).__init__()
        self.model = ConvNextV2ForImageClassification.from_pretrained(
            model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
    def forward(self, x):
        return self.model(x).logits

def get_model(model_name, num_classes):
    if model_name.startswith("efficientnet"):
        model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif "convnextv2" in model_name:
        model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=None)
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    return model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not os.path.exists(CHECKPOINT_DIR):
    os.makedirs(CHECKPOINT_DIR)

model_files = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pth')]
default_model_name = None

print(f"--- Loading models from {CHECKPOINT_DIR} ---")
for filename in model_files:
    path = os.path.join(CHECKPOINT_DIR, filename)
    try:
        ckpt = torch.load(path, map_location=device)
        m_name = ckpt.get('model_name', 'efficientnet_b0')
        n_classes = ckpt.get('num_classes', 5)
        
        model = get_model(m_name, n_classes)
        model.load_state_dict(ckpt['state_dict'])
        model.to(device)
        model.eval()
        
        display_name = filename.replace('.pth', '')
        MODELS[display_name] = model
        
        if 'class_to_idx' in ckpt:
            LABELS[display_name] = {v: k for k, v in ckpt['class_to_idx'].items()}
        else:
            LABELS[display_name] = {0:'Bat', 1:'Bed', 2:'Din', 3:'Kit', 4:'Liv'}
            
        if default_model_name is None: default_model_name = display_name
        print(f"Loaded: {display_name}")
        
    except Exception as e:
        print(f"Failed to load {filename}: {e}")

if not MODELS:
    print("WARNING: No models loaded. Using Dummy for build.")
    default_model_name = "dummy"

inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

class Base64Image(BaseModel): 
    image_data: str
    model_name: Optional[str] = default_model_name

def base64_to_pil(base64_str: str) -> Image.Image:
    if "base64," in base64_str: base64_str = base64_str.split("base64,")[1]
    return Image.open(io.BytesIO(base64.b64decode(base64_str)))

def run_inference(pil_image, model_key):
    if model_key not in MODELS:
        raise ValueError("Model not found")
        
    model = MODELS[model_key]
    idx_map = LABELS[model_key]
    
    img_tensor = inference_transform(pil_image.convert("RGB")).unsqueeze(0).to(device)
    
    with torch.no_grad():
        logits = model(img_tensor)
        probs = torch.softmax(logits, dim=1).squeeze().tolist()
    
    return {idx_map[i]: float(probs[i]) for i in range(len(probs))}

app = FastAPI(title="Room Type Classifier API")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

@app.get("/")
def home():
    return {"message": "Room Classifier API is running", "models": list(MODELS.keys())}

@app.post("/predict")
def predict_api(payload: Base64Image):
    m_name = payload.model_name if payload.model_name else default_model_name
    try:
        img = base64_to_pil(payload.image_data)
        result = run_inference(img, m_name)
        return {"model": m_name, "predictions": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

def predict_gradio(img, model_choice):
    if img is None: return None
    return run_inference(img, model_choice)

if MODELS:
    gradio_iface = gr.Interface(
        fn=predict_gradio,
        inputs=[
            gr.Image(type="pil", label="Image"),
            gr.Dropdown(choices=list(MODELS.keys()), value=default_model_name, label="Model")
        ],
        outputs=gr.Label(num_top_classes=5),
        title="Room Type Classifier",
        description="Detects: Bathroom, Bedroom, Dining, Kitchen, Living",
        allow_flagging="never"
    )
    app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")