| import os |
| from io import BytesIO |
| from pathlib import Path |
|
|
| import chromadb |
| import torch |
| from fastapi import FastAPI, File, HTTPException, Query, UploadFile |
| from fastapi.middleware.cors import CORSMiddleware |
| from PIL import Image |
| from transformers import AutoModel, AutoProcessor |
|
|
| |
| BASE_DIR = Path(__file__).parent |
| CHROMADB_DIR = BASE_DIR / "chromadb" |
| CSV_PATH = BASE_DIR / "furniture_dataset.csv" |
| IMAGE_BASE_URL = "https://huggingface.co/datasets/MohamedSameh77i/Furniture_Synthetic_Dataset/tree/main" |
| SIGLIP_MODEL_ID = "google/siglip2-so400m-patch16-naflex" |
| DEFAULT_TOP_K = int(os.getenv("SEARCH_TOP_K", "5")) |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| print(f"Loading SigLIP2 on {DEVICE}...") |
| processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID) |
| model = AutoModel.from_pretrained(SIGLIP_MODEL_ID, torch_dtype=torch.float32).to(DEVICE) |
| model.eval() |
|
|
| chroma_client = chromadb.PersistentClient(path=str(CHROMADB_DIR)) |
| collection = chroma_client.get_collection("furniture") |
| N_ITEMS = collection.count() |
| print(f"Ready β {N_ITEMS} items.") |
|
|
|
|
| |
| @torch.inference_mode() |
| def embed(pil_image: Image.Image) -> list[float]: |
| inputs = processor(images=[pil_image], return_tensors="pt").to(DEVICE) |
| outputs = model.vision_model(**inputs) |
| vec = outputs.pooler_output |
| vec = vec / vec.norm(dim=-1, keepdim=True) |
| return vec.squeeze().cpu().float().tolist() |
|
|
|
|
| |
| app = FastAPI(title="IntelliRoom HF API", version="1.0.0") |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.get("/") |
| def root(): |
| return { |
| "service": "IntelliRoom HF API", |
| "device": DEVICE, |
| "items": N_ITEMS, |
| "endpoints": ["/health", "/search", "/docs"], |
| } |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return { |
| "status": "running", |
| "device": DEVICE, |
| "items": N_ITEMS, |
| } |
|
|
|
|
| @app.post("/search") |
| async def search_endpoint(file: UploadFile = File(...), top_k: int = Query(DEFAULT_TOP_K, ge=1, le=50)): |
| image_bytes = await file.read() |
|
|
| try: |
| image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| vector = embed(image) |
|
|
| results = collection.query( |
| query_embeddings=[vector], |
| n_results=top_k, |
| include=["distances", "metadatas"], |
| ) |
|
|
| matches = [] |
| for index in range(len(results["ids"][0])): |
| meta = results["metadatas"][0][index] |
| dist = results["distances"][0][index] |
| filename = meta.get("filename") |
| matches.append( |
| { |
| "rank": index + 1, |
| "filename": filename, |
| "name": meta.get("name"), |
| "similarity": round(1 - dist, 3), |
| "image_url": f"{IMAGE_BASE_URL}/{filename}" if filename else "", |
| } |
| ) |
|
|
| return {"results": matches, "count": len(matches)} |
|
|
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) |