Diwank Singh commited on
Commit
26d97be
·
1 Parent(s): 6883665
Dockerfile CHANGED
@@ -1,17 +1,20 @@
1
  FROM python:3.10
 
2
  WORKDIR /app
 
3
  COPY . /app
 
4
  RUN pip install --no-cache-dir -r requirements.txt
5
 
 
6
  RUN python -c "from transformers import CLIPModel, CLIPProcessor; \
7
- CLIPModel.from_pretrained('openai/clip-vit-base-patch32'); \
8
- CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')"
9
 
10
- # Lock HuggingFace to use only the cached model at runtime
11
  ENV TRANSFORMERS_OFFLINE=1
12
  ENV HF_DATASETS_OFFLINE=1
 
13
 
14
  EXPOSE 7860
15
- CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "backend.main:app", "-b", "0.0.0.0:7860"]
16
-
17
 
 
 
1
  FROM python:3.10
2
+
3
  WORKDIR /app
4
+
5
  COPY . /app
6
+
7
  RUN pip install --no-cache-dir -r requirements.txt
8
 
9
+ # Pre-download model
10
  RUN python -c "from transformers import CLIPModel, CLIPProcessor; \
11
+ CLIPModel.from_pretrained('openai/clip-vit-base-patch32'); \
12
+ CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32')"
13
 
 
14
  ENV TRANSFORMERS_OFFLINE=1
15
  ENV HF_DATASETS_OFFLINE=1
16
+ ENV TOKENIZERS_PARALLELISM=false
17
 
18
  EXPOSE 7860
 
 
19
 
20
+ CMD ["gunicorn", "-k", "uvicorn.workers.UvicornWorker", "backend.main:app", "-b", "0.0.0.0:7860"]
backend/ai.py CHANGED
@@ -1,49 +1,99 @@
1
- import os
2
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
 
 
4
  import torch
 
5
  from transformers import CLIPModel, CLIPProcessor
6
 
7
- MODEL = "openai/clip-vit-base-patch32"
8
 
 
 
9
 
10
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
11
- print("CLIP running on:", device)
 
12
 
13
- model = CLIPModel.from_pretrained(MODEL).to(device)
14
- model.eval()
 
15
 
