import io import logging import os import threading import uuid from typing import Optional import requests import torch from flask import Flask, jsonify, request from PIL import Image logging.basicConfig(level=logging.INFO) log = logging.getLogger("sam3-ls-backend") MODEL_ID = os.environ.get("SAM3_MODEL_ID", "facebook/sam3") MODEL_VERSION = os.environ.get("MODEL_VERSION", "sam3-real-v1") DEFAULT_LABEL = os.environ.get("DEFAULT_LABEL", "butterfly") CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.5")) MASK_THRESHOLD = float(os.environ.get("MASK_THRESHOLD", "0.5")) app = Flask(__name__) _model = None _processor = None _load_lock = threading.Lock() _load_error: Optional[str] = None def get_model(): global _model, _processor, _load_error if _model is not None: return _model, _processor with _load_lock: if _model is not None: return _model, _processor try: from transformers import Sam3Model, Sam3Processor device = "cuda" if torch.cuda.is_available() else "cpu" log.info("Loading SAM3 (%s) on %s...", MODEL_ID, device) _processor = Sam3Processor.from_pretrained(MODEL_ID) _model = Sam3Model.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device) _model.eval() log.info("SAM3 ready.") return _model, _processor except Exception as e: _load_error = str(e) log.exception("Model load failed") raise def fetch_image(url: str) -> Image.Image: resp = requests.get(url, timeout=30, headers={"User-Agent": "sam3-ls-backend/1.0"}) resp.raise_for_status() img = Image.open(io.BytesIO(resp.content)) if img.mode != "RGB": img = img.convert("RGB") return img def run_inference(image: Image.Image, label: str): model, processor = get_model() device = next(model.parameters()).device dtype = next(model.parameters()).dtype inputs = processor( images=[image], text=[label], return_tensors="pt", ).to(device, dtype=dtype) with torch.no_grad(): outputs = model(**inputs) results = processor.post_process_instance_segmentation( outputs, threshold=CONFIDENCE_THRESHOLD, mask_threshold=MASK_THRESHOLD, target_sizes=inputs.get("original_sizes").tolist(), ) return results[0] def to_ls_prediction(image: Image.Image, result, label: str) -> dict: W, H = image.size items = [] boxes = result.get("boxes") scores = result.get("scores") if boxes is None or len(boxes) == 0: return {"model_version": MODEL_VERSION, "score": 0.0, "result": []} for box, score in zip(boxes.tolist(), scores.tolist()): x1, y1, x2, y2 = box items.append({ "id": str(uuid.uuid4())[:8], "from_name": "label", "to_name": "image", "type": "rectanglelabels", "original_width": W, "original_height": H, "image_rotation": 0, "value": { "x": x1 / W * 100.0, "y": y1 / H * 100.0, "width": (x2 - x1) / W * 100.0, "height": (y2 - y1) / H * 100.0, "rotation": 0, "rectanglelabels": [label], }, "score": float(score), }) overall = max((it["score"] for it in items), default=0.0) return {"model_version": MODEL_VERSION, "score": float(overall), "result": items} @app.route("/health", methods=["GET"]) def health(): return jsonify({ "status": "UP", "model_version": MODEL_VERSION, "model_loaded": _model is not None, "load_error": _load_error, "cuda_available": torch.cuda.is_available(), }) @app.route("/setup", methods=["POST"]) def setup(): payload = request.get_json(silent=True) or {} log.info("setup: project=%s", payload.get("project")) return jsonify({"model_version": MODEL_VERSION}) @app.route("/predict", methods=["POST"]) def predict(): payload = request.get_json(silent=True) or {} tasks = payload.get("tasks", []) log.info("predict: %d task(s)", len(tasks)) out = [] for t in tasks: url = (t.get("data") or {}).get("image") if not url: out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": []}) continue try: img = fetch_image(url) r = run_inference(img, DEFAULT_LABEL) out.append(to_ls_prediction(img, r, DEFAULT_LABEL)) except Exception as e: log.exception("predict failed for task %s", t.get("id")) out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": [], "error": str(e)}) return jsonify({"results": out}) @app.route("/webhook", methods=["POST"]) def webhook(): payload = request.get_json(silent=True) or {} log.info("webhook event: %s", payload.get("action")) return jsonify({"status": "ok"}) @app.route("/", methods=["GET"]) def root(): return jsonify({ "service": "sam3-ls-backend", "model_id": MODEL_ID, "model_version": MODEL_VERSION, "model_loaded": _model is not None, "endpoints": ["/health", "/setup", "/predict", "/webhook"], }) if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)