Spaces:
Sleeping
Sleeping
| """ | |
| GroundedβSAM Flask API (CPU only) | |
| POST /segment | |
| Body (multipart/form-data): | |
| - image: the house photo | |
| - prompt: text prompt, e.g. "roof sheet" | |
| Query params: | |
| - overlay (bool, default=false): if true, returns a PNG overlay instead | |
| Returns: | |
| - image/png mask (single channel) OR overlay | |
| """ | |
| import io | |
| import os | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| from flask import Flask, request, send_file | |
| from flask_cors import CORS | |
| import torch | |
| from groundingdino.util.inference import Model as GroundingModel | |
| from segment_anything import sam_model_registry, SamPredictor | |
| # βββ Load models once βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device("cpu") | |
| DINO_CKPT = "weights/groundingdino_swint_ogc.pth" | |
| SAM_CKPT = "weights/sam_vit_h_4b8939.pth" | |
| grounder = GroundingModel(DINO_CKPT, device=device) | |
| sam = sam_model_registry["vit_h"](checkpoint=SAM_CKPT).to(device) | |
| predictor = SamPredictor(sam) | |
| # βββ Flask app ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = Flask(__name__) | |
| CORS(app) | |
| def segment(image_pil: Image.Image, prompt: str): | |
| # 1) Run GroundingDINO to get boxes for the prompt | |
| boxes, _, _ = grounder.predict(image_pil, prompt=prompt, box_threshold=0.3, text_threshold=0.25) | |
| if boxes.size == 0: | |
| raise ValueError("No boxes found for prompt.") | |
| # 2) Largest box β mask via SAM | |
| box = boxes[np.argmax((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1]))] | |
| predictor.set_image(np.array(image_pil)) | |
| masks, _, _ = predictor.predict(box=box) | |
| mask = masks[0] # boolean HxW | |
| return mask | |
| def segment_endpoint(): | |
| if "image" not in request.files or "prompt" not in request.form: | |
| return {"error": "image file and prompt are required."}, 400 | |
| prompt = request.form["prompt"] | |
| image = Image.open(request.files["image"].stream).convert("RGB") | |
| try: | |
| mask = segment(image, prompt) | |
| except ValueError as e: | |
| return {"error": str(e)}, 422 | |
| overlay = request.args.get("overlay", "false").lower() == "true" | |
| if overlay: | |
| colored = np.array(image).copy() | |
| colored[mask] = [255, 0, 0] # red overlay | |
| out_img = Image.fromarray(colored) | |
| else: | |
| out_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| buf = io.BytesIO() | |
| out_img.save(buf, format="PNG") | |
| buf.seek(0) | |
| return send_file(buf, mimetype="image/png") | |
| # βββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", default="127.0.0.1") | |
| parser.add_argument("--port", default=7860, type=int) | |
| args = parser.parse_args() | |
| app.run(host=args.host, port=args.port) | |