import os import torch import torch.nn as nn import yaml from torchvision import models, transforms from PIL import Image import gradio as gr from transformers import ConvNextV2ForImageClassification CHECKPOINT_PATH = "checkpoints/room_classifier_best.pth" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class HFConvNeXtWrapper(nn.Module): def __init__(self, model_name, num_labels): super(HFConvNeXtWrapper, self).__init__() self.model = ConvNextV2ForImageClassification.from_pretrained( model_name, num_labels=num_labels, ignore_mismatched_sizes=True) def forward(self, x): return self.model(x).logits def get_model(model_name, num_classes): if model_name.startswith("efficientnet"): model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None) num_ftrs = model.classifier[1].in_features model.classifier[1] = nn.Linear(num_ftrs, num_classes) elif "convnextv2" in model_name: model = HFConvNeXtWrapper(model_name, num_labels=num_classes) elif model_name == "vit_b_16": model = models.vit_b_16(weights=None) model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) else: raise ValueError(f"Unknown model: {model_name}") return model if not os.path.exists(CHECKPOINT_PATH): raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}") print(f"Loading model from {CHECKPOINT_PATH}...") checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) model_name = checkpoint['model_name'] num_classes = checkpoint.get('num_classes', 5) class_to_idx = checkpoint.get('class_to_idx', None) if class_to_idx: idx_to_class = {v: k for k, v in class_to_idx.items()} else: print("Warning: class_to_idx not found in checkpoint. Using default 5 classes.") idx_to_class = {0: 'Bathroom', 1: 'Bedroom', 2: 'Dining', 3: 'Kitchen', 4: 'Living'} model = get_model(model_name, num_classes) model.load_state_dict(checkpoint['state_dict']) model.to(DEVICE) model.eval() inference_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(pil_image): if pil_image is None: return None pil_image = pil_image.convert("RGB") tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(tensor) probs = torch.softmax(logits, dim=1).squeeze() return {idx_to_class[i]: float(probs[i]) for i in range(len(probs))} iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload Room Image"), outputs=gr.Label(num_top_classes=5, label="Predictions"), title="Room Type Classifier 🏠", description=f"Classifies images into: {', '.join(idx_to_class.values())}", ) if __name__ == "__main__": iface.launch()