invensync's picture
Initial commit — SigLIP-2 embed API
2c45e8a verified
"""
InvenSync V3 — Embed API minimal (FastAPI sur HF Space gratuit).
POST /embed { image_base64 } → { vector[768], dim, model, latency_ms }
GET / → health check (model_loaded, device, version)
Pourquoi ce Space existe : HF Inference Providers (free tier) ne sert pas les
SigLIP-2 custom. On self-héberge donc le modèle ici, gratuitement, en CPU.
"""
import base64
import io
import time
from typing import Optional
import torch
from fastapi import FastAPI, HTTPException
from PIL import Image
from pydantic import BaseModel, Field
from transformers import AutoModel, AutoProcessor
MODEL_ID = "invensync/siglip2-base-invensync-v1"
IMG_SIZE = 384
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Pré-load au cold start du container.
# Sur CPU 2 vCPU HF Space free, ~5-10 s pour load les ~800 Mo de poids.
print(f"[boot] Loading {MODEL_ID} on {DEVICE} …")
_t0 = time.time()
_model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE).eval()
_processor = AutoProcessor.from_pretrained(MODEL_ID)
print(f"[boot] Model loaded in {time.time() - _t0:.1f}s")
def _underlying(m):
return m.get_base_model() if hasattr(m, "get_base_model") else m
_u = _underlying(_model)
_has_visual_proj = hasattr(_u, "visual_projection") and not isinstance(
getattr(_u, "visual_projection", None), torch.nn.Identity
)
app = FastAPI(
title="InvenSync V3 Embed API",
description="SigLIP-2 fine-tuned embedding API",
version="1.0.0",
)
class EmbedRequest(BaseModel):
image_base64: str = Field(..., description="PNG/JPEG/WebP en base64 (avec ou sans préfixe data:)")
class EmbedResponse(BaseModel):
vector: list[float]
dim: int
model: str
latency_ms: int
def _decode_image(b64: str) -> Image.Image:
# Strip prefix "data:image/...;base64," si présent
if "," in b64 and b64.startswith("data:"):
b64 = b64.split(",", 1)[1]
try:
img_bytes = base64.b64decode(b64, validate=False)
except Exception as e:
raise HTTPException(status_code=400, detail=f"base64 decode failed: {e}")
try:
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"image open failed: {e}")
return img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)
@torch.no_grad()
def _embed(img: Image.Image) -> list[float]:
inputs = _processor(images=img, return_tensors="pt").to(DEVICE)
vision_out = _u.vision_model(pixel_values=inputs["pixel_values"])
pooled = vision_out.pooler_output
if _has_visual_proj:
pooled = _u.visual_projection(pooled)
feats = torch.nn.functional.normalize(pooled, dim=-1)
return feats[0].cpu().tolist()
@app.get("/")
def health():
return {
"ok": True,
"model": MODEL_ID,
"device": DEVICE,
"img_size": IMG_SIZE,
"has_visual_proj": _has_visual_proj,
}
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
t0 = time.time()
img = _decode_image(req.image_base64)
vec = _embed(img)
if len(vec) != 768:
raise HTTPException(status_code=500, detail=f"unexpected embedding dim: {len(vec)}")
return EmbedResponse(
vector=vec,
dim=len(vec),
model=MODEL_ID,
latency_ms=int((time.time() - t0) * 1000),
)