Yashashvibhardwaj's picture
Deploy backend code
3eaabcf
raw
history blame
4.39 kB
from fastapi import FastAPI, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
import requests
import io
import faiss
import json
import os
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
# Init FastAPI
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # you can restrict to your Vercel URL later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# Load CLIP model once
print("🧠 Loading CLIP model...")
model = SentenceTransformer("clip-ViT-B-32")
# Load dataset
PRODUCTS_FILE = "products.json"
INDEX_FILE = "products.index"
with open(PRODUCTS_FILE, "r", encoding="utf-8", errors="ignore") as f:
products = json.load(f)
# Build or load FAISS index
if os.path.exists(INDEX_FILE):
print("πŸ“¦ Loading existing FAISS index...")
index = faiss.read_index(INDEX_FILE)
else:
print("⚑ Building FAISS index from products.json (first startup only)...")
# Encode product names (lightweight, avoids downloading images)
texts = [p["name"] + " " + p["category"] + " " + p["brand"]
for p in products]
embeddings = model.encode(
texts, convert_to_numpy=True, normalize_embeddings=True)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings.astype("float32"))
faiss.write_index(index, INDEX_FILE)
print(f"βœ… Saved FAISS index with {index.ntotal} vectors")
def embed_image(img: Image.Image):
return model.encode(img, convert_to_numpy=True, normalize_embeddings=True)
def embed_text(query: str):
return model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
@app.post("/match")
async def match(
file: UploadFile = None,
image_url: str = Form(None),
min_score: float = Form(0.6),
top_k: int = Form(60),
categories: str = Form(None),
brands: str = Form(None),
min_price: float = Form(0),
max_price: float = Form(9999)
):
try:
# Get query image
if file:
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
elif image_url:
img = Image.open(io.BytesIO(requests.get(
image_url).content)).convert("RGB")
else:
return {"matches": []}
# Encode query
q_emb = embed_image(img).reshape(1, -1)
# Search FAISS
scores, ids = index.search(q_emb, top_k)
# Parse filters
categories = json.loads(categories) if categories else []
brands = json.loads(brands) if brands else []
# Collect results
results = []
for score, idx in zip(scores[0], ids[0]):
if score < min_score:
continue
p = products[idx]
# Apply filters
if categories and p["category"] not in categories:
continue
if brands and p["brand"] not in brands:
continue
if not (min_price <= p["price"] <= max_price):
continue
results.append({**p, "score": float(score)})
return {"matches": results}
except Exception as e:
return {"error": str(e)}
@app.post("/search_text")
async def search_text(
query: str = Form(...),
min_score: float = Form(0.6),
top_k: int = Form(60),
categories: str = Form(None),
brands: str = Form(None),
min_price: float = Form(0),
max_price: float = Form(9999)
):
try:
# Encode text query
q_emb = embed_text(query).reshape(1, -1)
# Search FAISS
scores, ids = index.search(q_emb, top_k)
# Parse filters
categories = json.loads(categories) if categories else []
brands = json.loads(brands) if brands else []
# Collect results
results = []
for score, idx in zip(scores[0], ids[0]):
if score < min_score:
continue
p = products[idx]
# Apply filters
if categories and p["category"] not in categories:
continue
if brands and p["brand"] not in brands:
continue
if not (min_price <= p["price"] <= max_price):
continue
results.append({**p, "score": float(score)})
return {"matches": results}
except Exception as e:
return {"error": str(e)}