|
|
from fastapi import FastAPI, UploadFile, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
import requests |
|
|
import io |
|
|
import faiss |
|
|
import json |
|
|
import os |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"] |
|
|
) |
|
|
|
|
|
|
|
|
print("π§ Loading CLIP model...") |
|
|
model = SentenceTransformer("clip-ViT-B-32") |
|
|
|
|
|
|
|
|
PRODUCTS_FILE = "products.json" |
|
|
INDEX_FILE = "products.index" |
|
|
|
|
|
with open(PRODUCTS_FILE, "r", encoding="utf-8", errors="ignore") as f: |
|
|
products = json.load(f) |
|
|
|
|
|
|
|
|
if os.path.exists(INDEX_FILE): |
|
|
print("π¦ Loading existing FAISS index...") |
|
|
index = faiss.read_index(INDEX_FILE) |
|
|
else: |
|
|
print("β‘ Building FAISS index from products.json (first startup only)...") |
|
|
|
|
|
texts = [p["name"] + " " + p["category"] + " " + p["brand"] |
|
|
for p in products] |
|
|
embeddings = model.encode( |
|
|
texts, convert_to_numpy=True, normalize_embeddings=True) |
|
|
index = faiss.IndexFlatIP(embeddings.shape[1]) |
|
|
index.add(embeddings.astype("float32")) |
|
|
faiss.write_index(index, INDEX_FILE) |
|
|
print(f"β
Saved FAISS index with {index.ntotal} vectors") |
|
|
|
|
|
|
|
|
def embed_image(img: Image.Image): |
|
|
return model.encode(img, convert_to_numpy=True, normalize_embeddings=True) |
|
|
|
|
|
|
|
|
def embed_text(query: str): |
|
|
return model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0] |
|
|
|
|
|
|
|
|
@app.post("/match") |
|
|
async def match( |
|
|
file: UploadFile = None, |
|
|
image_url: str = Form(None), |
|
|
min_score: float = Form(0.6), |
|
|
top_k: int = Form(60), |
|
|
categories: str = Form(None), |
|
|
brands: str = Form(None), |
|
|
min_price: float = Form(0), |
|
|
max_price: float = Form(9999) |
|
|
): |
|
|
try: |
|
|
|
|
|
if file: |
|
|
img = Image.open(io.BytesIO(await file.read())).convert("RGB") |
|
|
elif image_url: |
|
|
img = Image.open(io.BytesIO(requests.get( |
|
|
image_url).content)).convert("RGB") |
|
|
else: |
|
|
return {"matches": []} |
|
|
|
|
|
|
|
|
q_emb = embed_image(img).reshape(1, -1) |
|
|
|
|
|
|
|
|
scores, ids = index.search(q_emb, top_k) |
|
|
|
|
|
|
|
|
categories = json.loads(categories) if categories else [] |
|
|
brands = json.loads(brands) if brands else [] |
|
|
|
|
|
|
|
|
results = [] |
|
|
for score, idx in zip(scores[0], ids[0]): |
|
|
if score < min_score: |
|
|
continue |
|
|
p = products[idx] |
|
|
|
|
|
|
|
|
if categories and p["category"] not in categories: |
|
|
continue |
|
|
if brands and p["brand"] not in brands: |
|
|
continue |
|
|
if not (min_price <= p["price"] <= max_price): |
|
|
continue |
|
|
|
|
|
results.append({**p, "score": float(score)}) |
|
|
return {"matches": results} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
@app.post("/search_text") |
|
|
async def search_text( |
|
|
query: str = Form(...), |
|
|
min_score: float = Form(0.6), |
|
|
top_k: int = Form(60), |
|
|
categories: str = Form(None), |
|
|
brands: str = Form(None), |
|
|
min_price: float = Form(0), |
|
|
max_price: float = Form(9999) |
|
|
): |
|
|
try: |
|
|
|
|
|
q_emb = embed_text(query).reshape(1, -1) |
|
|
|
|
|
|
|
|
scores, ids = index.search(q_emb, top_k) |
|
|
|
|
|
|
|
|
categories = json.loads(categories) if categories else [] |
|
|
brands = json.loads(brands) if brands else [] |
|
|
|
|
|
|
|
|
results = [] |
|
|
for score, idx in zip(scores[0], ids[0]): |
|
|
if score < min_score: |
|
|
continue |
|
|
p = products[idx] |
|
|
|
|
|
|
|
|
if categories and p["category"] not in categories: |
|
|
continue |
|
|
if brands and p["brand"] not in brands: |
|
|
continue |
|
|
if not (min_price <= p["price"] <= max_price): |
|
|
continue |
|
|
|
|
|
results.append({**p, "score": float(score)}) |
|
|
return {"matches": results} |
|
|
except Exception as e: |
|
|
return {"error": str(e)} |
|
|
|