File size: 2,606 Bytes
3eaabcf
fac2e05
 
3eaabcf
fac2e05
 
3eaabcf
fac2e05
 
 
 
 
 
 
3eaabcf
 
fac2e05
dae3a4f
3eaabcf
 
dae3a4f
3eaabcf
 
fac2e05
3eaabcf
 
fac2e05
 
3eaabcf
 
fac2e05
 
3eaabcf
fac2e05
 
 
3eaabcf
 
fac2e05
 
 
3eaabcf
 
 
dae3a4f
fac2e05
 
 
 
 
 
dae3a4f
 
 
 
 
 
fac2e05
dae3a4f
 
 
 
 
fac2e05
dae3a4f
fac2e05
dae3a4f
 
 
 
 
 
 
 
 
 
fac2e05
 
dae3a4f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import json
import faiss
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from PIL import Image
import io

# Fix caching permissions for Hugging Face
os.environ["HF_HOME"] = "./cache"
os.environ["TRANSFORMERS_CACHE"] = "./cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache"

app = FastAPI()

# Enable CORS (so frontend on Netlify can call backend on HF)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # for now allow all, can restrict to Netlify domain
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load product metadata
with open("id_mapping.json", "r", encoding="utf-8") as f:
    products = json.load(f)

# Load FAISS index
index = faiss.read_index("products.index")

# Load CLIP model
print("🧠 Loading CLIP model...")
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache")


@app.get("/")
def root():
    return {"message": "πŸš€ Visual Product Matcher API is running!"}


@app.post("/search_text")
def search_text(query: str = Form(...), top_k: int = 5, min_score: float = 0.0):
    """
    Search products using text query.
    """
    query_emb = model.encode([query], convert_to_numpy=True)
    distances, indices = index.search(query_emb, top_k)

    results = []
    for score, idx in zip(distances[0], indices[0]):
        if score >= min_score:  # filter by threshold
            item = products[idx]
            item["score"] = float(score)
            results.append(item)

    return {"matches": results}


@app.post("/match")  # πŸ‘ˆ Renamed to match frontend
async def search_image(file: UploadFile = File(None), image_url: str = Form(None), top_k: int = 5, min_score: float = 0.0):
    """
    Search products using image query (upload or URL).
    """
    if file:
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    elif image_url:
        import requests
        response = requests.get(image_url)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
    else:
        return {"error": "No image provided"}

    image_emb = model.encode([image], convert_to_numpy=True)
    distances, indices = index.search(image_emb, top_k)

    results = []
    for score, idx in zip(distances[0], indices[0]):
        if score >= min_score:
            item = products[idx]
            item["score"] = float(score)
            results.append(item)

    return {"matches": results}