Noursine's picture
Update app.py
44b7142 verified
import io
import base64
from typing import Optional
import cv2
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.projects.point_rend import add_pointrend_config
# -------------------------------
# FastAPI Setup
# -------------------------------
app = FastAPI(title="Rooftop Segmentation API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# -------------------------------
# Detectron2 Config + Predictor
# -------------------------------
cfg = get_cfg()
add_pointrend_config(cfg)
pointrend_cfg_path = "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
cfg.merge_from_file(pointrend_cfg_path)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 # 0: rectangular, 1: irregular
cfg.MODEL.POINT_HEAD.NUM_CLASSES = cfg.MODEL.ROI_HEADS.NUM_CLASSES
cfg.MODEL.WEIGHTS = "/app/model_final.pth" # Path on Hugging Face
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.DEVICE = "cpu"
predictor = DefaultPredictor(cfg)
# -------------------------------
# POSTPROCESSING FUNCTIONS
# -------------------------------
def postprocess_rectangular(mask: np.ndarray) -> Optional[np.ndarray]:
"""Postprocess for rectangular rooftops (simple, clean edges)."""
if mask is None:
return None
mask_uint8 = (mask * 255).astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
simp = np.zeros_like(mask_uint8)
if contours:
c = max(contours, key=cv2.contourArea)
epsilon = 0.01 * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, epsilon, True)
cv2.fillPoly(simp, [approx], 255)
return simp
def morphological_open(mask, kernel_size=20, iterations=1):
"""Apply morphological open to clean irregular rooftop mask."""
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_size, kernel_size))
opened = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=iterations)
return opened
def postprocess_irregular(mask: np.ndarray, epsilon_ratio: float = 0.004) -> Optional[np.ndarray]:
"""Postprocess for irregular rooftops (morphology + simplify)."""
if mask is None:
return None
mask_clean = morphological_open(mask)
mask_uint8 = (mask_clean * 255).astype(np.uint8)
contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
simp = np.zeros_like(mask_uint8)
if contours:
c = max(contours, key=cv2.contourArea)
epsilon = epsilon_ratio * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, epsilon, True)
cv2.fillPoly(simp, [approx], 255)
return simp
def mask_to_polygon(mask: np.ndarray) -> Optional[np.ndarray]:
"""Convert simplified mask to polygon coordinates."""
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
largest = max(contours, key=cv2.contourArea)
return largest.reshape(-1, 2)
def im_to_b64_png(im: np.ndarray) -> str:
"""Convert BGR image to base64 PNG."""
_, buffer = cv2.imencode(".png", im)
return base64.b64encode(buffer).decode()
# -------------------------------
# API ENDPOINTS
# -------------------------------
@app.get("/")
def root():
return {"message": "Rooftop Segmentation API is running!"}
@app.post("/polygon")
async def polygon_endpoint(file: UploadFile = File(...)):
contents = await file.read()
try:
im_pil = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)})
im = np.array(im_pil)[:, :, ::-1].copy() # RGB -> BGR
outputs = predictor(im)
instances = outputs["instances"].to("cpu")
if len(instances) == 0:
return {"chosen": None, "polygon": None, "image": None}
# Take the highest-score instance
idx = int(instances.scores.argmax().item())
raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8)
cls_id = int(instances.pred_classes[idx].item())
# Choose class name
class_names = {0: "irregular", 1: "rectangular"}
chosen_class = class_names.get(cls_id, "unknown")
# --- Choose postprocessing dynamically ---
if chosen_class == "rectangular":
simp_mask = postprocess_rectangular(raw_mask)
else:
simp_mask = postprocess_irregular(raw_mask)
poly = mask_to_polygon(simp_mask)
# --- Visualization ---
overlay = im.copy()
poly_list = None
if poly is not None:
cv2.polylines(overlay, [poly.astype(np.int32)], True, (0, 0, 255), 2)
poly_list = poly.tolist()
img_b64 = im_to_b64_png(overlay)
return {
"chosen_class": chosen_class,
"polygon": poly_list,
"image": img_b64,
}