autocaptions / server.py
pandemuliada's picture
Fix rerank string parsing
8a9e1fd
from fastapi import FastAPI, Request, UploadFile, File, Form
from PIL import Image
import requests
from io import BytesIO
import json
import torch
from transformers import AutoProcessor, AutoModel
app = FastAPI()
@app.get("/health")
async def health():
return {"status": "healthy"}
# Load SigLIP model (768-dim embeddings, sigmoid loss for better fine-grained matching)
model_name = "google/siglip-base-patch16-224"
model = AutoModel.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
model.eval()
def load_image_from_url(url: str) -> Image.Image:
response = requests.get(url, timeout=60)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")
def get_image_embedding(image: Image.Image) -> list:
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
features = model.get_image_features(**inputs)
features = features / features.norm(p=2, dim=-1, keepdim=True)
return features.squeeze().tolist()
@app.post("/embed")
async def embed_image(
url: str = Form(None), # Optional URL
file: UploadFile = File(None) # Optional file
):
if not url and not file:
return { "message": "Either 'url' or 'file' must be provided" }
# Load image from URL
if url:
image = load_image_from_url(url)
# Load image from uploaded file
elif file:
image = Image.open(BytesIO(await file.read())).convert("RGB")
embedding = get_image_embedding(image)
return { "embedding": embedding, "dimension": len(embedding) }
@app.post("/rerank")
async def rerank(request: Request):
body = await request.json()
query_embedding = body["query_embedding"] # 768-dim from query image
candidates = body["candidates"] # [{product_id, image_urls: [...]}]
# pgvector returns embeddings as strings, parse if needed
if isinstance(query_embedding, str):
query_embedding = json.loads(query_embedding)
query_tensor = torch.tensor(query_embedding).unsqueeze(0)
query_tensor = query_tensor / query_tensor.norm(p=2, dim=-1, keepdim=True)
results = []
for candidate in candidates:
max_similarity = 0.0
for image_url in candidate["image_urls"]:
try:
image = load_image_from_url(image_url)
candidate_embedding = get_image_embedding(image)
candidate_tensor = torch.tensor(candidate_embedding).unsqueeze(0)
similarity = torch.cosine_similarity(query_tensor, candidate_tensor).item()
max_similarity = max(max_similarity, similarity)
except Exception:
continue
results.append({
"product_id": candidate["product_id"],
"max_similarity": max_similarity,
})
results.sort(key=lambda x: x["max_similarity"], reverse=True)
return { "results": results }