TallManager267's picture
Update app.py
58d299f verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from huggingface_hub import hf_hub_download
import os
# Only 1 class
ROOM_CLASSES = ["room"]
# Hugging Face repo info
HF_REPO = "TallManager267/SG_Room_Segmentation"
WEIGHTS_FILE = "sg_room_segmentation_726_4000itr_0.004lr.pth"
# Download weights from Hugging Face
os.makedirs("model", exist_ok=True)
weights_path = hf_hub_download(repo_id=HF_REPO, filename=WEIGHTS_FILE)
def load_model(weights_path):
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(
"COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
))
cfg.MODEL.WEIGHTS = weights_path
cfg.MODEL.DEVICE = "cpu"
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
# Metadata for visualization
metadata = MetadataCatalog.get("room_metadata")
metadata.set(thing_classes=ROOM_CLASSES)
predictor = DefaultPredictor(cfg)
return predictor, metadata
predictor, metadata = load_model(weights_path)
import random
def predict(pil_img):
# Convert to RGB uint8
img = np.array(pil_img.convert("RGB"), dtype=np.uint8)
outputs = predictor(img)
instances = outputs["instances"].to("cpu")
v = Visualizer(
img[:, :, ::-1],
metadata=metadata,
scale=1.0
)
# Draw ONLY masks with random colors
if instances.has("pred_masks"):
for mask in instances.pred_masks:
random_color = (
random.random(), # R
random.random(), # G
random.random() # B
)
v.draw_binary_mask(
mask.numpy(),
color=random_color,
alpha=0.6
)
out_img = v.output.get_image()[:, :, ::-1]
return Image.fromarray(out_img)
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload floor plan"),
outputs=gr.Image(type="pil", label="Room segmentation"),
title="Room Segmentation (Detectron2)",
description="Upload a floor plan image to segment the room using Detectron2."
).launch(server_name="0.0.0.0", server_port=7860)