File size: 4,393 Bytes
3eaabcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from fastapi import FastAPI, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
import requests
import io
import faiss
import json
import os
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer

# Init FastAPI
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # you can restrict to your Vercel URL later
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)

# Load CLIP model once
print("🧠 Loading CLIP model...")
model = SentenceTransformer("clip-ViT-B-32")

# Load dataset
PRODUCTS_FILE = "products.json"
INDEX_FILE = "products.index"

with open(PRODUCTS_FILE, "r", encoding="utf-8", errors="ignore") as f:
    products = json.load(f)

# Build or load FAISS index
if os.path.exists(INDEX_FILE):
    print("📦 Loading existing FAISS index...")
    index = faiss.read_index(INDEX_FILE)
else:
    print("⚡ Building FAISS index from products.json (first startup only)...")
    # Encode product names (lightweight, avoids downloading images)
    texts = [p["name"] + " " + p["category"] + " " + p["brand"]
             for p in products]
    embeddings = model.encode(
        texts, convert_to_numpy=True, normalize_embeddings=True)
    index = faiss.IndexFlatIP(embeddings.shape[1])
    index.add(embeddings.astype("float32"))
    faiss.write_index(index, INDEX_FILE)
    print(f"✅ Saved FAISS index with {index.ntotal} vectors")


def embed_image(img: Image.Image):
    return model.encode(img, convert_to_numpy=True, normalize_embeddings=True)


def embed_text(query: str):
    return model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]


@app.post("/match")
async def match(
    file: UploadFile = None,
    image_url: str = Form(None),
    min_score: float = Form(0.6),
    top_k: int = Form(60),
    categories: str = Form(None),
    brands: str = Form(None),
    min_price: float = Form(0),
    max_price: float = Form(9999)
):
    try:
        # Get query image
        if file:
            img = Image.open(io.BytesIO(await file.read())).convert("RGB")
        elif image_url:
            img = Image.open(io.BytesIO(requests.get(
                image_url).content)).convert("RGB")
        else:
            return {"matches": []}

        # Encode query
        q_emb = embed_image(img).reshape(1, -1)

        # Search FAISS
        scores, ids = index.search(q_emb, top_k)

        # Parse filters
        categories = json.loads(categories) if categories else []
        brands = json.loads(brands) if brands else []

        # Collect results
        results = []
        for score, idx in zip(scores[0], ids[0]):
            if score < min_score:
                continue
            p = products[idx]

            # Apply filters
            if categories and p["category"] not in categories:
                continue
            if brands and p["brand"] not in brands:
                continue
            if not (min_price <= p["price"] <= max_price):
                continue

            results.append({**p, "score": float(score)})
        return {"matches": results}
    except Exception as e:
        return {"error": str(e)}


@app.post("/search_text")
async def search_text(
    query: str = Form(...),
    min_score: float = Form(0.6),
    top_k: int = Form(60),
    categories: str = Form(None),
    brands: str = Form(None),
    min_price: float = Form(0),
    max_price: float = Form(9999)
):
    try:
        # Encode text query
        q_emb = embed_text(query).reshape(1, -1)

        # Search FAISS
        scores, ids = index.search(q_emb, top_k)

        # Parse filters
        categories = json.loads(categories) if categories else []
        brands = json.loads(brands) if brands else []

        # Collect results
        results = []
        for score, idx in zip(scores[0], ids[0]):
            if score < min_score:
                continue
            p = products[idx]

            # Apply filters
            if categories and p["category"] not in categories:
                continue
            if brands and p["brand"] not in brands:
                continue
            if not (min_price <= p["price"] <= max_price):
                continue

            results.append({**p, "score": float(score)})
        return {"matches": results}
    except Exception as e:
        return {"error": str(e)}