16
- processor = CLIPProcessor.from_pretrained(MODEL, use_fast=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @torch.no_grad()
19
- def encode_image(img):
 
 
 
 
 
 
 
 
 
20
  if img is None:
21
- raise RuntimeError("embed_image() called with empty image")
22
 
23
- img = img.resize((224, 224))
24
- batch = processor(images=img, return_tensors="pt").to(device)
25
- vec = model.get_image_features(**batch)
26
 
27
- # fashion-clip may return a wrapped object instead of a raw tensor
28
- if not isinstance(vec, torch.Tensor):
29
- if hasattr(vec, "pooler_output") and vec.pooler_output is not None:
30
- vec = vec.pooler_output
31
- elif hasattr(vec, "last_hidden_state"):
32
- vec = vec.last_hidden_state[:, 0, :]
33
 
 
 
 
 
 
 
 
34
  vec = vec / vec.norm(dim=-1, keepdim=True)
35
- return vec.cpu().numpy().astype("float32")
36
 
37
- @torch.no_grad()
38
- def encode_text(text):
39
- inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
40
- vec = model.get_text_features(**inputs)
41
 
42
- if not isinstance(vec, torch.Tensor):
43
- if hasattr(vec, "pooler_output") and vec.pooler_output is not None:
44
- vec = vec.pooler_output
45
- elif hasattr(vec, "last_hidden_state"):
46
- vec = vec.last_hidden_state[:, 0, :]
47
 
48
- vec = vec / vec.norm(dim=-1, keepdim=True)
49
- return vec.cpu().numpy().astype("float32")
 
1
+ import logging
 
2
 
3
+ import numpy as np
4
  import torch
5
+ from PIL import Image
6
  from transformers import CLIPModel, CLIPProcessor
7
 
8
+ logger = logging.getLogger(__name__)
9
 
10
+ # Stable CLIP model for Apple Silicon / CPU
11
+ _MODEL_NAME = "openai/clip-vit-base-patch32"
12
 
13
+ # ---------------------------------------------------------------------------
14
+ # Lazy initialization
15
+ # ---------------------------------------------------------------------------
16
 
17
+ _device: torch.device | None = None
18
+ _model: CLIPModel | None = None
19
+ _processor: CLIPProcessor | None = None
20
 
21
+
22
+ def _get_device() -> torch.device:
23
+ # Force CPU for stability on M1/M2 Macs
24
+ return torch.device("cpu")
25
+
26
+
27
+ def _load() -> None:
28
+ """Load model and processor exactly once."""
29
+ global _device, _model, _processor
30
+
31
+ if _model is not None:
32
+ return
33
+
34
+ _device = _get_device()
35
+
36
+ logger.info(
37
+ "Loading CLIP model '%s' on %s…",
38
+ _MODEL_NAME,
39
+ _device,
40
+ )
41
+
42
+ try:
43
+ _model = CLIPModel.from_pretrained(_MODEL_NAME).to(_device)
44
+ _model.eval()
45
+
46
+ _processor = CLIPProcessor.from_pretrained(_MODEL_NAME)
47
+
48
+ except Exception as exc:
49
+ _model = None
50
+ _processor = None
51
+
52
+ raise RuntimeError(
53
+ f"Failed to load CLIP model '{_MODEL_NAME}'"
54
+ ) from exc
55
+
56
+ logger.info("CLIP model ready.")
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Public API
61
+ # ---------------------------------------------------------------------------
62
 
63
  @torch.no_grad()
64
+ def encode_image(img: Image.Image) -> np.ndarray:
65
+ """
66
+ Encode a PIL image into a normalized float32 embedding.
67
+
68
+ Returns
69
+ -------
70
+ np.ndarray
71
+ Shape (1, 512), dtype float32, L2-normalized.
72
+ """
73
+
74
  if img is None:
75
+ raise ValueError("encode_image() called with img=None")
76
 
77
+ _load()
 
 
78
 
79
+ batch = _processor(
80
+ images=img,
81
+ return_tensors="pt"
82
+ ).to(_device)
 
 
83
 
84
+ # Forward pass through vision encoder
85
+ outputs = _model.vision_model(**batch)
86
+
87
+ # Extract pooled CLS embedding
88
+ vec = outputs.pooler_output
89
+
90
+ # L2 normalize
91
  vec = vec / vec.norm(dim=-1, keepdim=True)
 
92
 
93
+ result = vec.cpu().numpy().astype("float32")
 
 
 
94
 
95
+ assert result.shape == (1, 768), (
96
+ f"Unexpected embedding shape: {result.shape}"
97
+ )
 
 
98
 
99
+ return result
 
backend/data/metadata.csv CHANGED
The diff for this file is too large to render. See raw diff
 
backend/generate_metadata.py CHANGED
@@ -1,78 +1,96 @@
1
  import os
2
  import csv
 
3
  from PIL import Image
4
- from ai import encode_image
5
- # Root folder where product images are stored.
6
- # Each subfolder represents a frame style (aviator, round, etc.)
7
 
8
  BASE_DIR = os.path.dirname(__file__)
9
-
10
  IMAGE_DIR = os.path.join(BASE_DIR, "data/images")
11
  META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
12
- # Mapping of folder names to readable product styles.
13
- # Keeping this explicit avoids relying on folder naming everywhere else.
14
  STYLE_MAP = {
15
- "aviator": "Aviator","round": "Round","square": "Square","rimless": "Rimless","transparent":"Transparent",
 
 
 
 
16
  "rectangle": "Rectangle",
17
  }
18
- # Simple material rotation to add demo variety.
19
- # This keeps the dataset from feeling artificially uniform.
20
  MATERIALS = ["Metal", "Plastic", "Steel"]
21
- rows = []
22
- pid = 0
23
  def is_valid_image(filename: str) -> bool:
24
- """Small helper to filter supported image formats."""
25
- return filename.lower().endswith((".jpg", ".png", ".jpeg", ".webp"))
26
 
27
 
28
- # Walk through each style folder and convert images into product records
29
- for folder in sorted(os.listdir(IMAGE_DIR)):
30
- style = STYLE_MAP.get(folder.lower())
 
31
 
32
- # Skip unknown folders (e.g., stray files or system artifacts)
33
- if not style:
34
- continue
35
 
36
- folder_path = os.path.join(IMAGE_DIR, folder)
 
37
 
38
- for img in sorted(os.listdir(folder_path)):
39
- if not is_valid_image(img):
40
  continue
41
 
42
- image_path = os.path.join(folder_path, img)
43
 
44
- # Try loading the image. If an image fails, we skip it instead of
45
- # breaking the whole dataset generation process.
46
- try:
47
- image = Image.open(image_path).convert("RGB")
48
- except Exception as e:
49
- print(f"Skipping corrupted image: {image_path} ({e})")
50
  continue
51
 
52
- # Generate embedding for similarity-based search and recommendations
53
- emb = encode_image(image)[0]
54
-
55
- # Create a product entry.
56
- # Some values are generated programmatically to simulate real catalog variety.
57
- rows.append({
58
- "product_id": pid,
59
- "image": f"{folder}/{img}",
60
- "brand": "Lenskart",
61
- "material": MATERIALS[pid % len(MATERIALS)],
62
- "price": 1800 + (pid % 6) * 300, # small price variation for realism
63
- "style": style,
64
- "embedding": " ".join(map(str, emb.tolist()))
65
- })
66
-
67
- pid += 1
68
-
69
- # Write all generated products to a CSV file
70
- if rows:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  with open(META_FILE, "w", newline="", encoding="utf-8") as f:
72
  writer = csv.DictWriter(f, fieldnames=rows[0].keys())
73
  writer.writeheader()
74
  writer.writerows(rows)
75
- else:
76
- print("Warning: No products were generated. Check IMAGE_DIR and folder structure.")
77
 
78
- print(f"Generated {pid} products and stored them in {META_FILE}")
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import csv
3
+ import sys
4
  from PIL import Image
5
+ from backend.ai import encode_image
 
 
6
 
7
  BASE_DIR = os.path.dirname(__file__)
 
8
  IMAGE_DIR = os.path.join(BASE_DIR, "data/images")
9
  META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
10
+
 
11
  STYLE_MAP = {
12
+ "aviator": "Aviator",
13
+ "round": "Round",
14
+ "square": "Square",
15
+ "rimless": "Rimless",
16
+ "transparent": "Transparent",
17
  "rectangle": "Rectangle",
18
  }
 
 
19
  MATERIALS = ["Metal", "Plastic", "Steel"]
20
+
21
+
22
  def is_valid_image(filename: str) -> bool:
23
+ """Return True for supported image formats."""
24
+ return filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp"))
25
 
26
 
27
+ def main(force: bool = False) -> None:
28
+ if os.path.exists(META_FILE) and not force:
29
+ print(f"{META_FILE} already exists. Use --force to overwrite.")
30
+ sys.exit(0)
31
 
32
+ rows = []
33
+ pid = 0
 
34
 
35
+ for folder in sorted(os.listdir(IMAGE_DIR)):
36
+ style = STYLE_MAP.get(folder.lower())
37
 
38
+ # Skip unknown folders (e.g. stray files or system artifacts)
39
+ if not style:
40
  continue
41
 
42
+ folder_path = os.path.join(IMAGE_DIR, folder)
43
 
44
+ # Guard against loose files sitting directly inside IMAGE_DIR
45
+ if not os.path.isdir(folder_path):
 
 
 
 
46
  continue
47
 
48
+ for img in sorted(os.listdir(folder_path)):
49
+ if not is_valid_image(img):
50
+ continue
51
+
52
+ image_path = os.path.join(folder_path, img)
53
+
54
+ # Skip corrupted images without aborting the whole run
55
+ try:
56
+ image = Image.open(image_path).convert("RGB")
57
+ except Exception as e:
58
+ print(f"Skipping corrupted image: {image_path} ({e})")
59
+ continue
60
+
61
+ # Generate embedding for similarity-based search
62
+ emb = encode_image(image)[0]
63
+
64
+ rows.append({
65
+ "product_id": pid,
66
+ "image": f"{folder}/{img}",
67
+ "brand": "Lenskart",
68
+ "material": MATERIALS[pid % len(MATERIALS)],
69
+ "price": 1800 + (pid % 6) * 300,
70
+ "style": style,
71
+ # f"{x:.9g}" preserves full float32 precision (vs default str())
72
+ "embedding": " ".join(f"{x:.9g}" for x in emb.tolist()),
73
+ })
74
+
75
+ pid += 1
76
+
77
+ if not rows:
78
+ print("Warning: No products were generated. Check IMAGE_DIR and folder structure.")
79
+ sys.exit(1)
80
+
81
  with open(META_FILE, "w", newline="", encoding="utf-8") as f:
82
  writer = csv.DictWriter(f, fieldnames=rows[0].keys())
83
  writer.writeheader()
84
  writer.writerows(rows)
 
 
85
 
86
+ print(f"Generated {pid} products {META_FILE}")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ import argparse
91
+
92
+ parser = argparse.ArgumentParser(description="Ingest product images into metadata CSV.")
93
+ parser.add_argument("--force", action="store_true", help="Overwrite existing metadata.csv")
94
+ args = parser.parse_args()
95
+
96
+ main(force=args.force)
backend/main.py CHANGED
@@ -1,6 +1,9 @@
 
1
  import os
2
  import csv
3
  from contextlib import asynccontextmanager
 
 
4
  from fastapi import FastAPI, UploadFile, File, Query, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.staticfiles import StaticFiles
@@ -11,22 +14,66 @@ from backend.store import VectorStore
11
  from backend.feedback import get_boost, record_click
12
  from backend.accuracy_test import run_accuracy_check
13
 
 
 
14
  BASE_DIR = os.path.dirname(__file__)
15
- BASE_URL = "https://diwank3221-visual-search-backend.hf.space"
 
 
 
 
 
 
 
 
 
 
16
  @asynccontextmanager
17
  async def lifespan(app: FastAPI):
18
- print("\nStarting backend health check...\n")
19
- run_accuracy_check()
 
 
 
 
 
 
 
 
20
  yield
21
- print("\nBackend ready.\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  app = FastAPI(title="Visual Product API", lifespan=lifespan)
24
 
25
- # Serve product images
26
  app.mount(
27
  "/images",
28
  StaticFiles(directory=os.path.join(BASE_DIR, "data/images")),
29
- name="images"
30
  )
31
 
32
  app.add_middleware(
@@ -36,25 +83,37 @@ app.add_middleware(
36
  allow_headers=["*"],
37
  )
38
 
39
- store = VectorStore()
40
- PRODUCTS = {}
41
- SEARCH_BASE = "https://www.lenskart.com/search?q="
42
 
43
- # Load catalog safely
44
- with open(os.path.join(BASE_DIR, "data/metadata.csv"), newline="", encoding="utf-8") as f:
45
- for r in csv.DictReader(f):
46
- r["image"] = r["image"].replace("images/", "")
47
- r["image"] = r["image"].title()
48
- PRODUCTS[int(r["product_id"])] = r
 
49
 
50
- # Helper
51
- def tag_image(image):
52
  try:
53
  return store.classify(encode_image(image))
54
  except Exception:
 
55
  return None
56
 
57
- # Search
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @app.post("/search")
59
  async def search(
60
  file: UploadFile = File(...),
@@ -62,62 +121,57 @@ async def search(
62
  max_price: int = Query(10000, ge=0),
63
  material: str | None = None,
64
  style: str | None = None,
65
- frame: str | None = None
66
  ):
67
- # Only guard image decoding – this is the most common real-world failure
68
  try:
69
  img = Image.open(file.file).convert("RGB")
70
  except Exception:
71
- raise HTTPException(status_code=400, detail="Invalid image file")
 
72
  q = encode_image(img)
73
 
74
  if not store.is_eyewear(q):
75
- raise HTTPException(
76
- status_code=400,
77
- detail="No eyewear detected. Please upload a glasses image."
78
- )
79
 
80
- tag = tag_image(img)
81
- raw = store.search(q, k=60)
82
  results = []
83
 
84
- for item, score in raw:
85
  pid = int(item["product_id"])
86
  p = PRODUCTS.get(pid)
87
  if not p:
88
  continue
89
 
90
  price = int(p["price"])
91
- if not (min_price <= price <= max_price): continue
92
- if material and p["material"] != material: continue
93
- if style and p["style"] != style: continue
94
- if frame and p["style"] != frame: continue
 
 
 
 
95
 
96
- r = {k: v for k, v in p.items() if k != "embedding"}
97
- r["score"] = score * get_boost(pid)
98
- r["image_url"] = f"{BASE_URL}/images/{r['image']}"
99
- r["buy_url"] = SEARCH_BASE + f"{r['style']} {r['material']} glasses under {r['price']}".replace(" ", "+")
100
- results.append(r)
101
 
102
  results.sort(key=lambda x: x["score"], reverse=True)
103
  return {"tag": tag, "results": results[:8]}
104
- # Product catalog
105
- @app.get("/products")
106
- def products():
107
- out = []
108
 
109
- for pid, p in PRODUCTS.items():
110
- r = {k: v for k, v in p.items() if k != "embedding"}
111
- r["boost"] = get_boost(pid)
112
- r["image_url"] = f"{BASE_URL}/images/{r['image']}"
113
- r["buy_url"] = SEARCH_BASE + f"{r['style']} {r['material']} glasses under {r['price']}".replace(" ", "+")
114
- out.append(r)
115
 
 
 
 
 
 
 
116
  out.sort(key=lambda x: x["boost"], reverse=True)
117
  return {"count": len(out), "results": out}
118
 
119
- # Click feedback
120
  @app.post("/click/{pid}")
121
  def click(pid: int):
122
  record_click(pid)
123
- return {"status": "ok"}
 
1
+ import logging
2
  import os
3
  import csv
4
  from contextlib import asynccontextmanager
5
+ from urllib.parse import quote_plus
6
+
7
  from fastapi import FastAPI, UploadFile, File, Query, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.staticfiles import StaticFiles
 
14
  from backend.feedback import get_boost, record_click
15
  from backend.accuracy_test import run_accuracy_check
16
 
17
+ logger = logging.getLogger(__name__)
18
+
19
  BASE_DIR = os.path.dirname(__file__)
20
+ BASE_URL = os.getenv("BASE_URL", "https://diwank3221-visual-search-backend.hf.space")
21
+ SEARCH_BASE = "https://www.lenskart.com/search?q="
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Startup / shutdown
25
+ # ---------------------------------------------------------------------------
26
+
27
+ store: VectorStore
28
+ PRODUCTS: dict[int, dict] = {}
29
+
30
+
31
  @asynccontextmanager
32
  async def lifespan(app: FastAPI):
33
+ global store, PRODUCTS
34
+
35
+ # Initialise heavy resources once, at startup — not at import time
36
+ store = VectorStore()
37
+ PRODUCTS = _load_catalog(os.path.join(BASE_DIR, "data/metadata.csv"))
38
+
39
+ logger.info("Starting backend health check…")
40
+ # run_accuracy_check()
41
+ logger.info("Backend ready — %d products loaded.", len(PRODUCTS))
42
+
43
  yield
44
+ # (add any teardown here if needed)
45
+
46
+
47
+ def _load_catalog(path: str) -> dict[int, dict]:
48
+ """Read metadata.csv and return a dict keyed by product_id."""
49
+ products: dict[int, dict] = {}
50
+
51
+ with open(path, newline="", encoding="utf-8") as f:
52
+ for r in csv.DictReader(f):
53
+ try:
54
+ pid = int(r["product_id"])
55
+ except (KeyError, ValueError) as exc:
56
+ logger.warning("Skipping malformed CSV row (%s): %s", exc, r)
57
+ continue
58
+
59
+ # Strip any accidental "images/" prefix left in the CSV
60
+ r["image"] = r["image"].removeprefix("images/")
61
+ # NOTE: .title() was removed — it corrupts paths ("img.jpg" → "Img.Jpg")
62
+ products[pid] = r
63
+
64
+ return products
65
+
66
+
67
+ # ---------------------------------------------------------------------------
68
+ # App
69
+ # ---------------------------------------------------------------------------
70
 
71
  app = FastAPI(title="Visual Product API", lifespan=lifespan)
72
 
 
73
  app.mount(
74
  "/images",
75
  StaticFiles(directory=os.path.join(BASE_DIR, "data/images")),
76
+ name="images",
77
  )
78
 
79
  app.add_middleware(
 
83
  allow_headers=["*"],
84
  )
85
 
 
 
 
86
 
87
+ # ---------------------------------------------------------------------------
88
+ # Helpers
89
+ # ---------------------------------------------------------------------------
90
+
91
+ def _build_buy_url(style: str, material: str, price: str) -> str:
92
+ return SEARCH_BASE + quote_plus(f"{style} {material} glasses under {price}")
93
+
94
 
95
+ def _tag_image(image: Image.Image) -> str | None:
96
+ """Classify the uploaded image into a style tag; returns None on failure."""
97
  try:
98
  return store.classify(encode_image(image))
99
  except Exception:
100
+ logger.exception("tag_image failed")
101
  return None
102
 
103
+
104
+ def _format_product(pid: int, p: dict, score: float | None = None) -> dict:
105
+ r = {k: v for k, v in p.items() if k != "embedding"}
106
+ r["image_url"] = f"{BASE_URL}/images/{p['image']}"
107
+ r["buy_url"] = _build_buy_url(p["style"], p["material"], p["price"])
108
+ if score is not None:
109
+ r["score"] = score
110
+ return r
111
+
112
+
113
+ # ---------------------------------------------------------------------------
114
+ # Routes
115
+ # ---------------------------------------------------------------------------
116
+
117
  @app.post("/search")
118
  async def search(
119
  file: UploadFile = File(...),
 
121
  max_price: int = Query(10000, ge=0),
122
  material: str | None = None,
123
  style: str | None = None,
124
+ frame: str | None = None,
125
  ):
 
126
  try:
127
  img = Image.open(file.file).convert("RGB")
128
  except Exception:
129
+ raise HTTPException(status_code=400, detail="Invalid image file.")
130
+
131
  q = encode_image(img)
132
 
133
  if not store.is_eyewear(q):
134
+ raise HTTPException(
135
+ status_code=400,
136
+ detail="No eyewear detected. Please upload a glasses image.",
137
+ )
138
 
139
+ tag = _tag_image(img)
 
140
  results = []
141
 
142
+ for item, score in store.search(q, k=60):
143
  pid = int(item["product_id"])
144
  p = PRODUCTS.get(pid)
145
  if not p:
146
  continue
147
 
148
  price = int(p["price"])
149
+ if not (min_price <= price <= max_price):
150
+ continue
151
+ if material and p["material"] != material:
152
+ continue
153
+ if style and p["style"] != style:
154
+ continue
155
+ if frame and p["style"] != frame:
156
+ continue
157
 
158
+ results.append(_format_product(pid, p, score=score * get_boost(pid)))
 
 
 
 
159
 
160
  results.sort(key=lambda x: x["score"], reverse=True)
161
  return {"tag": tag, "results": results[:8]}
 
 
 
 
162
 
 
 
 
 
 
 
163
 
164
+ @app.get("/products")
165
+ def products():
166
+ out = [
167
+ {**_format_product(pid, p), "boost": get_boost(pid)}
168
+ for pid, p in PRODUCTS.items()
169
+ ]
170
  out.sort(key=lambda x: x["boost"], reverse=True)
171
  return {"count": len(out), "results": out}
172
 
173
+
174
  @app.post("/click/{pid}")
175
  def click(pid: int):
176
  record_click(pid)
177
+ return {"status": "ok"}
backend/store.py CHANGED
@@ -1,82 +1,123 @@
1
  import csv
 
2
  import os
 
 
3
  import faiss
4
  import numpy as np
5
- from collections import defaultdict
 
 
6
  BASE_DIR = os.path.dirname(__file__)
7
  META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
8
  INDEX_FILE = os.path.join(BASE_DIR, "index.faiss")
9
  CENTROID_FILE = os.path.join(BASE_DIR, "centroids.npy")
10
- # CLIP embedding dimension (fixed)
11
- DIM = 512
 
 
 
 
 
 
 
 
 
 
 
12
  class VectorStore:
13
  """
14
- FAISS based vector store for:
15
- - similarity search
16
- - rough style classification
 
17
  """
18
- def is_eyewear(self, q, threshold=0.53):
19
- q = q / np.linalg.norm(q, axis=1, keepdims=True)
20
 
21
- scores = [float(np.dot(q[0], c)) for c in self.centroids.values()]
22
- return max(scores) > threshold
23
 
 
 
 
 
24
 
25
- def __init__(self):
26
- self.meta = self._load_meta()
27
- # Load cached index if available, otherwise build fresh
28
  if os.path.exists(INDEX_FILE):
29
- self.index = faiss.read_index(INDEX_FILE)
 
30
  else:
31
- self._build_index()
32
- # Pre-compute normalized centroids for tagging
33
- self._build_centroids()
34
- def _load_meta(self):
 
 
 
 
 
 
 
35
  with open(META_FILE, newline="", encoding="utf-8") as f:
36
  rows = list(csv.DictReader(f))
37
- print(f"Loaded {len(rows)} products from metadata")
38
- return rows
39
 
40
- def _build_index(self):
41
  vectors = []
42
  for r in self.meta:
43
  v = np.fromstring(r["embedding"], sep=" ").astype("float32")
44
- v = v / np.linalg.norm(v) # normalize stored vectors
45
- vectors.append(v)
46
 
47
- vectors = np.vstack(vectors)
48
  self.index = faiss.IndexFlatIP(DIM)
49
- self.index.add(vectors)
50
  faiss.write_index(self.index, INDEX_FILE)
 
51
 
52
- print(f"[VectorStore] Built FAISS index with {self.index.ntotal} vectors")
53
-
54
- def _build_centroids(self):
55
- clusters = defaultdict(list)
56
  for r in self.meta:
57
  v = np.fromstring(r["embedding"], sep=" ").astype("float32")
58
- v = v / np.linalg.norm(v) # normalize before clustering
59
- clusters[r["style"]].append(v)
60
- self.centroids = {}
61
  for style, vecs in clusters.items():
62
  c = np.mean(vecs, axis=0)
63
- c = c / np.linalg.norm(c) # normalize centroid
64
- self.centroids[style] = c
65
- np.save(CENTROID_FILE, self.centroids)
66
- print(f"Built {len(self.centroids)} normalized style centroids")
67
-
68
- def search(self, q, k=40):
69
- # Normalize query to match index math
70
- q = q / np.linalg.norm(q, axis=1, keepdims=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  D, I = self.index.search(q, k)
72
- return [(self.meta[i], float(score)) for i, score in zip(I[0], D[0])]
73
-
74
- def classify(self, q):
75
- q = q / np.linalg.norm(q, axis=1, keepdims=True)
76
- best_style, best_score = None, -1
77
- for style, centroid in self.centroids.items():
78
- score = float(np.dot(q[0], centroid))
79
- if score > best_score:
80
- best_style, best_score = style, score
81
-
82
- return best_style
 
 
 
 
 
1
  import csv
2
+ import logging
3
  import os
4
+ from collections import defaultdict
5
+
6
  import faiss
7
  import numpy as np
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
  BASE_DIR = os.path.dirname(__file__)
12
  META_FILE = os.path.join(BASE_DIR, "data/metadata.csv")
13
  INDEX_FILE = os.path.join(BASE_DIR, "index.faiss")
14
  CENTROID_FILE = os.path.join(BASE_DIR, "centroids.npy")
15
+
16
+ # CLIP embedding dimension (fixed for patrickjohncyh/fashion-clip)
17
+ DIM = 768
18
+ # Minimum cosine similarity to the nearest style centroid for a query
19
+ # to be considered eyewear. Calibrated empirically on the current dataset.
20
+ _EYEWEAR_THRESHOLD = 0.53
21
+
22
+
23
+ def _normalise(x: np.ndarray) -> np.ndarray:
24
+ """L2-normalise each row of a 2-D array in-place-free fashion."""
25
+ return x / np.linalg.norm(x, axis=-1, keepdims=True)
26
+
27
+
28
  class VectorStore:
29
  """
30
+ FAISS-backed vector store providing:
31
+ - k-NN similarity search over product embeddings
32
+ - coarse style classification via centroid cosine similarity
33
+ - eyewear gating to reject non-glasses queries
34
  """
 
 
35
 
36
+ def __init__(self) -> None:
37
+ self.meta = self._load_meta()
38
 
39
+ if not self.meta:
40
+ raise RuntimeError(
41
+ f"No products found in {META_FILE}. Run ingest.py first."
42
+ )
43
 
 
 
 
44
  if os.path.exists(INDEX_FILE):
45
+ self.index = faiss.read_index(INDEX_FILE)
46
+ logger.info("Loaded FAISS index from cache (%d vectors).", self.index.ntotal)
47
  else:
48
+ self._build_index()
49
+
50
+ # Always rebuild centroids from the CSV so they stay in sync with
51
+ # the metadata even when the FAISS index is loaded from cache.
52
+ self._build_centroids()
53
+
54
+ # ------------------------------------------------------------------
55
+ # Internal helpers
56
+ # ------------------------------------------------------------------
57
+
58
+ def _load_meta(self) -> list[dict]:
59
  with open(META_FILE, newline="", encoding="utf-8") as f:
60
  rows = list(csv.DictReader(f))
61
+ logger.info("Loaded %d products from metadata.", len(rows))
62
+ return rows
63
 
64
+ def _build_index(self) -> None:
65
  vectors = []
66
  for r in self.meta:
67
  v = np.fromstring(r["embedding"], sep=" ").astype("float32")
68
+ vectors.append(_normalise(v[np.newaxis])[0])
 
69
 
70
+ matrix = np.vstack(vectors)
71
  self.index = faiss.IndexFlatIP(DIM)
72
+ self.index.add(matrix)
73
  faiss.write_index(self.index, INDEX_FILE)
74
+ logger.info("Built FAISS index with %d vectors.", self.index.ntotal)
75
 
76
+ def _build_centroids(self) -> None:
77
+ clusters: dict[str, list[np.ndarray]] = defaultdict(list)
 
 
78
  for r in self.meta:
79
  v = np.fromstring(r["embedding"], sep=" ").astype("float32")
80
+ clusters[r["style"]].append(_normalise(v[np.newaxis])[0])
81
+
82
+ self.centroids: dict[str, np.ndarray] = {}
83
  for style, vecs in clusters.items():
84
  c = np.mean(vecs, axis=0)
85
+ self.centroids[style] = _normalise(c[np.newaxis])[0]
86
+
87
+ logger.info("Built %d style centroids.", len(self.centroids))
88
+
89
+ # ------------------------------------------------------------------
90
+ # Public API
91
+ # ------------------------------------------------------------------
92
+
93
+ def is_eyewear(self, q: np.ndarray, threshold: float = _EYEWEAR_THRESHOLD) -> bool:
94
+ """Return True if *q* is close enough to any style centroid."""
95
+ q = _normalise(q)
96
+ scores = [float(np.dot(q[0], c)) for c in self.centroids.values()]
97
+ return bool(scores) and max(scores) > threshold
98
+
99
+ def search(self, q: np.ndarray, k: int = 40) -> list[tuple[dict, float]]:
100
+ """
101
+ Return the *k* most similar products to query *q*.
102
+
103
+ Filters out FAISS sentinel index -1, which is returned when
104
+ k > index.ntotal (not enough vectors to fill the result set).
105
+ """
106
+ q = _normalise(q)
107
+ k = min(k, self.index.ntotal) # guard: k must not exceed index size
108
  D, I = self.index.search(q, k)
109
+ return [
110
+ (self.meta[i], float(score))
111
+ for i, score in zip(I[0], D[0])
112
+ if i != -1 # FAISS sentinel for unfilled slots
113
+ ]
114
+
115
+ def classify(self, q: np.ndarray) -> str | None:
116
+ """Return the style name whose centroid is closest to *q*."""
117
+ if not self.centroids:
118
+ return None
119
+ q = _normalise(q)
120
+ return max(
121
+ self.centroids,
122
+ key=lambda style: float(np.dot(q[0], self.centroids[style])),
123
+ )