Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import base64 | |
| from typing import Optional, List | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2 import model_zoo | |
| from detectron2.data import MetadataCatalog | |
| # ------------------- | |
| # Detectron2 / PointRend setup (multi-class) | |
| # ------------------- | |
| classes = ["Irregular-rooftops", "Rectangular-rooftops"] | |
| cfg = get_cfg() | |
| # PointRend config | |
| cfg.merge_from_file( | |
| "detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml" | |
| ) | |
| cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(classes) | |
| cfg.MODEL.POINT_HEAD.NUM_CLASSES = len(classes) | |
| cfg.MODEL.WEIGHTS = "model_final.pth" | |
| cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| cfg.MODEL.DEVICE = "cpu" # or "cuda" | |
| MetadataCatalog.get("__unused__").thing_classes = classes | |
| predictor = DefaultPredictor(cfg) | |
| # ------------------- | |
| # FastAPI setup | |
| # ------------------- | |
| app = FastAPI(title="PointRend Multi-Class Rooftop API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| # ------------------- | |
| # Post-processing / helpers | |
| # ------------------- | |
| def postprocess_simplified(mask: np.ndarray, epsilon_factor: float = 0.01) -> np.ndarray: | |
| if mask is None: | |
| return None | |
| mask_uint8 = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8) | |
| bw = (mask_uint8 > 127).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| simp = np.zeros_like(bw) | |
| if not contours: | |
| return simp | |
| c = max(contours, key=cv2.contourArea) | |
| epsilon = epsilon_factor * cv2.arcLength(c, True) | |
| approx = cv2.approxPolyDP(c, epsilon, True) | |
| cv2.fillPoly(simp, [approx], 255) | |
| return simp | |
| def mask_to_polygon(mask: np.ndarray, epsilon_factor: float = 0.01, min_area: int = 150) -> Optional[np.ndarray]: | |
| if mask is None: | |
| return None | |
| mask_uint8 = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8) | |
| bw = (mask_uint8 > 127).astype(np.uint8) * 255 | |
| contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| contour = max(contours, key=cv2.contourArea) | |
| if cv2.contourArea(contour) < min_area: | |
| return None | |
| epsilon = epsilon_factor * cv2.arcLength(contour, True) | |
| approx = cv2.approxPolyDP(contour, epsilon, True) | |
| return approx.reshape(-1, 2) | |
| def im_to_b64_png(im: np.ndarray) -> str: | |
| ok, buf = cv2.imencode(".png", im) | |
| if not ok: | |
| raise RuntimeError("Failed to encode image") | |
| return base64.b64encode(buf).decode("utf-8") | |
| # ------------------- | |
| # Prediction endpoint (multi-class) | |
| # ------------------- | |
| async def predict_endpoint(file: UploadFile = File(...)): | |
| contents = await file.read() | |
| try: | |
| im_pil = Image.open(io.BytesIO(contents)).convert("RGB") | |
| except Exception as e: | |
| return JSONResponse(status_code=400, content={"error": "Invalid image", "detail": str(e)}) | |
| im = np.array(im_pil)[:, :, ::-1].copy() # RGB->BGR | |
| outputs = predictor(im) | |
| instances = outputs["instances"].to("cpu") | |
| if len(instances) == 0: | |
| return JSONResponse(content={"instances": []}) | |
| results = [] | |
| for idx in range(len(instances)): | |
| cls_id = int(instances.pred_classes[idx]) | |
| cls_name = classes[cls_id] | |
| raw_mask = instances.pred_masks[idx].numpy().astype(np.uint8) | |
| simp_mask = postprocess_simplified(raw_mask, epsilon_factor=0.01) | |
| poly = mask_to_polygon(simp_mask, epsilon_factor=0.01, min_area=150) | |
| overlay = im.copy() | |
| if poly is not None: | |
| cv2.polylines(overlay, [poly.astype(np.int32)], True, (0, 0, 255), 2) | |
| poly_list = poly.tolist() | |
| else: | |
| poly_list = None | |
| results.append({ | |
| "class_id": cls_id, | |
| "class_name": cls_name, | |
| "polygon": poly_list, | |
| "overlay_image": im_to_b64_png(overlay), | |
| "mask_image": im_to_b64_png(cv2.cvtColor(simp_mask, cv2.COLOR_GRAY2BGR)) | |
| }) | |
| return {"instances": results} | |
| # ------------------- | |
| # Run API | |
| # ------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app_pointrend_multiclass:app", host="0.0.0.0", port=8000, reload=False) | |