MarouaneAyech commited on
Commit
50319a8
·
verified ·
1 Parent(s): 3baa1ce

Update main3.py

Browse files
Files changed (1) hide show
  1. main3.py +191 -191
main3.py CHANGED
@@ -1,191 +1,191 @@
1
- import io
2
- import os
3
- import gdown
4
- import base64
5
- import cv2
6
- import numpy as np
7
- from PIL import Image
8
- from typing import Optional
9
- from fastapi import FastAPI, UploadFile, File, Form
10
- from fastapi.responses import JSONResponse
11
- from fastapi.middleware.cors import CORSMiddleware
12
- from detectron2.engine import DefaultPredictor
13
- from detectron2.config import get_cfg
14
- from detectron2.projects.point_rend import add_pointrend_config
15
-
16
- # -------------------------------
17
- # FastAPI setup
18
- # -------------------------------
19
- app = FastAPI(title="Rooftop Segmentation API")
20
-
21
- app.add_middleware(
22
- CORSMiddleware,
23
- allow_origins=["*"],
24
- allow_credentials=True,
25
- allow_methods=["*"],
26
- allow_headers=["*"],
27
- )
28
-
29
- # -------------------------------
30
- # Available epsilons
31
- # -------------------------------
32
- EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001]
33
-
34
- @app.get("/epsilons")
35
- def get_epsilons():
36
- return {"epsilons": EPSILONS}
37
-
38
- # -------------------------------
39
- # Google Drive model download (irregular-flat)
40
- # -------------------------------
41
- MODEL_PATH_IRREGULAR = "/tmp/model_irregular_flat.pth"
42
- DRIVE_FILE_ID = "1GO_Ut-N2e2we8t9mnsysb0P1qMsBF8FW"
43
-
44
- def download_irregular_model():
45
- if not os.path.exists(MODEL_PATH_IRREGULAR):
46
- url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}"
47
- tmp_dir = "/tmp/gdown"
48
- os.makedirs(tmp_dir, exist_ok=True)
49
- os.environ["GDOWN_CACHE_DIR"] = tmp_dir
50
- print("Downloading irregular-flat Detectron2 model...")
51
- gdown.download(url, MODEL_PATH_IRREGULAR, quiet=False, fuzzy=True, use_cookies=False)
52
- print("Download complete.")
53
- else:
54
- print("Irregular-flat model already exists, skipping download.")
55
-
56
- download_irregular_model()
57
-
58
- if os.path.exists(MODEL_PATH_IRREGULAR):
59
- print("Irregular-flat model is ready at", MODEL_PATH_IRREGULAR)
60
- else:
61
- print("Irregular-flat model NOT found! Something went wrong!")
62
-
63
- # -------------------------------
64
- # Detectron2 model setup
65
- # -------------------------------
66
- def setup_model_rect(weights_path: str):
67
- cfg = get_cfg()
68
- add_pointrend_config(cfg)
69
- cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
70
- cfg.merge_from_file(cfg_path)
71
- cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
72
- cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
73
- cfg.MODEL.WEIGHTS = weights_path
74
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
75
- cfg.MODEL.DEVICE = "cpu"
76
- return DefaultPredictor(cfg)
77
-
78
- def setup_model_irregular(weights_path: str):
79
- cfg = get_cfg()
80
- add_pointrend_config(cfg)
81
- cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
82
- cfg.merge_from_file(cfg_path)
83
- cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
84
- cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
85
- cfg.MODEL.WEIGHTS = weights_path
86
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
87
- cfg.MODEL.DEVICE = "cpu"
88
- return DefaultPredictor(cfg)
89
-
90
- # Load models
91
- predictor_rect = setup_model_rect("/app/model_rect_final.pth")
92
- predictor_irregular_flat = setup_model_irregular(MODEL_PATH_IRREGULAR)
93
-
94
- # -------------------------------
95
- # Utility functions
96
- # -------------------------------
97
- def im_to_b64_png(im: np.ndarray) -> str:
98
- _, buffer = cv2.imencode(".png", im)
99
- return base64.b64encode(buffer).decode()
100
-
101
- def extract_polygon(mask: np.ndarray, epsilon_ratio: float = 0.004):
102
- mask_uint8 = (mask * 255).astype(np.uint8)
103
- contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
104
- if not contours:
105
- return None
106
- c = max(contours, key=cv2.contourArea)
107
- epsilon = epsilon_ratio * cv2.arcLength(c, True)
108
- polygon = cv2.approxPolyDP(c, epsilon, True)
109
- return polygon.reshape(-1, 2)
110
-
111
- def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray], vertex_color=(0,0,255), line_color=(0,255,0)):
112
- overlay = im.copy()
113
- if polygon is not None:
114
- # Draw polygon outline (thin)
115
- cv2.polylines(overlay, [polygon.astype(np.int32)], True, line_color, thickness=2)
116
-
117
- # Draw vertices
118
- for i, (x, y) in enumerate(polygon):
119
- cv2.circle(overlay, (int(x), int(y)), 4, vertex_color, -1)
120
- # Draw vertex index (black number)
121
- cv2.putText(overlay, str(i+1), (int(x)+5, int(y)-5),
122
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20,20,20), 1, cv2.LINE_AA)
123
-
124
- # Display vertex count on top
125
- vertex_count = len(polygon)
126
- cv2.putText(overlay, f"num_vertices = {vertex_count}", (20, 35),
127
- cv2.FONT_HERSHEY_SIMPLEX, 0.9, (20,20,20), 2, cv2.LINE_AA)
128
-
129
- return overlay
130
-
131
- # -------------------------------
132
- # API endpoints
133
- # -------------------------------
134
- @app.get("/")
135
- def root():
136
- return {"message": "Rooftop Segmentation API is running!"}
137
-
138
- @app.post("/predict")
139
- async def predict(
140
- file: UploadFile = File(...),
141
- rooftop_type: str = Form(...),
142
- epsilon: float = Form(0.004)
143
- ):
144
- contents = await file.read()
145
- try:
146
- im_pil = Image.open(io.BytesIO(contents)).convert("RGB")
147
- except Exception as e:
148
- return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)})
149
-
150
- im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR
151
-
152
- if rooftop_type.lower() == "rectangular":
153
- predictor = predictor_rect
154
- model_used = "model_rect_final.pth"
155
- elif rooftop_type.lower() == "irregular":
156
- predictor = predictor_irregular_flat
157
- model_used = "model_irregular_flat.pth"
158
- else:
159
- return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."})
160
-
161
- outputs = predictor(im)
162
- instances = outputs["instances"].to("cpu")
163
-
164
- if len(instances) == 0:
165
- return {
166
- "polygon": None,
167
- "vertices": None,
168
- "vertex_count": 0,
169
- "image": None,
170
- "model_used": model_used,
171
- "rooftop_type": rooftop_type,
172
- "epsilon": epsilon
173
- }
174
-
175
- idx = int(instances.scores.argmax().item())
176
- raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8)
177
-
178
- polygon = extract_polygon(raw_mask, epsilon_ratio=epsilon)
179
- vertex_count = len(polygon) if polygon is not None else 0
180
-
181
- overlay = overlay_polygon(im, polygon)
182
- img_b64 = im_to_b64_png(overlay)
183
-
184
- return {
185
- "polygon": polygon.tolist() if polygon is not None else None,
186
- "vertex_count": vertex_count,
187
- "image": img_b64,
188
- "model_used": model_used,
189
- "rooftop_type": rooftop_type,
190
- "epsilon": epsilon
191
- }
 
