Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from detectron2 import model_zoo | |
| from detectron2.config import get_cfg | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.utils.visualizer import Visualizer | |
| from detectron2.data import MetadataCatalog | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Only 1 class | |
| ROOM_CLASSES = ["room"] | |
| # Hugging Face repo info | |
| HF_REPO = "TallManager267/SG_Room_Segmentation" | |
| WEIGHTS_FILE = "sg_room_segmentation_726_4000itr_0.004lr.pth" | |
| # Download weights from Hugging Face | |
| os.makedirs("model", exist_ok=True) | |
| weights_path = hf_hub_download(repo_id=HF_REPO, filename=WEIGHTS_FILE) | |
| def load_model(weights_path): | |
| cfg = get_cfg() | |
| cfg.merge_from_file(model_zoo.get_config_file( | |
| "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml" | |
| )) | |
| cfg.MODEL.WEIGHTS = weights_path | |
| cfg.MODEL.DEVICE = "cpu" | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | |
| # Metadata for visualization | |
| metadata = MetadataCatalog.get("room_metadata") | |
| metadata.set(thing_classes=ROOM_CLASSES) | |
| predictor = DefaultPredictor(cfg) | |
| return predictor, metadata | |
| predictor, metadata = load_model(weights_path) | |
| import random | |
| def predict(pil_img): | |
| # Convert to RGB uint8 | |
| img = np.array(pil_img.convert("RGB"), dtype=np.uint8) | |
| outputs = predictor(img) | |
| instances = outputs["instances"].to("cpu") | |
| v = Visualizer( | |
| img[:, :, ::-1], | |
| metadata=metadata, | |
| scale=1.0 | |
| ) | |
| # Draw ONLY masks with random colors | |
| if instances.has("pred_masks"): | |
| for mask in instances.pred_masks: | |
| random_color = ( | |
| random.random(), # R | |
| random.random(), # G | |
| random.random() # B | |
| ) | |
| v.draw_binary_mask( | |
| mask.numpy(), | |
| color=random_color, | |
| alpha=0.6 | |
| ) | |
| out_img = v.output.get_image()[:, :, ::-1] | |
| return Image.fromarray(out_img) | |
| gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil", label="Upload floor plan"), | |
| outputs=gr.Image(type="pil", label="Room segmentation"), | |
| title="Room Segmentation (Detectron2)", | |
| description="Upload a floor plan image to segment the room using Detectron2." | |
| ).launch(server_name="0.0.0.0", server_port=7860) |