Spaces:
Sleeping
Sleeping
| import io | |
| import cv2 | |
| import torch | |
| import os | |
| import numpy as np | |
| import uvicorn | |
| from fastapi import FastAPI, UploadFile | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from detectron2.engine import DefaultPredictor | |
| from detectron2.config import get_cfg | |
| from detectron2 import model_zoo | |
| from detectron2.data import MetadataCatalog | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| from hydra import initialize | |
| from hydra.core.global_hydra import GlobalHydra | |
| # ------------------- | |
| # Init FastAPI | |
| # ------------------- | |
| app = FastAPI( | |
| title="Polygon Segmentation API", | |
| description="Mask R-CNN + SAM2 edge-refinement with polygon cleaning", | |
| version="2.0" | |
| ) | |
| # ------------------- | |
| # Health check | |
| # ------------------- | |
| def home(): | |
| return {"status": "running"} | |
| # ------------------- | |
| # Detectron2 setup | |
| # ------------------- | |
| det_cfg = get_cfg() | |
| det_cfg.merge_from_file( | |
| model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") | |
| ) | |
| det_cfg.MODEL.WEIGHTS = "/app/model_final.pth" # your trained weights path | |
| det_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
| det_cfg.MODEL.DEVICE = "cpu" # Hugging Face free tier is CPU only | |
| det_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 | |
| MetadataCatalog.get("__unused__").thing_classes = ["toproof"] | |
| predictor = DefaultPredictor(det_cfg) | |
| # ------------------- | |
| # SAM2 setup | |
| # ------------------- | |
| os.chdir("/app") | |
| if GlobalHydra.instance().is_initialized(): | |
| GlobalHydra.instance().clear() | |
| with initialize(version_base=None, config_path="."): | |
| sam2_model = build_sam2("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", device="cpu") | |
| sam2_predictor = SAM2ImagePredictor(sam2_model) | |
| # ------------------- | |
| # Polygonization helpers | |
| # ------------------- | |
| def _largest_contour(mask, use_chain_approx_none=True): | |
| mode = cv2.CHAIN_APPROX_NONE if use_chain_approx_none else cv2.CHAIN_APPROX_SIMPLE | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, mode) | |
| if not contours: | |
| return None | |
| return max(contours, key=cv2.contourArea) | |
| def _smooth_closed_poly(pts, k=7): | |
| if k < 3 or len(pts) < k: | |
| return pts.astype(np.float32) | |
| pts = pts.reshape(-1, 2).astype(np.float32) | |
| pad = k // 2 | |
| pts_pad = np.vstack([pts[-pad:], pts, pts[:pad]]) | |
| kernel = np.ones((k,), dtype=np.float32) / k | |
| xs = np.convolve(pts_pad[:, 0], kernel, mode="valid") | |
| ys = np.convolve(pts_pad[:, 1], kernel, mode="valid") | |
| smoothed = np.stack([xs, ys], axis=1) | |
| return smoothed.astype(np.float32) | |
| def _poly_mask_iou(mask, poly): | |
| h, w = mask.shape[:2] | |
| pm = np.zeros((h, w), dtype=np.uint8) | |
| cv2.fillPoly(pm, [poly.astype(np.int32).reshape(-1,1,2)], 1) | |
| m = (mask > 0).astype(np.uint8) | |
| inter = np.logical_and(pm, m).sum() | |
| union = np.logical_or(pm, m).sum() | |
| return float(inter) / float(union + 1e-6) | |
| def _min_area_rect_to_poly(cnt): | |
| rect = cv2.minAreaRect(cnt) | |
| box = cv2.boxPoints(rect) | |
| return box.astype(np.float32).reshape(-1,1,2) | |
| def mask_to_polygon_no_holes(mask, epsilon_factor=0.01, min_area=100): | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return None | |
| contour = max(contours, key=cv2.contourArea) | |
| if cv2.contourArea(contour) < min_area: | |
| return None | |
| epsilon = epsilon_factor * cv2.arcLength(contour, True) | |
| approx = cv2.approxPolyDP(contour, epsilon, True) | |
| return approx | |
| def clean_polygon_strict(mask, epsilon_factor=0.005, | |
| smooth_k=7, rect_iou_thresh=0.88, min_area=100): | |
| if mask.dtype != np.uint8: | |
| mask = mask.astype(np.uint8) | |
| bw = (mask > 127).astype(np.uint8) | |
| cnt = _largest_contour(bw, use_chain_approx_none=True) | |
| if cnt is None or len(cnt) < 4: | |
| return None | |
| if smooth_k and smooth_k >= 3: | |
| sm = _smooth_closed_poly(cnt, k=smooth_k) | |
| cnt_for_approx = sm.reshape(-1,1,2) | |
| else: | |
| cnt_for_approx = cnt.astype(np.float32) | |
| rect_poly = _min_area_rect_to_poly(cnt_for_approx) | |
| approx = mask_to_polygon_no_holes(bw, epsilon_factor=0.005, min_area=min_area) | |
| iou_rect = _poly_mask_iou(bw, rect_poly) | |
| iou_approx = _poly_mask_iou(bw, approx) if approx is not None else 0 | |
| if iou_rect >= rect_iou_thresh and (iou_rect + 0.01) >= iou_approx: | |
| best = rect_poly | |
| else: | |
| best = approx if approx is not None else rect_poly | |
| return best.astype(np.int32) | |
| # ------------------- | |
| # Core pipeline | |
| # ------------------- | |
| def refine_predictions(im, boxes, masks, sam2_predictor): | |
| sam2_predictor.set_image(im) | |
| refined_masks = [] | |
| for i, box in enumerate(boxes): | |
| mask_rcnn = (masks[i].astype(np.uint8) * 255) | |
| box_input = box[None, :] | |
| sam_masks, sam_scores, _ = sam2_predictor.predict( | |
| box=box_input, | |
| multimask_output=True | |
| ) | |
| best_index = np.argmax(sam_scores) | |
| sam_mask = (sam_masks[best_index].astype(np.uint8) * 255) | |
| sam_mask_clean = cv2.morphologyEx(sam_mask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8)) | |
| sam_mask_clean = cv2.GaussianBlur(sam_mask_clean, (3,3), 0) | |
| _, sam_mask_clean = cv2.threshold(sam_mask_clean, 127, 255, cv2.THRESH_BINARY) | |
| mask_rcnn_dilated = cv2.dilate(mask_rcnn, np.ones((5,5), np.uint8), iterations=1) | |
| combined = cv2.bitwise_and(mask_rcnn_dilated, sam_mask_clean) | |
| final_polygon = clean_polygon_strict(combined, | |
| epsilon_factor=0.005, | |
| smooth_k=7, | |
| rect_iou_thresh=0.88) | |
| final_polygons = [final_polygon] if final_polygon is not None else [] | |
| refined_masks.append((combined, final_polygons)) | |
| return refined_masks | |
| # ------------------- | |
| # API Endpoint | |
| # ------------------- | |
| async def get_polygon(file: UploadFile): | |
| """Accepts an image upload and returns polygonized result""" | |
| try: | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| im = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if im is None: | |
| return JSONResponse(content={"error": "Invalid image"}, status_code=400) | |
| outputs = predictor(im) | |
| instances = outputs["instances"].to("cpu") | |
| final_polygons_all = [] | |
| if instances.has("pred_boxes") and instances.has("pred_masks"): | |
| boxes = instances.pred_boxes.tensor.numpy() | |
| masks = instances.pred_masks.numpy() | |
| refined_masks = refine_predictions(im, boxes, masks, sam2_predictor) | |
| for combined, polygons in refined_masks: | |
| final_polygons_all.extend(polygons) | |
| result = im.copy() | |
| if final_polygons_all: | |
| cv2.polylines(result, final_polygons_all, isClosed=True, color=(0, 255, 0), thickness=2) | |
| _, encoded_img = cv2.imencode(".png", result) | |
| return StreamingResponse(io.BytesIO(encoded_img.tobytes()), media_type="image/png") | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |