room_classifier / app.py
Nightfury16's picture
Initial commit
8317439
import os
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
CHECKPOINT_PATH = "checkpoints/room_classifier_best.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class HFConvNeXtWrapper(nn.Module):
def __init__(self, model_name, num_labels):
super(HFConvNeXtWrapper, self).__init__()
self.model = ConvNextV2ForImageClassification.from_pretrained(
model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
def forward(self, x):
return self.model(x).logits
def get_model(model_name, num_classes):
if model_name.startswith("efficientnet"):
model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif "convnextv2" in model_name:
model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
else:
raise ValueError(f"Unknown model: {model_name}")
return model
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")
print(f"Loading model from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model_name = checkpoint['model_name']
num_classes = checkpoint.get('num_classes', 5)
class_to_idx = checkpoint.get('class_to_idx', None)
if class_to_idx:
idx_to_class = {v: k for k, v in class_to_idx.items()}
else:
print("Warning: class_to_idx not found in checkpoint. Using default 5 classes.")
idx_to_class = {0: 'Bathroom', 1: 'Bedroom', 2: 'Dining', 3: 'Kitchen', 4: 'Living'}
model = get_model(model_name, num_classes)
model.load_state_dict(checkpoint['state_dict'])
model.to(DEVICE)
model.eval()
inference_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(pil_image):
if pil_image is None: return None
pil_image = pil_image.convert("RGB")
tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze()
return {idx_to_class[i]: float(probs[i]) for i in range(len(probs))}
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Room Image"),
outputs=gr.Label(num_top_classes=5, label="Predictions"),
title="Room Type Classifier 🏠",
description=f"Classifies images into: {', '.join(idx_to_class.values())}",
)
if __name__ == "__main__":
iface.launch()