Noursine's picture
Update app2.py
ea63f49 verified
import io
import cv2
import torch
import os
import numpy as np
import uvicorn
from fastapi import FastAPI, UploadFile
from fastapi.responses import StreamingResponse, JSONResponse
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize
from hydra.core.global_hydra import GlobalHydra
# -------------------
# Init FastAPI
# -------------------
app = FastAPI(
title="Polygon Segmentation API",
description="Mask R-CNN + SAM2 edge-refinement with polygon cleaning",
version="2.0"
)
# -------------------
# Health check
# -------------------
@app.get("/")
def home():
return {"status": "running"}
# -------------------
# Detectron2 setup
# -------------------
det_cfg = get_cfg()
det_cfg.merge_from_file(
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
)
det_cfg.MODEL.WEIGHTS = "/app/model_final.pth" # your trained weights path
det_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
det_cfg.MODEL.DEVICE = "cpu" # Hugging Face free tier is CPU only
det_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
MetadataCatalog.get("__unused__").thing_classes = ["toproof"]
predictor = DefaultPredictor(det_cfg)
# -------------------
# SAM2 setup
# -------------------
os.chdir("/app")
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
with initialize(version_base=None, config_path="."):
sam2_model = build_sam2("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", device="cpu")
sam2_predictor = SAM2ImagePredictor(sam2_model)
# -------------------
# Polygonization helpers
# -------------------
def _largest_contour(mask, use_chain_approx_none=True):
mode = cv2.CHAIN_APPROX_NONE if use_chain_approx_none else cv2.CHAIN_APPROX_SIMPLE
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, mode)
if not contours:
return None
return max(contours, key=cv2.contourArea)
def _smooth_closed_poly(pts, k=7):
if k < 3 or len(pts) < k:
return pts.astype(np.float32)
pts = pts.reshape(-1, 2).astype(np.float32)
pad = k // 2
pts_pad = np.vstack([pts[-pad:], pts, pts[:pad]])
kernel = np.ones((k,), dtype=np.float32) / k
xs = np.convolve(pts_pad[:, 0], kernel, mode="valid")
ys = np.convolve(pts_pad[:, 1], kernel, mode="valid")
smoothed = np.stack([xs, ys], axis=1)
return smoothed.astype(np.float32)
def _poly_mask_iou(mask, poly):
h, w = mask.shape[:2]
pm = np.zeros((h, w), dtype=np.uint8)
cv2.fillPoly(pm, [poly.astype(np.int32).reshape(-1,1,2)], 1)
m = (mask > 0).astype(np.uint8)
inter = np.logical_and(pm, m).sum()
union = np.logical_or(pm, m).sum()
return float(inter) / float(union + 1e-6)
def _min_area_rect_to_poly(cnt):
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
return box.astype(np.float32).reshape(-1,1,2)
def mask_to_polygon_no_holes(mask, epsilon_factor=0.01, min_area=100):
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
contour = max(contours, key=cv2.contourArea)
if cv2.contourArea(contour) < min_area:
return None
epsilon = epsilon_factor * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
return approx
def clean_polygon_strict(mask, epsilon_factor=0.005,
smooth_k=7, rect_iou_thresh=0.88, min_area=100):
if mask.dtype != np.uint8:
mask = mask.astype(np.uint8)
bw = (mask > 127).astype(np.uint8)
cnt = _largest_contour(bw, use_chain_approx_none=True)
if cnt is None or len(cnt) < 4:
return None
if smooth_k and smooth_k >= 3:
sm = _smooth_closed_poly(cnt, k=smooth_k)
cnt_for_approx = sm.reshape(-1,1,2)
else:
cnt_for_approx = cnt.astype(np.float32)
rect_poly = _min_area_rect_to_poly(cnt_for_approx)
approx = mask_to_polygon_no_holes(bw, epsilon_factor=0.005, min_area=min_area)
iou_rect = _poly_mask_iou(bw, rect_poly)
iou_approx = _poly_mask_iou(bw, approx) if approx is not None else 0
if iou_rect >= rect_iou_thresh and (iou_rect + 0.01) >= iou_approx:
best = rect_poly
else:
best = approx if approx is not None else rect_poly
return best.astype(np.int32)
# -------------------
# Core pipeline
# -------------------
def refine_predictions(im, boxes, masks, sam2_predictor):
sam2_predictor.set_image(im)
refined_masks = []
for i, box in enumerate(boxes):
mask_rcnn = (masks[i].astype(np.uint8) * 255)
box_input = box[None, :]
sam_masks, sam_scores, _ = sam2_predictor.predict(
box=box_input,
multimask_output=True
)
best_index = np.argmax(sam_scores)
sam_mask = (sam_masks[best_index].astype(np.uint8) * 255)
sam_mask_clean = cv2.morphologyEx(sam_mask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8))
sam_mask_clean = cv2.GaussianBlur(sam_mask_clean, (3,3), 0)
_, sam_mask_clean = cv2.threshold(sam_mask_clean, 127, 255, cv2.THRESH_BINARY)
mask_rcnn_dilated = cv2.dilate(mask_rcnn, np.ones((5,5), np.uint8), iterations=1)
combined = cv2.bitwise_and(mask_rcnn_dilated, sam_mask_clean)
final_polygon = clean_polygon_strict(combined,
epsilon_factor=0.005,
smooth_k=7,
rect_iou_thresh=0.88)
final_polygons = [final_polygon] if final_polygon is not None else []
refined_masks.append((combined, final_polygons))
return refined_masks
# -------------------
# API Endpoint
# -------------------
@app.post("/predict")
async def get_polygon(file: UploadFile):
"""Accepts an image upload and returns polygonized result"""
try:
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
im = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
if im is None:
return JSONResponse(content={"error": "Invalid image"}, status_code=400)
outputs = predictor(im)
instances = outputs["instances"].to("cpu")
final_polygons_all = []
if instances.has("pred_boxes") and instances.has("pred_masks"):
boxes = instances.pred_boxes.tensor.numpy()
masks = instances.pred_masks.numpy()
refined_masks = refine_predictions(im, boxes, masks, sam2_predictor)
for combined, polygons in refined_masks:
final_polygons_all.extend(polygons)
result = im.copy()
if final_polygons_all:
cv2.polylines(result, final_polygons_all, isClosed=True, color=(0, 255, 0), thickness=2)
_, encoded_img = cv2.imencode(".png", result)
return StreamingResponse(io.BytesIO(encoded_img.tobytes()), media_type="image/png")
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)