Spaces:
Running
Running
| from fastapi import FastAPI, Request, UploadFile, File, Form | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import json | |
| import torch | |
| from transformers import AutoProcessor, AutoModel | |
| app = FastAPI() | |
| async def health(): | |
| return {"status": "healthy"} | |
| # Load SigLIP model (768-dim embeddings, sigmoid loss for better fine-grained matching) | |
| model_name = "google/siglip-base-patch16-224" | |
| model = AutoModel.from_pretrained(model_name) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model.eval() | |
| def load_image_from_url(url: str) -> Image.Image: | |
| response = requests.get(url, timeout=60) | |
| response.raise_for_status() | |
| return Image.open(BytesIO(response.content)).convert("RGB") | |
| def get_image_embedding(image: Image.Image) -> list: | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| features = model.get_image_features(**inputs) | |
| features = features / features.norm(p=2, dim=-1, keepdim=True) | |
| return features.squeeze().tolist() | |
| async def embed_image( | |
| url: str = Form(None), # Optional URL | |
| file: UploadFile = File(None) # Optional file | |
| ): | |
| if not url and not file: | |
| return { "message": "Either 'url' or 'file' must be provided" } | |
| # Load image from URL | |
| if url: | |
| image = load_image_from_url(url) | |
| # Load image from uploaded file | |
| elif file: | |
| image = Image.open(BytesIO(await file.read())).convert("RGB") | |
| embedding = get_image_embedding(image) | |
| return { "embedding": embedding, "dimension": len(embedding) } | |
| async def rerank(request: Request): | |
| body = await request.json() | |
| query_embedding = body["query_embedding"] # 768-dim from query image | |
| candidates = body["candidates"] # [{product_id, image_urls: [...]}] | |
| # pgvector returns embeddings as strings, parse if needed | |
| if isinstance(query_embedding, str): | |
| query_embedding = json.loads(query_embedding) | |
| query_tensor = torch.tensor(query_embedding).unsqueeze(0) | |
| query_tensor = query_tensor / query_tensor.norm(p=2, dim=-1, keepdim=True) | |
| results = [] | |
| for candidate in candidates: | |
| max_similarity = 0.0 | |
| for image_url in candidate["image_urls"]: | |
| try: | |
| image = load_image_from_url(image_url) | |
| candidate_embedding = get_image_embedding(image) | |
| candidate_tensor = torch.tensor(candidate_embedding).unsqueeze(0) | |
| similarity = torch.cosine_similarity(query_tensor, candidate_tensor).item() | |
| max_similarity = max(max_similarity, similarity) | |
| except Exception: | |
| continue | |
| results.append({ | |
| "product_id": candidate["product_id"], | |
| "max_similarity": max_similarity, | |
| }) | |
| results.sort(key=lambda x: x["max_similarity"], reverse=True) | |
| return { "results": results } | |