sam3-ls-backend / app.py
davanstrien's picture
davanstrien HF Staff
Upload folder using huggingface_hub
41acb38 verified
import io
import logging
import os
import threading
import uuid
from typing import Optional
import requests
import torch
from flask import Flask, jsonify, request
from PIL import Image
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("sam3-ls-backend")
MODEL_ID = os.environ.get("SAM3_MODEL_ID", "facebook/sam3")
MODEL_VERSION = os.environ.get("MODEL_VERSION", "sam3-real-v1")
DEFAULT_LABEL = os.environ.get("DEFAULT_LABEL", "butterfly")
CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.5"))
MASK_THRESHOLD = float(os.environ.get("MASK_THRESHOLD", "0.5"))
app = Flask(__name__)
_model = None
_processor = None
_load_lock = threading.Lock()
_load_error: Optional[str] = None
def get_model():
global _model, _processor, _load_error
if _model is not None:
return _model, _processor
with _load_lock:
if _model is not None:
return _model, _processor
try:
from transformers import Sam3Model, Sam3Processor
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info("Loading SAM3 (%s) on %s...", MODEL_ID, device)
_processor = Sam3Processor.from_pretrained(MODEL_ID)
_model = Sam3Model.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
_model.eval()
log.info("SAM3 ready.")
return _model, _processor
except Exception as e:
_load_error = str(e)
log.exception("Model load failed")
raise
def fetch_image(url: str) -> Image.Image:
resp = requests.get(url, timeout=30, headers={"User-Agent": "sam3-ls-backend/1.0"})
resp.raise_for_status()
img = Image.open(io.BytesIO(resp.content))
if img.mode != "RGB":
img = img.convert("RGB")
return img
def run_inference(image: Image.Image, label: str):
model, processor = get_model()
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
inputs = processor(
images=[image],
text=[label],
return_tensors="pt",
).to(device, dtype=dtype)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=CONFIDENCE_THRESHOLD,
mask_threshold=MASK_THRESHOLD,
target_sizes=inputs.get("original_sizes").tolist(),
)
return results[0]
def to_ls_prediction(image: Image.Image, result, label: str) -> dict:
W, H = image.size
items = []
boxes = result.get("boxes")
scores = result.get("scores")
if boxes is None or len(boxes) == 0:
return {"model_version": MODEL_VERSION, "score": 0.0, "result": []}
for box, score in zip(boxes.tolist(), scores.tolist()):
x1, y1, x2, y2 = box
items.append({
"id": str(uuid.uuid4())[:8],
"from_name": "label",
"to_name": "image",
"type": "rectanglelabels",
"original_width": W,
"original_height": H,
"image_rotation": 0,
"value": {
"x": x1 / W * 100.0,
"y": y1 / H * 100.0,
"width": (x2 - x1) / W * 100.0,
"height": (y2 - y1) / H * 100.0,
"rotation": 0,
"rectanglelabels": [label],
},
"score": float(score),
})
overall = max((it["score"] for it in items), default=0.0)
return {"model_version": MODEL_VERSION, "score": float(overall), "result": items}
@app.route("/health", methods=["GET"])
def health():
return jsonify({
"status": "UP",
"model_version": MODEL_VERSION,
"model_loaded": _model is not None,
"load_error": _load_error,
"cuda_available": torch.cuda.is_available(),
})
@app.route("/setup", methods=["POST"])
def setup():
payload = request.get_json(silent=True) or {}
log.info("setup: project=%s", payload.get("project"))
return jsonify({"model_version": MODEL_VERSION})
@app.route("/predict", methods=["POST"])
def predict():
payload = request.get_json(silent=True) or {}
tasks = payload.get("tasks", [])
log.info("predict: %d task(s)", len(tasks))
out = []
for t in tasks:
url = (t.get("data") or {}).get("image")
if not url:
out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": []})
continue
try:
img = fetch_image(url)
r = run_inference(img, DEFAULT_LABEL)
out.append(to_ls_prediction(img, r, DEFAULT_LABEL))
except Exception as e:
log.exception("predict failed for task %s", t.get("id"))
out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": [], "error": str(e)})
return jsonify({"results": out})
@app.route("/webhook", methods=["POST"])
def webhook():
payload = request.get_json(silent=True) or {}
log.info("webhook event: %s", payload.get("action"))
return jsonify({"status": "ok"})
@app.route("/", methods=["GET"])
def root():
return jsonify({
"service": "sam3-ls-backend",
"model_id": MODEL_ID,
"model_version": MODEL_VERSION,
"model_loaded": _model is not None,
"endpoints": ["/health", "/setup", "/predict", "/webhook"],
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)