Diwank Singh commited on
Commit ·
26d97be
1
Parent(s): 6883665
fixes
Browse files- Dockerfile +8 -5
- backend/ai.py +81 -31
- backend/data/metadata.csv +0 -0
- backend/generate_metadata.py +69 -51
- backend/main.py +103 -49
- backend/store.py +92 -51
Dockerfile
CHANGED
|
@@ -1,17 +1,20 @@
|
|
| 1 |
FROM python:3.10
|
|
|
|
| 2 |
WORKDIR /app
|
|
|
|
| 3 |
COPY . /app
|
|
|
|
| 4 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 5 |
|
|
|
|
| 6 |
RUN python -c "from transformers import CLIPModel, CLIPProcessor; \
|
| 7 |
-
|
| 8 |
-
|
| 9 |
|
| 10 |
-
# Lock HuggingFace to use only the cached model at runtime
|
| 11 |
ENV TRANSFORMERS_OFFLINE=1
|
| 12 |
ENV HF_DATASETS_OFFLINE=1
|
|
|
|
| 13 |
|
| 14 |
EXPOSE 7860
|
| 15 |
-
CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "backend.main:app", "-b", "0.0.0.0:7860"]
|
| 16 |
-
|
| 17 |
|
|
|
|
|
|
| 1 |
FROM python:3.10
|
| 2 |
+
|
| 3 |
WORKDIR /app
|
| 4 |
+
|
| 5 |
COPY . /app
|
| 6 |
+
|
| 7 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
|
| 9 |
+
# Pre-download model
|
| 10 |
RUN python -c "from transformers import CLIPModel, CLIPProcessor; \
|
| 11 |
+
CLIPModel.from_pretrained('openai/clip-vit-base-patch32'); \
|
| 12 |
+
CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')"
|
| 13 |
|
|
|
|
| 14 |
ENV TRANSFORMERS_OFFLINE=1
|
| 15 |
ENV HF_DATASETS_OFFLINE=1
|
| 16 |
+
ENV TOKENIZERS_PARALLELISM=false
|
| 17 |
|
| 18 |
EXPOSE 7860
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "backend.main:app", "-b", "0.0.0.0:7860"]
|
backend/ai.py
CHANGED
|
@@ -1,49 +1,99 @@
|
|
| 1 |
-
import
|
| 2 |
-
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 3 |
|
|
|
|
| 4 |
import torch
|
|
|
|
| 5 |
from transformers import CLIPModel, CLIPProcessor
|
| 6 |
|
| 7 |
-
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
@torch.no_grad()
|
| 19 |
-
def encode_image(img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
if img is None:
|
| 21 |
-
raise
|
| 22 |
|
| 23 |
-
|
| 24 |
-
batch = processor(images=img, return_tensors="pt").to(device)
|
| 25 |
-
vec = model.get_image_features(**batch)
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
elif hasattr(vec, "last_hidden_state"):
|
| 32 |
-
vec = vec.last_hidden_state[:, 0, :]
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
vec = vec / vec.norm(dim=-1, keepdim=True)
|
| 35 |
-
return vec.cpu().numpy().astype("float32")
|
| 36 |
|
| 37 |
-
|
| 38 |
-
def encode_text(text):
|
| 39 |
-
inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
|
| 40 |
-
vec = model.get_text_features(**inputs)
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
elif hasattr(vec, "last_hidden_state"):
|
| 46 |
-
vec = vec.last_hidden_state[:, 0, :]
|
| 47 |
|
| 48 |
-
|
| 49 |
-
return vec.cpu().numpy().astype("float32")
|
|
|
|
| 1 |
+
import logging
|
|
|
|
| 2 |
|
| 3 |
+
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
from transformers import CLIPModel, CLIPProcessor
|
| 7 |
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
+
# Stable CLIP model for Apple Silicon / CPU
|
| 11 |
+
_MODEL_NAME = "openai/clip-vit-base-patch32"
|
| 12 |
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Lazy initialization
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
|
| 17 |
+
_device: torch.device | None = None
|
| 18 |
+
_model: CLIPModel | None = None
|
| 19 |
+
_processor: CLIPProcessor | None = None
|
| 20 |
|
| 21 |
+
|
| 22 |
+
def _get_device() -> torch.device:
|
| 23 |
+
# Force CPU for stability on M1/M2 Macs
|
| 24 |
+
return torch.device("cpu")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _load() -> None:
|
| 28 |
+
"""Load model and processor exactly once."""
|
| 29 |
+
global _device, _model, _processor
|
| 30 |
+
|
| 31 |
+
if _model is not None:
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
_device = _get_device()
|
| 35 |
+
|
| 36 |
+
logger.info(
|
| 37 |
+
"Loading CLIP model '%s' on %s…",
|
| 38 |
+
_MODEL_NAME,
|
| 39 |
+
_device,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
_model = CLIPModel.from_pretrained(_MODEL_NAME).to(_device)
|
| 44 |
+
_model.eval()
|
| 45 |
+
|
| 46 |
+
_processor = CLIPProcessor.from_pretrained(_MODEL_NAME)
|
| 47 |
+
|
| 48 |
+
except Exception as exc:
|
| 49 |
+
_model = None
|
| 50 |
+
_processor = None
|
| 51 |
+
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
f"Failed to load CLIP model '{_MODEL_NAME}'"
|
| 54 |
+
) from exc
|
| 55 |
+
|
| 56 |
+
logger.info("CLIP model ready.")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ---------------------------------------------------------------------------
|
| 60 |
+
# Public API
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
|
| 63 |
@torch.no_grad()
|
| 64 |
+
def encode_image(img: Image.Image) -> np.ndarray:
|
| 65 |
+
"""
|
| 66 |
+
Encode a PIL image into a normalized float32 embedding.
|
| 67 |
+
|
| 68 |
+
Returns
|
| 69 |
+
-------
|
| 70 |
+
np.ndarray
|
| 71 |
+
Shape (1, 512), dtype float32, L2-normalized.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
if img is None:
|
| 75 |
+
raise ValueError("encode_image() called with img=None")
|
| 76 |
|
| 77 |
+
_load()
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
batch = _processor(
|
| 80 |
+
images=img,
|
| 81 |
+
return_tensors="pt"
|
| 82 |
+
).to(_device)
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
# Forward pass through vision encoder
|
| 85 |
+
outputs = _model.vision_model(**batch)
|
| 86 |
+
|
| 87 |
+
# Extract pooled CLS embedding
|
| 88 |
+
vec = outputs.pooler_output
|
| 89 |
+
|
| 90 |
+
# L2 normalize
|
| 91 |
vec = vec / vec.norm(dim=-1, keepdim=True)
|
|
|
|
| 92 |
|
| 93 |
+
result = vec.cpu().numpy().astype("float32")
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
+
assert result.shape == (1, 768), (
|
| 96 |
+
f"Unexpected embedding shape: {result.shape}"
|
| 97 |
+
)
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
return result
|
|
|
backend/data/metadata.csv
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
backend/generate_metadata.py
CHANGED
|
@@ -1,78 +1,96 @@
|
|
| 1 |
import os
|
| 2 |
import csv
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
-
from ai import encode_image
|
| 5 |
-
# Root folder where product images are stored.
|
| 6 |
-
# Each subfolder represents a frame style (aviator, round, etc.)
|
| 7 |
|
| 8 |
BASE_DIR = os.path.dirname(__file__)
|
| 9 |
-
|
| 10 |
IMAGE_DIR = os.path.join(BASE_DIR, "data/images")
|
| 11 |
META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
|
| 12 |
-
|
| 13 |
-
# Keeping this explicit avoids relying on folder naming everywhere else.
|
| 14 |
STYLE_MAP = {
|
| 15 |
-
"aviator": "Aviator",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
"rectangle": "Rectangle",
|
| 17 |
}
|
| 18 |
-
# Simple material rotation to add demo variety.
|
| 19 |
-
# This keeps the dataset from feeling artificially uniform.
|
| 20 |
MATERIALS = ["Metal", "Plastic", "Steel"]
|
| 21 |
-
|
| 22 |
-
|
| 23 |
def is_valid_image(filename: str) -> bool:
|
| 24 |
-
|
| 25 |
-
|
| 26 |
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
continue
|
| 35 |
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
-
|
| 39 |
-
if not
|
| 40 |
continue
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
-
#
|
| 45 |
-
|
| 46 |
-
try:
|
| 47 |
-
image = Image.open(image_path).convert("RGB")
|
| 48 |
-
except Exception as e:
|
| 49 |
-
print(f"Skipping corrupted image: {image_path} ({e})")
|
| 50 |
continue
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
with open(META_FILE, "w", newline="", encoding="utf-8") as f:
|
| 72 |
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
| 73 |
writer.writeheader()
|
| 74 |
writer.writerows(rows)
|
| 75 |
-
else:
|
| 76 |
-
print("Warning: No products were generated. Check IMAGE_DIR and folder structure.")
|
| 77 |
|
| 78 |
-
print(f"Generated {pid} products
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import csv
|
| 3 |
+
import sys
|
| 4 |
from PIL import Image
|
| 5 |
+
from backend.ai import encode_image
|
|
|
|
|
|
|
| 6 |
|
| 7 |
BASE_DIR = os.path.dirname(__file__)
|
|
|
|
| 8 |
IMAGE_DIR = os.path.join(BASE_DIR, "data/images")
|
| 9 |
META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
|
| 10 |
+
|
|
|
|
| 11 |
STYLE_MAP = {
|
| 12 |
+
"aviator": "Aviator",
|
| 13 |
+
"round": "Round",
|
| 14 |
+
"square": "Square",
|
| 15 |
+
"rimless": "Rimless",
|
| 16 |
+
"transparent": "Transparent",
|
| 17 |
"rectangle": "Rectangle",
|
| 18 |
}
|
|
|
|
|
|
|
| 19 |
MATERIALS = ["Metal", "Plastic", "Steel"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
def is_valid_image(filename: str) -> bool:
|
| 23 |
+
"""Return True for supported image formats."""
|
| 24 |
+
return filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
|
| 25 |
|
| 26 |
|
| 27 |
+
def main(force: bool = False) -> None:
|
| 28 |
+
if os.path.exists(META_FILE) and not force:
|
| 29 |
+
print(f"{META_FILE} already exists. Use --force to overwrite.")
|
| 30 |
+
sys.exit(0)
|
| 31 |
|
| 32 |
+
rows = []
|
| 33 |
+
pid = 0
|
|
|
|
| 34 |
|
| 35 |
+
for folder in sorted(os.listdir(IMAGE_DIR)):
|
| 36 |
+
style = STYLE_MAP.get(folder.lower())
|
| 37 |
|
| 38 |
+
# Skip unknown folders (e.g. stray files or system artifacts)
|
| 39 |
+
if not style:
|
| 40 |
continue
|
| 41 |
|
| 42 |
+
folder_path = os.path.join(IMAGE_DIR, folder)
|
| 43 |
|
| 44 |
+
# Guard against loose files sitting directly inside IMAGE_DIR
|
| 45 |
+
if not os.path.isdir(folder_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
continue
|
| 47 |
|
| 48 |
+
for img in sorted(os.listdir(folder_path)):
|
| 49 |
+
if not is_valid_image(img):
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
image_path = os.path.join(folder_path, img)
|
| 53 |
+
|
| 54 |
+
# Skip corrupted images without aborting the whole run
|
| 55 |
+
try:
|
| 56 |
+
image = Image.open(image_path).convert("RGB")
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Skipping corrupted image: {image_path} ({e})")
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
# Generate embedding for similarity-based search
|
| 62 |
+
emb = encode_image(image)[0]
|
| 63 |
+
|
| 64 |
+
rows.append({
|
| 65 |
+
"product_id": pid,
|
| 66 |
+
"image": f"{folder}/{img}",
|
| 67 |
+
"brand": "Lenskart",
|
| 68 |
+
"material": MATERIALS[pid % len(MATERIALS)],
|
| 69 |
+
"price": 1800 + (pid % 6) * 300,
|
| 70 |
+
"style": style,
|
| 71 |
+
# f"{x:.9g}" preserves full float32 precision (vs default str())
|
| 72 |
+
"embedding": " ".join(f"{x:.9g}" for x in emb.tolist()),
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
pid += 1
|
| 76 |
+
|
| 77 |
+
if not rows:
|
| 78 |
+
print("Warning: No products were generated. Check IMAGE_DIR and folder structure.")
|
| 79 |
+
sys.exit(1)
|
| 80 |
+
|
| 81 |
with open(META_FILE, "w", newline="", encoding="utf-8") as f:
|
| 82 |
writer = csv.DictWriter(f, fieldnames=rows[0].keys())
|
| 83 |
writer.writeheader()
|
| 84 |
writer.writerows(rows)
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
print(f"Generated {pid} products → {META_FILE}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
import argparse
|
| 91 |
+
|
| 92 |
+
parser = argparse.ArgumentParser(description="Ingest product images into metadata CSV.")
|
| 93 |
+
parser.add_argument("--force", action="store_true", help="Overwrite existing metadata.csv")
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
main(force=args.force)
|
backend/main.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import csv
|
| 3 |
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
| 4 |
from fastapi import FastAPI, UploadFile, File, Query, HTTPException
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from fastapi.staticfiles import StaticFiles
|
|
@@ -11,22 +14,66 @@ from backend.store import VectorStore
|
|
| 11 |
from backend.feedback import get_boost, record_click
|
| 12 |
from backend.accuracy_test import run_accuracy_check
|
| 13 |
|
|
|
|
|
|
|
| 14 |
BASE_DIR = os.path.dirname(__file__)
|
| 15 |
-
BASE_URL = "https://diwank3221-visual-search-backend.hf.space"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
@asynccontextmanager
|
| 17 |
async def lifespan(app: FastAPI):
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
yield
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
app = FastAPI(title="Visual Product API", lifespan=lifespan)
|
| 24 |
|
| 25 |
-
# Serve product images
|
| 26 |
app.mount(
|
| 27 |
"/images",
|
| 28 |
StaticFiles(directory=os.path.join(BASE_DIR, "data/images")),
|
| 29 |
-
name="images"
|
| 30 |
)
|
| 31 |
|
| 32 |
app.add_middleware(
|
|
@@ -36,25 +83,37 @@ app.add_middleware(
|
|
| 36 |
allow_headers=["*"],
|
| 37 |
)
|
| 38 |
|
| 39 |
-
store = VectorStore()
|
| 40 |
-
PRODUCTS = {}
|
| 41 |
-
SEARCH_BASE = "https://www.lenskart.com/search?q="
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
try:
|
| 53 |
return store.classify(encode_image(image))
|
| 54 |
except Exception:
|
|
|
|
| 55 |
return None
|
| 56 |
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
@app.post("/search")
|
| 59 |
async def search(
|
| 60 |
file: UploadFile = File(...),
|
|
@@ -62,62 +121,57 @@ async def search(
|
|
| 62 |
max_price: int = Query(10000, ge=0),
|
| 63 |
material: str | None = None,
|
| 64 |
style: str | None = None,
|
| 65 |
-
frame: str | None = None
|
| 66 |
):
|
| 67 |
-
# Only guard image decoding – this is the most common real-world failure
|
| 68 |
try:
|
| 69 |
img = Image.open(file.file).convert("RGB")
|
| 70 |
except Exception:
|
| 71 |
-
raise HTTPException(status_code=400, detail="Invalid image file")
|
|
|
|
| 72 |
q = encode_image(img)
|
| 73 |
|
| 74 |
if not store.is_eyewear(q):
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
-
tag =
|
| 81 |
-
raw = store.search(q, k=60)
|
| 82 |
results = []
|
| 83 |
|
| 84 |
-
for item, score in
|
| 85 |
pid = int(item["product_id"])
|
| 86 |
p = PRODUCTS.get(pid)
|
| 87 |
if not p:
|
| 88 |
continue
|
| 89 |
|
| 90 |
price = int(p["price"])
|
| 91 |
-
if not (min_price <= price <= max_price):
|
| 92 |
-
|
| 93 |
-
if
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
r["score"] = score * get_boost(pid)
|
| 98 |
-
r["image_url"] = f"{BASE_URL}/images/{r['image']}"
|
| 99 |
-
r["buy_url"] = SEARCH_BASE + f"{r['style']} {r['material']} glasses under {r['price']}".replace(" ", "+")
|
| 100 |
-
results.append(r)
|
| 101 |
|
| 102 |
results.sort(key=lambda x: x["score"], reverse=True)
|
| 103 |
return {"tag": tag, "results": results[:8]}
|
| 104 |
-
# Product catalog
|
| 105 |
-
@app.get("/products")
|
| 106 |
-
def products():
|
| 107 |
-
out = []
|
| 108 |
|
| 109 |
-
for pid, p in PRODUCTS.items():
|
| 110 |
-
r = {k: v for k, v in p.items() if k != "embedding"}
|
| 111 |
-
r["boost"] = get_boost(pid)
|
| 112 |
-
r["image_url"] = f"{BASE_URL}/images/{r['image']}"
|
| 113 |
-
r["buy_url"] = SEARCH_BASE + f"{r['style']} {r['material']} glasses under {r['price']}".replace(" ", "+")
|
| 114 |
-
out.append(r)
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
out.sort(key=lambda x: x["boost"], reverse=True)
|
| 117 |
return {"count": len(out), "results": out}
|
| 118 |
|
| 119 |
-
|
| 120 |
@app.post("/click/{pid}")
|
| 121 |
def click(pid: int):
|
| 122 |
record_click(pid)
|
| 123 |
-
return {"status": "ok"}
|
|
|
|
| 1 |
+
import logging
|
| 2 |
import os
|
| 3 |
import csv
|
| 4 |
from contextlib import asynccontextmanager
|
| 5 |
+
from urllib.parse import quote_plus
|
| 6 |
+
|
| 7 |
from fastapi import FastAPI, UploadFile, File, Query, HTTPException
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 14 |
from backend.feedback import get_boost, record_click
|
| 15 |
from backend.accuracy_test import run_accuracy_check
|
| 16 |
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
BASE_DIR = os.path.dirname(__file__)
|
| 20 |
+
BASE_URL = os.getenv("BASE_URL", "https://diwank3221-visual-search-backend.hf.space")
|
| 21 |
+
SEARCH_BASE = "https://www.lenskart.com/search?q="
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Startup / shutdown
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
store: VectorStore
|
| 28 |
+
PRODUCTS: dict[int, dict] = {}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
@asynccontextmanager
|
| 32 |
async def lifespan(app: FastAPI):
|
| 33 |
+
global store, PRODUCTS
|
| 34 |
+
|
| 35 |
+
# Initialise heavy resources once, at startup — not at import time
|
| 36 |
+
store = VectorStore()
|
| 37 |
+
PRODUCTS = _load_catalog(os.path.join(BASE_DIR, "data/metadata.csv"))
|
| 38 |
+
|
| 39 |
+
logger.info("Starting backend health check…")
|
| 40 |
+
# run_accuracy_check()
|
| 41 |
+
logger.info("Backend ready — %d products loaded.", len(PRODUCTS))
|
| 42 |
+
|
| 43 |
yield
|
| 44 |
+
# (add any teardown here if needed)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _load_catalog(path: str) -> dict[int, dict]:
|
| 48 |
+
"""Read metadata.csv and return a dict keyed by product_id."""
|
| 49 |
+
products: dict[int, dict] = {}
|
| 50 |
+
|
| 51 |
+
with open(path, newline="", encoding="utf-8") as f:
|
| 52 |
+
for r in csv.DictReader(f):
|
| 53 |
+
try:
|
| 54 |
+
pid = int(r["product_id"])
|
| 55 |
+
except (KeyError, ValueError) as exc:
|
| 56 |
+
logger.warning("Skipping malformed CSV row (%s): %s", exc, r)
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
# Strip any accidental "images/" prefix left in the CSV
|
| 60 |
+
r["image"] = r["image"].removeprefix("images/")
|
| 61 |
+
# NOTE: .title() was removed — it corrupts paths ("img.jpg" → "Img.Jpg")
|
| 62 |
+
products[pid] = r
|
| 63 |
+
|
| 64 |
+
return products
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
# App
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
|
| 71 |
app = FastAPI(title="Visual Product API", lifespan=lifespan)
|
| 72 |
|
|
|
|
| 73 |
app.mount(
|
| 74 |
"/images",
|
| 75 |
StaticFiles(directory=os.path.join(BASE_DIR, "data/images")),
|
| 76 |
+
name="images",
|
| 77 |
)
|
| 78 |
|
| 79 |
app.add_middleware(
|
|
|
|
| 83 |
allow_headers=["*"],
|
| 84 |
)
|
| 85 |
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
# Helpers
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def _build_buy_url(style: str, material: str, price: str) -> str:
|
| 92 |
+
return SEARCH_BASE + quote_plus(f"{style} {material} glasses under {price}")
|
| 93 |
+
|
| 94 |
|
| 95 |
+
def _tag_image(image: Image.Image) -> str | None:
|
| 96 |
+
"""Classify the uploaded image into a style tag; returns None on failure."""
|
| 97 |
try:
|
| 98 |
return store.classify(encode_image(image))
|
| 99 |
except Exception:
|
| 100 |
+
logger.exception("tag_image failed")
|
| 101 |
return None
|
| 102 |
|
| 103 |
+
|
| 104 |
+
def _format_product(pid: int, p: dict, score: float | None = None) -> dict:
|
| 105 |
+
r = {k: v for k, v in p.items() if k != "embedding"}
|
| 106 |
+
r["image_url"] = f"{BASE_URL}/images/{p['image']}"
|
| 107 |
+
r["buy_url"] = _build_buy_url(p["style"], p["material"], p["price"])
|
| 108 |
+
if score is not None:
|
| 109 |
+
r["score"] = score
|
| 110 |
+
return r
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
# Routes
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
@app.post("/search")
|
| 118 |
async def search(
|
| 119 |
file: UploadFile = File(...),
|
|
|
|
| 121 |
max_price: int = Query(10000, ge=0),
|
| 122 |
material: str | None = None,
|
| 123 |
style: str | None = None,
|
| 124 |
+
frame: str | None = None,
|
| 125 |
):
|
|
|
|
| 126 |
try:
|
| 127 |
img = Image.open(file.file).convert("RGB")
|
| 128 |
except Exception:
|
| 129 |
+
raise HTTPException(status_code=400, detail="Invalid image file.")
|
| 130 |
+
|
| 131 |
q = encode_image(img)
|
| 132 |
|
| 133 |
if not store.is_eyewear(q):
|
| 134 |
+
raise HTTPException(
|
| 135 |
+
status_code=400,
|
| 136 |
+
detail="No eyewear detected. Please upload a glasses image.",
|
| 137 |
+
)
|
| 138 |
|
| 139 |
+
tag = _tag_image(img)
|
|
|
|
| 140 |
results = []
|
| 141 |
|
| 142 |
+
for item, score in store.search(q, k=60):
|
| 143 |
pid = int(item["product_id"])
|
| 144 |
p = PRODUCTS.get(pid)
|
| 145 |
if not p:
|
| 146 |
continue
|
| 147 |
|
| 148 |
price = int(p["price"])
|
| 149 |
+
if not (min_price <= price <= max_price):
|
| 150 |
+
continue
|
| 151 |
+
if material and p["material"] != material:
|
| 152 |
+
continue
|
| 153 |
+
if style and p["style"] != style:
|
| 154 |
+
continue
|
| 155 |
+
if frame and p["style"] != frame:
|
| 156 |
+
continue
|
| 157 |
|
| 158 |
+
results.append(_format_product(pid, p, score=score * get_boost(pid)))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
results.sort(key=lambda x: x["score"], reverse=True)
|
| 161 |
return {"tag": tag, "results": results[:8]}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
@app.get("/products")
|
| 165 |
+
def products():
|
| 166 |
+
out = [
|
| 167 |
+
{**_format_product(pid, p), "boost": get_boost(pid)}
|
| 168 |
+
for pid, p in PRODUCTS.items()
|
| 169 |
+
]
|
| 170 |
out.sort(key=lambda x: x["boost"], reverse=True)
|
| 171 |
return {"count": len(out), "results": out}
|
| 172 |
|
| 173 |
+
|
| 174 |
@app.post("/click/{pid}")
|
| 175 |
def click(pid: int):
|
| 176 |
record_click(pid)
|
| 177 |
+
return {"status": "ok"}
|
backend/store.py
CHANGED
|
@@ -1,82 +1,123 @@
|
|
| 1 |
import csv
|
|
|
|
| 2 |
import os
|
|
|
|
|
|
|
| 3 |
import faiss
|
| 4 |
import numpy as np
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
BASE_DIR = os.path.dirname(__file__)
|
| 7 |
META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
|
| 8 |
INDEX_FILE = os.path.join(BASE_DIR, "index.faiss")
|
| 9 |
CENTROID_FILE = os.path.join(BASE_DIR, "centroids.npy")
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class VectorStore:
|
| 13 |
"""
|
| 14 |
-
FAISS
|
| 15 |
-
- similarity search
|
| 16 |
-
-
|
|
|
|
| 17 |
"""
|
| 18 |
-
def is_eyewear(self, q, threshold=0.53):
|
| 19 |
-
q = q / np.linalg.norm(q, axis=1, keepdims=True)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
def __init__(self):
|
| 26 |
-
self.meta = self._load_meta()
|
| 27 |
-
# Load cached index if available, otherwise build fresh
|
| 28 |
if os.path.exists(INDEX_FILE):
|
| 29 |
-
|
|
|
|
| 30 |
else:
|
| 31 |
-
self._build_index()
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
with open(META_FILE, newline="", encoding="utf-8") as f:
|
| 36 |
rows = list(csv.DictReader(f))
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
-
def _build_index(self):
|
| 41 |
vectors = []
|
| 42 |
for r in self.meta:
|
| 43 |
v = np.fromstring(r["embedding"], sep=" ").astype("float32")
|
| 44 |
-
|
| 45 |
-
vectors.append(v)
|
| 46 |
|
| 47 |
-
|
| 48 |
self.index = faiss.IndexFlatIP(DIM)
|
| 49 |
-
self.index.add(
|
| 50 |
faiss.write_index(self.index, INDEX_FILE)
|
|
|
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def _build_centroids(self):
|
| 55 |
-
clusters = defaultdict(list)
|
| 56 |
for r in self.meta:
|
| 57 |
v = np.fromstring(r["embedding"], sep=" ").astype("float32")
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
self.centroids = {}
|
| 61 |
for style, vecs in clusters.items():
|
| 62 |
c = np.mean(vecs, axis=0)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
D, I = self.index.search(q, k)
|
| 72 |
-
return [
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import csv
|
| 2 |
+
import logging
|
| 3 |
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
import faiss
|
| 7 |
import numpy as np
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
BASE_DIR = os.path.dirname(__file__)
|
| 12 |
META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
|
| 13 |
INDEX_FILE = os.path.join(BASE_DIR, "index.faiss")
|
| 14 |
CENTROID_FILE = os.path.join(BASE_DIR, "centroids.npy")
|
| 15 |
+
|
| 16 |
+
# CLIP embedding dimension (fixed for patrickjohncyh/fashion-clip)
|
| 17 |
+
DIM = 768
|
| 18 |
+
# Minimum cosine similarity to the nearest style centroid for a query
|
| 19 |
+
# to be considered eyewear. Calibrated empirically on the current dataset.
|
| 20 |
+
_EYEWEAR_THRESHOLD = 0.53
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _normalise(x: np.ndarray) -> np.ndarray:
|
| 24 |
+
"""L2-normalise each row of a 2-D array in-place-free fashion."""
|
| 25 |
+
return x / np.linalg.norm(x, axis=-1, keepdims=True)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
class VectorStore:
|
| 29 |
"""
|
| 30 |
+
FAISS-backed vector store providing:
|
| 31 |
+
- k-NN similarity search over product embeddings
|
| 32 |
+
- coarse style classification via centroid cosine similarity
|
| 33 |
+
- eyewear gating to reject non-glasses queries
|
| 34 |
"""
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
def __init__(self) -> None:
|
| 37 |
+
self.meta = self._load_meta()
|
| 38 |
|
| 39 |
+
if not self.meta:
|
| 40 |
+
raise RuntimeError(
|
| 41 |
+
f"No products found in {META_FILE}. Run ingest.py first."
|
| 42 |
+
)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
| 44 |
if os.path.exists(INDEX_FILE):
|
| 45 |
+
self.index = faiss.read_index(INDEX_FILE)
|
| 46 |
+
logger.info("Loaded FAISS index from cache (%d vectors).", self.index.ntotal)
|
| 47 |
else:
|
| 48 |
+
self._build_index()
|
| 49 |
+
|
| 50 |
+
# Always rebuild centroids from the CSV so they stay in sync with
|
| 51 |
+
# the metadata even when the FAISS index is loaded from cache.
|
| 52 |
+
self._build_centroids()
|
| 53 |
+
|
| 54 |
+
# ------------------------------------------------------------------
|
| 55 |
+
# Internal helpers
|
| 56 |
+
# ------------------------------------------------------------------
|
| 57 |
+
|
| 58 |
+
def _load_meta(self) -> list[dict]:
|
| 59 |
with open(META_FILE, newline="", encoding="utf-8") as f:
|
| 60 |
rows = list(csv.DictReader(f))
|
| 61 |
+
logger.info("Loaded %d products from metadata.", len(rows))
|
| 62 |
+
return rows
|
| 63 |
|
| 64 |
+
def _build_index(self) -> None:
|
| 65 |
vectors = []
|
| 66 |
for r in self.meta:
|
| 67 |
v = np.fromstring(r["embedding"], sep=" ").astype("float32")
|
| 68 |
+
vectors.append(_normalise(v[np.newaxis])[0])
|
|
|
|
| 69 |
|
| 70 |
+
matrix = np.vstack(vectors)
|
| 71 |
self.index = faiss.IndexFlatIP(DIM)
|
| 72 |
+
self.index.add(matrix)
|
| 73 |
faiss.write_index(self.index, INDEX_FILE)
|
| 74 |
+
logger.info("Built FAISS index with %d vectors.", self.index.ntotal)
|
| 75 |
|
| 76 |
+
def _build_centroids(self) -> None:
|
| 77 |
+
clusters: dict[str, list[np.ndarray]] = defaultdict(list)
|
|
|
|
|
|
|
| 78 |
for r in self.meta:
|
| 79 |
v = np.fromstring(r["embedding"], sep=" ").astype("float32")
|
| 80 |
+
clusters[r["style"]].append(_normalise(v[np.newaxis])[0])
|
| 81 |
+
|
| 82 |
+
self.centroids: dict[str, np.ndarray] = {}
|
| 83 |
for style, vecs in clusters.items():
|
| 84 |
c = np.mean(vecs, axis=0)
|
| 85 |
+
self.centroids[style] = _normalise(c[np.newaxis])[0]
|
| 86 |
+
|
| 87 |
+
logger.info("Built %d style centroids.", len(self.centroids))
|
| 88 |
+
|
| 89 |
+
# ------------------------------------------------------------------
|
| 90 |
+
# Public API
|
| 91 |
+
# ------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
def is_eyewear(self, q: np.ndarray, threshold: float = _EYEWEAR_THRESHOLD) -> bool:
|
| 94 |
+
"""Return True if *q* is close enough to any style centroid."""
|
| 95 |
+
q = _normalise(q)
|
| 96 |
+
scores = [float(np.dot(q[0], c)) for c in self.centroids.values()]
|
| 97 |
+
return bool(scores) and max(scores) > threshold
|
| 98 |
+
|
| 99 |
+
def search(self, q: np.ndarray, k: int = 40) -> list[tuple[dict, float]]:
|
| 100 |
+
"""
|
| 101 |
+
Return the *k* most similar products to query *q*.
|
| 102 |
+
|
| 103 |
+
Filters out FAISS sentinel index -1, which is returned when
|
| 104 |
+
k > index.ntotal (not enough vectors to fill the result set).
|
| 105 |
+
"""
|
| 106 |
+
q = _normalise(q)
|
| 107 |
+
k = min(k, self.index.ntotal) # guard: k must not exceed index size
|
| 108 |
D, I = self.index.search(q, k)
|
| 109 |
+
return [
|
| 110 |
+
(self.meta[i], float(score))
|
| 111 |
+
for i, score in zip(I[0], D[0])
|
| 112 |
+
if i != -1 # FAISS sentinel for unfilled slots
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
def classify(self, q: np.ndarray) -> str | None:
|
| 116 |
+
"""Return the style name whose centroid is closest to *q*."""
|
| 117 |
+
if not self.centroids:
|
| 118 |
+
return None
|
| 119 |
+
q = _normalise(q)
|
| 120 |
+
return max(
|
| 121 |
+
self.centroids,
|
| 122 |
+
key=lambda style: float(np.dot(q[0], self.centroids[style])),
|
| 123 |
+
)
|