Spaces:
Sleeping
Sleeping
File size: 2,966 Bytes
8317439 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import os
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification
CHECKPOINT_PATH = "checkpoints/room_classifier_best.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class HFConvNeXtWrapper(nn.Module):
def __init__(self, model_name, num_labels):
super(HFConvNeXtWrapper, self).__init__()
self.model = ConvNextV2ForImageClassification.from_pretrained(
model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
def forward(self, x):
return self.model(x).logits
def get_model(model_name, num_classes):
if model_name.startswith("efficientnet"):
model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, num_classes)
elif "convnextv2" in model_name:
model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
elif model_name == "vit_b_16":
model = models.vit_b_16(weights=None)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
else:
raise ValueError(f"Unknown model: {model_name}")
return model
if not os.path.exists(CHECKPOINT_PATH):
raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")
print(f"Loading model from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model_name = checkpoint['model_name']
num_classes = checkpoint.get('num_classes', 5)
class_to_idx = checkpoint.get('class_to_idx', None)
if class_to_idx:
idx_to_class = {v: k for k, v in class_to_idx.items()}
else:
print("Warning: class_to_idx not found in checkpoint. Using default 5 classes.")
idx_to_class = {0: 'Bathroom', 1: 'Bedroom', 2: 'Dining', 3: 'Kitchen', 4: 'Living'}
model = get_model(model_name, num_classes)
model.load_state_dict(checkpoint['state_dict'])
model.to(DEVICE)
model.eval()
inference_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(pil_image):
if pil_image is None: return None
pil_image = pil_image.convert("RGB")
tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(tensor)
probs = torch.softmax(logits, dim=1).squeeze()
return {idx_to_class[i]: float(probs[i]) for i in range(len(probs))}
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Room Image"),
outputs=gr.Label(num_top_classes=5, label="Predictions"),
title="Room Type Classifier 🏠",
description=f"Classifies images into: {', '.join(idx_to_class.values())}",
)
if __name__ == "__main__":
iface.launch()
|