File size: 2,966 Bytes
8317439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import torch
import torch.nn as nn
import yaml
from torchvision import models, transforms
from PIL import Image
import gradio as gr
from transformers import ConvNextV2ForImageClassification

CHECKPOINT_PATH = "checkpoints/room_classifier_best.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class HFConvNeXtWrapper(nn.Module):
    def __init__(self, model_name, num_labels):
        super(HFConvNeXtWrapper, self).__init__()
        self.model = ConvNextV2ForImageClassification.from_pretrained(
            model_name, num_labels=num_labels, ignore_mismatched_sizes=True)
    def forward(self, x):
        return self.model(x).logits

def get_model(model_name, num_classes):
    if model_name.startswith("efficientnet"):
        model = models.efficientnet_b0(weights=None) if "b0" in model_name else models.efficientnet_b3(weights=None)
        num_ftrs = model.classifier[1].in_features
        model.classifier[1] = nn.Linear(num_ftrs, num_classes)
    elif "convnextv2" in model_name:
        model = HFConvNeXtWrapper(model_name, num_labels=num_classes)
    elif model_name == "vit_b_16":
        model = models.vit_b_16(weights=None)
        model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    return model

if not os.path.exists(CHECKPOINT_PATH):
    raise FileNotFoundError(f"Checkpoint not found at {CHECKPOINT_PATH}")

print(f"Loading model from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
model_name = checkpoint['model_name']
num_classes = checkpoint.get('num_classes', 5)

class_to_idx = checkpoint.get('class_to_idx', None)
if class_to_idx:
    idx_to_class = {v: k for k, v in class_to_idx.items()}
else:
    print("Warning: class_to_idx not found in checkpoint. Using default 5 classes.")
    idx_to_class = {0: 'Bathroom', 1: 'Bedroom', 2: 'Dining', 3: 'Kitchen', 4: 'Living'}

model = get_model(model_name, num_classes)
model.load_state_dict(checkpoint['state_dict'])
model.to(DEVICE)
model.eval()

inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict(pil_image):
    if pil_image is None: return None
    pil_image = pil_image.convert("RGB")
    tensor = inference_transform(pil_image).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        logits = model(tensor)
        probs = torch.softmax(logits, dim=1).squeeze()
        
    return {idx_to_class[i]: float(probs[i]) for i in range(len(probs))}

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload Room Image"),
    outputs=gr.Label(num_top_classes=5, label="Predictions"),
    title="Room Type Classifier 🏠",
    description=f"Classifies images into: {', '.join(idx_to_class.values())}",
)

if __name__ == "__main__":
    iface.launch()