| | import io |
| | import os |
| | import gdown |
| | import base64 |
| | from typing import Optional |
| | import cv2 |
| | import numpy as np |
| | from PIL import Image |
| | from fastapi import FastAPI, UploadFile, File, Form |
| | from fastapi.responses import JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from detectron2.engine import DefaultPredictor |
| | from detectron2.config import get_cfg |
| | from detectron2.projects.point_rend import add_pointrend_config |
| |
|
| | |
| | |
| | |
| | app = FastAPI(title="Rooftop Segmentation API") |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | |
| | |
| | EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001] |
| |
|
| | @app.get("/epsilons") |
| | def get_epsilons(): |
| | return {"epsilons": EPSILONS} |
| |
|
| | |
| | |
| | |
| | MODEL_PATH_IRREGULAR = "/tmp/model_irregular_flat.pth" |
| | DRIVE_FILE_ID = "15vi4zPhCs3aBnGepVnXFOqQjxdK1jpnA" |
| |
|
| | def download_irregular_model(): |
| | if not os.path.exists(MODEL_PATH_IRREGULAR): |
| | url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}" |
| |
|
| | tmp_dir = "/tmp/gdown" |
| | os.makedirs(tmp_dir, exist_ok=True) |
| |
|
| | os.environ["GDOWN_CACHE_DIR"] = tmp_dir |
| |
|
| | print("Downloading irregular-flat Detectron2 model...") |
| | gdown.download( |
| | url, |
| | MODEL_PATH_IRREGULAR, |
| | quiet=False, |
| | fuzzy=True, |
| | use_cookies=False |
| | ) |
| | print("Download complete.") |
| | else: |
| | print("Irregular-flat model already exists, skipping download.") |
| |
|
| | |
| | download_irregular_model() |
| |
|
| | |
| | |
| | |
| | if os.path.exists(MODEL_PATH_IRREGULAR): |
| | print("Irregular-flat model is ready at", MODEL_PATH_IRREGULAR) |
| | else: |
| | print("Irregular-flat model NOT found! Something went wrong!") |
| |
|
| | |
| | |
| | |
| | def setup_model_rect(weights_path: str): |
| | cfg = get_cfg() |
| | add_pointrend_config(cfg) |
| | cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" |
| | cfg.merge_from_file(cfg_path) |
| | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 |
| | cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES |
| | cfg.MODEL.WEIGHTS = weights_path |
| | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
| | cfg.MODEL.DEVICE = "cpu" |
| | return DefaultPredictor(cfg) |
| |
|
| | def setup_model_irregular(weights_path: str): |
| | cfg = get_cfg() |
| | add_pointrend_config(cfg) |
| | cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" |
| | cfg.merge_from_file(cfg_path) |
| | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 |
| | cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES |
| | cfg.MODEL.WEIGHTS = weights_path |
| | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 |
| | cfg.MODEL.DEVICE = "cpu" |
| | return DefaultPredictor(cfg) |
| |
|
| | |
| | predictor_rect = setup_model_rect("/app/model_rect_final.pth") |
| | predictor_irregular_flat = setup_model_irregular(MODEL_PATH_IRREGULAR) |
| |
|
| | |
| | |
| | |
| | def postprocess_rect(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: |
| | mask_uint8 = (mask * 255).astype(np.uint8) |
| | contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | if not contours: |
| | return None |
| | c = max(contours, key=cv2.contourArea) |
| | eps = epsilon * cv2.arcLength(c, True) |
| | approx = cv2.approxPolyDP(c, eps, True) |
| | simp = np.zeros_like(mask_uint8) |
| | cv2.fillPoly(simp, [approx], 255) |
| | return simp |
| |
|
| | def postprocess_irregular(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: |
| | mask_uint8 = (mask * 255).astype(np.uint8) |
| | contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | if not contours: |
| | return None |
| | c = max(contours, key=cv2.contourArea) |
| | eps = epsilon * cv2.arcLength(c, True) |
| | polygon = cv2.approxPolyDP(c, eps, True) |
| | return polygon.reshape(-1, 2) |
| |
|
| | def mask_to_polygon(mask: np.ndarray) -> Optional[np.ndarray]: |
| | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | if not contours: |
| | return None |
| | largest = max(contours, key=cv2.contourArea) |
| | return largest.reshape(-1, 2) |
| |
|
| | def im_to_b64_png(im: np.ndarray) -> str: |
| | _, buffer = cv2.imencode(".png", im) |
| | return base64.b64encode(buffer).decode() |
| |
|
| | def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray]) -> np.ndarray: |
| | overlay = im.copy() |
| | if polygon is not None: |
| | cv2.polylines(overlay, [polygon.astype(np.int32)], True, (0,0,255), 2) |
| | return overlay |
| |
|
| | |
| | |
| | |
| | @app.get("/") |
| | def root(): |
| | return {"message": "Rooftop Segmentation API is running!"} |
| |
|
| | @app.post("/predict") |
| | async def predict( |
| | file: UploadFile = File(...), |
| | rooftop_type: str = Form(...), |
| | epsilon: float = Form(0.004) |
| | ): |
| | contents = await file.read() |
| | try: |
| | im_pil = Image.open(io.BytesIO(contents)).convert("RGB") |
| | except Exception as e: |
| | return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)}) |
| |
|
| | im = np.array(im_pil)[:, :, ::-1].copy() |
| |
|
| | if rooftop_type.lower() == "rectangular": |
| | predictor = predictor_rect |
| | post_fn = lambda mask: postprocess_rect(mask, epsilon) |
| | model_used = "model_rect_final.pth" |
| | elif rooftop_type.lower() == "irregular": |
| | predictor = predictor_irregular_flat |
| | post_fn = lambda mask: postprocess_irregular(mask, epsilon) |
| | model_used = "model_irregular_flat.pth" |
| | else: |
| | return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."}) |
| |
|
| | outputs = predictor(im) |
| | instances = outputs["instances"].to("cpu") |
| |
|
| | if len(instances) == 0: |
| | return {"polygon": None, "image": None, "model_used": model_used, "rooftop_type": rooftop_type, "epsilon": epsilon} |
| |
|
| | idx = int(instances.scores.argmax().item()) |
| | raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8) |
| |
|
| | result_mask = post_fn(raw_mask) |
| | polygon = mask_to_polygon(result_mask) if rooftop_type.lower() == "rectangular" else result_mask |
| |
|
| | overlay = overlay_polygon(im, polygon) |
| | img_b64 = im_to_b64_png(overlay) |
| |
|
| | return { |
| | "polygon": polygon.tolist() if polygon is not None else None, |
| | "image": img_b64, |
| | "model_used": model_used, |
| | "rooftop_type": rooftop_type, |
| | "epsilon": epsilon |
| | } |
| |
|