Spaces:
Build error
Build error
| import io | |
| import os | |
| import gdown | |
| import base64 | |
| from typing import Optional | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2.projects.point_rend import add_pointrend_config | |
| # ------------------------------- | |
| # FastAPI setup | |
| # ------------------------------- | |
| app = FastAPI(title="Rooftop Segmentation API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------------- | |
| # Available epsilons | |
| # ------------------------------- | |
| EPSILONS = [0.01, 0.005, 0.004, 0.003, 0.001] | |
| def get_epsilons(): | |
| return {"epsilons": EPSILONS} | |
| # ------------------------------- | |
| # Google Drive model download (irregular-flat) | |
| # ------------------------------- | |
| MODEL_PATH_IRREGULAR = "/tmp/model_irregular_flat.pth" | |
| DRIVE_FILE_ID = "15vi4zPhCs3aBnGepVnXFOqQjxdK1jpnA" | |
| def download_irregular_model(): | |
| if not os.path.exists(MODEL_PATH_IRREGULAR): | |
| url = f"https://drive.google.com/uc?id={DRIVE_FILE_ID}" | |
| tmp_dir = "/tmp/gdown" | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| os.environ["GDOWN_CACHE_DIR"] = tmp_dir | |
| print("Downloading irregular-flat Detectron2 model...") | |
| gdown.download(url, MODEL_PATH_IRREGULAR, quiet=False, fuzzy=True, use_cookies=False) | |
| print("Download complete.") | |
| else: | |
| print("Irregular-flat model already exists, skipping download.") | |
| download_irregular_model() | |
| if os.path.exists(MODEL_PATH_IRREGULAR): | |
| print("Irregular-flat model is ready at", MODEL_PATH_IRREGULAR) | |
| else: | |
| print("Irregular-flat model NOT found! Something went wrong!") | |
| # ------------------------------- | |
| # Detectron2 model setup | |
| # ------------------------------- | |
| def setup_model_rect(weights_path: str): | |
| cfg = get_cfg() | |
| add_pointrend_config(cfg) | |
| cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" | |
| cfg.merge_from_file(cfg_path) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 | |
| cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES | |
| cfg.MODEL.WEIGHTS = weights_path | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.DEVICE = "cpu" | |
| return DefaultPredictor(cfg) | |
| def setup_model_irregular(weights_path: str): | |
| cfg = get_cfg() | |
| add_pointrend_config(cfg) | |
| cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" | |
| cfg.merge_from_file(cfg_path) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | |
| cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES | |
| cfg.MODEL.WEIGHTS = weights_path | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.DEVICE = "cpu" | |
| return DefaultPredictor(cfg) | |
| # Load models | |
| predictor_rect = setup_model_rect("/app/model_rect_final.pth") | |
| predictor_irregular_flat = setup_model_irregular(MODEL_PATH_IRREGULAR) | |
| # ------------------------------- | |
| # Post-processing functions | |
| # ------------------------------- | |
| def postprocess_rect(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| c = max(contours, key=cv2.contourArea) | |
| eps = epsilon * cv2.arcLength(c, True) | |
| approx = cv2.approxPolyDP(c, eps, True) | |
| simp = np.zeros_like(mask_uint8) | |
| cv2.fillPoly(simp, [approx], 255) | |
| return simp | |
| def postprocess_irregular(mask: np.ndarray, epsilon: float) -> Optional[np.ndarray]: | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| c = max(contours, key=cv2.contourArea) | |
| eps = epsilon * cv2.arcLength(c, True) | |
| polygon = cv2.approxPolyDP(c, eps, True) | |
| return polygon.reshape(-1, 2) | |
| def mask_to_polygon(mask: np.ndarray) -> Optional[np.ndarray]: | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| largest = max(contours, key=cv2.contourArea) | |
| return largest.reshape(-1, 2) | |
| def extract_polygon_vertices(mask: np.ndarray, epsilon_ratio: float = 0.004): | |
| """ | |
| Extract clean polygon vertices from a binary mask. | |
| Returns Nx2 array of vertices. | |
| """ | |
| mask_uint8 = (mask * 255).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| c = max(contours, key=cv2.contourArea) | |
| epsilon = epsilon_ratio * cv2.arcLength(c, True) | |
| polygon = cv2.approxPolyDP(c, epsilon, True) | |
| return polygon.reshape(-1, 2) | |
| def im_to_b64_png(im: np.ndarray) -> str: | |
| _, buffer = cv2.imencode(".png", im) | |
| return base64.b64encode(buffer).decode() | |
| def overlay_polygon(im: np.ndarray, polygon: Optional[np.ndarray]) -> np.ndarray: | |
| overlay = im.copy() | |
| if polygon is not None: | |
| # Draw polygon outline | |
| cv2.polylines(overlay, [polygon.astype(np.int32)], True, (0, 255, 0), 2) | |
| # Draw vertex points (red circles) | |
| for (x, y) in polygon: | |
| cv2.circle(overlay, (int(x), int(y)), radius=4, color=(0, 0, 255), thickness=-1) | |
| return overlay | |
| # ------------------------------- | |
| # API endpoints | |
| # ------------------------------- | |
| def root(): | |
| return {"message": "Rooftop Segmentation API is running!"} | |
| async def predict( | |
| file: UploadFile = File(...), | |
| rooftop_type: str = Form(...), | |
| epsilon: float = Form(0.004) | |
| ): | |
| contents = await file.read() | |
| try: | |
| im_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception as e: | |
| return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)}) | |
| im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR | |
| if rooftop_type.lower() == "rectangular": | |
| predictor = predictor_rect | |
| post_fn = lambda mask: postprocess_rect(mask, epsilon) | |
| model_used = "model_rect_final.pth" | |
| elif rooftop_type.lower() == "irregular": | |
| predictor = predictor_irregular_flat | |
| post_fn = lambda mask: postprocess_irregular(mask, epsilon) | |
| model_used = "model_irregular_flat.pth" | |
| else: | |
| return JSONResponse(status_code=400, content={"error": "Invalid rooftop_type. Choose 'rectangular' or 'irregular'."}) | |
| outputs = predictor(im) | |
| instances = outputs["instances"].to("cpu") | |
| if len(instances) == 0: | |
| return { | |
| "polygon": None, | |
| "vertices": None, | |
| "vertex_count": 0, | |
| "image": None, | |
| "model_used": model_used, | |
| "rooftop_type": rooftop_type, | |
| "epsilon": epsilon | |
| } | |
| idx = int(instances.scores.argmax().item()) | |
| raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8) | |
| result_mask = post_fn(raw_mask) | |
| polygon = mask_to_polygon(result_mask) if rooftop_type.lower() == "rectangular" else result_mask | |
| # --- Vertices extraction --- | |
| # vertices = extract_polygon_vertices(raw_mask, epsilon) | |
| # vertex_count = len(vertices) if vertices is not None else 0 | |
| overlay = overlay_polygon(im, polygon) | |
| img_b64 = im_to_b64_png(overlay) | |
| return { | |
| "polygon": polygon.tolist() if polygon is not None else None, | |
| "image": img_b64, | |
| "model_used": model_used, | |
| "rooftop_type": rooftop_type, | |
| "epsilon": epsilon | |
| } | |