Spaces:
Sleeping
Sleeping
| import io | |
| import logging | |
| import os | |
| import threading | |
| import uuid | |
| from typing import Optional | |
| import requests | |
| import torch | |
| from flask import Flask, jsonify, request | |
| from PIL import Image | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("sam3-ls-backend") | |
| MODEL_ID = os.environ.get("SAM3_MODEL_ID", "facebook/sam3") | |
| MODEL_VERSION = os.environ.get("MODEL_VERSION", "sam3-real-v1") | |
| DEFAULT_LABEL = os.environ.get("DEFAULT_LABEL", "butterfly") | |
| CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.5")) | |
| MASK_THRESHOLD = float(os.environ.get("MASK_THRESHOLD", "0.5")) | |
| app = Flask(__name__) | |
| _model = None | |
| _processor = None | |
| _load_lock = threading.Lock() | |
| _load_error: Optional[str] = None | |
| def get_model(): | |
| global _model, _processor, _load_error | |
| if _model is not None: | |
| return _model, _processor | |
| with _load_lock: | |
| if _model is not None: | |
| return _model, _processor | |
| try: | |
| from transformers import Sam3Model, Sam3Processor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| log.info("Loading SAM3 (%s) on %s...", MODEL_ID, device) | |
| _processor = Sam3Processor.from_pretrained(MODEL_ID) | |
| _model = Sam3Model.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device) | |
| _model.eval() | |
| log.info("SAM3 ready.") | |
| return _model, _processor | |
| except Exception as e: | |
| _load_error = str(e) | |
| log.exception("Model load failed") | |
| raise | |
| def fetch_image(url: str) -> Image.Image: | |
| resp = requests.get(url, timeout=30, headers={"User-Agent": "sam3-ls-backend/1.0"}) | |
| resp.raise_for_status() | |
| img = Image.open(io.BytesIO(resp.content)) | |
| if img.mode != "RGB": | |
| img = img.convert("RGB") | |
| return img | |
| def run_inference(image: Image.Image, label: str): | |
| model, processor = get_model() | |
| device = next(model.parameters()).device | |
| dtype = next(model.parameters()).dtype | |
| inputs = processor( | |
| images=[image], | |
| text=[label], | |
| return_tensors="pt", | |
| ).to(device, dtype=dtype) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| results = processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=CONFIDENCE_THRESHOLD, | |
| mask_threshold=MASK_THRESHOLD, | |
| target_sizes=inputs.get("original_sizes").tolist(), | |
| ) | |
| return results[0] | |
| def to_ls_prediction(image: Image.Image, result, label: str) -> dict: | |
| W, H = image.size | |
| items = [] | |
| boxes = result.get("boxes") | |
| scores = result.get("scores") | |
| if boxes is None or len(boxes) == 0: | |
| return {"model_version": MODEL_VERSION, "score": 0.0, "result": []} | |
| for box, score in zip(boxes.tolist(), scores.tolist()): | |
| x1, y1, x2, y2 = box | |
| items.append({ | |
| "id": str(uuid.uuid4())[:8], | |
| "from_name": "label", | |
| "to_name": "image", | |
| "type": "rectanglelabels", | |
| "original_width": W, | |
| "original_height": H, | |
| "image_rotation": 0, | |
| "value": { | |
| "x": x1 / W * 100.0, | |
| "y": y1 / H * 100.0, | |
| "width": (x2 - x1) / W * 100.0, | |
| "height": (y2 - y1) / H * 100.0, | |
| "rotation": 0, | |
| "rectanglelabels": [label], | |
| }, | |
| "score": float(score), | |
| }) | |
| overall = max((it["score"] for it in items), default=0.0) | |
| return {"model_version": MODEL_VERSION, "score": float(overall), "result": items} | |
| def health(): | |
| return jsonify({ | |
| "status": "UP", | |
| "model_version": MODEL_VERSION, | |
| "model_loaded": _model is not None, | |
| "load_error": _load_error, | |
| "cuda_available": torch.cuda.is_available(), | |
| }) | |
| def setup(): | |
| payload = request.get_json(silent=True) or {} | |
| log.info("setup: project=%s", payload.get("project")) | |
| return jsonify({"model_version": MODEL_VERSION}) | |
| def predict(): | |
| payload = request.get_json(silent=True) or {} | |
| tasks = payload.get("tasks", []) | |
| log.info("predict: %d task(s)", len(tasks)) | |
| out = [] | |
| for t in tasks: | |
| url = (t.get("data") or {}).get("image") | |
| if not url: | |
| out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": []}) | |
| continue | |
| try: | |
| img = fetch_image(url) | |
| r = run_inference(img, DEFAULT_LABEL) | |
| out.append(to_ls_prediction(img, r, DEFAULT_LABEL)) | |
| except Exception as e: | |
| log.exception("predict failed for task %s", t.get("id")) | |
| out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": [], "error": str(e)}) | |
| return jsonify({"results": out}) | |
| def webhook(): | |
| payload = request.get_json(silent=True) or {} | |
| log.info("webhook event: %s", payload.get("action")) | |
| return jsonify({"status": "ok"}) | |
| def root(): | |
| return jsonify({ | |
| "service": "sam3-ls-backend", | |
| "model_id": MODEL_ID, | |
| "model_version": MODEL_VERSION, | |
| "model_loaded": _model is not None, | |
| "endpoints": ["/health", "/setup", "/predict", "/webhook"], | |
| }) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860) | |