File size: 3,350 Bytes
2c45e8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
InvenSync V3 — Embed API minimal (FastAPI sur HF Space gratuit).

POST /embed { image_base64 } → { vector[768], dim, model, latency_ms }
GET  /                       → health check (model_loaded, device, version)

Pourquoi ce Space existe : HF Inference Providers (free tier) ne sert pas les
SigLIP-2 custom. On self-héberge donc le modèle ici, gratuitement, en CPU.
"""

import base64
import io
import time
from typing import Optional

import torch
from fastapi import FastAPI, HTTPException
from PIL import Image
from pydantic import BaseModel, Field
from transformers import AutoModel, AutoProcessor

MODEL_ID = "invensync/siglip2-base-invensync-v1"
IMG_SIZE = 384
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Pré-load au cold start du container.
# Sur CPU 2 vCPU HF Space free, ~5-10 s pour load les ~800 Mo de poids.
print(f"[boot] Loading {MODEL_ID} on {DEVICE} …")
_t0 = time.time()
_model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE).eval()
_processor = AutoProcessor.from_pretrained(MODEL_ID)
print(f"[boot] Model loaded in {time.time() - _t0:.1f}s")


def _underlying(m):
    return m.get_base_model() if hasattr(m, "get_base_model") else m


_u = _underlying(_model)
_has_visual_proj = hasattr(_u, "visual_projection") and not isinstance(
    getattr(_u, "visual_projection", None), torch.nn.Identity
)

app = FastAPI(
    title="InvenSync V3 Embed API",
    description="SigLIP-2 fine-tuned embedding API",
    version="1.0.0",
)


class EmbedRequest(BaseModel):
    image_base64: str = Field(..., description="PNG/JPEG/WebP en base64 (avec ou sans préfixe data:)")


class EmbedResponse(BaseModel):
    vector: list[float]
    dim: int
    model: str
    latency_ms: int


def _decode_image(b64: str) -> Image.Image:
    # Strip prefix "data:image/...;base64," si présent
    if "," in b64 and b64.startswith("data:"):
        b64 = b64.split(",", 1)[1]
    try:
        img_bytes = base64.b64decode(b64, validate=False)
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"base64 decode failed: {e}")
    try:
        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
    except Exception as e:
        raise HTTPException(status_code=400, detail=f"image open failed: {e}")
    return img.resize((IMG_SIZE, IMG_SIZE), Image.BILINEAR)


@torch.no_grad()
def _embed(img: Image.Image) -> list[float]:
    inputs = _processor(images=img, return_tensors="pt").to(DEVICE)
    vision_out = _u.vision_model(pixel_values=inputs["pixel_values"])
    pooled = vision_out.pooler_output
    if _has_visual_proj:
        pooled = _u.visual_projection(pooled)
    feats = torch.nn.functional.normalize(pooled, dim=-1)
    return feats[0].cpu().tolist()


@app.get("/")
def health():
    return {
        "ok": True,
        "model": MODEL_ID,
        "device": DEVICE,
        "img_size": IMG_SIZE,
        "has_visual_proj": _has_visual_proj,
    }


@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
    t0 = time.time()
    img = _decode_image(req.image_base64)
    vec = _embed(img)
    if len(vec) != 768:
        raise HTTPException(status_code=500, detail=f"unexpected embedding dim: {len(vec)}")
    return EmbedResponse(
        vector=vec,
        dim=len(vec),
        model=MODEL_ID,
        latency_ms=int((time.time() - t0) * 1000),
    )