Spaces:
Sleeping
Sleeping
File size: 5,410 Bytes
41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 41acb38 c3f4311 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import io
import logging
import os
import threading
import uuid
from typing import Optional
import requests
import torch
from flask import Flask, jsonify, request
from PIL import Image
logging.basicConfig(level=logging.INFO)
log = logging.getLogger("sam3-ls-backend")
MODEL_ID = os.environ.get("SAM3_MODEL_ID", "facebook/sam3")
MODEL_VERSION = os.environ.get("MODEL_VERSION", "sam3-real-v1")
DEFAULT_LABEL = os.environ.get("DEFAULT_LABEL", "butterfly")
CONFIDENCE_THRESHOLD = float(os.environ.get("CONFIDENCE_THRESHOLD", "0.5"))
MASK_THRESHOLD = float(os.environ.get("MASK_THRESHOLD", "0.5"))
app = Flask(__name__)
_model = None
_processor = None
_load_lock = threading.Lock()
_load_error: Optional[str] = None
def get_model():
global _model, _processor, _load_error
if _model is not None:
return _model, _processor
with _load_lock:
if _model is not None:
return _model, _processor
try:
from transformers import Sam3Model, Sam3Processor
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info("Loading SAM3 (%s) on %s...", MODEL_ID, device)
_processor = Sam3Processor.from_pretrained(MODEL_ID)
_model = Sam3Model.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16).to(device)
_model.eval()
log.info("SAM3 ready.")
return _model, _processor
except Exception as e:
_load_error = str(e)
log.exception("Model load failed")
raise
def fetch_image(url: str) -> Image.Image:
resp = requests.get(url, timeout=30, headers={"User-Agent": "sam3-ls-backend/1.0"})
resp.raise_for_status()
img = Image.open(io.BytesIO(resp.content))
if img.mode != "RGB":
img = img.convert("RGB")
return img
def run_inference(image: Image.Image, label: str):
model, processor = get_model()
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
inputs = processor(
images=[image],
text=[label],
return_tensors="pt",
).to(device, dtype=dtype)
with torch.no_grad():
outputs = model(**inputs)
results = processor.post_process_instance_segmentation(
outputs,
threshold=CONFIDENCE_THRESHOLD,
mask_threshold=MASK_THRESHOLD,
target_sizes=inputs.get("original_sizes").tolist(),
)
return results[0]
def to_ls_prediction(image: Image.Image, result, label: str) -> dict:
W, H = image.size
items = []
boxes = result.get("boxes")
scores = result.get("scores")
if boxes is None or len(boxes) == 0:
return {"model_version": MODEL_VERSION, "score": 0.0, "result": []}
for box, score in zip(boxes.tolist(), scores.tolist()):
x1, y1, x2, y2 = box
items.append({
"id": str(uuid.uuid4())[:8],
"from_name": "label",
"to_name": "image",
"type": "rectanglelabels",
"original_width": W,
"original_height": H,
"image_rotation": 0,
"value": {
"x": x1 / W * 100.0,
"y": y1 / H * 100.0,
"width": (x2 - x1) / W * 100.0,
"height": (y2 - y1) / H * 100.0,
"rotation": 0,
"rectanglelabels": [label],
},
"score": float(score),
})
overall = max((it["score"] for it in items), default=0.0)
return {"model_version": MODEL_VERSION, "score": float(overall), "result": items}
@app.route("/health", methods=["GET"])
def health():
return jsonify({
"status": "UP",
"model_version": MODEL_VERSION,
"model_loaded": _model is not None,
"load_error": _load_error,
"cuda_available": torch.cuda.is_available(),
})
@app.route("/setup", methods=["POST"])
def setup():
payload = request.get_json(silent=True) or {}
log.info("setup: project=%s", payload.get("project"))
return jsonify({"model_version": MODEL_VERSION})
@app.route("/predict", methods=["POST"])
def predict():
payload = request.get_json(silent=True) or {}
tasks = payload.get("tasks", [])
log.info("predict: %d task(s)", len(tasks))
out = []
for t in tasks:
url = (t.get("data") or {}).get("image")
if not url:
out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": []})
continue
try:
img = fetch_image(url)
r = run_inference(img, DEFAULT_LABEL)
out.append(to_ls_prediction(img, r, DEFAULT_LABEL))
except Exception as e:
log.exception("predict failed for task %s", t.get("id"))
out.append({"model_version": MODEL_VERSION, "score": 0.0, "result": [], "error": str(e)})
return jsonify({"results": out})
@app.route("/webhook", methods=["POST"])
def webhook():
payload = request.get_json(silent=True) or {}
log.info("webhook event: %s", payload.get("action"))
return jsonify({"status": "ok"})
@app.route("/", methods=["GET"])
def root():
return jsonify({
"service": "sam3-ls-backend",
"model_id": MODEL_ID,
"model_version": MODEL_VERSION,
"model_loaded": _model is not None,
"endpoints": ["/health", "/setup", "/predict", "/webhook"],
})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)
|