1
+ import io
2
+ import os
3
+ import gdown
4
+ import base64
5
+ import cv2
6
+ import numpy as np
7
+ from PIL import Image
8
+ from typing import Optional
9
+ from fastapi import FastAPI, UploadFile, File, Form
10
+ from fastapi.responses import JSONResponse
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from detectron2.engine import DefaultPredictor
13
+ from detectron2.config import get_cfg
14
+ from detectron2.projects.point_rend import add_pointrend_config
15
+
16
+ # -------------------------------
17
+ # FastAPI setup
18
+ # -------------------------------
19
+ app = FastAPI(title="Rooftop Segmentation API")
20
+
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"],
24
+ allow_credentials=True,
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # # -------------------------------
30
+ # # Available epsilons
31
+ # # -------------------------------
32
+ # EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001]
33
+
34
+ # @app.get("/epsilons")
35
+ # def get_epsilons():
36
+ # return {"epsilons": EPSILONS}
37
+
38
+ # # -------------------------------
39
+ # # Google Drive model download (irregular-flat)
40
+ # # -------------------------------
41
+ # MODEL_PATH_IRREGULAR = "/tmp/model_irregular_flat.pth"
42
+ # DRIVE_FILE_ID = "1GO_Ut-N2e2we8t9mnsysb0P1qMsBF8FW"
43
+
44
+ # def download_irregular_model():
45
+ # if not os.path.exists(MODEL_PATH_IRREGULAR):
46
+ # url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}"
47
+ # tmp_dir = "/tmp/gdown"
48
+ # os.makedirs(tmp_dir, exist_ok=True)
49
+ # os.environ["GDOWN_CACHE_DIR"] = tmp_dir
50
+ # print("Downloading irregular-flat Detectron2 model...")
51
+ # gdown.download(url, MODEL_PATH_IRREGULAR, quiet=False, fuzzy=True, use_cookies=False)
52
+ # print("Download complete.")
53
+ # else:
54
+ # print("Irregular-flat model already exists, skipping download.")
55
+
56
+ # download_irregular_model()
57
+
58
+ # if os.path.exists(MODEL_PATH_IRREGULAR):
59
+ # print("Irregular-flat model is ready at", MODEL_PATH_IRREGULAR)
60
+ # else:
61
+ # print("Irregular-flat model NOT found! Something went wrong!")
62
+
63
+ # # -------------------------------
64
+ # # Detectron2 model setup
65
+ # # -------------------------------
66
+ # def setup_model_rect(weights_path: str):
67
+ # cfg = get_cfg()
68
+ # add_pointrend_config(cfg)
69
+ # cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
70
+ # cfg.merge_from_file(cfg_path)
71
+ # cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
72
+ # cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
73
+ # cfg.MODEL.WEIGHTS = weights_path
74
+ # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
75
+ # cfg.MODEL.DEVICE = "cpu"
76
+ # return DefaultPredictor(cfg)
77
+
78
+ # def setup_model_irregular(weights_path: str):
79
+ # cfg = get_cfg()
80
+ # add_pointrend_config(cfg)
81
+ # cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
82
+ # cfg.merge_from_file(cfg_path)
83
+ # cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
84
+ # cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
85
+ # cfg.MODEL.WEIGHTS = weights_path
86
+ # cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
87
+ # cfg.MODEL.DEVICE = "cpu"
88
+ # return DefaultPredictor(cfg)
89
+
90
+ # # Load models
91
+ # predictor_rect = setup_model_rect("/app/model_rect_final.pth")
92
+ # predictor_irregular_flat = setup_model_irregular(MODEL_PATH_IRREGULAR)
93
+
94
+ # # -------------------------------
95
+ # # Utility functions
96
+ # # -------------------------------
97
+ # def im_to_b64_png(im: np.ndarray) -> str:
98
+ # _, buffer = cv2.imencode(".png", im)
99
+ # return base64.b64encode(buffer).decode()
100
+
101
+ # def extract_polygon(mask: np.ndarray, epsilon_ratio: float = 0.004):
102
+ # mask_uint8 = (mask * 255).astype(np.uint8)
103
+ # contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
104
+ # if not contours:
105
+ # return None
106
+ # c = max(contours, key=cv2.contourArea)
107
+ # epsilon = epsilon_ratio * cv2.arcLength(c, True)
108
+ # polygon = cv2.approxPolyDP(c, epsilon, True)
109
+ # return polygon.reshape(-1, 2)
110
+
111
+ # def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray], vertex_color=(0,0,255), line_color=(0,255,0)):
112
+ # overlay = im.copy()
113
+ # if polygon is not None:
114
+ # # Draw polygon outline (thin)
115
+ # cv2.polylines(overlay, [polygon.astype(np.int32)], True, line_color, thickness=2)
116
+
117
+ # # Draw vertices
118
+ # for i, (x, y) in enumerate(polygon):
119
+ # cv2.circle(overlay, (int(x), int(y)), 4, vertex_color, -1)
120
+ # # Draw vertex index (black number)
121
+ # cv2.putText(overlay, str(i+1), (int(x)+5, int(y)-5),
122
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20,20,20), 1, cv2.LINE_AA)
123
+
124
+ # # Display vertex count on top
125
+ # vertex_count = len(polygon)
126
+ # cv2.putText(overlay, f"num_vertices = {vertex_count}", (20, 35),
127
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.9, (20,20,20), 2, cv2.LINE_AA)
128
+
129
+ # return overlay
130
+
131
+ # -------------------------------
132
+ # API endpoints
133
+ # -------------------------------
134
+ @app.get("/")
135
+ def root():
136
+ return {"message": "Rooftop Segmentation API is running!"}
137
+
138
+ # @app.post("/predict")
139
+ # async def predict(
140
+ # file: UploadFile = File(...),
141
+ # rooftop_type: str = Form(...),
142
+ # epsilon: float = Form(0.004)
143
+ # ):
144
+ # contents = await file.read()
145
+ # try:
146
+ # im_pil = Image.open(io.BytesIO(contents)).convert("RGB")
147
+ # except Exception as e:
148
+ # return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)})
149
+
150
+ # im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR
151
+
152
+ # if rooftop_type.lower() == "rectangular":
153
+ # predictor = predictor_rect
154
+ # model_used = "model_rect_final.pth"
155
+ # elif rooftop_type.lower() == "irregular":
156
+ # predictor = predictor_irregular_flat
157
+ # model_used = "model_irregular_flat.pth"
158
+ # else:
159
+ # return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."})
160
+
161
+ # outputs = predictor(im)
162
+ # instances = outputs["instances"].to("cpu")
163
+
164
+ # if len(instances) == 0:
165
+ # return {
166
+ # "polygon": None,
167
+ # "vertices": None,
168
+ # "vertex_count": 0,
169
+ # "image": None,
170
+ # "model_used": model_used,
171
+ # "rooftop_type": rooftop_type,
172
+ # "epsilon": epsilon
173
+ # }
174
+
175
+ # idx = int(instances.scores.argmax().item())
176
+ # raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8)
177
+
178
+ # polygon = extract_polygon(raw_mask, epsilon_ratio=epsilon)
179
+ # vertex_count = len(polygon) if polygon is not None else 0
180
+
181
+ # overlay = overlay_polygon(im, polygon)
182
+ # img_b64 = im_to_b64_png(overlay)
183
+
184
+ # return {
185
+ # "polygon": polygon.tolist() if polygon is not None else None,
186
+ # "vertex_count": vertex_count,
187
+ # "image": img_b64,
188
+ # "model_used": model_used,
189
+ # "rooftop_type": rooftop_type,
190
+ # "epsilon": epsilon
191
+ # }