IntelliVDB / app.py
MohamedSameh77i's picture
Upload app.py with huggingface_hub
d1ae96c verified
import os
from io import BytesIO
from pathlib import Path
import chromadb
import torch
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
from transformers import AutoModel, AutoProcessor
# ── Paths ─────────────────────────────────────────────────────────────────────
BASE_DIR = Path(__file__).parent
CHROMADB_DIR = BASE_DIR / "chromadb"
CSV_PATH = BASE_DIR / "furniture_dataset.csv"
IMAGE_BASE_URL = "https://huggingface.co/datasets/MohamedSameh77i/Furniture_Synthetic_Dataset/tree/main"
SIGLIP_MODEL_ID = "google/siglip2-so400m-patch16-naflex"
DEFAULT_TOP_K = int(os.getenv("SEARCH_TOP_K", "5"))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── Load models ───────────────────────────────────────────────────────────────
print(f"Loading SigLIP2 on {DEVICE}...")
processor = AutoProcessor.from_pretrained(SIGLIP_MODEL_ID)
model = AutoModel.from_pretrained(SIGLIP_MODEL_ID, torch_dtype=torch.float32).to(DEVICE)
model.eval()
chroma_client = chromadb.PersistentClient(path=str(CHROMADB_DIR))
collection = chroma_client.get_collection("furniture")
N_ITEMS = collection.count()
print(f"Ready β€” {N_ITEMS} items.")
# ── Embed ─────────────────────────────────────────────────────────────────────
@torch.inference_mode()
def embed(pil_image: Image.Image) -> list[float]:
inputs = processor(images=[pil_image], return_tensors="pt").to(DEVICE)
outputs = model.vision_model(**inputs)
vec = outputs.pooler_output
vec = vec / vec.norm(dim=-1, keepdim=True)
return vec.squeeze().cpu().float().tolist()
# ── FastAPI Setup ─────────────────────────────────────────────────────────────
app = FastAPI(title="IntelliRoom HF API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def root():
return {
"service": "IntelliRoom HF API",
"device": DEVICE,
"items": N_ITEMS,
"endpoints": ["/health", "/search", "/docs"],
}
@app.get("/health")
def health():
return {
"status": "running",
"device": DEVICE,
"items": N_ITEMS,
}
@app.post("/search")
async def search_endpoint(file: UploadFile = File(...), top_k: int = Query(DEFAULT_TOP_K, ge=1, le=50)):
image_bytes = await file.read()
try:
image = Image.open(BytesIO(image_bytes)).convert("RGB")
vector = embed(image)
results = collection.query(
query_embeddings=[vector],
n_results=top_k,
include=["distances", "metadatas"],
)
matches = []
for index in range(len(results["ids"][0])):
meta = results["metadatas"][0][index]
dist = results["distances"][0][index]
filename = meta.get("filename")
matches.append(
{
"rank": index + 1,
"filename": filename,
"name": meta.get("name"),
"similarity": round(1 - dist, 3),
"image_url": f"{IMAGE_BASE_URL}/{filename}" if filename else "",
}
)
return {"results": matches, "count": len(matches)}
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))