TallManager267 commited on
Commit
c2ccc93
·
verified ·
1 Parent(s): 0e353f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from detectron2 import model_zoo
6
+ from detectron2.config import get_cfg
7
+ from detectron2.engine import DefaultPredictor
8
+ from detectron2.utils.visualizer import Visualizer
9
+ from detectron2.data import MetadataCatalog
10
+ from huggingface_hub import hf_hub_download
11
+ import os
12
+
13
+ # Only 1 class
14
+ ROOM_CLASSES = ["room"]
15
+
16
+ # Hugging Face repo info
17
+ HF_REPO = "TallManager267/SG_Room_Segmentation"
18
+ WEIGHTS_FILE = "sg_room_segmentation_726_4000itr_0.004lr.pth"
19
+
20
+ # Download weights from Hugging Face
21
+ os.makedirs("model", exist_ok=True)
22
+ weights_path = hf_hub_download(repo_id=HF_REPO, filename=WEIGHTS_FILE)
23
+
24
+ def load_model(weights_path):
25
+ cfg = get_cfg()
26
+ cfg.merge_from_file(model_zoo.get_config_file(
27
+ "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x.yaml"
28
+ ))
29
+ cfg.MODEL.WEIGHTS = weights_path
30
+ cfg.MODEL.DEVICE = "cpu"
31
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
32
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
33
+
34
+ # Metadata for visualization
35
+ metadata = MetadataCatalog.get("room_metadata")
36
+ metadata.set(thing_classes=ROOM_CLASSES)
37
+
38
+ predictor = DefaultPredictor(cfg)
39
+ return predictor, metadata
40
+
41
+ predictor, metadata = load_model(weights_path)
42
+
43
+ def predict(pil_img):
44
+ # Convert to RGB and ensure dtype is uint8
45
+ img = np.array(pil_img.convert("RGB"), dtype=np.uint8)
46
+
47
+ outputs = predictor(img)
48
+ instances = outputs["instances"].to("cpu")
49
+
50
+ v = Visualizer(
51
+ img[:, :, ::-1],
52
+ metadata=metadata,
53
+ scale=1.0
54
+ )
55
+ out = v.draw_instance_predictions(instances)
56
+ out_img = out.get_image()[:, :, ::-1]
57
+
58
+ return Image.fromarray(out_img)
59
+
60
+ gr.Interface(
61
+ fn=predict,
62
+ inputs=gr.Image(type="pil", label="Upload floor plan"),
63
+ outputs=gr.Image(type="pil", label="Room segmentation"),
64
+ title="Room Segmentation (Detectron2)",
65
+ description="Upload a floor plan image to segment the room using Detectron2."
66
+ ).launch(server_name="0.0.0.0", server_port=7860)