File size: 4,393 Bytes
3eaabcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from fastapi import FastAPI, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
import requests
import io
import faiss
import json
import os
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
# Init FastAPI
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # you can restrict to your Vercel URL later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# Load CLIP model once
print("🧠 Loading CLIP model...")
model = SentenceTransformer("clip-ViT-B-32")
# Load dataset
PRODUCTS_FILE = "products.json"
INDEX_FILE = "products.index"
with open(PRODUCTS_FILE, "r", encoding="utf-8", errors="ignore") as f:
products = json.load(f)
# Build or load FAISS index
if os.path.exists(INDEX_FILE):
print("📦 Loading existing FAISS index...")
index = faiss.read_index(INDEX_FILE)
else:
print("⚡ Building FAISS index from products.json (first startup only)...")
# Encode product names (lightweight, avoids downloading images)
texts = [p["name"] + " " + p["category"] + " " + p["brand"]
for p in products]
embeddings = model.encode(
texts, convert_to_numpy=True, normalize_embeddings=True)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings.astype("float32"))
faiss.write_index(index, INDEX_FILE)
print(f"✅ Saved FAISS index with {index.ntotal} vectors")
def embed_image(img: Image.Image):
return model.encode(img, convert_to_numpy=True, normalize_embeddings=True)
def embed_text(query: str):
return model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
@app.post("/match")
async def match(
file: UploadFile = None,
image_url: str = Form(None),
min_score: float = Form(0.6),
top_k: int = Form(60),
categories: str = Form(None),
brands: str = Form(None),
min_price: float = Form(0),
max_price: float = Form(9999)
):
try:
# Get query image
if file:
img = Image.open(io.BytesIO(await file.read())).convert("RGB")
elif image_url:
img = Image.open(io.BytesIO(requests.get(
image_url).content)).convert("RGB")
else:
return {"matches": []}
# Encode query
q_emb = embed_image(img).reshape(1, -1)
# Search FAISS
scores, ids = index.search(q_emb, top_k)
# Parse filters
categories = json.loads(categories) if categories else []
brands = json.loads(brands) if brands else []
# Collect results
results = []
for score, idx in zip(scores[0], ids[0]):
if score < min_score:
continue
p = products[idx]
# Apply filters
if categories and p["category"] not in categories:
continue
if brands and p["brand"] not in brands:
continue
if not (min_price <= p["price"] <= max_price):
continue
results.append({**p, "score": float(score)})
return {"matches": results}
except Exception as e:
return {"error": str(e)}
@app.post("/search_text")
async def search_text(
query: str = Form(...),
min_score: float = Form(0.6),
top_k: int = Form(60),
categories: str = Form(None),
brands: str = Form(None),
min_price: float = Form(0),
max_price: float = Form(9999)
):
try:
# Encode text query
q_emb = embed_text(query).reshape(1, -1)
# Search FAISS
scores, ids = index.search(q_emb, top_k)
# Parse filters
categories = json.loads(categories) if categories else []
brands = json.loads(brands) if brands else []
# Collect results
results = []
for score, idx in zip(scores[0], ids[0]):
if score < min_score:
continue
p = products[idx]
# Apply filters
if categories and p["category"] not in categories:
continue
if brands and p["brand"] not in brands:
continue
if not (min_price <= p["price"] <= max_price):
continue
results.append({**p, "score": float(score)})
return {"matches": results}
except Exception as e:
return {"error": str(e)}
|