import os from io import BytesIO from pathlib import Path import chromadb import torch from fastapi import FastAPI, File, HTTPException, Query, UploadFile from fastapi.middleware.cors import CORSMiddleware from PIL import Image from transformers import AutoModel, AutoProcessor # ── Paths ───────────────────────────────────────────────────────────────────── BASE_DIR = Path(__file__).parent CHROMADB_DIR = BASE_DIR / "chromadb" CSV_PATH = BASE_DIR / "furniture_dataset.csv" IMAGE_BASE_URL = "https://huggingface.co/datasets/MohamedSameh77i/Furniture_Synthetic_Dataset/tree/main" SIGLIP_MODEL_ID = "google/siglip2-so400m-patch16-naflex" DEFAULT_TOP_K = int(os.getenv("SEARCH_TOP_K", "5")) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ── Load models ─────────────────────────────────────────────────────────────── print(f"Loading SigLIP2 on {DEVICE}...") processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID) model = AutoModel.from_pretrained(SIGLIP_MODEL_ID, torch_dtype=torch.float32).to(DEVICE) model.eval() chroma_client = chromadb.PersistentClient(path=str(CHROMADB_DIR)) collection = chroma_client.get_collection("furniture") N_ITEMS = collection.count() print(f"Ready — {N_ITEMS} items.") # ── Embed ───────────────────────────────────────────────────────────────────── @torch.inference_mode() def embed(pil_image: Image.Image) -> list[float]: inputs = processor(images=[pil_image], return_tensors="pt").to(DEVICE) outputs = model.vision_model(**inputs) vec = outputs.pooler_output vec = vec / vec.norm(dim=-1, keepdim=True) return vec.squeeze().cpu().float().tolist() # ── FastAPI Setup ───────────────────────────────────────────────────────────── app = FastAPI(title="IntelliRoom HF API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def root(): return { "service": "IntelliRoom HF API", "device": DEVICE, "items": N_ITEMS, "endpoints": ["/health", "/search", "/docs"], } @app.get("/health") def health(): return { "status": "running", "device": DEVICE, "items": N_ITEMS, } @app.post("/search") async def search_endpoint(file: UploadFile = File(...), top_k: int = Query(DEFAULT_TOP_K, ge=1, le=50)): image_bytes = await file.read() try: image = Image.open(BytesIO(image_bytes)).convert("RGB") vector = embed(image) results = collection.query( query_embeddings=[vector], n_results=top_k, include=["distances", "metadatas"], ) matches = [] for index in range(len(results["ids"][0])): meta = results["metadatas"][0][index] dist = results["distances"][0][index] filename = meta.get("filename") matches.append( { "rank": index + 1, "filename": filename, "name": meta.get("name"), "similarity": round(1 - dist, 3), "image_url": f"{IMAGE_BASE_URL}/{filename}" if filename else "", } ) return {"results": matches, "count": len(matches)} except Exception as exc: raise HTTPException(status_code=500, detail=str(exc))