Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import gradio as gr | |
| from transformers import ConvNextV2ForImageClassification | |
| CHECKPOINT_PATH = "checkpoints/room_classifier_best.pth" | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class HFConvNeXtWrapper(nn.Module): | |
| def __init__(self, model_name, num_labels): | |
| super(HFConvNeXtWrapper, self).__init__() | |
| self.model = ConvNextV2ForImageClassification.from_pretrained( | |
| model_name, num_labels=num_labels, ignore_mismatched_sizes=True) | |
| def forward(self, x): | |
| return self.model(x).logits | |
| def get_model(model_name, num_classes): | |
| if model_name.startswith("efficientnet"): | |
| model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None) | |
| num_ftrs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(num_ftrs, num_classes) | |
| elif "convnextv2" in model_name: | |
| model = HFConvNeXtWrapper(model_name, num_labels=num_classes) | |
| elif model_name == "vit_b_16": | |
| model = models.vit_b_16(weights=None) | |
| model.heads.head = nn.Linear(model.heads.head.in_features, num_classes) | |
| else: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| return model | |
| if not os.path.exists(CHECKPOINT_PATH): | |
| raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}") | |
| print(f"Loading model from {CHECKPOINT_PATH}...") | |
| checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE) | |
| model_name = checkpoint['model_name'] | |
| num_classes = checkpoint.get('num_classes', 5) | |
| class_to_idx = checkpoint.get('class_to_idx', None) | |
| if class_to_idx: | |
| idx_to_class = {v: k for k, v in class_to_idx.items()} | |
| else: | |
| print("Warning: class_to_idx not found in checkpoint. Using default 5 classes.") | |
| idx_to_class = {0: 'Bathroom', 1: 'Bedroom', 2: 'Dining', 3: 'Kitchen', 4: 'Living'} | |
| model = get_model(model_name, num_classes) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.to(DEVICE) | |
| model.eval() | |
| inference_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(pil_image): | |
| if pil_image is None: return None | |
| pil_image = pil_image.convert("RGB") | |
| tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(tensor) | |
| probs = torch.softmax(logits, dim=1).squeeze() | |
| return {idx_to_class[i]: float(probs[i]) for i in range(len(probs))} | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload Room Image"), | |
| outputs=gr.Label(num_top_classes=5, label="Predictions"), | |
| title="Room Type Classifier 🏠", | |
| description=f"Classifies images into: {', '.join(idx_to_class.values())}", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |