from fastapi import FastAPI, Request, UploadFile, File, Form from PIL import Image import requests from io import BytesIO import json import torch from transformers import AutoProcessor, AutoModel app = FastAPI() @app.get("/health") async def health(): return {"status": "healthy"} # Load SigLIP model (768-dim embeddings, sigmoid loss for better fine-grained matching) model_name = "google/siglip-base-patch16-224" model = AutoModel.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name) model.eval() def load_image_from_url(url: str) -> Image.Image: response = requests.get(url, timeout=60) response.raise_for_status() return Image.open(BytesIO(response.content)).convert("RGB") def get_image_embedding(image: Image.Image) -> list: inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): features = model.get_image_features(**inputs) features = features / features.norm(p=2, dim=-1, keepdim=True) return features.squeeze().tolist() @app.post("/embed") async def embed_image( url: str = Form(None), # Optional URL file: UploadFile = File(None) # Optional file ): if not url and not file: return { "message": "Either 'url' or 'file' must be provided" } # Load image from URL if url: image = load_image_from_url(url) # Load image from uploaded file elif file: image = Image.open(BytesIO(await file.read())).convert("RGB") embedding = get_image_embedding(image) return { "embedding": embedding, "dimension": len(embedding) } @app.post("/rerank") async def rerank(request: Request): body = await request.json() query_embedding = body["query_embedding"] # 768-dim from query image candidates = body["candidates"] # [{product_id, image_urls: [...]}] # pgvector returns embeddings as strings, parse if needed if isinstance(query_embedding, str): query_embedding = json.loads(query_embedding) query_tensor = torch.tensor(query_embedding).unsqueeze(0) query_tensor = query_tensor / query_tensor.norm(p=2, dim=-1, keepdim=True) results = [] for candidate in candidates: max_similarity = 0.0 for image_url in candidate["image_urls"]: try: image = load_image_from_url(image_url) candidate_embedding = get_image_embedding(image) candidate_tensor = torch.tensor(candidate_embedding).unsqueeze(0) similarity = torch.cosine_similarity(query_tensor, candidate_tensor).item() max_similarity = max(max_similarity, similarity) except Exception: continue results.append({ "product_id": candidate["product_id"], "max_similarity": max_similarity, }) results.sort(key=lambda x: x["max_similarity"], reverse=True) return { "results": results }