Final-rooftop-detection / Final-app.py
Noursine's picture
Update Final-app.py
d0e7fa4 verified
import os
import io
import base64
from typing import Optional, List
import cv2
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog
# -------------------
# Detectron2 / PointRend setup (multi-class)
# -------------------
classes = ["Irregular-rooftops", "Rectangular-rooftops"]
cfg = get_cfg()
# PointRend config
cfg.merge_from_file(
"detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes)
cfg.MODEL.POINT_HEAD.NUM_CLASSES = len(classes)
cfg.MODEL.WEIGHTS = "model_final.pth"
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.MODEL.DEVICE = "cpu" # or "cuda"
MetadataCatalog.get("__unused__").thing_classes = classes
predictor = DefaultPredictor(cfg)
# -------------------
# FastAPI setup
# -------------------
app = FastAPI(title="PointRend Multi-Class Rooftop API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
# -------------------
# Post-processing / helpers
# -------------------
def postprocess_simplified(mask: np.ndarray, epsilon_factor: float = 0.01) -> np.ndarray:
if mask is None:
return None
mask_uint8 = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
bw = (mask_uint8 > 127).astype(np.uint8) * 255
contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
simp = np.zeros_like(bw)
if not contours:
return simp
c = max(contours, key=cv2.contourArea)
epsilon = epsilon_factor * cv2.arcLength(c, True)
approx = cv2.approxPolyDP(c, epsilon, True)
cv2.fillPoly(simp, [approx], 255)
return simp
def mask_to_polygon(mask: np.ndarray, epsilon_factor: float = 0.01, min_area: int = 150) -> Optional[np.ndarray]:
if mask is None:
return None
mask_uint8 = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
bw = (mask_uint8 > 127).astype(np.uint8) * 255
contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
contour = max(contours, key=cv2.contourArea)
if cv2.contourArea(contour) < min_area:
return None
epsilon = epsilon_factor * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
return approx.reshape(-1, 2)
def im_to_b64_png(im: np.ndarray) -> str:
ok, buf = cv2.imencode(".png", im)
if not ok:
raise RuntimeError("Failed to encode image")
return base64.b64encode(buf).decode("utf-8")
# -------------------
# Prediction endpoint (multi-class)
# -------------------
@app.post("/predict")
async def predict_endpoint(file: UploadFile = File(...)):
contents = await file.read()
try:
im_pil = Image.open(io.BytesIO(contents)).convert("RGB")
except Exception as e:
return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)})
im = np.array(im_pil)[:, :, ::-1].copy() # RGB->BGR
outputs = predictor(im)
instances = outputs["instances"].to("cpu")
if len(instances) == 0:
return JSONResponse(content={"instances": []})
results = []
for idx in range(len(instances)):
cls_id = int(instances.pred_classes[idx])
cls_name = classes[cls_id]
raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8)
simp_mask = postprocess_simplified(raw_mask, epsilon_factor=0.01)
poly = mask_to_polygon(simp_mask, epsilon_factor=0.01, min_area=150)
overlay = im.copy()
if poly is not None:
cv2.polylines(overlay, [poly.astype(np.int32)], True, (0, 0, 255), 2)
poly_list = poly.tolist()
else:
poly_list = None
results.append({
"class_id": cls_id,
"class_name": cls_name,
"polygon": poly_list,
"overlay_image": im_to_b64_png(overlay),
"mask_image": im_to_b64_png(cv2.cvtColor(simp_mask, cv2.COLOR_GRAY2BGR))
})
return {"instances": results}
# -------------------
# Run API
# -------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app_pointrend_multiclass:app", host="0.0.0.0", port=8000, reload=False)