import io import base64 from typing import Optional import cv2 import numpy as np from PIL import Image from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from detectron2.engine import DefaultPredictor from detectron2.config import get_cfg from detectron2.projects.point_rend import add_pointrend_config # ------------------------------- # FastAPI Setup # ------------------------------- app = FastAPI(title="Rooftop Segmentation API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------------------- # Detectron2 Config + Predictor # ------------------------------- cfg = get_cfg() add_pointrend_config(cfg) pointrend_cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" cfg.merge_from_file(pointrend_cfg_path) cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # 0: rectangular, 1: irregular cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES cfg.MODEL.WEIGHTS = "/app/model_final.pth" # Path on Hugging Face cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 cfg.MODEL.DEVICE = "cpu" predictor = DefaultPredictor(cfg) # ------------------------------- # POSTPROCESSING FUNCTIONS # ------------------------------- def postprocess_rectangular(mask: np.ndarray) -> Optional[np.ndarray]: """Postprocess for rectangular rooftops (simple, clean edges).""" if mask is None: return None mask_uint8 = (mask * 255).astype(np.uint8) contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) simp = np.zeros_like(mask_uint8) if contours: c = max(contours, key=cv2.contourArea) epsilon = 0.01 * cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, epsilon, True) cv2.fillPoly(simp, [approx], 255) return simp def morphological_open(mask, kernel_size=20, iterations=1): """Apply morphological open to clean irregular rooftop mask.""" kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size)) opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=iterations) return opened def postprocess_irregular(mask: np.ndarray, epsilon_ratio: float = 0.004) -> Optional[np.ndarray]: """Postprocess for irregular rooftops (morphology + simplify).""" if mask is None: return None mask_clean = morphological_open(mask) mask_uint8 = (mask_clean * 255).astype(np.uint8) contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) simp = np.zeros_like(mask_uint8) if contours: c = max(contours, key=cv2.contourArea) epsilon = epsilon_ratio * cv2.arcLength(c, True) approx = cv2.approxPolyDP(c, epsilon, True) cv2.fillPoly(simp, [approx], 255) return simp def mask_to_polygon(mask: np.ndarray) -> Optional[np.ndarray]: """Convert simplified mask to polygon coordinates.""" contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return None largest = max(contours, key=cv2.contourArea) return largest.reshape(-1, 2) def im_to_b64_png(im: np.ndarray) -> str: """Convert BGR image to base64 PNG.""" _, buffer = cv2.imencode(".png", im) return base64.b64encode(buffer).decode() # ------------------------------- # API ENDPOINTS # ------------------------------- @app.get("/") def root(): return {"message": "Rooftop Segmentation API is running!"} @app.post("/polygon") async def polygon_endpoint(file: UploadFile = File(...)): contents = await file.read() try: im_pil = Image.open(io.BytesIO(contents)).convert("RGB") except Exception as e: return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)}) im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR outputs = predictor(im) instances = outputs["instances"].to("cpu") if len(instances) == 0: return {"chosen": None, "polygon": None, "image": None} # Take the highest-score instance idx = int(instances.scores.argmax().item()) raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8) cls_id = int(instances.pred_classes[idx].item()) # Choose class name class_names = {0: "irregular", 1: "rectangular"} chosen_class = class_names.get(cls_id, "unknown") # --- Choose postprocessing dynamically --- if chosen_class == "rectangular": simp_mask = postprocess_rectangular(raw_mask) else: simp_mask = postprocess_irregular(raw_mask) poly = mask_to_polygon(simp_mask) # --- Visualization --- overlay = im.copy() poly_list = None if poly is not None: cv2.polylines(overlay, [poly.astype(np.int32)], True, (0, 0, 255), 2) poly_list = poly.tolist() img_b64 = im_to_b64_png(overlay) return { "chosen_class": chosen_class, "polygon": poly_list, "image": img_b64, }