Noursine's picture
Update app4.py
4cdf8a4 verified
import base64
import io
import cv2
import numpy as np
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import torch
import os
import uvicorn
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2 import model_zoo
from detectron2.data import MetadataCatalog
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from hydra import initialize
from hydra.core.global_hydra import GlobalHydra
# -------------------
# Detectron2 setup
# -------------------
det_cfg = get_cfg()
det_cfg.merge_from_file(
model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
)
det_cfg.MODEL.WEIGHTS = "/app/model_final2.pth" # your trained weights
det_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
det_cfg.MODEL.DEVICE = "cpu"
det_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
MetadataCatalog.get("__unused__").thing_classes = ["toproof"]
predictor = DefaultPredictor(det_cfg)
# -------------------
# SAM2 setup
# -------------------
os.chdir("/app")
if GlobalHydra.instance().is_initialized():
GlobalHydra.instance().clear()
with initialize(version_base=None, config_path="."):
sam2_model = build_sam2("sam2.1_hiera_l.yaml", "sam2.1_hiera_large.pt", device="cpu")
sam2_predictor = SAM2ImagePredictor(sam2_model)
# -------------------
# FastAPI app
# -------------------
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
@app.get("/")
def home():
return {"status": "running"}
# -------------------
# Helpers
# -------------------
def _largest_contour(mask):
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
return max(contours, key=cv2.contourArea)
def _min_area_rect_to_poly(cnt):
rect = cv2.minAreaRect(cnt)
box = cv2.boxPoints(rect)
return box.astype(np.float32).reshape(-1,1,2)
def mask_to_polygon_no_holes(mask, epsilon_factor=0.005, min_area=150):
if mask.dtype != np.uint8:
if mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8)
mask = (mask > 0).astype(np.uint8) * 255
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return None
contour = max(contours, key=cv2.contourArea)
if cv2.contourArea(contour) < min_area:
return None
epsilon = epsilon_factor * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
return approx
def clean_polygon_strict(mask, epsilon_factor=0.01, min_area=150):
if mask.dtype != np.uint8:
if mask.max() <= 1:
mask = (mask * 255).astype(np.uint8)
else:
mask = mask.astype(np.uint8)
bw = (mask > 127).astype(np.uint8) * 255
cnt = _largest_contour(bw)
if cnt is None:
return None, "No contour"
rect_poly = _min_area_rect_to_poly(cnt)
polyB = mask_to_polygon_no_holes(bw, epsilon_factor=epsilon_factor, min_area=min_area)
if rect_poly is not None and polyB is not None:
rect_area = cv2.contourArea(rect_poly)
contour_area = cv2.contourArea(cnt)
area_ratio = rect_area / contour_area if contour_area > 0 else 0
# ๐Ÿ”น If polygon has > 4 sides โ†’ prefer irregular
if len(polyB) > 4:
return polyB, "Polygon (Irregular)"
# ๐Ÿ”น Stricter rectangle test
if 0.95 < area_ratio < 1.05 and len(polyB) == 4:
return rect_poly, "Rectangle"
else:
return polyB, "Polygon (Irregular)"
elif rect_poly is not None:
return rect_poly, "Rectangle"
elif polyB is not None:
return polyB, "Polygon (Irregular)"
else:
return None, "No polygon"
# -------------------
# API Endpoint
# -------------------
@app.post("/polygon")
async def polygon_endpoint(file: UploadFile = File(...)):
contents = await file.read()
im = np.array(Image.open(io.BytesIO(contents)).convert("RGB"))
# --- Step 1: Mask R-CNN ---
outputs = predictor(im)
instances = outputs["instances"].to("cpu")
boxes = instances.pred_boxes.tensor.numpy()
masks = instances.pred_masks.numpy()
if len(masks) == 0:
return JSONResponse(content={"chosen": "No mask found", "polygon": None, "image": None})
# --- Step 2: SAM2 Refinement ---
refined_all = []
sam2_predictor.set_image(im)
for i, box in enumerate(boxes):
mask_rcnn = (masks[i].astype(np.uint8) * 255)
sam_masks, sam_scores, _ = sam2_predictor.predict(box=box[None, :], multimask_output=True)
best_idx = np.argmax(sam_scores)
sam_mask = (sam_masks[best_idx].astype(np.uint8) * 255)
# Clean SAM2 mask
sam_clean = cv2.morphologyEx(sam_mask, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8))
sam_clean = cv2.GaussianBlur(sam_clean, (3,3), 0)
_, sam_clean = cv2.threshold(sam_clean, 127, 255, cv2.THRESH_BINARY)
# --- Step 3: Rectangular vs irregular ---
temp_poly, chosen = clean_polygon_strict(sam_clean)
if chosen == "Rectangle":
# Rectangular โ†’ use SAM2 only
final_mask = sam_clean
else:
# Irregular โ†’ fusion
mask_rcnn_dilated = cv2.dilate(mask_rcnn, np.ones((5,5), np.uint8), iterations=1)
final_mask = cv2.bitwise_and(mask_rcnn_dilated, sam_clean)
# --- Step 4: Final polygonization ---
poly, chosen = clean_polygon_strict(final_mask)
refined_all.append((final_mask, poly, chosen))
# Take first polygon for demo
if not refined_all or refined_all[0][1] is None:
return JSONResponse(content={"chosen": "No polygon", "polygon": None, "image": None})
final_mask, final_poly, chosen = refined_all[0]
# --- Step 5: Preview overlay ---
overlay = im.copy()
cv2.polylines(overlay, [final_poly.astype(np.int32)], True, (0,0,255), 2)
_, buffer = cv2.imencode(".png", overlay)
img_b64 = base64.b64encode(buffer).decode("utf-8")
return {
"chosen": chosen,
"polygon": final_poly.reshape(-1, 2).tolist(),
"image": img_b64
}