File size: 2,364 Bytes
c2ccc93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3542a05
 
c2ccc93
3542a05
c2ccc93
 
 
 
 
 
 
 
 
 
 
3542a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58d299f
c2ccc93
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
import numpy as np
from PIL import Image
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from huggingface_hub import hf_hub_download
import os

# Only 1 class
ROOM_CLASSES = ["room"]

# Hugging Face repo info
HF_REPO = "TallManager267/SG_Room_Segmentation"
WEIGHTS_FILE = "sg_room_segmentation_726_4000itr_0.004lr.pth"

# Download weights from Hugging Face
os.makedirs("model", exist_ok=True)
weights_path = hf_hub_download(repo_id=HF_REPO, filename=WEIGHTS_FILE)

def load_model(weights_path):
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file(
        "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
    ))
    cfg.MODEL.WEIGHTS = weights_path
    cfg.MODEL.DEVICE = "cpu"
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 

    # Metadata for visualization
    metadata = MetadataCatalog.get("room_metadata")
    metadata.set(thing_classes=ROOM_CLASSES)

    predictor = DefaultPredictor(cfg)
    return predictor, metadata

predictor, metadata = load_model(weights_path)

import random

def predict(pil_img):
    # Convert to RGB uint8
    img = np.array(pil_img.convert("RGB"), dtype=np.uint8)

    outputs = predictor(img)
    instances = outputs["instances"].to("cpu")

    v = Visualizer(
        img[:, :, ::-1],
        metadata=metadata,
        scale=1.0
    )

    # Draw ONLY masks with random colors
    if instances.has("pred_masks"):
        for mask in instances.pred_masks:
            random_color = (
                random.random(),  # R
                random.random(),  # G
                random.random()   # B
            )

            v.draw_binary_mask(
                mask.numpy(),
                color=random_color,
                alpha=0.6
            )

    out_img = v.output.get_image()[:, :, ::-1]
    return Image.fromarray(out_img)

gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload floor plan"),
    outputs=gr.Image(type="pil", label="Room segmentation"),
    title="Room Segmentation (Detectron2)",
    description="Upload a floor plan image to segment the room using Detectron2."
).launch(server_name="0.0.0.0", server_port=7860)