Commit ·
d992912
0
Parent(s):
feat: HuggingFace Spaces deployment
Browse files- .dockerignore +18 -0
- .gitattributes +3 -0
- Dockerfile +36 -0
- asos_clean.csv +3 -0
- asos_engine/faiss_image_index.bin +3 -0
- asos_engine/faiss_text_index.bin +3 -0
- asos_engine/image_embeddings.npy +3 -0
- asos_engine/text_embeddings.npy +3 -0
- backend/.dockerignore +6 -0
- backend/Dockerfile +23 -0
- backend/app/__init__.py +1 -0
- backend/app/config.py +161 -0
- backend/app/dependencies.py +12 -0
- backend/app/engine/__init__.py +14 -0
- backend/app/engine/bm25.py +50 -0
- backend/app/engine/encoder.py +195 -0
- backend/app/engine/evaluator.py +111 -0
- backend/app/engine/index.py +127 -0
- backend/app/engine/nlp.py +307 -0
- backend/app/engine/query_parser.py +299 -0
- backend/app/engine/reranker.py +387 -0
- backend/app/engine/search_engine.py +636 -0
- backend/app/exceptions.py +16 -0
- backend/app/main.py +88 -0
- backend/app/models/__init__.py +8 -0
- backend/app/models/product.py +53 -0
- backend/app/models/search.py +63 -0
- backend/app/routers/__init__.py +0 -0
- backend/app/routers/health.py +30 -0
- backend/app/routers/products.py +27 -0
- backend/app/routers/search.py +60 -0
- backend/app/services/__init__.py +0 -0
- backend/app/services/search_service.py +325 -0
- backend/pyproject.toml +9 -0
- backend/requirements.txt +14 -0
- backend/tests/__init__.py +0 -0
- backend/tests/conftest.py +7 -0
- backend/tests/test_api_health.py +47 -0
- backend/tests/test_api_products.py +92 -0
- backend/tests/test_api_search.py +108 -0
- backend/tests/test_bm25.py +51 -0
- backend/tests/test_nlp.py +83 -0
- backend/tests/test_query_parser.py +194 -0
.dockerignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git/
|
| 2 |
+
docs/
|
| 3 |
+
frontend/node_modules/
|
| 4 |
+
frontend/.next/
|
| 5 |
+
asos_image_cache/
|
| 6 |
+
__pycache__/
|
| 7 |
+
**/__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
.pytest_cache/
|
| 10 |
+
backend/tests/
|
| 11 |
+
.venv/
|
| 12 |
+
*.ipynb
|
| 13 |
+
.ipynb_checkpoints/
|
| 14 |
+
EDA_Visualization.md
|
| 15 |
+
Preprocessing.md
|
| 16 |
+
README.md
|
| 17 |
+
.env
|
| 18 |
+
.env.local
|
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# System packages needed by PyTorch, FAISS, and Pillow
|
| 4 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 5 |
+
gcc g++ libgomp1 && \
|
| 6 |
+
rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
# Install Python dependencies as root so they are globally accessible
|
| 9 |
+
COPY backend/requirements.txt /tmp/requirements.txt
|
| 10 |
+
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
| 11 |
+
|
| 12 |
+
# HuggingFace Spaces requires a non-root user with uid=1000
|
| 13 |
+
RUN useradd -m -u 1000 user
|
| 14 |
+
USER user
|
| 15 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 16 |
+
|
| 17 |
+
WORKDIR /app
|
| 18 |
+
|
| 19 |
+
# Copy backend Python source
|
| 20 |
+
COPY --chown=user backend/ /app/backend/
|
| 21 |
+
|
| 22 |
+
# Bake the pre-built data into the image.
|
| 23 |
+
# This avoids a 3-5 hour FAISS rebuild on free CPU on every cold start.
|
| 24 |
+
COPY --chown=user asos_clean.csv /app/data/asos_clean.csv
|
| 25 |
+
COPY --chown=user asos_engine/ /app/data/asos_engine/
|
| 26 |
+
|
| 27 |
+
# Tell the backend where to find its data
|
| 28 |
+
ENV PYTHONPATH=/app
|
| 29 |
+
ENV ASOS_DATA_PATH=/app/data/asos_clean.csv
|
| 30 |
+
ENV ASOS_PERSISTENT_DIR=/app/data/asos_engine
|
| 31 |
+
ENV ASOS_LOG_LEVEL=INFO
|
| 32 |
+
|
| 33 |
+
# HuggingFace Spaces requires port 7860
|
| 34 |
+
EXPOSE 7860
|
| 35 |
+
|
| 36 |
+
CMD ["uvicorn", "backend.app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
asos_clean.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:674c28a2d26a52a612cfb59a3011b23dc4dd8aa81f978f840045973b322e4c92
|
| 3 |
+
size 53197197
|
asos_engine/faiss_image_index.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37a71cff75158cf78b85ac79626fbfec30dad880e447bee745510053b7e4e41d
|
| 3 |
+
size 62146851
|
asos_engine/faiss_text_index.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a6f484a66cbde5ac53f8e83e3ed6347d8b016d891539c0acc77b30adea38eb1b
|
| 3 |
+
size 62146851
|
asos_engine/image_embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c1c990cfb083b033cb13e969040396f79a71e29c11e2380900a085e9115320a3
|
| 3 |
+
size 61380736
|
asos_engine/text_embeddings.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fee77c8c548dab94ee71cc2b30279bbb4db0aceb4867cb18f00f8deb833a2623
|
| 3 |
+
size 61380736
|
backend/.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
.pytest_cache
|
| 4 |
+
tests/
|
| 5 |
+
*.npy
|
| 6 |
+
*.bin
|
backend/Dockerfile
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies for PIL, torch, faiss
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
gcc g++ libgomp1 && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
COPY requirements.txt .
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
# Copy backend source (data files are mounted at runtime)
|
| 14 |
+
COPY . /app/backend/
|
| 15 |
+
|
| 16 |
+
ENV PYTHONPATH=/app
|
| 17 |
+
ENV ASOS_DATA_PATH=/data/asos_clean.csv
|
| 18 |
+
ENV ASOS_PERSISTENT_DIR=/data/asos_engine
|
| 19 |
+
|
| 20 |
+
# PORT is injected by Render (~10000) and HuggingFace Spaces (7860); falls back to 8000 locally
|
| 21 |
+
EXPOSE ${PORT:-8000}
|
| 22 |
+
|
| 23 |
+
CMD ["sh", "-c", "uvicorn backend.app.main:app --host 0.0.0.0 --port ${PORT:-8000}"]
|
backend/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ASOS Multimodal Fashion Search Engine Backend."""
|
backend/app/config.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from pydantic_settings import BaseSettings
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("asos_search")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _detect_environment() -> str:
|
| 15 |
+
if "google.colab" in sys.modules:
|
| 16 |
+
return "colab"
|
| 17 |
+
if "KAGGLE_KERNEL_RUN_TYPE" in os.environ:
|
| 18 |
+
return "kaggle"
|
| 19 |
+
return "local"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Settings(BaseSettings):
|
| 23 |
+
"""Server-level settings loaded from environment variables."""
|
| 24 |
+
|
| 25 |
+
host: str = "0.0.0.0"
|
| 26 |
+
port: int = 8000
|
| 27 |
+
cors_origins: list[str] = ["*"]
|
| 28 |
+
data_dir: str = ""
|
| 29 |
+
data_path: str = ""
|
| 30 |
+
persistent_dir: str = ""
|
| 31 |
+
image_cache_dir: str = ""
|
| 32 |
+
log_level: str = "INFO"
|
| 33 |
+
hf_token: Optional[str] = None
|
| 34 |
+
|
| 35 |
+
model_config = {"env_prefix": "ASOS_"}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class SearchConfig:
|
| 40 |
+
"""Central configuration for the search engine."""
|
| 41 |
+
|
| 42 |
+
# Model
|
| 43 |
+
primary_model: str = "patrickjohncyh/fashion-clip"
|
| 44 |
+
fallback_model: str = "openai/clip-vit-base-patch32"
|
| 45 |
+
embedding_dim: int = 512
|
| 46 |
+
device: str = ""
|
| 47 |
+
hf_token: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
# FAISS Index
|
| 50 |
+
n_clusters: int = 256
|
| 51 |
+
n_probe: int = 20
|
| 52 |
+
|
| 53 |
+
# Search Pipeline
|
| 54 |
+
retrieval_top_k: int = 300
|
| 55 |
+
final_top_n: int = 20
|
| 56 |
+
|
| 57 |
+
# Dual-Index Fusion
|
| 58 |
+
rrf_k: int = 60
|
| 59 |
+
image_index_weight: float = 0.55
|
| 60 |
+
text_index_weight: float = 0.45
|
| 61 |
+
|
| 62 |
+
# Re-ranking Weights
|
| 63 |
+
alpha_clip: float = 0.55
|
| 64 |
+
beta_tags: float = 0.25
|
| 65 |
+
gamma_text: float = 0.15
|
| 66 |
+
delta_freshness: float = 0.05
|
| 67 |
+
|
| 68 |
+
# CLIP Prompt Ensembling
|
| 69 |
+
prompt_templates: Tuple[str, ...] = (
|
| 70 |
+
"a photo of {}, a fashion product",
|
| 71 |
+
"a product photo of {}",
|
| 72 |
+
"a fashion item: {}",
|
| 73 |
+
"{}, studio product photography",
|
| 74 |
+
"an e-commerce photo of {}",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Embedding Computation
|
| 78 |
+
embed_batch_size: int = 32
|
| 79 |
+
embed_checkpoint_interval: int = 2000
|
| 80 |
+
|
| 81 |
+
# Features
|
| 82 |
+
enable_multilingual: bool = True
|
| 83 |
+
enable_spell_correction: bool = True
|
| 84 |
+
|
| 85 |
+
# Paths (auto-detected)
|
| 86 |
+
data_dir: str = ""
|
| 87 |
+
data_path: str = ""
|
| 88 |
+
persistent_dir: str = ""
|
| 89 |
+
image_cache_dir: str = ""
|
| 90 |
+
|
| 91 |
+
# Derived Paths
|
| 92 |
+
image_index_path: str = ""
|
| 93 |
+
text_index_path: str = ""
|
| 94 |
+
image_embeddings_path: str = ""
|
| 95 |
+
text_embeddings_path: str = ""
|
| 96 |
+
|
| 97 |
+
def __post_init__(self):
|
| 98 |
+
if not self.device:
|
| 99 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 100 |
+
if not self.hf_token:
|
| 101 |
+
self.hf_token = os.environ.get("HF_TOKEN", None)
|
| 102 |
+
|
| 103 |
+
env = _detect_environment()
|
| 104 |
+
|
| 105 |
+
if env == "colab":
|
| 106 |
+
drive_base = "/content/drive/MyDrive/Colab Notebooks"
|
| 107 |
+
if not self.data_dir:
|
| 108 |
+
self.data_dir = drive_base
|
| 109 |
+
if not self.persistent_dir:
|
| 110 |
+
self.persistent_dir = os.path.join(drive_base, "asos_engine")
|
| 111 |
+
if not self.image_cache_dir:
|
| 112 |
+
self.image_cache_dir = "/content/asos_image_cache"
|
| 113 |
+
elif env == "kaggle":
|
| 114 |
+
if not self.data_dir:
|
| 115 |
+
self.data_dir = "/kaggle/input"
|
| 116 |
+
if not self.persistent_dir:
|
| 117 |
+
self.persistent_dir = "/kaggle/working/asos_engine"
|
| 118 |
+
if not self.image_cache_dir:
|
| 119 |
+
self.image_cache_dir = "/kaggle/working/asos_image_cache"
|
| 120 |
+
else:
|
| 121 |
+
project_root = str(Path(__file__).resolve().parent.parent.parent)
|
| 122 |
+
if not self.data_dir:
|
| 123 |
+
self.data_dir = project_root
|
| 124 |
+
if not self.persistent_dir:
|
| 125 |
+
self.persistent_dir = os.path.join(project_root, "asos_engine")
|
| 126 |
+
if not self.image_cache_dir:
|
| 127 |
+
self.image_cache_dir = os.path.join(project_root, "asos_image_cache")
|
| 128 |
+
|
| 129 |
+
if not self.data_path:
|
| 130 |
+
pq = Path(self.data_dir) / "asos_clean.parquet"
|
| 131 |
+
csv = Path(self.data_dir) / "asos_clean.csv"
|
| 132 |
+
if pq.exists():
|
| 133 |
+
self.data_path = str(pq)
|
| 134 |
+
elif csv.exists():
|
| 135 |
+
self.data_path = str(csv)
|
| 136 |
+
else:
|
| 137 |
+
self.data_path = str(csv)
|
| 138 |
+
|
| 139 |
+
Path(self.persistent_dir).mkdir(parents=True, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
p = Path(self.persistent_dir)
|
| 142 |
+
self.image_index_path = str(p / "faiss_image_index.bin")
|
| 143 |
+
self.text_index_path = str(p / "faiss_text_index.bin")
|
| 144 |
+
self.image_embeddings_path = str(p / "image_embeddings.npy")
|
| 145 |
+
self.text_embeddings_path = str(p / "text_embeddings.npy")
|
| 146 |
+
|
| 147 |
+
@classmethod
|
| 148 |
+
def from_settings(cls, settings: Settings) -> "SearchConfig":
|
| 149 |
+
"""Create SearchConfig from server Settings, allowing env overrides."""
|
| 150 |
+
kwargs = {}
|
| 151 |
+
if settings.data_dir:
|
| 152 |
+
kwargs["data_dir"] = settings.data_dir
|
| 153 |
+
if settings.data_path:
|
| 154 |
+
kwargs["data_path"] = settings.data_path
|
| 155 |
+
if settings.persistent_dir:
|
| 156 |
+
kwargs["persistent_dir"] = settings.persistent_dir
|
| 157 |
+
if settings.image_cache_dir:
|
| 158 |
+
kwargs["image_cache_dir"] = settings.image_cache_dir
|
| 159 |
+
if settings.hf_token:
|
| 160 |
+
kwargs["hf_token"] = settings.hf_token
|
| 161 |
+
return cls(**kwargs)
|
backend/app/dependencies.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request
|
| 2 |
+
|
| 3 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 4 |
+
from backend.app.exceptions import EngineNotReadyError
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_engine(request: Request) -> ASOSSearchEngine:
|
| 8 |
+
"""FastAPI dependency: retrieve the engine singleton from app state."""
|
| 9 |
+
engine: ASOSSearchEngine = getattr(request.app.state, "engine", None)
|
| 10 |
+
if engine is None or not engine._is_ready:
|
| 11 |
+
raise EngineNotReadyError("Search engine is not ready")
|
| 12 |
+
return engine
|
backend/app/engine/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ASOS Search Engine core modules."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def __getattr__(name):
|
| 5 |
+
if name == "ASOSSearchEngine":
|
| 6 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 7 |
+
return ASOSSearchEngine
|
| 8 |
+
if name == "SearchEvaluator":
|
| 9 |
+
from backend.app.engine.evaluator import SearchEvaluator
|
| 10 |
+
return SearchEvaluator
|
| 11 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ["ASOSSearchEngine", "SearchEvaluator"]
|
backend/app/engine/bm25.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
__all__ = ["SimpleBM25"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SimpleBM25:
|
| 11 |
+
def __init__(self, k1: float = 1.5, b: float = 0.75):
|
| 12 |
+
self.k1 = k1
|
| 13 |
+
self.b = b
|
| 14 |
+
self.doc_tokens: List[List[str]] = []
|
| 15 |
+
self.avg_dl: float = 0
|
| 16 |
+
self.df: Dict[str, int] = {}
|
| 17 |
+
self.n_docs: int = 0
|
| 18 |
+
|
| 19 |
+
def fit(self, documents: List[str]):
|
| 20 |
+
self.doc_tokens = [self._tokenize(d) for d in documents]
|
| 21 |
+
self.n_docs = len(self.doc_tokens)
|
| 22 |
+
self.avg_dl = np.mean([len(t) for t in self.doc_tokens]) if self.doc_tokens else 1
|
| 23 |
+
self.df = Counter()
|
| 24 |
+
for tokens in self.doc_tokens:
|
| 25 |
+
for t in set(tokens):
|
| 26 |
+
self.df[t] += 1
|
| 27 |
+
|
| 28 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 29 |
+
return re.findall(r'\b[a-z]+\b', str(text).lower())
|
| 30 |
+
|
| 31 |
+
def score_candidates(self, query: str, candidate_indices: List[int]) -> np.ndarray:
|
| 32 |
+
q_tokens = self._tokenize(query)
|
| 33 |
+
scores = np.zeros(len(candidate_indices), dtype=np.float32)
|
| 34 |
+
for i, doc_idx in enumerate(candidate_indices):
|
| 35 |
+
if doc_idx >= len(self.doc_tokens):
|
| 36 |
+
continue
|
| 37 |
+
doc = self.doc_tokens[doc_idx]
|
| 38 |
+
dl = len(doc)
|
| 39 |
+
tf_doc = Counter(doc)
|
| 40 |
+
s = 0.0
|
| 41 |
+
for qt in q_tokens:
|
| 42 |
+
if qt not in self.df:
|
| 43 |
+
continue
|
| 44 |
+
tf = tf_doc.get(qt, 0)
|
| 45 |
+
idf = np.log((self.n_docs - self.df[qt] + 0.5) / (self.df[qt] + 0.5) + 1)
|
| 46 |
+
num = tf * (self.k1 + 1)
|
| 47 |
+
den = tf + self.k1 * (1 - self.b + self.b * dl / self.avg_dl)
|
| 48 |
+
s += idf * num / den
|
| 49 |
+
scores[i] = s
|
| 50 |
+
return scores
|
backend/app/engine/encoder.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
engine/encoder.py
|
| 3 |
+
|
| 4 |
+
FashionCLIPEncoder — wraps a HuggingFace CLIP model for text and image encoding.
|
| 5 |
+
Extracted from finalized_search_engine_full_script.py (lines 482-652).
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from transformers import CLIPModel, CLIPProcessor
|
| 17 |
+
|
| 18 |
+
from backend.app.config import SearchConfig
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
__all__ = ["FashionCLIPEncoder"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FashionCLIPEncoder:
|
| 26 |
+
"""
|
| 27 |
+
v3.1 — Handles models that return BaseModelOutputWithPooling
|
| 28 |
+
instead of raw tensors from get_text_features / get_image_features.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, config: SearchConfig):
|
| 32 |
+
self.config = config
|
| 33 |
+
self.device = config.device
|
| 34 |
+
self.model = None
|
| 35 |
+
self.processor = None
|
| 36 |
+
self.model_name = None
|
| 37 |
+
self._load_model()
|
| 38 |
+
|
| 39 |
+
def _load_model(self):
|
| 40 |
+
models_to_try = [self.config.primary_model, self.config.fallback_model]
|
| 41 |
+
for model_name in models_to_try:
|
| 42 |
+
try:
|
| 43 |
+
logger.info(f"Loading model: {model_name}")
|
| 44 |
+
kwargs = {}
|
| 45 |
+
if self.config.hf_token:
|
| 46 |
+
kwargs['token'] = self.config.hf_token
|
| 47 |
+
self.model = CLIPModel.from_pretrained(model_name, **kwargs)
|
| 48 |
+
self.processor = CLIPProcessor.from_pretrained(model_name, **kwargs)
|
| 49 |
+
self.model = self.model.to(self.device)
|
| 50 |
+
self.model.eval()
|
| 51 |
+
self.model_name = model_name
|
| 52 |
+
|
| 53 |
+
# ── Probe the model to find actual embedding dim ──
|
| 54 |
+
test_inputs = self.processor(
|
| 55 |
+
text=["test"], return_tensors="pt",
|
| 56 |
+
padding=True, truncation=True, max_length=77,
|
| 57 |
+
)
|
| 58 |
+
test_inputs = {k: v.to(self.device) for k, v in test_inputs.items()}
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
test_out = self.model.get_text_features(**test_inputs)
|
| 61 |
+
test_tensor = self._to_tensor(test_out)
|
| 62 |
+
actual_dim = test_tensor.shape[-1]
|
| 63 |
+
if actual_dim != self.config.embedding_dim:
|
| 64 |
+
logger.info(
|
| 65 |
+
f"Model embedding dim = {actual_dim} "
|
| 66 |
+
f"(config said {self.config.embedding_dim}). Updating config."
|
| 67 |
+
)
|
| 68 |
+
self.config.embedding_dim = actual_dim
|
| 69 |
+
|
| 70 |
+
logger.info(f"Model loaded: {model_name} on {self.device} (dim={actual_dim})")
|
| 71 |
+
return
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.warning(f"Failed to load {model_name}: {e}")
|
| 74 |
+
continue
|
| 75 |
+
raise RuntimeError(
|
| 76 |
+
"Could not load any CLIP model. Check internet connection and HF_TOKEN."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def _to_tensor(output) -> torch.Tensor:
|
| 81 |
+
if isinstance(output, torch.Tensor):
|
| 82 |
+
return output
|
| 83 |
+
if hasattr(output, 'pooler_output') and output.pooler_output is not None:
|
| 84 |
+
return output.pooler_output
|
| 85 |
+
if hasattr(output, 'last_hidden_state'):
|
| 86 |
+
return output.last_hidden_state.mean(dim=1)
|
| 87 |
+
if hasattr(output, 'text_embeds'):
|
| 88 |
+
return output.text_embeds
|
| 89 |
+
if hasattr(output, 'image_embeds'):
|
| 90 |
+
return output.image_embeds
|
| 91 |
+
if isinstance(output, (tuple, list)) and len(output) > 0:
|
| 92 |
+
return output[0] if isinstance(output[0], torch.Tensor) else output[1]
|
| 93 |
+
raise TypeError(
|
| 94 |
+
f"Cannot extract tensor from model output of type {type(output)}. "
|
| 95 |
+
f"Available attributes: {[a for a in dir(output) if not a.startswith('_')]}"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
@torch.no_grad()
|
| 99 |
+
def encode_texts(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray:
|
| 100 |
+
batch_size = batch_size or min(self.config.embed_batch_size * 4, 256)
|
| 101 |
+
texts = [str(t) if t and str(t) != 'nan' else '' for t in texts]
|
| 102 |
+
all_emb = []
|
| 103 |
+
for i in range(0, len(texts), batch_size):
|
| 104 |
+
batch = texts[i:i + batch_size]
|
| 105 |
+
inputs = self.processor(
|
| 106 |
+
text=batch, return_tensors="pt",
|
| 107 |
+
padding=True, truncation=True, max_length=77,
|
| 108 |
+
)
|
| 109 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 110 |
+
raw = self.model.get_text_features(**inputs)
|
| 111 |
+
feats = self._to_tensor(raw)
|
| 112 |
+
feats = F.normalize(feats, p=2, dim=-1).cpu().numpy()
|
| 113 |
+
all_emb.append(feats)
|
| 114 |
+
return np.vstack(all_emb).astype(np.float32)
|
| 115 |
+
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def encode_images_from_paths(
|
| 118 |
+
self, paths: List[Path], batch_size: Optional[int] = None,
|
| 119 |
+
) -> np.ndarray:
|
| 120 |
+
batch_size = batch_size or self.config.embed_batch_size
|
| 121 |
+
n = len(paths)
|
| 122 |
+
dim = self.config.embedding_dim
|
| 123 |
+
embeddings = np.zeros((n, dim), dtype=np.float32)
|
| 124 |
+
|
| 125 |
+
for start in range(0, n, batch_size):
|
| 126 |
+
end = min(start + batch_size, n)
|
| 127 |
+
batch_paths = paths[start:end]
|
| 128 |
+
|
| 129 |
+
images = []
|
| 130 |
+
valid_in_batch = []
|
| 131 |
+
for j, p in enumerate(batch_paths):
|
| 132 |
+
try:
|
| 133 |
+
img = Image.open(p).convert("RGB")
|
| 134 |
+
images.append(img)
|
| 135 |
+
valid_in_batch.append(start + j)
|
| 136 |
+
except Exception:
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
if not images:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
try:
|
| 143 |
+
inputs = self.processor(images=images, return_tensors="pt", padding=True)
|
| 144 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 145 |
+
if self.device == "cuda":
|
| 146 |
+
with torch.amp.autocast("cuda"):
|
| 147 |
+
raw = self.model.get_image_features(**inputs)
|
| 148 |
+
else:
|
| 149 |
+
raw = self.model.get_image_features(**inputs)
|
| 150 |
+
feats = self._to_tensor(raw)
|
| 151 |
+
feats = F.normalize(feats, p=2, dim=-1).cpu().numpy()
|
| 152 |
+
for local_j, global_j in enumerate(valid_in_batch):
|
| 153 |
+
embeddings[global_j] = feats[local_j]
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.warning(f"Batch encoding failed at {start}: {e}")
|
| 156 |
+
|
| 157 |
+
if self.device == "cuda" and start % (batch_size * 10) == 0:
|
| 158 |
+
torch.cuda.empty_cache()
|
| 159 |
+
|
| 160 |
+
return embeddings
|
| 161 |
+
|
| 162 |
+
@torch.no_grad()
|
| 163 |
+
def encode_images(self, images: List[Image.Image], batch_size: Optional[int] = None) -> np.ndarray:
|
| 164 |
+
batch_size = batch_size or self.config.embed_batch_size
|
| 165 |
+
all_emb = []
|
| 166 |
+
for i in range(0, len(images), batch_size):
|
| 167 |
+
batch = images[i:i + batch_size]
|
| 168 |
+
inputs = self.processor(images=batch, return_tensors="pt", padding=True)
|
| 169 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 170 |
+
if self.device == "cuda":
|
| 171 |
+
with torch.amp.autocast("cuda"):
|
| 172 |
+
raw = self.model.get_image_features(**inputs)
|
| 173 |
+
else:
|
| 174 |
+
raw = self.model.get_image_features(**inputs)
|
| 175 |
+
feats = self._to_tensor(raw)
|
| 176 |
+
all_emb.append(F.normalize(feats, p=2, dim=-1).cpu().numpy())
|
| 177 |
+
return np.vstack(all_emb).astype(np.float32)
|
| 178 |
+
|
| 179 |
+
@torch.no_grad()
|
| 180 |
+
def encode_query_text(self, query: str) -> np.ndarray:
|
| 181 |
+
prompted = [tmpl.format(query) for tmpl in self.config.prompt_templates]
|
| 182 |
+
embeddings = self.encode_texts(prompted)
|
| 183 |
+
avg = embeddings.mean(axis=0, keepdims=True)
|
| 184 |
+
avg = avg / (np.linalg.norm(avg, axis=-1, keepdims=True) + 1e-8)
|
| 185 |
+
return avg.astype(np.float32)
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def encode_multimodal_query(
|
| 189 |
+
self, text: str, image: Image.Image, text_weight: float = 0.5,
|
| 190 |
+
) -> np.ndarray:
|
| 191 |
+
text_emb = self.encode_query_text(text)
|
| 192 |
+
img_emb = self.encode_images([image])
|
| 193 |
+
fused = text_weight * text_emb + (1 - text_weight) * img_emb
|
| 194 |
+
fused = fused / (np.linalg.norm(fused, axis=-1, keepdims=True) + 1e-8)
|
| 195 |
+
return fused.astype(np.float32)
|
backend/app/engine/evaluator.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import time
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Dict, List, Set
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
|
| 9 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("asos_search")
|
| 12 |
+
|
| 13 |
+
__all__ = ["EvalResult", "SearchEvaluator"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class EvalResult:
|
| 18 |
+
query: str
|
| 19 |
+
recall_at_k: Dict[int, float]
|
| 20 |
+
precision_at_k: Dict[int, float]
|
| 21 |
+
mrr: float
|
| 22 |
+
latency_ms: float
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class SearchEvaluator:
|
| 26 |
+
def __init__(self, engine: ASOSSearchEngine):
|
| 27 |
+
self.engine = engine
|
| 28 |
+
|
| 29 |
+
def evaluate_single(
|
| 30 |
+
self, query: str, relevant_skus: Set[str], k_values: List[int] = [5, 10, 20]
|
| 31 |
+
) -> EvalResult:
|
| 32 |
+
max_k = max(k_values)
|
| 33 |
+
t0 = time.time()
|
| 34 |
+
results = self.engine.search(query, top_n=max_k)
|
| 35 |
+
latency = (time.time() - t0) * 1000
|
| 36 |
+
|
| 37 |
+
retrieved = results["sku"].astype(str).tolist()
|
| 38 |
+
relevant = set(str(s) for s in relevant_skus)
|
| 39 |
+
|
| 40 |
+
recall_at, precision_at = {}, {}
|
| 41 |
+
for k in k_values:
|
| 42 |
+
top_k = retrieved[:k]
|
| 43 |
+
found = len(set(top_k) & relevant)
|
| 44 |
+
recall_at[k] = found / len(relevant) if relevant else 0.0
|
| 45 |
+
precision_at[k] = found / k if k > 0 else 0.0
|
| 46 |
+
|
| 47 |
+
mrr = 0.0
|
| 48 |
+
for rank, sku in enumerate(retrieved, 1):
|
| 49 |
+
if sku in relevant:
|
| 50 |
+
mrr = 1.0 / rank
|
| 51 |
+
break
|
| 52 |
+
|
| 53 |
+
return EvalResult(
|
| 54 |
+
query=query,
|
| 55 |
+
recall_at_k=recall_at,
|
| 56 |
+
precision_at_k=precision_at,
|
| 57 |
+
mrr=mrr,
|
| 58 |
+
latency_ms=latency,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def evaluate(
|
| 62 |
+
self, test_queries: List[Dict], k_values: List[int] = [5, 10, 20]
|
| 63 |
+
) -> Dict:
|
| 64 |
+
results = []
|
| 65 |
+
for tq in tqdm(test_queries, desc="Evaluating"):
|
| 66 |
+
try:
|
| 67 |
+
res = self.evaluate_single(
|
| 68 |
+
tq["query"],
|
| 69 |
+
set(str(s) for s in tq["relevant_skus"]),
|
| 70 |
+
k_values,
|
| 71 |
+
)
|
| 72 |
+
results.append(res)
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.warning(f"Eval failed for '{tq['query']}': {e}")
|
| 75 |
+
|
| 76 |
+
if not results:
|
| 77 |
+
return {"error": "No successful evaluations"}
|
| 78 |
+
|
| 79 |
+
agg = {
|
| 80 |
+
"n_queries": len(results),
|
| 81 |
+
"avg_latency_ms": float(np.mean([r.latency_ms for r in results])),
|
| 82 |
+
"median_latency_ms": float(np.median([r.latency_ms for r in results])),
|
| 83 |
+
"mean_mrr": float(np.mean([r.mrr for r in results])),
|
| 84 |
+
}
|
| 85 |
+
for k in k_values:
|
| 86 |
+
agg[f"mean_recall@{k}"] = float(
|
| 87 |
+
np.mean([r.recall_at_k.get(k, 0) for r in results])
|
| 88 |
+
)
|
| 89 |
+
agg[f"mean_precision@{k}"] = float(
|
| 90 |
+
np.mean([r.precision_at_k.get(k, 0) for r in results])
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return {"aggregate": agg, "per_query": [
|
| 94 |
+
{"query": r.query, "mrr": r.mrr, "latency_ms": r.latency_ms,
|
| 95 |
+
"recall_at_k": r.recall_at_k, "precision_at_k": r.precision_at_k}
|
| 96 |
+
for r in results
|
| 97 |
+
]}
|
| 98 |
+
|
| 99 |
+
@staticmethod
|
| 100 |
+
def print_report(report: Dict):
|
| 101 |
+
agg = report.get("aggregate", {})
|
| 102 |
+
print("\n" + "=" * 65)
|
| 103 |
+
print(" SEARCH ENGINE EVALUATION REPORT")
|
| 104 |
+
print("=" * 65)
|
| 105 |
+
print(f" Queries evaluated: {agg.get('n_queries', 0)}")
|
| 106 |
+
print(f" Avg latency: {agg.get('avg_latency_ms', 0):.1f} ms")
|
| 107 |
+
print(f" Mean MRR: {agg.get('mean_mrr', 0):.4f}")
|
| 108 |
+
for key, val in sorted(agg.items()):
|
| 109 |
+
if "recall" in key or "precision" in key:
|
| 110 |
+
print(f" {key:25s} {val:.4f}")
|
| 111 |
+
print("=" * 65)
|
backend/app/engine/index.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import faiss
|
| 6 |
+
|
| 7 |
+
from backend.app.config import SearchConfig
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DualFAISSIndex:
|
| 13 |
+
"""
|
| 14 |
+
Two parallel FAISS indices (image + text) fused via Reciprocal Rank Fusion.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, dim: int, config: SearchConfig):
|
| 18 |
+
self.dim = dim
|
| 19 |
+
self.config = config
|
| 20 |
+
self.image_index = None
|
| 21 |
+
self.text_index = None
|
| 22 |
+
|
| 23 |
+
def _create_index(self, n_vectors: int) -> faiss.Index:
|
| 24 |
+
if n_vectors < 5000:
|
| 25 |
+
logger.info(f"Using IndexFlatIP (exact, n={n_vectors:,})")
|
| 26 |
+
return faiss.IndexFlatIP(self.dim)
|
| 27 |
+
|
| 28 |
+
n_clusters = min(self.config.n_clusters, max(16, n_vectors // 40))
|
| 29 |
+
logger.info(f"Using IndexIVFFlat (n={n_vectors:,}, clusters={n_clusters})")
|
| 30 |
+
quantizer = faiss.IndexFlatIP(self.dim)
|
| 31 |
+
index = faiss.IndexIVFFlat(
|
| 32 |
+
quantizer, self.dim, n_clusters, faiss.METRIC_INNER_PRODUCT
|
| 33 |
+
)
|
| 34 |
+
return index
|
| 35 |
+
|
| 36 |
+
def build(self, image_embeddings: np.ndarray, text_embeddings: np.ndarray):
|
| 37 |
+
image_embeddings = image_embeddings.astype(np.float32)
|
| 38 |
+
text_embeddings = text_embeddings.astype(np.float32)
|
| 39 |
+
|
| 40 |
+
assert image_embeddings.shape == text_embeddings.shape, (
|
| 41 |
+
f"Shape mismatch: images {image_embeddings.shape} vs text {text_embeddings.shape}"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
n = image_embeddings.shape[0]
|
| 45 |
+
|
| 46 |
+
logger.info("Building image FAISS index...")
|
| 47 |
+
self.image_index = self._create_index(n)
|
| 48 |
+
if hasattr(self.image_index, 'train'):
|
| 49 |
+
try:
|
| 50 |
+
self.image_index.train(image_embeddings)
|
| 51 |
+
except Exception:
|
| 52 |
+
self.image_index = faiss.IndexFlatIP(self.dim)
|
| 53 |
+
self.image_index.add(image_embeddings)
|
| 54 |
+
|
| 55 |
+
logger.info("Building text FAISS index...")
|
| 56 |
+
self.text_index = self._create_index(n)
|
| 57 |
+
if hasattr(self.text_index, 'train'):
|
| 58 |
+
try:
|
| 59 |
+
self.text_index.train(text_embeddings)
|
| 60 |
+
except Exception:
|
| 61 |
+
self.text_index = faiss.IndexFlatIP(self.dim)
|
| 62 |
+
self.text_index.add(text_embeddings)
|
| 63 |
+
|
| 64 |
+
logger.info(
|
| 65 |
+
f"Dual index built: {self.image_index.ntotal:,} image, "
|
| 66 |
+
f"{self.text_index.ntotal:,} text vectors"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def search_image_index(self, query: np.ndarray, top_k: int):
|
| 70 |
+
q = query.astype(np.float32).reshape(1, -1)
|
| 71 |
+
if hasattr(self.image_index, 'nprobe'):
|
| 72 |
+
self.image_index.nprobe = self.config.n_probe
|
| 73 |
+
return self.image_index.search(q, top_k)
|
| 74 |
+
|
| 75 |
+
def search_text_index(self, query: np.ndarray, top_k: int):
|
| 76 |
+
q = query.astype(np.float32).reshape(1, -1)
|
| 77 |
+
if hasattr(self.text_index, 'nprobe'):
|
| 78 |
+
self.text_index.nprobe = self.config.n_probe
|
| 79 |
+
return self.text_index.search(q, top_k)
|
| 80 |
+
|
| 81 |
+
def search_fused(
|
| 82 |
+
self, query: np.ndarray, top_k: int,
|
| 83 |
+
image_weight: Optional[float] = None,
|
| 84 |
+
text_weight: Optional[float] = None,
|
| 85 |
+
) -> Tuple[List[int], List[float]]:
|
| 86 |
+
iw = image_weight or self.config.image_index_weight
|
| 87 |
+
tw = text_weight or self.config.text_index_weight
|
| 88 |
+
rrf_k = self.config.rrf_k
|
| 89 |
+
|
| 90 |
+
broad_k = min(top_k * 3, self.image_index.ntotal)
|
| 91 |
+
|
| 92 |
+
_, img_ids = self.search_image_index(query, broad_k)
|
| 93 |
+
_, txt_ids = self.search_text_index(query, broad_k)
|
| 94 |
+
img_ids = img_ids[0]
|
| 95 |
+
txt_ids = txt_ids[0]
|
| 96 |
+
|
| 97 |
+
img_rank = {int(idx): rank + 1 for rank, idx in enumerate(img_ids) if idx >= 0}
|
| 98 |
+
txt_rank = {int(idx): rank + 1 for rank, idx in enumerate(txt_ids) if idx >= 0}
|
| 99 |
+
|
| 100 |
+
all_candidates = set(img_rank.keys()) | set(txt_rank.keys())
|
| 101 |
+
scores = {}
|
| 102 |
+
for idx in all_candidates:
|
| 103 |
+
score = 0.0
|
| 104 |
+
if idx in img_rank:
|
| 105 |
+
score += iw / (rrf_k + img_rank[idx])
|
| 106 |
+
if idx in txt_rank:
|
| 107 |
+
score += tw / (rrf_k + txt_rank[idx])
|
| 108 |
+
scores[idx] = score
|
| 109 |
+
|
| 110 |
+
ranked = sorted(scores.items(), key=lambda x: -x[1])[:top_k]
|
| 111 |
+
return [r[0] for r in ranked], [r[1] for r in ranked]
|
| 112 |
+
|
| 113 |
+
def save(self, image_path: str, text_path: str):
|
| 114 |
+
faiss.write_index(self.image_index, image_path)
|
| 115 |
+
faiss.write_index(self.text_index, text_path)
|
| 116 |
+
logger.info(f"Saved dual index to {image_path} and {text_path}")
|
| 117 |
+
|
| 118 |
+
def load(self, image_path: str, text_path: str):
|
| 119 |
+
self.image_index = faiss.read_index(image_path)
|
| 120 |
+
self.text_index = faiss.read_index(text_path)
|
| 121 |
+
logger.info(
|
| 122 |
+
f"Loaded dual index: {self.image_index.ntotal:,} image, "
|
| 123 |
+
f"{self.text_index.ntotal:,} text vectors"
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
__all__ = ["DualFAISSIndex"]
|
backend/app/engine/nlp.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
engine/nlp.py — Multilingual query handling and spell correction.
|
| 3 |
+
|
| 4 |
+
Extracted from finalized_search_engine_full_script.py (lines 80-364).
|
| 5 |
+
Contains:
|
| 6 |
+
- MultilingualHandler: language detection + dictionary-based translation
|
| 7 |
+
- SpellCorrector: Norvig-style spell correction built from the product catalog
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List, Tuple, Set
|
| 13 |
+
from collections import Counter
|
| 14 |
+
|
| 15 |
+
__all__ = ["MultilingualHandler", "SpellCorrector"]
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("asos_search")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 21 |
+
# MULTILINGUAL SUPPORT — lightweight language detection + translation
|
| 22 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 23 |
+
class MultilingualHandler:
|
| 24 |
+
"""
|
| 25 |
+
Detects non-English queries and translates them to English using a
|
| 26 |
+
dictionary-based approach for common fashion terms in major languages.
|
| 27 |
+
For production, swap this with a proper translation API (Google Translate,
|
| 28 |
+
DeepL, or a local model like Helsinki-NLP/opus-mt-*).
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Common fashion terms in multiple languages → English
|
| 32 |
+
FASHION_DICT = {
|
| 33 |
+
# French
|
| 34 |
+
'robe': 'dress', 'jupe': 'skirt', 'chemise': 'shirt', 'pantalon': 'trousers',
|
| 35 |
+
'veste': 'jacket', 'manteau': 'coat', 'chaussures': 'shoes',
|
| 36 |
+
'bottes': 'boots', 'sac': 'bag', 'ceinture': 'belt',
|
| 37 |
+
'rouge': 'red', 'bleu': 'blue', 'noir': 'black', 'blanc': 'white',
|
| 38 |
+
'vert': 'green', 'jaune': 'yellow', 'rose': 'pink', 'gris': 'grey',
|
| 39 |
+
'violet': 'purple', 'marron': 'brown', 'orange': 'orange',
|
| 40 |
+
'élégant': 'elegant', 'décontracté': 'casual', 'chic': 'chic',
|
| 41 |
+
'femme': 'women', 'homme': 'men', 'fille': 'girl',
|
| 42 |
+
'soie': 'silk', 'coton': 'cotton', 'cuir': 'leather', 'lin': 'linen',
|
| 43 |
+
'floral': 'floral', 'rayé': 'striped', 'imprimé': 'printed',
|
| 44 |
+
'été': 'summer', 'hiver': 'winter', 'printemps': 'spring', 'automne': 'autumn',
|
| 45 |
+
'mini': 'mini', 'maxi': 'maxi', 'midi': 'midi',
|
| 46 |
+
'pas cher': 'budget', 'luxe': 'luxury', 'bon marché': 'cheap',
|
| 47 |
+
|
| 48 |
+
# Spanish
|
| 49 |
+
'vestido': 'dress', 'falda': 'skirt', 'camisa': 'shirt',
|
| 50 |
+
'pantalón': 'trousers', 'pantalones': 'trousers', 'chaqueta': 'jacket',
|
| 51 |
+
'abrigo': 'coat', 'zapatos': 'shoes', 'botas': 'boots',
|
| 52 |
+
'bolso': 'bag', 'cinturón': 'belt', 'sombrero': 'hat',
|
| 53 |
+
'rojo': 'red', 'azul': 'blue', 'negro': 'black', 'blanco': 'white',
|
| 54 |
+
'verde': 'green', 'amarillo': 'yellow', 'rosado': 'pink', 'morado': 'purple',
|
| 55 |
+
'marrón': 'brown', 'gris': 'grey', 'naranja': 'orange',
|
| 56 |
+
'elegante': 'elegant', 'informal': 'casual', 'moderno': 'modern',
|
| 57 |
+
'mujer': 'women', 'hombre': 'men', 'barato': 'cheap',
|
| 58 |
+
'algodón': 'cotton', 'seda': 'silk', 'cuero': 'leather',
|
| 59 |
+
'verano': 'summer', 'invierno': 'winter',
|
| 60 |
+
|
| 61 |
+
# German
|
| 62 |
+
'kleid': 'dress', 'rock': 'skirt', 'hemd': 'shirt', 'bluse': 'blouse',
|
| 63 |
+
'hose': 'trousers', 'jacke': 'jacket', 'mantel': 'coat',
|
| 64 |
+
'schuhe': 'shoes', 'stiefel': 'boots', 'tasche': 'bag',
|
| 65 |
+
'gürtel': 'belt', 'hut': 'hat', 'pullover': 'sweater',
|
| 66 |
+
'rot': 'red', 'blau': 'blue', 'schwarz': 'black', 'weiß': 'white',
|
| 67 |
+
'weiss': 'white', 'grün': 'green', 'gelb': 'yellow', 'rosa': 'pink',
|
| 68 |
+
'lila': 'purple', 'braun': 'brown', 'grau': 'grey',
|
| 69 |
+
'frau': 'women', 'herren': 'men', 'damen': 'women',
|
| 70 |
+
'seide': 'silk', 'baumwolle': 'cotton', 'leder': 'leather',
|
| 71 |
+
'sommer': 'summer', 'winter': 'winter',
|
| 72 |
+
|
| 73 |
+
# Italian
|
| 74 |
+
'abito': 'dress', 'gonna': 'skirt', 'camicia': 'shirt',
|
| 75 |
+
'giacca': 'jacket', 'cappotto': 'coat', 'scarpe': 'shoes',
|
| 76 |
+
'stivali': 'boots', 'borsa': 'bag', 'cintura': 'belt',
|
| 77 |
+
'rosso': 'red', 'blu': 'blue', 'nero': 'black', 'bianco': 'white',
|
| 78 |
+
'grigio': 'grey', 'giallo': 'yellow', 'donna': 'women', 'uomo': 'men',
|
| 79 |
+
'seta': 'silk', 'cotone': 'cotton', 'pelle': 'leather',
|
| 80 |
+
'estate': 'summer', 'inverno': 'winter',
|
| 81 |
+
|
| 82 |
+
# Portuguese
|
| 83 |
+
'vestido': 'dress', 'saia': 'skirt', 'calça': 'trousers',
|
| 84 |
+
'jaqueta': 'jacket', 'casaco': 'coat', 'sapatos': 'shoes',
|
| 85 |
+
'bolsa': 'bag', 'vermelho': 'red', 'preto': 'black', 'branco': 'white',
|
| 86 |
+
'mulher': 'women', 'homem': 'men',
|
| 87 |
+
|
| 88 |
+
# Japanese (romaji)
|
| 89 |
+
'doresu': 'dress', 'sukato': 'skirt', 'shatsu': 'shirt',
|
| 90 |
+
'zubon': 'trousers', 'jaketto': 'jacket', 'kutsu': 'shoes',
|
| 91 |
+
'baggu': 'bag', 'aka': 'red', 'ao': 'blue', 'kuro': 'black',
|
| 92 |
+
'shiro': 'white',
|
| 93 |
+
|
| 94 |
+
# Common multilingual fashion terms
|
| 95 |
+
'kimono': 'kimono', 'sari': 'sari', 'hijab': 'hijab',
|
| 96 |
+
'kaftan': 'kaftan', 'poncho': 'poncho',
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
# Character-range heuristics for script detection
|
| 100 |
+
_LATIN_EXTENDED = re.compile(r'[àáâãäåæçèéêëìíîïðñòóôõöùúûüýþÿ]', re.I)
|
| 101 |
+
_CJK = re.compile(r'[\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff]')
|
| 102 |
+
_CYRILLIC = re.compile(r'[\u0400-\u04ff]')
|
| 103 |
+
_ARABIC = re.compile(r'[\u0600-\u06ff]')
|
| 104 |
+
_DEVANAGARI = re.compile(r'[\u0900-\u097f]')
|
| 105 |
+
|
| 106 |
+
@classmethod
|
| 107 |
+
def detect_language(cls, text: str) -> str:
|
| 108 |
+
"""Return a rough language tag: 'en', 'fr', 'es', 'de', 'it', 'pt', 'ja', 'zh', 'ar', 'hi', 'ru', or 'other'."""
|
| 109 |
+
if cls._CJK.search(text):
|
| 110 |
+
return 'ja' if re.search(r'[\u3040-\u30ff]', text) else 'zh'
|
| 111 |
+
if cls._CYRILLIC.search(text):
|
| 112 |
+
return 'ru'
|
| 113 |
+
if cls._ARABIC.search(text):
|
| 114 |
+
return 'ar'
|
| 115 |
+
if cls._DEVANAGARI.search(text):
|
| 116 |
+
return 'hi'
|
| 117 |
+
|
| 118 |
+
words = set(re.findall(r'\b[a-zàáâãäåæçèéêëìíîïñòóôõöùúûüýÿ]+\b', text.lower()))
|
| 119 |
+
# French markers
|
| 120 |
+
fr_markers = {'le', 'la', 'les', 'un', 'une', 'des', 'du', 'de', 'et', 'en', 'pour', 'avec', 'je', 'ce', 'cette'}
|
| 121 |
+
es_markers = {'el', 'la', 'los', 'las', 'un', 'una', 'de', 'en', 'y', 'para', 'con', 'por', 'que', 'muy'}
|
| 122 |
+
de_markers = {'der', 'die', 'das', 'ein', 'eine', 'und', 'für', 'mit', 'ich', 'ist', 'nicht', 'auch'}
|
| 123 |
+
it_markers = {'il', 'lo', 'la', 'gli', 'le', 'un', 'una', 'di', 'e', 'per', 'con', 'che', 'sono'}
|
| 124 |
+
pt_markers = {'o', 'a', 'os', 'as', 'um', 'uma', 'de', 'em', 'para', 'com', 'que', 'não'}
|
| 125 |
+
|
| 126 |
+
scores = {
|
| 127 |
+
'fr': len(words & fr_markers),
|
| 128 |
+
'es': len(words & es_markers),
|
| 129 |
+
'de': len(words & de_markers),
|
| 130 |
+
'it': len(words & it_markers),
|
| 131 |
+
'pt': len(words & pt_markers),
|
| 132 |
+
}
|
| 133 |
+
best = max(scores, key=scores.get)
|
| 134 |
+
if scores[best] >= 2:
|
| 135 |
+
return best
|
| 136 |
+
|
| 137 |
+
# Check if any words are in our fashion dictionary
|
| 138 |
+
dict_words = words & set(cls.FASHION_DICT.keys())
|
| 139 |
+
en_words = {'the', 'a', 'an', 'in', 'on', 'for', 'with', 'and', 'or', 'is', 'are'}
|
| 140 |
+
if dict_words and not (words & en_words):
|
| 141 |
+
return 'other'
|
| 142 |
+
|
| 143 |
+
return 'en'
|
| 144 |
+
|
| 145 |
+
@classmethod
|
| 146 |
+
def translate_query(cls, query: str) -> Tuple[str, str, bool]:
|
| 147 |
+
"""
|
| 148 |
+
Translate a query to English using the fashion dictionary.
|
| 149 |
+
|
| 150 |
+
Returns: (translated_query, detected_language, was_translated)
|
| 151 |
+
"""
|
| 152 |
+
lang = cls.detect_language(query)
|
| 153 |
+
|
| 154 |
+
if lang == 'en':
|
| 155 |
+
return query, 'en', False
|
| 156 |
+
|
| 157 |
+
# For non-Latin scripts, we can't do dictionary translation
|
| 158 |
+
if lang in ('ja', 'zh', 'ar', 'hi', 'ru'):
|
| 159 |
+
logger.info(f"Non-Latin script detected ({lang}). Passing through to CLIP.")
|
| 160 |
+
return query, lang, False
|
| 161 |
+
|
| 162 |
+
# Dictionary-based word-by-word translation for Latin-script languages
|
| 163 |
+
words = query.lower().split()
|
| 164 |
+
translated = []
|
| 165 |
+
was_translated = False
|
| 166 |
+
|
| 167 |
+
i = 0
|
| 168 |
+
while i < len(words):
|
| 169 |
+
# Try 2-word phrases first
|
| 170 |
+
if i + 1 < len(words):
|
| 171 |
+
bigram = f"{words[i]} {words[i+1]}"
|
| 172 |
+
if bigram in cls.FASHION_DICT:
|
| 173 |
+
translated.append(cls.FASHION_DICT[bigram])
|
| 174 |
+
was_translated = True
|
| 175 |
+
i += 2
|
| 176 |
+
continue
|
| 177 |
+
|
| 178 |
+
word = words[i]
|
| 179 |
+
if word in cls.FASHION_DICT:
|
| 180 |
+
translated.append(cls.FASHION_DICT[word])
|
| 181 |
+
was_translated = True
|
| 182 |
+
else:
|
| 183 |
+
translated.append(word)
|
| 184 |
+
i += 1
|
| 185 |
+
|
| 186 |
+
result = ' '.join(translated)
|
| 187 |
+
if was_translated:
|
| 188 |
+
logger.info(f"Translated [{lang}]: \"{query}\" → \"{result}\"")
|
| 189 |
+
|
| 190 |
+
return result, lang, was_translated
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 194 |
+
# QUERY SPELL-CORRECTION
|
| 195 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 196 |
+
class SpellCorrector:
|
| 197 |
+
"""
|
| 198 |
+
Lightweight spell correction for fashion search queries.
|
| 199 |
+
Uses a vocabulary built from the product catalog + common fashion terms.
|
| 200 |
+
Based on Peter Norvig's spell corrector algorithm.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
def __init__(self):
|
| 204 |
+
self.word_freq: Counter = Counter()
|
| 205 |
+
self._ready = False
|
| 206 |
+
|
| 207 |
+
def fit(self, texts: List[str]):
|
| 208 |
+
"""Build vocabulary from product catalog texts."""
|
| 209 |
+
for text in texts:
|
| 210 |
+
words = re.findall(r'\b[a-z]+\b', str(text).lower())
|
| 211 |
+
self.word_freq.update(words)
|
| 212 |
+
|
| 213 |
+
# Boost common fashion terms
|
| 214 |
+
fashion_boost = [
|
| 215 |
+
'dress', 'dresses', 'skirt', 'shirt', 'blouse', 'jacket', 'coat',
|
| 216 |
+
'jeans', 'trousers', 'shorts', 'hoodie', 'sweater', 'cardigan',
|
| 217 |
+
'boots', 'sneakers', 'trainers', 'sandals', 'heels', 'shoes',
|
| 218 |
+
'bag', 'handbag', 'tote', 'backpack', 'clutch',
|
| 219 |
+
'black', 'white', 'blue', 'red', 'green', 'pink', 'yellow',
|
| 220 |
+
'purple', 'brown', 'grey', 'gray', 'navy', 'beige', 'cream',
|
| 221 |
+
'casual', 'formal', 'elegant', 'vintage', 'boho', 'minimalist',
|
| 222 |
+
'streetwear', 'oversized', 'cropped', 'fitted', 'floral',
|
| 223 |
+
'leather', 'denim', 'satin', 'silk', 'cotton', 'linen',
|
| 224 |
+
'summer', 'winter', 'spring', 'autumn', 'party', 'office',
|
| 225 |
+
'midi', 'mini', 'maxi', 'sequin', 'lace', 'velvet',
|
| 226 |
+
]
|
| 227 |
+
for w in fashion_boost:
|
| 228 |
+
self.word_freq[w] += 1000
|
| 229 |
+
|
| 230 |
+
self._ready = True
|
| 231 |
+
logger.info(f"SpellCorrector fitted with {len(self.word_freq):,} words")
|
| 232 |
+
|
| 233 |
+
def _edits1(self, word: str) -> Set[str]:
|
| 234 |
+
"""All edits that are one edit distance away from `word`."""
|
| 235 |
+
letters = 'abcdefghijklmnopqrstuvwxyz'
|
| 236 |
+
splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
|
| 237 |
+
deletes = [L + R[1:] for L, R in splits if R]
|
| 238 |
+
transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
|
| 239 |
+
replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
|
| 240 |
+
inserts = [L + c + R for L, R in splits for c in letters]
|
| 241 |
+
return set(deletes + transposes + replaces + inserts)
|
| 242 |
+
|
| 243 |
+
def _edits2(self, word: str) -> Set[str]:
|
| 244 |
+
"""All edits that are two edits away from `word`."""
|
| 245 |
+
return set(e2 for e1 in self._edits1(word) for e2 in self._edits1(e1))
|
| 246 |
+
|
| 247 |
+
def _known(self, words: Set[str]) -> Set[str]:
|
| 248 |
+
"""Subset of words that are in the vocabulary."""
|
| 249 |
+
return words & set(self.word_freq.keys())
|
| 250 |
+
|
| 251 |
+
def correct_word(self, word: str) -> str:
|
| 252 |
+
"""Return the most likely spelling correction for a single word."""
|
| 253 |
+
if not self._ready or len(word) <= 2:
|
| 254 |
+
return word
|
| 255 |
+
|
| 256 |
+
word_lower = word.lower()
|
| 257 |
+
|
| 258 |
+
# Already known
|
| 259 |
+
if word_lower in self.word_freq:
|
| 260 |
+
return word
|
| 261 |
+
|
| 262 |
+
# Edit distance 1
|
| 263 |
+
candidates = self._known(self._edits1(word_lower))
|
| 264 |
+
if candidates:
|
| 265 |
+
best = max(candidates, key=self.word_freq.get)
|
| 266 |
+
if self.word_freq[best] > 10: # Only correct if the candidate is common enough
|
| 267 |
+
return best
|
| 268 |
+
|
| 269 |
+
# Edit distance 2 (only for longer words)
|
| 270 |
+
if len(word_lower) >= 5:
|
| 271 |
+
candidates = self._known(self._edits2(word_lower))
|
| 272 |
+
if candidates:
|
| 273 |
+
best = max(candidates, key=self.word_freq.get)
|
| 274 |
+
if self.word_freq[best] > 50:
|
| 275 |
+
return best
|
| 276 |
+
|
| 277 |
+
return word
|
| 278 |
+
|
| 279 |
+
def correct_query(self, query: str) -> Tuple[str, bool]:
|
| 280 |
+
"""
|
| 281 |
+
Correct a full query string.
|
| 282 |
+
Returns: (corrected_query, was_corrected)
|
| 283 |
+
"""
|
| 284 |
+
if not self._ready:
|
| 285 |
+
return query, False
|
| 286 |
+
|
| 287 |
+
words = query.split()
|
| 288 |
+
corrected = []
|
| 289 |
+
was_corrected = False
|
| 290 |
+
|
| 291 |
+
for word in words:
|
| 292 |
+
# Don't correct price tokens, numbers, or currency symbols
|
| 293 |
+
if re.match(r'^[£$€]?\d', word) or len(word) <= 2:
|
| 294 |
+
corrected.append(word)
|
| 295 |
+
continue
|
| 296 |
+
|
| 297 |
+
fixed = self.correct_word(word)
|
| 298 |
+
if fixed != word:
|
| 299 |
+
was_corrected = True
|
| 300 |
+
corrected.append(fixed)
|
| 301 |
+
else:
|
| 302 |
+
corrected.append(word)
|
| 303 |
+
|
| 304 |
+
result = ' '.join(corrected)
|
| 305 |
+
if was_corrected:
|
| 306 |
+
logger.info(f"Spell-corrected: \"{query}\" → \"{result}\"")
|
| 307 |
+
return result, was_corrected
|
backend/app/engine/query_parser.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
engine/query_parser.py
|
| 3 |
+
|
| 4 |
+
Extracted from finalized_search_engine_full_script.py (lines 776-1056).
|
| 5 |
+
Contains the ParsedQuery dataclass and QueryParser class responsible for
|
| 6 |
+
converting natural-language fashion queries into structured filter intents.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import logging
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
__all__ = ["ParsedQuery", "QueryParser"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class ParsedQuery:
|
| 21 |
+
raw_query: str
|
| 22 |
+
vibe_text: str
|
| 23 |
+
|
| 24 |
+
category_filter: Optional[str] = None
|
| 25 |
+
color_filter: Optional[str] = None
|
| 26 |
+
gender_filter: Optional[str] = None
|
| 27 |
+
price_min: Optional[float] = None
|
| 28 |
+
price_max: Optional[float] = None
|
| 29 |
+
brand_filter: Optional[str] = None
|
| 30 |
+
size_filter: Optional[str] = None
|
| 31 |
+
material_filter: Optional[str] = None
|
| 32 |
+
exclusions: List[str] = field(default_factory=list)
|
| 33 |
+
in_stock_only: bool = True
|
| 34 |
+
|
| 35 |
+
style_tags: List[str] = field(default_factory=list)
|
| 36 |
+
|
| 37 |
+
has_image: bool = False
|
| 38 |
+
text_weight: float = 0.5
|
| 39 |
+
|
| 40 |
+
# Multilingual / correction metadata
|
| 41 |
+
original_query: Optional[str] = None
|
| 42 |
+
detected_language: str = "en"
|
| 43 |
+
was_translated: bool = False
|
| 44 |
+
was_spell_corrected: bool = False
|
| 45 |
+
spell_correction_suggestion: Optional[str] = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class QueryParser:
|
| 49 |
+
"""Parses natural language fashion queries into structured intents."""
|
| 50 |
+
|
| 51 |
+
PRICE_PATTERNS = [
|
| 52 |
+
(r'[£$€]?\s*(\d+(?:\.\d+)?)\s*[-–to]+\s*[£$€]?\s*(\d+(?:\.\d+)?)', 'range'),
|
| 53 |
+
(r'(?:under|below|less\s+than|max|up\s+to|cheaper\s+than)\s*[£$€]?\s*(\d+(?:\.\d+)?)', 'max'),
|
| 54 |
+
(r'(?:over|above|more\s+than|min|at\s+least|from)\s*[£$€]?\s*(\d+(?:\.\d+)?)', 'min'),
|
| 55 |
+
(r'\b(?:budget|cheap|affordable|bargain|inexpensive|value)\b', 'budget'),
|
| 56 |
+
(r'\b(?:luxury|premium|high[\s-]?end|designer|expensive|splurge)\b', 'luxury'),
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
CATEGORY_TRIGGERS = {
|
| 60 |
+
'midi dress': 'Dresses', 'maxi dress': 'Dresses',
|
| 61 |
+
'mini dress': 'Dresses', 'slip dress': 'Dresses',
|
| 62 |
+
'bodycon': 'Dresses', 'dress': 'Dresses',
|
| 63 |
+
'dresses': 'Dresses', 'gown': 'Dresses',
|
| 64 |
+
|
| 65 |
+
'trench coat': 'Coats & Jackets', 'puffer jacket': 'Coats & Jackets',
|
| 66 |
+
'leather jacket': 'Coats & Jackets', 'denim jacket': 'Coats & Jackets',
|
| 67 |
+
'bomber jacket': 'Coats & Jackets',
|
| 68 |
+
'jacket': 'Coats & Jackets', 'coat': 'Coats & Jackets',
|
| 69 |
+
'blazer': 'Coats & Jackets', 'parka': 'Coats & Jackets',
|
| 70 |
+
|
| 71 |
+
't-shirt': 'Tops', 'tee': 'Tops',
|
| 72 |
+
'blouse': 'Tops', 'shirt': 'Tops',
|
| 73 |
+
'crop top': 'Tops', 'cami': 'Tops',
|
| 74 |
+
'bodysuit': 'Tops', 'top': 'Tops', 'tops': 'Tops',
|
| 75 |
+
|
| 76 |
+
'cardigan': 'Knitwear', 'jumper': 'Knitwear',
|
| 77 |
+
'sweater': 'Knitwear', 'pullover': 'Knitwear', 'knitwear': 'Knitwear',
|
| 78 |
+
|
| 79 |
+
'hoodie': 'Hoodies & Sweatshirts', 'sweatshirt': 'Hoodies & Sweatshirts',
|
| 80 |
+
|
| 81 |
+
'jeans': 'Jeans',
|
| 82 |
+
'trousers': 'Trousers', 'pants': 'Trousers',
|
| 83 |
+
'joggers': 'Trousers', 'leggings': 'Trousers', 'cargo': 'Trousers',
|
| 84 |
+
|
| 85 |
+
'shorts': 'Shorts',
|
| 86 |
+
|
| 87 |
+
'skirt': 'Skirts', 'midi skirt': 'Skirts', 'mini skirt': 'Skirts',
|
| 88 |
+
|
| 89 |
+
'trainers': 'Shoes', 'sneakers': 'Shoes',
|
| 90 |
+
'boots': 'Shoes', 'heels': 'Shoes',
|
| 91 |
+
'sandals': 'Shoes', 'loafers': 'Shoes',
|
| 92 |
+
'shoes': 'Shoes', 'mules': 'Shoes',
|
| 93 |
+
'platforms': 'Shoes', 'flats': 'Shoes',
|
| 94 |
+
|
| 95 |
+
'bag': 'Bags', 'handbag': 'Bags',
|
| 96 |
+
'tote': 'Bags', 'backpack': 'Bags',
|
| 97 |
+
'clutch': 'Bags', 'crossbody': 'Bags',
|
| 98 |
+
|
| 99 |
+
'watch': 'Accessories', 'sunglasses': 'Accessories',
|
| 100 |
+
'hat': 'Accessories', 'cap': 'Accessories',
|
| 101 |
+
'scarf': 'Accessories', 'belt': 'Accessories',
|
| 102 |
+
'jewellery': 'Accessories', 'jewelry': 'Accessories',
|
| 103 |
+
'necklace': 'Accessories', 'bracelet': 'Accessories',
|
| 104 |
+
'earrings': 'Accessories', 'ring': 'Accessories',
|
| 105 |
+
|
| 106 |
+
'swimsuit': 'Swimwear', 'bikini': 'Swimwear', 'swim': 'Swimwear',
|
| 107 |
+
'suit': 'Suits & Tailoring', 'waistcoat': 'Suits & Tailoring',
|
| 108 |
+
'jumpsuit': 'Jumpsuits & Playsuits', 'playsuit': 'Jumpsuits & Playsuits',
|
| 109 |
+
'romper': 'Jumpsuits & Playsuits',
|
| 110 |
+
'lingerie': 'Underwear & Socks', 'bra': 'Underwear & Socks',
|
| 111 |
+
'briefs': 'Underwear & Socks', 'boxers': 'Underwear & Socks',
|
| 112 |
+
'socks': 'Underwear & Socks',
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
# ── FIX: COLOR_MAP now outputs LOWERCASE to match actual data values ──
|
| 116 |
+
COLOR_MAP = {
|
| 117 |
+
'red': 'red', 'scarlet': 'red', 'crimson': 'red',
|
| 118 |
+
'blue': 'blue', 'cobalt': 'blue',
|
| 119 |
+
'sky blue': 'blue', 'teal': 'blue', 'aqua': 'blue',
|
| 120 |
+
'navy': 'navy', # data has 'navy' as its own family
|
| 121 |
+
'green': 'green', 'olive': 'green', 'emerald': 'green',
|
| 122 |
+
'sage': 'green', 'mint': 'green',
|
| 123 |
+
'khaki': 'khaki', # data has 'khaki' as its own family
|
| 124 |
+
'black': 'black', 'charcoal': 'black',
|
| 125 |
+
'white': 'white', 'cream': 'white', 'ivory': 'white',
|
| 126 |
+
'pink': 'pink', 'blush': 'pink', 'rose': 'pink',
|
| 127 |
+
'fuchsia': 'pink', 'magenta': 'pink', 'coral': 'pink',
|
| 128 |
+
'yellow': 'yellow', 'gold': 'yellow', 'mustard': 'yellow',
|
| 129 |
+
'orange': 'orange', 'rust': 'orange', 'terracotta': 'orange',
|
| 130 |
+
'brown': 'brown', 'tan': 'brown', 'camel': 'brown',
|
| 131 |
+
'beige': 'beige', 'taupe': 'beige', # data has 'beige' as its own family
|
| 132 |
+
'chocolate': 'brown',
|
| 133 |
+
'purple': 'purple', 'lilac': 'purple', 'plum': 'purple',
|
| 134 |
+
'lavender': 'purple', 'violet': 'purple', 'mauve': 'purple',
|
| 135 |
+
'burgundy': 'burgundy', # data has 'burgundy' as its own family
|
| 136 |
+
'grey': 'grey', 'gray': 'grey', 'silver': 'grey',
|
| 137 |
+
'multi': 'multi', 'rainbow': 'multi', 'multicolour': 'multi',
|
| 138 |
+
'multicolor': 'multi',
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
GENDER_TRIGGERS = {
|
| 142 |
+
"men's": "Men", "mens": "Men", "male": "Men", "for men": "Men",
|
| 143 |
+
"for him": "Men", "boys": "Men", "masculine": "Men",
|
| 144 |
+
"women's": "Women", "womens": "Women", "female": "Women",
|
| 145 |
+
"for women": "Women", "for her": "Women", "girls": "Women",
|
| 146 |
+
"ladies": "Women", "feminine": "Women",
|
| 147 |
+
"unisex": "Unisex",
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
STYLE_TAGS = [
|
| 151 |
+
'casual', 'formal', 'streetwear', 'boho', 'bohemian', 'minimalist',
|
| 152 |
+
'vintage', 'retro', 'y2k', 'goth', 'gothic', 'punk', 'preppy',
|
| 153 |
+
'athleisure', 'sporty', 'elegant', 'chic', 'edgy', 'romantic',
|
| 154 |
+
'classic', 'modern', 'oversized', 'cropped', 'fitted', 'relaxed',
|
| 155 |
+
'floral', 'striped', 'plaid', 'animal print', 'leopard', 'sequin',
|
| 156 |
+
'lace', 'denim', 'leather', 'satin', 'silk', 'velvet', 'knit',
|
| 157 |
+
'sustainable', 'eco', 'organic', 'recycled',
|
| 158 |
+
'festival', 'party', 'office', 'workwear', 'loungewear', 'sleepwear',
|
| 159 |
+
'coastal', 'cottagecore', 'grunge', 'cyber', 'futuristic',
|
| 160 |
+
'western', 'nautical', 'tropical', 'safari',
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
MATERIAL_KEYWORDS = {
|
| 164 |
+
'silk': 'silk', 'satin': 'satin', 'velvet': 'velvet',
|
| 165 |
+
'leather': 'leather', 'faux leather': 'faux leather',
|
| 166 |
+
'denim': 'denim', 'cotton': 'cotton', 'linen': 'linen',
|
| 167 |
+
'wool': 'wool', 'cashmere': 'cashmere', 'polyester': 'polyester',
|
| 168 |
+
'nylon': 'nylon', 'suede': 'suede', 'chiffon': 'chiffon',
|
| 169 |
+
'mesh': 'mesh', 'jersey': 'jersey',
|
| 170 |
+
'tweed': 'tweed', 'corduroy': 'corduroy', 'fleece': 'fleece',
|
| 171 |
+
'crochet': 'crochet', 'organza': 'organza', 'tulle': 'tulle',
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
SIZE_PATTERNS = [
|
| 175 |
+
(r'\bsize\s+(xx?s|xx?l|small|medium|large)\b', 'named'),
|
| 176 |
+
(r'\b(xx?s|xx?l)\b', 'named_bare'),
|
| 177 |
+
(r'\bsize\s+(\d{1,2})\b', 'numeric'),
|
| 178 |
+
(r'\buk\s+(\d{1,2})\b', 'numeric'),
|
| 179 |
+
(r'\beu\s+(\d{2})\b', 'eu'),
|
| 180 |
+
]
|
| 181 |
+
|
| 182 |
+
_SIZE_NORMALIZE = {
|
| 183 |
+
'xxs': 'XXS', 'xs': 'XS', 'x-small': 'XS', 'xsmall': 'XS',
|
| 184 |
+
's': 'S', 'small': 'S',
|
| 185 |
+
'm': 'M', 'medium': 'M',
|
| 186 |
+
'l': 'L', 'large': 'L',
|
| 187 |
+
'xl': 'XL', 'x-large': 'XL', 'xlarge': 'XL',
|
| 188 |
+
'xxl': 'XXL',
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
EXCLUSION_PATTERNS = [
|
| 192 |
+
r'\bnot\s+(\w+(?:\s+\w+)?)',
|
| 193 |
+
r'\bwithout\s+(\w+(?:\s+\w+)?)',
|
| 194 |
+
r'\bno\s+(\w+)',
|
| 195 |
+
r'\bexcluding\s+(\w+(?:\s+\w+)?)',
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
def parse(self, query: str) -> ParsedQuery:
|
| 199 |
+
raw = query.strip()
|
| 200 |
+
q = raw.lower()
|
| 201 |
+
vibe = q
|
| 202 |
+
|
| 203 |
+
# Price
|
| 204 |
+
price_min, price_max = None, None
|
| 205 |
+
for pattern, ptype in self.PRICE_PATTERNS:
|
| 206 |
+
m = re.search(pattern, q)
|
| 207 |
+
if m:
|
| 208 |
+
if ptype == 'range':
|
| 209 |
+
price_min, price_max = float(m.group(1)), float(m.group(2))
|
| 210 |
+
elif ptype == 'max':
|
| 211 |
+
price_max = float(m.group(1))
|
| 212 |
+
elif ptype == 'min':
|
| 213 |
+
price_min = float(m.group(1))
|
| 214 |
+
elif ptype == 'budget':
|
| 215 |
+
price_max = 30.0
|
| 216 |
+
elif ptype == 'luxury':
|
| 217 |
+
price_min = 100.0
|
| 218 |
+
vibe = vibe[:m.start()] + vibe[m.end():]
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
# Category
|
| 222 |
+
category = None
|
| 223 |
+
for trigger, cat in sorted(self.CATEGORY_TRIGGERS.items(), key=lambda x: -len(x[0])):
|
| 224 |
+
if re.search(r'\b' + re.escape(trigger) + r'\b', q):
|
| 225 |
+
category = cat
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
# Color
|
| 229 |
+
color = None
|
| 230 |
+
for color_term, family in sorted(self.COLOR_MAP.items(), key=lambda x: -len(x[0])):
|
| 231 |
+
if re.search(r'\b' + re.escape(color_term) + r'\b', q):
|
| 232 |
+
color = family
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
+
# Gender
|
| 236 |
+
gender = None
|
| 237 |
+
for trigger, gen in self.GENDER_TRIGGERS.items():
|
| 238 |
+
if trigger in q:
|
| 239 |
+
gender = gen
|
| 240 |
+
vibe = vibe.replace(trigger, '')
|
| 241 |
+
break
|
| 242 |
+
|
| 243 |
+
# Style tags
|
| 244 |
+
tags = [t for t in self.STYLE_TAGS if re.search(r'\b' + re.escape(t) + r'\b', q)]
|
| 245 |
+
|
| 246 |
+
# Material
|
| 247 |
+
material = None
|
| 248 |
+
for mat_term, mat_val in sorted(self.MATERIAL_KEYWORDS.items(), key=lambda x: -len(x[0])):
|
| 249 |
+
if re.search(r'\b' + re.escape(mat_term) + r'\b', q):
|
| 250 |
+
material = mat_val
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
# Size
|
| 254 |
+
size = None
|
| 255 |
+
for pattern, stype in self.SIZE_PATTERNS:
|
| 256 |
+
m = re.search(pattern, q)
|
| 257 |
+
if m:
|
| 258 |
+
raw_size = m.group(1).lower()
|
| 259 |
+
if stype in ('named', 'named_bare'):
|
| 260 |
+
size = self._SIZE_NORMALIZE.get(raw_size, raw_size.upper())
|
| 261 |
+
elif stype == 'numeric':
|
| 262 |
+
size = raw_size # keep as string "10", "12", etc.
|
| 263 |
+
elif stype == 'eu':
|
| 264 |
+
size = f"EU {raw_size}"
|
| 265 |
+
vibe = vibe[:m.start()] + vibe[m.end():]
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
# Exclusions ("not floral", "without black", "no heels")
|
| 269 |
+
exclusions = []
|
| 270 |
+
spans_to_remove = []
|
| 271 |
+
for exc_pattern in self.EXCLUSION_PATTERNS:
|
| 272 |
+
for m in re.finditer(exc_pattern, q):
|
| 273 |
+
excluded_term = m.group(1).strip()
|
| 274 |
+
if excluded_term and excluded_term not in exclusions:
|
| 275 |
+
exclusions.append(excluded_term)
|
| 276 |
+
spans_to_remove.append((m.start(), m.end()))
|
| 277 |
+
# Remove exclusion spans from vibe in reverse order to preserve positions
|
| 278 |
+
for start, end in sorted(spans_to_remove, reverse=True):
|
| 279 |
+
vibe = vibe[:start] + vibe[end:]
|
| 280 |
+
|
| 281 |
+
# Resolve material+exclusion conflict: if user says "no cotton",
|
| 282 |
+
# cotton is excluded, not desired as a material filter
|
| 283 |
+
if material and material.lower() in [e.lower() for e in exclusions]:
|
| 284 |
+
material = None
|
| 285 |
+
|
| 286 |
+
# Clean vibe text
|
| 287 |
+
vibe = re.sub(r'[£$€]\s*\d+', '', vibe)
|
| 288 |
+
vibe = re.sub(r'\b(under|below|over|above|less than|more than|up to)\b', '', vibe)
|
| 289 |
+
vibe = re.sub(r'\s+', ' ', vibe).strip()
|
| 290 |
+
if not vibe:
|
| 291 |
+
vibe = raw
|
| 292 |
+
|
| 293 |
+
return ParsedQuery(
|
| 294 |
+
raw_query=raw, vibe_text=vibe,
|
| 295 |
+
category_filter=category, color_filter=color,
|
| 296 |
+
gender_filter=gender, price_min=price_min, price_max=price_max,
|
| 297 |
+
style_tags=tags, material_filter=material,
|
| 298 |
+
size_filter=size, exclusions=exclusions,
|
| 299 |
+
)
|
backend/app/engine/reranker.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections import Counter
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from backend.app.config import SearchConfig
|
| 9 |
+
from backend.app.engine.query_parser import ParsedQuery
|
| 10 |
+
from backend.app.engine.bm25 import SimpleBM25
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger("asos_search")
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"apply_filters",
|
| 16 |
+
"relax_and_retry",
|
| 17 |
+
"hybrid_rerank",
|
| 18 |
+
"generate_suggestions",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def apply_filters(candidates: pd.DataFrame, parsed: ParsedQuery) -> pd.DataFrame:
|
| 23 |
+
df = candidates
|
| 24 |
+
if parsed.category_filter and 'category' in df.columns:
|
| 25 |
+
df = df[df['category'] == parsed.category_filter]
|
| 26 |
+
if parsed.color_filter and 'color_family' in df.columns:
|
| 27 |
+
df = df[df['color_family'].str.lower() == parsed.color_filter.lower()]
|
| 28 |
+
if parsed.gender_filter and 'gender' in df.columns:
|
| 29 |
+
df = df[(df['gender'] == parsed.gender_filter) | (df['gender'] == 'Unisex')]
|
| 30 |
+
if parsed.price_min is not None and 'price' in df.columns:
|
| 31 |
+
df = df[df['price'] >= parsed.price_min]
|
| 32 |
+
if parsed.price_max is not None and 'price' in df.columns:
|
| 33 |
+
df = df[df['price'] <= parsed.price_max]
|
| 34 |
+
if parsed.brand_filter and 'brand' in df.columns:
|
| 35 |
+
df = df[df['brand'].str.lower() == parsed.brand_filter.lower()]
|
| 36 |
+
|
| 37 |
+
# ── Size filtering (v3.3) ──
|
| 38 |
+
if parsed.size_filter and 'sizes_available' in df.columns:
|
| 39 |
+
size_val = parsed.size_filter.lower().strip()
|
| 40 |
+
df = df[df['sizes_available'].apply(
|
| 41 |
+
lambda sizes: any(
|
| 42 |
+
size_val == str(s).lower().strip()
|
| 43 |
+
for s in (sizes if isinstance(sizes, list) else [])
|
| 44 |
+
) if isinstance(sizes, list) else False
|
| 45 |
+
)]
|
| 46 |
+
|
| 47 |
+
# ── Material filtering (v3.3) ──
|
| 48 |
+
if parsed.material_filter and 'materials' in df.columns:
|
| 49 |
+
mat = parsed.material_filter.lower()
|
| 50 |
+
df = df[df['materials'].apply(
|
| 51 |
+
lambda mats: (
|
| 52 |
+
any(mat in str(m).lower() for m in mats)
|
| 53 |
+
if isinstance(mats, list) and len(mats) > 0
|
| 54 |
+
else mat in str(mats).lower() if mats else False
|
| 55 |
+
)
|
| 56 |
+
)]
|
| 57 |
+
|
| 58 |
+
# ── Exclusion filtering (v3.3) ──
|
| 59 |
+
if parsed.exclusions:
|
| 60 |
+
for excl in parsed.exclusions:
|
| 61 |
+
excl_lower = excl.lower()
|
| 62 |
+
# Check against name, color, category, style_tags, materials
|
| 63 |
+
mask = pd.Series(True, index=df.index)
|
| 64 |
+
if 'name' in df.columns:
|
| 65 |
+
mask &= ~df['name'].str.lower().str.contains(excl_lower, na=False)
|
| 66 |
+
if 'color_clean' in df.columns:
|
| 67 |
+
mask &= ~df['color_clean'].str.lower().str.contains(excl_lower, na=False)
|
| 68 |
+
if 'color_family' in df.columns:
|
| 69 |
+
mask &= ~(df['color_family'].str.lower() == excl_lower)
|
| 70 |
+
if 'style_tags' in df.columns:
|
| 71 |
+
mask &= ~df['style_tags'].apply(
|
| 72 |
+
lambda tags: any(excl_lower in str(t).lower() for t in tags)
|
| 73 |
+
if isinstance(tags, list) else False
|
| 74 |
+
)
|
| 75 |
+
if 'materials' in df.columns:
|
| 76 |
+
mask &= ~df['materials'].apply(
|
| 77 |
+
lambda mats: any(excl_lower in str(m).lower()
|
| 78 |
+
for m in (mats if isinstance(mats, list) else []))
|
| 79 |
+
)
|
| 80 |
+
df = df[mask]
|
| 81 |
+
|
| 82 |
+
if parsed.in_stock_only and 'any_in_stock' in df.columns:
|
| 83 |
+
df = df[df['any_in_stock'] == True]
|
| 84 |
+
return df
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def relax_and_retry(candidates: pd.DataFrame, parsed: ParsedQuery,
|
| 88 |
+
min_results: int = 10) -> pd.DataFrame:
|
| 89 |
+
"""
|
| 90 |
+
Smart progressive filter relaxation.
|
| 91 |
+
|
| 92 |
+
Key improvement: instead of dropping price_max entirely (which shows
|
| 93 |
+
£200 items for "under £10"), we progressively expand the budget in
|
| 94 |
+
steps (×1.5, ×2, ×3, ×5) so the user sees the cheapest viable options.
|
| 95 |
+
"""
|
| 96 |
+
relaxed = ParsedQuery(
|
| 97 |
+
raw_query=parsed.raw_query, vibe_text=parsed.vibe_text,
|
| 98 |
+
category_filter=parsed.category_filter, color_filter=parsed.color_filter,
|
| 99 |
+
gender_filter=parsed.gender_filter, price_min=parsed.price_min,
|
| 100 |
+
price_max=parsed.price_max, brand_filter=parsed.brand_filter,
|
| 101 |
+
in_stock_only=parsed.in_stock_only, style_tags=parsed.style_tags,
|
| 102 |
+
material_filter=parsed.material_filter, size_filter=parsed.size_filter,
|
| 103 |
+
exclusions=parsed.exclusions,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
best_so_far = pd.DataFrame()
|
| 107 |
+
|
| 108 |
+
# Phase 0: Try relaxing size and material first (least important constraints)
|
| 109 |
+
# Try each independently before committing
|
| 110 |
+
for attr in ('size_filter', 'material_filter'):
|
| 111 |
+
if getattr(relaxed, attr) is not None:
|
| 112 |
+
saved = getattr(relaxed, attr)
|
| 113 |
+
setattr(relaxed, attr, None)
|
| 114 |
+
result = apply_filters(candidates, relaxed)
|
| 115 |
+
if len(result) >= min_results:
|
| 116 |
+
logger.info(f"Relaxed filter '{attr}' -> {len(result)} results")
|
| 117 |
+
return result
|
| 118 |
+
if len(result) > len(best_so_far):
|
| 119 |
+
best_so_far = result
|
| 120 |
+
else:
|
| 121 |
+
setattr(relaxed, attr, saved) # restore if it didn't help
|
| 122 |
+
|
| 123 |
+
# Phase 0b: Relax exclusions if they're too restrictive
|
| 124 |
+
if relaxed.exclusions:
|
| 125 |
+
relaxed.exclusions = []
|
| 126 |
+
result = apply_filters(candidates, relaxed)
|
| 127 |
+
if len(result) > len(best_so_far):
|
| 128 |
+
best_so_far = result
|
| 129 |
+
if len(result) >= min_results:
|
| 130 |
+
logger.info(f"Relaxed exclusions -> {len(result)} results")
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
# Phase 1: Try relaxing non-price filters one by one
|
| 134 |
+
non_price_relaxations = [
|
| 135 |
+
('color_filter', None), ('gender_filter', None), ('in_stock_only', False),
|
| 136 |
+
]
|
| 137 |
+
for attr, val in non_price_relaxations:
|
| 138 |
+
if getattr(relaxed, attr) is not None and getattr(relaxed, attr) != val:
|
| 139 |
+
setattr(relaxed, attr, val)
|
| 140 |
+
result = apply_filters(candidates, relaxed)
|
| 141 |
+
if len(result) > len(best_so_far):
|
| 142 |
+
best_so_far = result
|
| 143 |
+
if len(result) >= min_results:
|
| 144 |
+
logger.info(f"Relaxed filter '{attr}' -> {len(result)} results")
|
| 145 |
+
return result
|
| 146 |
+
|
| 147 |
+
# Phase 2: Progressive price expansion (keep category if possible)
|
| 148 |
+
if parsed.price_max is not None:
|
| 149 |
+
original_max = parsed.price_max
|
| 150 |
+
expansion_factors = [1.5, 2.0, 3.0, 5.0, 10.0]
|
| 151 |
+
for factor in expansion_factors:
|
| 152 |
+
relaxed.price_max = original_max * factor
|
| 153 |
+
result = apply_filters(candidates, relaxed)
|
| 154 |
+
if len(result) > len(best_so_far):
|
| 155 |
+
best_so_far = result
|
| 156 |
+
if len(result) >= min_results:
|
| 157 |
+
logger.info(
|
| 158 |
+
f"Expanded price_max: £{original_max:.0f} -> "
|
| 159 |
+
f"£{relaxed.price_max:.0f} ({factor}×) -> {len(result)} results"
|
| 160 |
+
)
|
| 161 |
+
return result
|
| 162 |
+
|
| 163 |
+
# If even 10× doesn't work, drop the price filter
|
| 164 |
+
relaxed.price_max = None
|
| 165 |
+
result = apply_filters(candidates, relaxed)
|
| 166 |
+
if len(result) > len(best_so_far):
|
| 167 |
+
best_so_far = result
|
| 168 |
+
if len(result) >= min_results:
|
| 169 |
+
logger.info(f"Dropped price_max entirely -> {len(result)} results")
|
| 170 |
+
return result
|
| 171 |
+
|
| 172 |
+
if parsed.price_min is not None:
|
| 173 |
+
relaxed.price_min = None
|
| 174 |
+
result = apply_filters(candidates, relaxed)
|
| 175 |
+
if len(result) > len(best_so_far):
|
| 176 |
+
best_so_far = result
|
| 177 |
+
if len(result) >= min_results:
|
| 178 |
+
logger.info(f"Dropped price_min -> {len(result)} results")
|
| 179 |
+
return result
|
| 180 |
+
|
| 181 |
+
# Phase 3: Drop category as last resort
|
| 182 |
+
if relaxed.category_filter is not None:
|
| 183 |
+
relaxed.category_filter = None
|
| 184 |
+
result = apply_filters(candidates, relaxed)
|
| 185 |
+
if len(result) > len(best_so_far):
|
| 186 |
+
best_so_far = result
|
| 187 |
+
if len(result) >= min_results:
|
| 188 |
+
logger.info(f"Relaxed category_filter -> {len(result)} results")
|
| 189 |
+
return result
|
| 190 |
+
|
| 191 |
+
# Return best partial result even if < min_results
|
| 192 |
+
if len(best_so_far) > 0:
|
| 193 |
+
logger.info(f"Returning best available: {len(best_so_far)} results (wanted {min_results})")
|
| 194 |
+
return best_so_far
|
| 195 |
+
|
| 196 |
+
logger.warning("All filters relaxed. Returning unfiltered results.")
|
| 197 |
+
return candidates
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def hybrid_rerank(candidates: pd.DataFrame, parsed: ParsedQuery,
|
| 201 |
+
config: SearchConfig, bm25: Optional[SimpleBM25] = None) -> pd.DataFrame:
|
| 202 |
+
scored = candidates.copy()
|
| 203 |
+
if len(scored) == 0:
|
| 204 |
+
return scored
|
| 205 |
+
|
| 206 |
+
# Normalize RRF
|
| 207 |
+
rrf_vals = scored['rrf_score'].values
|
| 208 |
+
rrf_min, rrf_max = rrf_vals.min(), rrf_vals.max()
|
| 209 |
+
scored['rrf_norm'] = (
|
| 210 |
+
(rrf_vals - rrf_min) / (rrf_max - rrf_min) if rrf_max > rrf_min else 1.0
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Tag overlap
|
| 214 |
+
query_tags = set(parsed.style_tags)
|
| 215 |
+
if query_tags and 'style_tags' in scored.columns:
|
| 216 |
+
scored['tag_score'] = scored['style_tags'].apply(
|
| 217 |
+
lambda tags: (
|
| 218 |
+
len(set(tags) & query_tags) / len(query_tags)
|
| 219 |
+
if isinstance(tags, list) and query_tags else 0.0
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
else:
|
| 223 |
+
scored['tag_score'] = 0.0
|
| 224 |
+
|
| 225 |
+
# BM25
|
| 226 |
+
if bm25 is not None and '_orig_idx' in scored.columns:
|
| 227 |
+
bm25_raw = bm25.score_candidates(parsed.raw_query, scored['_orig_idx'].tolist())
|
| 228 |
+
bm25_max = bm25_raw.max()
|
| 229 |
+
scored['bm25_norm'] = bm25_raw / bm25_max if bm25_max > 0 else 0.0
|
| 230 |
+
else:
|
| 231 |
+
scored['bm25_norm'] = 0.0
|
| 232 |
+
|
| 233 |
+
# Stock bonus
|
| 234 |
+
if 'any_in_stock' in scored.columns:
|
| 235 |
+
scored['stock_bonus'] = scored['any_in_stock'].astype(float)
|
| 236 |
+
else:
|
| 237 |
+
scored['stock_bonus'] = 0.5
|
| 238 |
+
|
| 239 |
+
# ── Material match bonus (v3.3) ──
|
| 240 |
+
mat_bonus = np.zeros(len(scored), dtype=np.float32)
|
| 241 |
+
if parsed.material_filter and 'materials' in scored.columns:
|
| 242 |
+
mat_q = parsed.material_filter.lower()
|
| 243 |
+
mat_bonus = scored['materials'].apply(
|
| 244 |
+
lambda mats: 1.0 if isinstance(mats, list) and any(
|
| 245 |
+
mat_q in str(m).lower() for m in mats
|
| 246 |
+
) else 0.0
|
| 247 |
+
).values.astype(np.float32)
|
| 248 |
+
scored['material_bonus'] = mat_bonus
|
| 249 |
+
|
| 250 |
+
# ── Price proximity bonus ──
|
| 251 |
+
# When user specifies a budget, items closer to that price rank higher.
|
| 252 |
+
# This prevents £200 items outranking £20 items when user said "under £10".
|
| 253 |
+
price_proximity = np.zeros(len(scored), dtype=np.float32)
|
| 254 |
+
target_price = parsed.price_max or parsed.price_min
|
| 255 |
+
if target_price is not None and 'price' in scored.columns:
|
| 256 |
+
prices = scored['price'].values.astype(np.float32)
|
| 257 |
+
# Exponential decay: items at target_price get 1.0, items far away get ~0
|
| 258 |
+
# sigma controls how fast the penalty drops off
|
| 259 |
+
sigma = max(target_price * 0.5, 10.0) # half the budget or £10 minimum
|
| 260 |
+
price_proximity = np.exp(-((prices - target_price) ** 2) / (2 * sigma ** 2))
|
| 261 |
+
|
| 262 |
+
scored['price_proximity'] = price_proximity
|
| 263 |
+
|
| 264 |
+
# Weighted combination — price proximity gets 0.10 weight when active
|
| 265 |
+
has_price_intent = target_price is not None
|
| 266 |
+
has_material_intent = parsed.material_filter is not None
|
| 267 |
+
|
| 268 |
+
if has_price_intent:
|
| 269 |
+
scored['hybrid_score'] = (
|
| 270 |
+
0.40 * scored['rrf_norm'] +
|
| 271 |
+
0.18 * scored['tag_score'] +
|
| 272 |
+
0.10 * scored['bm25_norm'] +
|
| 273 |
+
0.05 * scored['stock_bonus'] +
|
| 274 |
+
0.20 * scored['price_proximity'] +
|
| 275 |
+
0.07 * scored['material_bonus']
|
| 276 |
+
)
|
| 277 |
+
elif has_material_intent:
|
| 278 |
+
scored['hybrid_score'] = (
|
| 279 |
+
0.45 * scored['rrf_norm'] +
|
| 280 |
+
0.20 * scored['tag_score'] +
|
| 281 |
+
0.12 * scored['bm25_norm'] +
|
| 282 |
+
0.05 * scored['stock_bonus'] +
|
| 283 |
+
0.18 * scored['material_bonus']
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
scored['hybrid_score'] = (
|
| 287 |
+
config.alpha_clip * scored['rrf_norm'] +
|
| 288 |
+
config.beta_tags * scored['tag_score'] +
|
| 289 |
+
config.gamma_text * scored['bm25_norm'] +
|
| 290 |
+
config.delta_freshness * scored['stock_bonus']
|
| 291 |
+
)
|
| 292 |
+
return scored.sort_values('hybrid_score', ascending=False)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def generate_suggestions(results: pd.DataFrame, parsed: ParsedQuery,
|
| 296 |
+
max_suggestions: int = 5) -> List[str]:
|
| 297 |
+
"""
|
| 298 |
+
Generate natural, diverse related search suggestions.
|
| 299 |
+
|
| 300 |
+
v3.3: produces clean, human-readable queries instead of awkward
|
| 301 |
+
concatenations. Covers color refinement, price ranges, category
|
| 302 |
+
alternatives, style variations, and brand-specific searches.
|
| 303 |
+
"""
|
| 304 |
+
if len(results) == 0:
|
| 305 |
+
return []
|
| 306 |
+
|
| 307 |
+
suggestions = []
|
| 308 |
+
|
| 309 |
+
# Extract core item type from the query for clean suggestion construction
|
| 310 |
+
cat = parsed.category_filter
|
| 311 |
+
cat_names = {
|
| 312 |
+
'Dresses': 'dresses', 'Tops': 'tops', 'Coats & Jackets': 'jackets',
|
| 313 |
+
'Knitwear': 'knitwear', 'Jeans': 'jeans', 'Trousers': 'trousers',
|
| 314 |
+
'Shoes': 'shoes', 'Bags': 'bags', 'Accessories': 'accessories',
|
| 315 |
+
'Skirts': 'skirts', 'Shorts': 'shorts', 'Swimwear': 'swimwear',
|
| 316 |
+
'Hoodies & Sweatshirts': 'hoodies', 'Suits & Tailoring': 'suits',
|
| 317 |
+
'Jumpsuits & Playsuits': 'jumpsuits',
|
| 318 |
+
}
|
| 319 |
+
base_term = cat_names.get(cat, parsed.vibe_text.strip()[:30])
|
| 320 |
+
|
| 321 |
+
# 1. Color refinements — suggest specific colors the user hasn't tried
|
| 322 |
+
if 'color_family' in results.columns and not parsed.color_filter:
|
| 323 |
+
top_colors = (results['color_family']
|
| 324 |
+
.value_counts()
|
| 325 |
+
.head(4).index.tolist())
|
| 326 |
+
for color in top_colors[:2]:
|
| 327 |
+
if color and color not in ('other', 'multi'):
|
| 328 |
+
suggestions.append(f"{color} {base_term}")
|
| 329 |
+
|
| 330 |
+
# 2. Alternate color if user specified one
|
| 331 |
+
if parsed.color_filter and 'color_family' in results.columns:
|
| 332 |
+
alt_colors = ['black', 'white', 'navy', 'beige']
|
| 333 |
+
for ac in alt_colors:
|
| 334 |
+
if ac != parsed.color_filter:
|
| 335 |
+
suggestions.append(f"{ac} {base_term}")
|
| 336 |
+
break
|
| 337 |
+
|
| 338 |
+
# 3. Price-constrained suggestion
|
| 339 |
+
if parsed.price_max is None and parsed.price_min is None and 'price' in results.columns:
|
| 340 |
+
p25 = results['price'].quantile(0.25)
|
| 341 |
+
if p25 > 5:
|
| 342 |
+
suggestions.append(f"{base_term} under \u00a3{int(p25)}")
|
| 343 |
+
|
| 344 |
+
# 4. Style variation — suggest a popular style tag from results
|
| 345 |
+
if 'style_tags' in results.columns:
|
| 346 |
+
tag_counts = Counter()
|
| 347 |
+
for tags in results['style_tags']:
|
| 348 |
+
if isinstance(tags, list):
|
| 349 |
+
for t in tags:
|
| 350 |
+
if t not in parsed.style_tags and t not in parsed.vibe_text:
|
| 351 |
+
tag_counts[t] += 1
|
| 352 |
+
if tag_counts:
|
| 353 |
+
best_tag = tag_counts.most_common(1)[0][0]
|
| 354 |
+
suggestions.append(f"{best_tag} {base_term}")
|
| 355 |
+
|
| 356 |
+
# 5. Brand-specific suggestion (clean format)
|
| 357 |
+
if 'brand' in results.columns:
|
| 358 |
+
top_brand = (results['brand']
|
| 359 |
+
.value_counts()
|
| 360 |
+
.head(1).index.tolist())
|
| 361 |
+
if top_brand and top_brand[0] and top_brand[0] != 'Unknown':
|
| 362 |
+
brand = top_brand[0]
|
| 363 |
+
if brand.lower() not in parsed.vibe_text.lower():
|
| 364 |
+
suggestions.append(f"{brand} {base_term}")
|
| 365 |
+
|
| 366 |
+
# 6. Category alternatives — suggest related categories
|
| 367 |
+
if cat:
|
| 368 |
+
related = {
|
| 369 |
+
'Dresses': 'jumpsuits', 'Tops': 'blouses',
|
| 370 |
+
'Jeans': 'trousers', 'Trousers': 'jeans',
|
| 371 |
+
'Coats & Jackets': 'blazers', 'Knitwear': 'cardigans',
|
| 372 |
+
'Skirts': 'dresses', 'Shorts': 'skirts',
|
| 373 |
+
}
|
| 374 |
+
alt = related.get(cat)
|
| 375 |
+
if alt:
|
| 376 |
+
prefix = f"{parsed.color_filter} " if parsed.color_filter else ""
|
| 377 |
+
suggestions.append(f"{prefix}{alt}".strip())
|
| 378 |
+
|
| 379 |
+
# Deduplicate and limit
|
| 380 |
+
seen = set()
|
| 381 |
+
unique = []
|
| 382 |
+
for s in suggestions:
|
| 383 |
+
s_clean = s.strip().lower()
|
| 384 |
+
if s_clean not in seen and s_clean != parsed.raw_query.lower():
|
| 385 |
+
seen.add(s_clean)
|
| 386 |
+
unique.append(s.strip())
|
| 387 |
+
return unique[:max_suggestions]
|
backend/app/engine/search_engine.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = ["ASOSSearchEngine"]
|
| 2 |
+
|
| 3 |
+
import ast
|
| 4 |
+
import logging
|
| 5 |
+
import time
|
| 6 |
+
from collections import Counter
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Dict, List, Optional, Set
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
|
| 15 |
+
from backend.app.config import SearchConfig
|
| 16 |
+
from backend.app.engine.encoder import FashionCLIPEncoder
|
| 17 |
+
from backend.app.engine.index import DualFAISSIndex
|
| 18 |
+
from backend.app.engine.query_parser import QueryParser, ParsedQuery
|
| 19 |
+
from backend.app.engine.bm25 import SimpleBM25
|
| 20 |
+
from backend.app.engine.nlp import MultilingualHandler, SpellCorrector
|
| 21 |
+
from backend.app.exceptions import EngineNotReadyError, SKUNotFoundError
|
| 22 |
+
from backend.app.engine import reranker
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ASOSSearchEngine:
|
| 28 |
+
"""
|
| 29 |
+
v3.3 — Multimodal + multilingual fashion search engine.
|
| 30 |
+
|
| 31 |
+
build_index() encodes ALL product text in ~3-5 min (no image downloading).
|
| 32 |
+
Both FAISS indices (image + text) are populated from text embeddings.
|
| 33 |
+
Image URLs from metadata are preserved for website card display.
|
| 34 |
+
|
| 35 |
+
New in v3.1:
|
| 36 |
+
- Multilingual query support (auto-detect + translate)
|
| 37 |
+
- Spell correction for typos
|
| 38 |
+
- Fixed SKU type handling (int/str agnostic)
|
| 39 |
+
- Fixed color filter case matching
|
| 40 |
+
|
| 41 |
+
New in v3.3:
|
| 42 |
+
- Style-coherent Complete the Look outfit recommendations
|
| 43 |
+
- Improved suggested related searches
|
| 44 |
+
- Size-aware filtering
|
| 45 |
+
- Material/fabric filtering
|
| 46 |
+
- Negative/exclusion query support
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# ── Outfit category pairings for "Complete the Look" ──
|
| 50 |
+
OUTFIT_PAIRS = {
|
| 51 |
+
'Dresses': ['Shoes', 'Coats & Jackets', 'Bags', 'Accessories'],
|
| 52 |
+
'Tops': ['Trousers', 'Jeans', 'Skirts', 'Shoes', 'Accessories'],
|
| 53 |
+
'Knitwear': ['Trousers', 'Jeans', 'Skirts', 'Shoes'],
|
| 54 |
+
'Hoodies & Sweatshirts': ['Trousers', 'Jeans', 'Shoes', 'Accessories'],
|
| 55 |
+
'Coats & Jackets': ['Tops', 'Trousers', 'Jeans', 'Shoes', 'Accessories'],
|
| 56 |
+
'Trousers': ['Tops', 'Knitwear', 'Shoes', 'Coats & Jackets', 'Accessories'],
|
| 57 |
+
'Jeans': ['Tops', 'Knitwear', 'Shoes', 'Coats & Jackets', 'Accessories'],
|
| 58 |
+
'Shorts': ['Tops', 'Shoes', 'Accessories'],
|
| 59 |
+
'Skirts': ['Tops', 'Knitwear', 'Shoes', 'Accessories'],
|
| 60 |
+
'Shoes': ['Bags', 'Accessories'],
|
| 61 |
+
'Suits & Tailoring': ['Tops', 'Shoes', 'Accessories'],
|
| 62 |
+
'Swimwear': ['Shoes', 'Accessories', 'Bags'],
|
| 63 |
+
'Jumpsuits & Playsuits': ['Shoes', 'Coats & Jackets', 'Bags', 'Accessories'],
|
| 64 |
+
'Bags': ['Shoes', 'Accessories'],
|
| 65 |
+
'Accessories': ['Bags', 'Shoes'],
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Colors that pair well together for outfit coherence
|
| 69 |
+
COLOR_HARMONY = {
|
| 70 |
+
'black': ['white', 'red', 'pink', 'grey', 'navy', 'beige', 'multi'],
|
| 71 |
+
'white': ['black', 'navy', 'blue', 'beige', 'pink', 'red'],
|
| 72 |
+
'navy': ['white', 'beige', 'grey', 'pink', 'red', 'brown'],
|
| 73 |
+
'blue': ['white', 'navy', 'beige', 'brown', 'grey'],
|
| 74 |
+
'red': ['black', 'white', 'navy', 'grey', 'beige'],
|
| 75 |
+
'pink': ['black', 'white', 'navy', 'grey', 'beige', 'blue'],
|
| 76 |
+
'green': ['white', 'beige', 'brown', 'navy', 'black'],
|
| 77 |
+
'grey': ['black', 'white', 'navy', 'pink', 'blue', 'red'],
|
| 78 |
+
'brown': ['white', 'beige', 'navy', 'green', 'blue'],
|
| 79 |
+
'beige': ['navy', 'brown', 'white', 'black', 'blue', 'green'],
|
| 80 |
+
'yellow': ['navy', 'white', 'grey', 'black', 'blue'],
|
| 81 |
+
'orange': ['navy', 'white', 'black', 'brown', 'beige'],
|
| 82 |
+
'purple': ['white', 'black', 'grey', 'beige', 'pink'],
|
| 83 |
+
'burgundy': ['black', 'white', 'navy', 'beige', 'grey'],
|
| 84 |
+
'khaki': ['white', 'navy', 'brown', 'black', 'beige'],
|
| 85 |
+
'multi': ['black', 'white', 'navy', 'beige'],
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# ── Sort options for frontend ──
|
| 89 |
+
SORT_OPTIONS = {
|
| 90 |
+
'relevance': ('hybrid_score', False), # highest relevance first
|
| 91 |
+
'price_asc': ('price', True), # cheapest first
|
| 92 |
+
'price_desc': ('price', False), # most expensive first
|
| 93 |
+
'name_asc': ('name', True), # alphabetical A-Z
|
| 94 |
+
'name_desc': ('name', False), # alphabetical Z-A
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def __init__(self, config: SearchConfig):
|
| 98 |
+
self.config = config
|
| 99 |
+
self.encoder: Optional[FashionCLIPEncoder] = None
|
| 100 |
+
self.dual_index: Optional[DualFAISSIndex] = None
|
| 101 |
+
self.metadata: Optional[pd.DataFrame] = None
|
| 102 |
+
self.image_embeddings: Optional[np.ndarray] = None
|
| 103 |
+
self.text_embeddings: Optional[np.ndarray] = None
|
| 104 |
+
self.bm25: Optional[SimpleBM25] = None
|
| 105 |
+
self.query_parser = QueryParser()
|
| 106 |
+
self.multilingual = MultilingualHandler()
|
| 107 |
+
self.spell_corrector = SpellCorrector()
|
| 108 |
+
self._is_ready = False
|
| 109 |
+
|
| 110 |
+
def load_data(self, path: Optional[str] = None):
|
| 111 |
+
path = path or self.config.data_path
|
| 112 |
+
logger.info(f"Loading metadata: {path}")
|
| 113 |
+
|
| 114 |
+
if path.endswith('.parquet'):
|
| 115 |
+
self.metadata = pd.read_parquet(path)
|
| 116 |
+
else:
|
| 117 |
+
self.metadata = pd.read_csv(path)
|
| 118 |
+
|
| 119 |
+
list_cols = ['style_tags', 'materials', 'image_urls',
|
| 120 |
+
'sizes_available', 'sizes_out_of_stock']
|
| 121 |
+
for col in list_cols:
|
| 122 |
+
if col in self.metadata.columns and self.metadata[col].dtype == object:
|
| 123 |
+
self.metadata[col] = self.metadata[col].apply(
|
| 124 |
+
lambda x: ast.literal_eval(x)
|
| 125 |
+
if isinstance(x, str) and x.startswith('[') else (
|
| 126 |
+
x if isinstance(x, list) else []
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
required = ['sku', 'name', 'price', 'primary_image_url', 'search_text']
|
| 131 |
+
missing = [c for c in required if c not in self.metadata.columns]
|
| 132 |
+
if missing:
|
| 133 |
+
raise ValueError(f"Missing required columns: {missing}")
|
| 134 |
+
|
| 135 |
+
# ── Normalize SKU to string for consistent lookups ──
|
| 136 |
+
self.metadata['sku'] = self.metadata['sku'].astype(str)
|
| 137 |
+
|
| 138 |
+
# ── Normalize color_family to lowercase for consistent filtering ──
|
| 139 |
+
if 'color_family' in self.metadata.columns:
|
| 140 |
+
self.metadata['color_family'] = self.metadata['color_family'].str.lower().str.strip()
|
| 141 |
+
|
| 142 |
+
self.metadata = self.metadata.reset_index(drop=True)
|
| 143 |
+
logger.info(f"Loaded {len(self.metadata):,} products")
|
| 144 |
+
|
| 145 |
+
def build_index(self, force_rebuild: bool = False):
|
| 146 |
+
"""
|
| 147 |
+
Build search index from text only. ~3-5 min for 30K products.
|
| 148 |
+
"""
|
| 149 |
+
img_emb_path = Path(self.config.image_embeddings_path)
|
| 150 |
+
txt_emb_path = Path(self.config.text_embeddings_path)
|
| 151 |
+
img_idx_path = Path(self.config.image_index_path)
|
| 152 |
+
txt_idx_path = Path(self.config.text_index_path)
|
| 153 |
+
|
| 154 |
+
# ── Try loading from cache ──
|
| 155 |
+
if (not force_rebuild
|
| 156 |
+
and img_emb_path.exists() and txt_emb_path.exists()
|
| 157 |
+
and img_idx_path.exists() and txt_idx_path.exists()):
|
| 158 |
+
|
| 159 |
+
logger.info("Loading cached embeddings and indices...")
|
| 160 |
+
self.image_embeddings = np.load(str(img_emb_path))
|
| 161 |
+
self.text_embeddings = np.load(str(txt_emb_path))
|
| 162 |
+
|
| 163 |
+
n_meta = len(self.metadata)
|
| 164 |
+
if (self.image_embeddings.shape[0] == n_meta
|
| 165 |
+
and self.text_embeddings.shape[0] == n_meta):
|
| 166 |
+
|
| 167 |
+
actual_dim = self.text_embeddings.shape[1]
|
| 168 |
+
if actual_dim != self.config.embedding_dim:
|
| 169 |
+
logger.info(
|
| 170 |
+
f"Updating embedding_dim: {self.config.embedding_dim} -> {actual_dim}"
|
| 171 |
+
)
|
| 172 |
+
self.config.embedding_dim = actual_dim
|
| 173 |
+
|
| 174 |
+
self.dual_index = DualFAISSIndex(actual_dim, self.config)
|
| 175 |
+
self.dual_index.load(str(img_idx_path), str(txt_idx_path))
|
| 176 |
+
self._fit_bm25()
|
| 177 |
+
self._fit_spell_corrector()
|
| 178 |
+
self._is_ready = True
|
| 179 |
+
|
| 180 |
+
n_zero_img = int(np.sum(np.all(self.image_embeddings == 0, axis=1)))
|
| 181 |
+
n_zero_txt = int(np.sum(np.all(self.text_embeddings == 0, axis=1)))
|
| 182 |
+
logger.info(
|
| 183 |
+
f"Engine ready (from cache). "
|
| 184 |
+
f"Zero-vectors: {n_zero_img} img, {n_zero_txt} txt"
|
| 185 |
+
)
|
| 186 |
+
return
|
| 187 |
+
else:
|
| 188 |
+
logger.warning(
|
| 189 |
+
f"Cache shape mismatch: emb={self.image_embeddings.shape[0]} "
|
| 190 |
+
f"vs metadata={n_meta}. Rebuilding..."
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# ── Initialize encoder ──
|
| 194 |
+
if self.encoder is None:
|
| 195 |
+
self.encoder = FashionCLIPEncoder(self.config)
|
| 196 |
+
|
| 197 |
+
n = len(self.metadata)
|
| 198 |
+
dim = self.config.embedding_dim
|
| 199 |
+
t_start = time.time()
|
| 200 |
+
|
| 201 |
+
# ── Step 1: Encode product text ──
|
| 202 |
+
logger.info(f"Step 1/4: Encoding {n:,} product texts...")
|
| 203 |
+
texts = self.metadata['search_text'].fillna(self.metadata['name']).tolist()
|
| 204 |
+
product_texts = [f"a fashion product: {t}" for t in texts]
|
| 205 |
+
self.text_embeddings = self._encode_texts_with_progress(product_texts, "Text embeddings")
|
| 206 |
+
logger.info(f" Text encoding done in {time.time()-t_start:.1f}s")
|
| 207 |
+
|
| 208 |
+
# ── Step 2: Image-proxy embeddings from text ──
|
| 209 |
+
logger.info("Step 2/4: Creating image-proxy embeddings from text...")
|
| 210 |
+
image_proxy_texts = []
|
| 211 |
+
for i in range(n):
|
| 212 |
+
row = self.metadata.iloc[i]
|
| 213 |
+
name = row.get('search_text', row['name'])
|
| 214 |
+
if pd.isna(name) or str(name) == 'nan':
|
| 215 |
+
name = row['name']
|
| 216 |
+
image_proxy_texts.append(f"a fashion product photo of {name}")
|
| 217 |
+
self.image_embeddings = self._encode_texts_with_progress(image_proxy_texts, "Image proxies")
|
| 218 |
+
|
| 219 |
+
# ── Step 3: Build FAISS ──
|
| 220 |
+
logger.info("Step 3/4: Building dual FAISS index...")
|
| 221 |
+
self.dual_index = DualFAISSIndex(dim, self.config)
|
| 222 |
+
self.dual_index.build(self.image_embeddings, self.text_embeddings)
|
| 223 |
+
|
| 224 |
+
# ── Step 4: BM25 + Spell Corrector ──
|
| 225 |
+
logger.info("Step 4/4: Fitting BM25 lexical index + spell corrector...")
|
| 226 |
+
self._fit_bm25()
|
| 227 |
+
self._fit_spell_corrector()
|
| 228 |
+
|
| 229 |
+
# ── Save to persistent storage ──
|
| 230 |
+
np.save(str(img_emb_path), self.image_embeddings)
|
| 231 |
+
np.save(str(txt_emb_path), self.text_embeddings)
|
| 232 |
+
self.dual_index.save(str(img_idx_path), str(txt_idx_path))
|
| 233 |
+
|
| 234 |
+
self._is_ready = True
|
| 235 |
+
|
| 236 |
+
elapsed = time.time() - t_start
|
| 237 |
+
n_zero = int(np.sum(np.all(self.text_embeddings == 0, axis=1)))
|
| 238 |
+
logger.info(
|
| 239 |
+
f"\n{'='*60}\n"
|
| 240 |
+
f" ENGINE READY in {elapsed:.0f}s ({elapsed/60:.1f} min)\n"
|
| 241 |
+
f" Products indexed: {n:,}\n"
|
| 242 |
+
f" Embedding dim: {dim}\n"
|
| 243 |
+
f" Zero-vector texts: {n_zero}\n"
|
| 244 |
+
f" FAISS vectors: {self.dual_index.text_index.ntotal:,}\n"
|
| 245 |
+
f"{'='*60}"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def _encode_texts_with_progress(self, texts: List[str], desc: str) -> np.ndarray:
|
| 249 |
+
batch_size = min(self.config.embed_batch_size * 4, 256)
|
| 250 |
+
texts = [str(t) if t and str(t) != 'nan' else '' for t in texts]
|
| 251 |
+
all_emb = []
|
| 252 |
+
|
| 253 |
+
n_batches = (len(texts) + batch_size - 1) // batch_size
|
| 254 |
+
for i in tqdm(range(0, len(texts), batch_size), total=n_batches, desc=desc):
|
| 255 |
+
batch = texts[i:i + batch_size]
|
| 256 |
+
emb = self.encoder.encode_texts(batch, batch_size=len(batch))
|
| 257 |
+
all_emb.append(emb)
|
| 258 |
+
|
| 259 |
+
return np.vstack(all_emb).astype(np.float32)
|
| 260 |
+
|
| 261 |
+
def _fit_bm25(self):
|
| 262 |
+
texts = self.metadata['search_text'].fillna(self.metadata['name']).tolist()
|
| 263 |
+
self.bm25 = SimpleBM25()
|
| 264 |
+
self.bm25.fit(texts)
|
| 265 |
+
|
| 266 |
+
def _fit_spell_corrector(self):
|
| 267 |
+
if self.config.enable_spell_correction:
|
| 268 |
+
texts = self.metadata['search_text'].fillna(self.metadata['name']).tolist()
|
| 269 |
+
self.spell_corrector.fit(texts)
|
| 270 |
+
|
| 271 |
+
# ── Search ──
|
| 272 |
+
|
| 273 |
+
def search(
|
| 274 |
+
self, query: str,
|
| 275 |
+
query_image: Optional[Image.Image] = None,
|
| 276 |
+
top_n: Optional[int] = None,
|
| 277 |
+
text_weight: float = 0.5,
|
| 278 |
+
sort_by: str = 'relevance',
|
| 279 |
+
) -> pd.DataFrame:
|
| 280 |
+
if not self._is_ready:
|
| 281 |
+
raise EngineNotReadyError("Engine not ready. Call build_index() first.")
|
| 282 |
+
if self.encoder is None:
|
| 283 |
+
self.encoder = FashionCLIPEncoder(self.config)
|
| 284 |
+
|
| 285 |
+
top_n = top_n or self.config.final_top_n
|
| 286 |
+
t_start = time.time()
|
| 287 |
+
|
| 288 |
+
original_query = query
|
| 289 |
+
|
| 290 |
+
# ── Multilingual: translate if needed ──
|
| 291 |
+
if self.config.enable_multilingual:
|
| 292 |
+
query, detected_lang, was_translated = self.multilingual.translate_query(query)
|
| 293 |
+
else:
|
| 294 |
+
detected_lang, was_translated = 'en', False
|
| 295 |
+
|
| 296 |
+
# ── Spell correction ──
|
| 297 |
+
was_spell_corrected = False
|
| 298 |
+
spell_suggestion = None
|
| 299 |
+
if self.config.enable_spell_correction and self.spell_corrector._ready:
|
| 300 |
+
corrected, was_spell_corrected = self.spell_corrector.correct_query(query)
|
| 301 |
+
if was_spell_corrected:
|
| 302 |
+
spell_suggestion = corrected
|
| 303 |
+
query = corrected
|
| 304 |
+
|
| 305 |
+
# Parse intent
|
| 306 |
+
parsed = self.query_parser.parse(query)
|
| 307 |
+
parsed.has_image = query_image is not None
|
| 308 |
+
parsed.text_weight = text_weight
|
| 309 |
+
parsed.original_query = original_query
|
| 310 |
+
parsed.detected_language = detected_lang
|
| 311 |
+
parsed.was_translated = was_translated
|
| 312 |
+
parsed.was_spell_corrected = was_spell_corrected
|
| 313 |
+
parsed.spell_correction_suggestion = spell_suggestion
|
| 314 |
+
|
| 315 |
+
logger.info(
|
| 316 |
+
f"Query: \"{query}\" -> "
|
| 317 |
+
f"cat={parsed.category_filter}, col={parsed.color_filter}, "
|
| 318 |
+
f"price=[{parsed.price_min},{parsed.price_max}], "
|
| 319 |
+
f"gen={parsed.gender_filter}, tags={parsed.style_tags}, "
|
| 320 |
+
f"mat={parsed.material_filter}, size={parsed.size_filter}, "
|
| 321 |
+
f"excl={parsed.exclusions}"
|
| 322 |
+
f"{' [translated from ' + detected_lang + ']' if was_translated else ''}"
|
| 323 |
+
f"{' [spell-corrected]' if was_spell_corrected else ''}"
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Encode query
|
| 327 |
+
if query_image is not None:
|
| 328 |
+
query_vec = self.encoder.encode_multimodal_query(
|
| 329 |
+
parsed.vibe_text, query_image, text_weight
|
| 330 |
+
)
|
| 331 |
+
else:
|
| 332 |
+
query_vec = self.encoder.encode_query_text(parsed.vibe_text)
|
| 333 |
+
|
| 334 |
+
# Dual-index retrieval with RRF
|
| 335 |
+
candidate_indices, rrf_scores = self.dual_index.search_fused(
|
| 336 |
+
query_vec[0], top_k=self.config.retrieval_top_k,
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if not candidate_indices:
|
| 340 |
+
logger.warning("No candidates from FAISS.")
|
| 341 |
+
return pd.DataFrame()
|
| 342 |
+
|
| 343 |
+
candidates = self.metadata.iloc[candidate_indices].copy()
|
| 344 |
+
candidates['rrf_score'] = rrf_scores
|
| 345 |
+
candidates['_orig_idx'] = candidate_indices
|
| 346 |
+
|
| 347 |
+
# Metadata filter
|
| 348 |
+
filtered = reranker.apply_filters(candidates, parsed)
|
| 349 |
+
if len(filtered) == 0:
|
| 350 |
+
logger.warning("Zero results after filtering. Relaxing constraints...")
|
| 351 |
+
filtered = reranker.relax_and_retry(candidates, parsed, min_results=top_n)
|
| 352 |
+
|
| 353 |
+
# Hybrid re-ranking
|
| 354 |
+
ranked = reranker.hybrid_rerank(filtered, parsed, self.config, self.bm25)
|
| 355 |
+
|
| 356 |
+
# Build result
|
| 357 |
+
result_cols = [
|
| 358 |
+
'sku', 'name', 'brand', 'price', 'color_clean', 'color_family',
|
| 359 |
+
'category', 'gender', 'style_tags', 'primary_image_url', 'image_urls',
|
| 360 |
+
'rrf_score', 'hybrid_score', 'any_in_stock', 'sizes_available',
|
| 361 |
+
'product_details', 'materials', 'url',
|
| 362 |
+
]
|
| 363 |
+
available_cols = [c for c in result_cols if c in ranked.columns]
|
| 364 |
+
results = ranked[available_cols].head(top_n).reset_index(drop=True)
|
| 365 |
+
|
| 366 |
+
# ── Apply sort ──
|
| 367 |
+
if sort_by != 'relevance' and sort_by in self.SORT_OPTIONS:
|
| 368 |
+
sort_col, ascending = self.SORT_OPTIONS[sort_by]
|
| 369 |
+
if sort_col in results.columns:
|
| 370 |
+
results = results.sort_values(sort_col, ascending=ascending).reset_index(drop=True)
|
| 371 |
+
|
| 372 |
+
results.index = range(1, len(results) + 1)
|
| 373 |
+
results.index.name = 'rank'
|
| 374 |
+
|
| 375 |
+
# ── Generate suggested related searches ──
|
| 376 |
+
suggested_searches = reranker.generate_suggestions(results, parsed)
|
| 377 |
+
|
| 378 |
+
# Attach query metadata for frontend use
|
| 379 |
+
results.attrs['query_info'] = {
|
| 380 |
+
'original_query': original_query,
|
| 381 |
+
'processed_query': query,
|
| 382 |
+
'detected_language': detected_lang,
|
| 383 |
+
'was_translated': was_translated,
|
| 384 |
+
'was_spell_corrected': was_spell_corrected,
|
| 385 |
+
'spell_suggestion': spell_suggestion,
|
| 386 |
+
'parsed_category': parsed.category_filter,
|
| 387 |
+
'parsed_color': parsed.color_filter,
|
| 388 |
+
'parsed_price_range': [parsed.price_min, parsed.price_max],
|
| 389 |
+
'parsed_gender': parsed.gender_filter,
|
| 390 |
+
'parsed_style_tags': parsed.style_tags,
|
| 391 |
+
'parsed_material': parsed.material_filter,
|
| 392 |
+
'parsed_size': parsed.size_filter,
|
| 393 |
+
'parsed_exclusions': parsed.exclusions,
|
| 394 |
+
'sort_by': sort_by,
|
| 395 |
+
'available_sorts': list(self.SORT_OPTIONS.keys()),
|
| 396 |
+
'suggested_searches': suggested_searches,
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
elapsed = time.time() - t_start
|
| 400 |
+
logger.info(
|
| 401 |
+
f"Search complete: {len(results)} results in {elapsed:.2f}s "
|
| 402 |
+
f"(from {len(candidates)} candidates -> {len(filtered)} filtered)"
|
| 403 |
+
)
|
| 404 |
+
return results
|
| 405 |
+
|
| 406 |
+
def search_similar(self, sku, top_n: int = 10) -> pd.DataFrame:
|
| 407 |
+
"""Find visually similar products to a given SKU."""
|
| 408 |
+
if not self._is_ready:
|
| 409 |
+
raise EngineNotReadyError("Engine not ready.")
|
| 410 |
+
|
| 411 |
+
# ── FIX: compare as strings consistently ──
|
| 412 |
+
sku_str = str(sku)
|
| 413 |
+
match = self.metadata[self.metadata['sku'] == sku_str]
|
| 414 |
+
if len(match) == 0:
|
| 415 |
+
raise SKUNotFoundError(sku_str)
|
| 416 |
+
|
| 417 |
+
idx = match.index[0]
|
| 418 |
+
query_vec = self.image_embeddings[idx]
|
| 419 |
+
|
| 420 |
+
dists, ids = self.dual_index.search_image_index(query_vec, top_n + 1)
|
| 421 |
+
ids, dists = ids[0], dists[0]
|
| 422 |
+
mask = ids != idx
|
| 423 |
+
ids, dists = ids[mask][:top_n], dists[mask][:top_n]
|
| 424 |
+
|
| 425 |
+
results = self.metadata.iloc[ids].copy()
|
| 426 |
+
results['similarity_score'] = dists
|
| 427 |
+
return results.reset_index(drop=True)
|
| 428 |
+
|
| 429 |
+
def search_by_image(self, image: Image.Image, top_n: int = 20) -> pd.DataFrame:
|
| 430 |
+
"""Search using an uploaded image only (no text query)."""
|
| 431 |
+
if self.encoder is None:
|
| 432 |
+
self.encoder = FashionCLIPEncoder(self.config)
|
| 433 |
+
img_emb = self.encoder.encode_images([image])
|
| 434 |
+
indices, scores = self.dual_index.search_fused(
|
| 435 |
+
img_emb[0], top_n, image_weight=0.8, text_weight=0.2,
|
| 436 |
+
)
|
| 437 |
+
results = self.metadata.iloc[indices].copy()
|
| 438 |
+
results['score'] = scores
|
| 439 |
+
results.index = range(1, len(results) + 1)
|
| 440 |
+
results.index.name = 'rank'
|
| 441 |
+
return results
|
| 442 |
+
|
| 443 |
+
def get_product_detail(self, sku) -> Optional[Dict]:
|
| 444 |
+
"""
|
| 445 |
+
Get full product detail for a single SKU — used when a user clicks a card.
|
| 446 |
+
Returns all metadata + all image URLs for the product detail page.
|
| 447 |
+
"""
|
| 448 |
+
sku_str = str(sku)
|
| 449 |
+
match = self.metadata[self.metadata['sku'] == sku_str]
|
| 450 |
+
if len(match) == 0:
|
| 451 |
+
return None
|
| 452 |
+
|
| 453 |
+
row = match.iloc[0]
|
| 454 |
+
detail = row.to_dict()
|
| 455 |
+
|
| 456 |
+
# Ensure image_urls is a proper list
|
| 457 |
+
if 'image_urls' in detail and isinstance(detail['image_urls'], str):
|
| 458 |
+
try:
|
| 459 |
+
detail['image_urls'] = ast.literal_eval(detail['image_urls'])
|
| 460 |
+
except (ValueError, SyntaxError):
|
| 461 |
+
detail['image_urls'] = [detail.get('primary_image_url', '')]
|
| 462 |
+
|
| 463 |
+
return detail
|
| 464 |
+
|
| 465 |
+
# ── "Complete the Look" — cross-category outfit recommendation ──
|
| 466 |
+
|
| 467 |
+
def complete_the_look(self, sku, n_per_category: int = 3) -> Dict[str, pd.DataFrame]:
|
| 468 |
+
"""
|
| 469 |
+
Given a product SKU, suggest complementary items from DIFFERENT categories
|
| 470 |
+
to help the user build a complete outfit.
|
| 471 |
+
|
| 472 |
+
v3.3 improvements:
|
| 473 |
+
- Searches per-category pools (not just top-200 global neighbors)
|
| 474 |
+
- Scores by style coherence (tag overlap), color harmony, price tier,
|
| 475 |
+
gender consistency, and embedding similarity
|
| 476 |
+
- Returns genuinely complementary items, not just same-category lookalikes
|
| 477 |
+
|
| 478 |
+
Returns a dict mapping category names to DataFrames of recommendations.
|
| 479 |
+
"""
|
| 480 |
+
if not self._is_ready:
|
| 481 |
+
raise EngineNotReadyError("Engine not ready.")
|
| 482 |
+
|
| 483 |
+
sku_str = str(sku)
|
| 484 |
+
match = self.metadata[self.metadata['sku'] == sku_str]
|
| 485 |
+
if len(match) == 0:
|
| 486 |
+
raise SKUNotFoundError(sku_str)
|
| 487 |
+
|
| 488 |
+
source = match.iloc[0]
|
| 489 |
+
source_category = source.get('category', '')
|
| 490 |
+
source_idx = match.index[0]
|
| 491 |
+
source_color = str(source.get('color_family', '')).lower()
|
| 492 |
+
source_gender = source.get('gender', '')
|
| 493 |
+
source_price = source.get('price', 0)
|
| 494 |
+
source_tags = source.get('style_tags', [])
|
| 495 |
+
if not isinstance(source_tags, list):
|
| 496 |
+
source_tags = []
|
| 497 |
+
source_tag_set = set(source_tags)
|
| 498 |
+
|
| 499 |
+
target_categories = self.OUTFIT_PAIRS.get(
|
| 500 |
+
source_category, ['Shoes', 'Accessories', 'Bags']
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
# Get compatible colors for the source product
|
| 504 |
+
compatible_colors = set(self.COLOR_HARMONY.get(source_color, []))
|
| 505 |
+
compatible_colors.add(source_color) # same color is always ok
|
| 506 |
+
|
| 507 |
+
# Get a broad set of candidates from fused search (both indices)
|
| 508 |
+
query_vec = self.image_embeddings[source_idx]
|
| 509 |
+
_, img_ids = self.dual_index.search_image_index(query_vec, 800)
|
| 510 |
+
_, txt_ids = self.dual_index.search_text_index(query_vec, 800)
|
| 511 |
+
img_ids = set(int(i) for i in img_ids[0] if i >= 0 and i != source_idx)
|
| 512 |
+
txt_ids = set(int(i) for i in txt_ids[0] if i >= 0 and i != source_idx)
|
| 513 |
+
all_candidate_ids = img_ids | txt_ids
|
| 514 |
+
|
| 515 |
+
# Price tier: items within 0.3x-3x of source price
|
| 516 |
+
price_low = max(source_price * 0.3, 3.0)
|
| 517 |
+
price_high = source_price * 3.0
|
| 518 |
+
|
| 519 |
+
outfit = {}
|
| 520 |
+
for target_cat in target_categories:
|
| 521 |
+
# Filter candidates to this category
|
| 522 |
+
cat_mask = self.metadata['category'] == target_cat
|
| 523 |
+
cat_indices = set(self.metadata.index[cat_mask].tolist())
|
| 524 |
+
pool = list(all_candidate_ids & cat_indices)
|
| 525 |
+
|
| 526 |
+
if not pool:
|
| 527 |
+
continue
|
| 528 |
+
|
| 529 |
+
# Score each candidate with a multi-factor outfit coherence score
|
| 530 |
+
scores = []
|
| 531 |
+
for cidx in pool:
|
| 532 |
+
row = self.metadata.iloc[cidx]
|
| 533 |
+
|
| 534 |
+
# 1. Embedding similarity (0-1, already normalized)
|
| 535 |
+
sim = max(0.0, float(np.dot(query_vec, self.image_embeddings[cidx])))
|
| 536 |
+
|
| 537 |
+
# 2. Style tag overlap
|
| 538 |
+
c_tags = row.get('style_tags', [])
|
| 539 |
+
if isinstance(c_tags, list) and source_tag_set:
|
| 540 |
+
tag_overlap = len(set(c_tags) & source_tag_set) / max(len(source_tag_set), 1)
|
| 541 |
+
else:
|
| 542 |
+
tag_overlap = 0.0
|
| 543 |
+
|
| 544 |
+
# 3. Color harmony
|
| 545 |
+
c_color = str(row.get('color_family', '')).lower()
|
| 546 |
+
if c_color in compatible_colors:
|
| 547 |
+
color_score = 1.0
|
| 548 |
+
elif c_color in ('black', 'white', 'grey', 'navy', 'beige'):
|
| 549 |
+
color_score = 0.7 # neutrals always work
|
| 550 |
+
else:
|
| 551 |
+
color_score = 0.2
|
| 552 |
+
|
| 553 |
+
# 4. Gender consistency (empty gender = universal match)
|
| 554 |
+
c_gender = row.get('gender', '')
|
| 555 |
+
if (not c_gender or not source_gender or
|
| 556 |
+
c_gender == source_gender or
|
| 557 |
+
c_gender == 'Unisex' or source_gender == 'Unisex'):
|
| 558 |
+
gender_score = 1.0
|
| 559 |
+
else:
|
| 560 |
+
gender_score = 0.0
|
| 561 |
+
|
| 562 |
+
# 5. Price tier proximity
|
| 563 |
+
c_price = row.get('price', 0)
|
| 564 |
+
price_range = price_high - price_low
|
| 565 |
+
if price_range > 0 and price_low <= c_price <= price_high:
|
| 566 |
+
price_score = 1.0 - abs(c_price - source_price) / price_range
|
| 567 |
+
else:
|
| 568 |
+
price_score = 0.1
|
| 569 |
+
|
| 570 |
+
# 6. In-stock bonus
|
| 571 |
+
stock_score = 1.0 if row.get('any_in_stock', False) else 0.3
|
| 572 |
+
|
| 573 |
+
# Weighted combination
|
| 574 |
+
outfit_score = (
|
| 575 |
+
0.30 * sim +
|
| 576 |
+
0.25 * tag_overlap +
|
| 577 |
+
0.15 * color_score +
|
| 578 |
+
0.15 * gender_score +
|
| 579 |
+
0.10 * price_score +
|
| 580 |
+
0.05 * stock_score
|
| 581 |
+
)
|
| 582 |
+
scores.append((cidx, outfit_score))
|
| 583 |
+
|
| 584 |
+
# Sort by outfit coherence score and take top n
|
| 585 |
+
scores.sort(key=lambda x: -x[1])
|
| 586 |
+
top_items = scores[:n_per_category]
|
| 587 |
+
|
| 588 |
+
if top_items:
|
| 589 |
+
indices = [s[0] for s in top_items]
|
| 590 |
+
df = self.metadata.iloc[indices].copy()
|
| 591 |
+
df['outfit_score'] = [s[1] for s in top_items]
|
| 592 |
+
outfit[target_cat] = df.reset_index(drop=True)
|
| 593 |
+
|
| 594 |
+
return outfit
|
| 595 |
+
|
| 596 |
+
# ── Audit ──
|
| 597 |
+
|
| 598 |
+
def audit(self) -> Dict:
|
| 599 |
+
"""Print diagnostic report of engine state."""
|
| 600 |
+
report = {"status": "ready" if self._is_ready else "not_ready"}
|
| 601 |
+
if self.metadata is not None:
|
| 602 |
+
n = len(self.metadata)
|
| 603 |
+
report["products"] = n
|
| 604 |
+
report["has_price"] = int(self.metadata['price'].notna().sum())
|
| 605 |
+
report["has_image_url"] = int(
|
| 606 |
+
self.metadata['primary_image_url'].apply(
|
| 607 |
+
lambda x: isinstance(x, str) and x.startswith('http')
|
| 608 |
+
).sum()
|
| 609 |
+
)
|
| 610 |
+
report["has_search_text"] = int(self.metadata['search_text'].notna().sum())
|
| 611 |
+
if 'color_family' in self.metadata.columns:
|
| 612 |
+
report["color_families"] = sorted(self.metadata['color_family'].dropna().unique().tolist())
|
| 613 |
+
if 'category' in self.metadata.columns:
|
| 614 |
+
report["categories"] = sorted(self.metadata['category'].dropna().unique().tolist())
|
| 615 |
+
|
| 616 |
+
if self.text_embeddings is not None:
|
| 617 |
+
report["text_embeddings"] = self.text_embeddings.shape
|
| 618 |
+
report["zero_text_emb"] = int(np.sum(np.all(self.text_embeddings == 0, axis=1)))
|
| 619 |
+
if self.image_embeddings is not None:
|
| 620 |
+
report["image_embeddings"] = self.image_embeddings.shape
|
| 621 |
+
report["zero_img_emb"] = int(np.sum(np.all(self.image_embeddings == 0, axis=1)))
|
| 622 |
+
if self.dual_index and self.dual_index.image_index:
|
| 623 |
+
report["faiss_image_vectors"] = self.dual_index.image_index.ntotal
|
| 624 |
+
report["faiss_text_vectors"] = self.dual_index.text_index.ntotal
|
| 625 |
+
|
| 626 |
+
report["multilingual_enabled"] = self.config.enable_multilingual
|
| 627 |
+
report["spell_correction_enabled"] = self.config.enable_spell_correction
|
| 628 |
+
report["spell_corrector_vocab_size"] = len(self.spell_corrector.word_freq) if self.spell_corrector._ready else 0
|
| 629 |
+
|
| 630 |
+
print("\n" + "=" * 55)
|
| 631 |
+
print(" ENGINE AUDIT")
|
| 632 |
+
print("=" * 55)
|
| 633 |
+
for k, v in report.items():
|
| 634 |
+
print(f" {k:30s} {v}")
|
| 635 |
+
print("=" * 55)
|
| 636 |
+
return report
|
backend/app/exceptions.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class EngineNotReadyError(Exception):
|
| 2 |
+
"""Raised when the search engine hasn't finished loading."""
|
| 3 |
+
pass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SKUNotFoundError(Exception):
|
| 7 |
+
"""Raised when a requested SKU doesn't exist."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, sku: str):
|
| 10 |
+
self.sku = sku
|
| 11 |
+
super().__init__(f"SKU '{sku}' not found")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class InvalidQueryError(Exception):
|
| 15 |
+
"""Raised when a search query is invalid."""
|
| 16 |
+
pass
|
backend/app/main.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from contextlib import asynccontextmanager
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI, Request
|
| 5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
+
from fastapi.responses import JSONResponse
|
| 7 |
+
|
| 8 |
+
from backend.app.config import Settings, SearchConfig
|
| 9 |
+
from backend.app.exceptions import EngineNotReadyError, SKUNotFoundError, InvalidQueryError
|
| 10 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 11 |
+
from backend.app.routers import health, search, products
|
| 12 |
+
|
| 13 |
+
settings = Settings()
|
| 14 |
+
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=getattr(logging, settings.log_level.upper(), logging.INFO),
|
| 17 |
+
format="%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
|
| 18 |
+
datefmt="%H:%M:%S",
|
| 19 |
+
)
|
| 20 |
+
logger = logging.getLogger("asos_search")
|
| 21 |
+
|
| 22 |
+
for _noisy in (
|
| 23 |
+
"urllib3", "urllib3.connectionpool", "requests", "PIL",
|
| 24 |
+
"transformers", "transformers.modeling_utils",
|
| 25 |
+
):
|
| 26 |
+
logging.getLogger(_noisy).setLevel(logging.ERROR)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@asynccontextmanager
|
| 30 |
+
async def lifespan(app: FastAPI):
|
| 31 |
+
"""Load engine at startup, clean up on shutdown."""
|
| 32 |
+
logger.info("Starting ASOS Search Engine...")
|
| 33 |
+
config = SearchConfig.from_settings(settings)
|
| 34 |
+
engine = ASOSSearchEngine(config)
|
| 35 |
+
engine.load_data()
|
| 36 |
+
engine.build_index()
|
| 37 |
+
app.state.engine = engine
|
| 38 |
+
logger.info(f"Engine ready with {len(engine.metadata):,} products")
|
| 39 |
+
yield
|
| 40 |
+
logger.info("Shutting down ASOS Search Engine.")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
app = FastAPI(
|
| 44 |
+
title="ASOS Fashion Search API",
|
| 45 |
+
description="Multimodal, intent-driven semantic search engine for fashion products",
|
| 46 |
+
version="1.0.0",
|
| 47 |
+
lifespan=lifespan,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
app.add_middleware(
|
| 51 |
+
CORSMiddleware,
|
| 52 |
+
allow_origins=settings.cors_origins,
|
| 53 |
+
allow_credentials=True,
|
| 54 |
+
allow_methods=["*"],
|
| 55 |
+
allow_headers=["*"],
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Register routers
|
| 59 |
+
app.include_router(health.router, prefix="/api/v1")
|
| 60 |
+
app.include_router(search.router, prefix="/api/v1")
|
| 61 |
+
app.include_router(products.router, prefix="/api/v1")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Exception handlers
|
| 65 |
+
@app.exception_handler(EngineNotReadyError)
|
| 66 |
+
async def engine_not_ready_handler(request: Request, exc: EngineNotReadyError):
|
| 67 |
+
return JSONResponse(status_code=503, content={"detail": str(exc)})
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@app.exception_handler(SKUNotFoundError)
|
| 71 |
+
async def sku_not_found_handler(request: Request, exc: SKUNotFoundError):
|
| 72 |
+
return JSONResponse(status_code=404, content={"detail": str(exc)})
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@app.exception_handler(InvalidQueryError)
|
| 76 |
+
async def invalid_query_handler(request: Request, exc: InvalidQueryError):
|
| 77 |
+
return JSONResponse(status_code=422, content={"detail": str(exc)})
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@app.exception_handler(Exception)
|
| 81 |
+
async def general_exception_handler(request: Request, exc: Exception):
|
| 82 |
+
logger.exception(f"Unhandled error: {exc}")
|
| 83 |
+
return JSONResponse(status_code=500, content={"detail": "Internal server error"})
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
import uvicorn
|
| 88 |
+
uvicorn.run(app, host=settings.host, port=settings.port)
|
backend/app/models/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from backend.app.models.search import (
|
| 2 |
+
SearchRequest, SearchResponse, SearchResultItem, QueryInfo,
|
| 3 |
+
ImageSearchRequest, EvaluateRequest,
|
| 4 |
+
)
|
| 5 |
+
from backend.app.models.product import (
|
| 6 |
+
ProductDetail, OutfitResponse, OutfitItem,
|
| 7 |
+
SimilarProductItem, SimilarResponse,
|
| 8 |
+
)
|
backend/app/models/product.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ProductDetail(BaseModel):
|
| 5 |
+
sku: str
|
| 6 |
+
name: str
|
| 7 |
+
brand: str
|
| 8 |
+
price: float
|
| 9 |
+
color: str
|
| 10 |
+
color_family: str
|
| 11 |
+
category: str
|
| 12 |
+
gender: str
|
| 13 |
+
image_url: str
|
| 14 |
+
url: str = ""
|
| 15 |
+
image_urls: list[str] = []
|
| 16 |
+
style_tags: list[str] = []
|
| 17 |
+
materials: list[str] = []
|
| 18 |
+
sizes_available: list[str] = []
|
| 19 |
+
product_details: str = ""
|
| 20 |
+
in_stock: bool = True
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OutfitItem(BaseModel):
|
| 24 |
+
sku: str
|
| 25 |
+
name: str
|
| 26 |
+
brand: str
|
| 27 |
+
price: float
|
| 28 |
+
color_family: str
|
| 29 |
+
category: str
|
| 30 |
+
image_url: str
|
| 31 |
+
outfit_score: float
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class OutfitResponse(BaseModel):
|
| 35 |
+
source: ProductDetail
|
| 36 |
+
outfit: dict[str, list[OutfitItem]]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class SimilarProductItem(BaseModel):
|
| 40 |
+
sku: str
|
| 41 |
+
name: str
|
| 42 |
+
brand: str
|
| 43 |
+
price: float
|
| 44 |
+
color: str
|
| 45 |
+
category: str
|
| 46 |
+
image_url: str
|
| 47 |
+
similarity_score: float
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SimilarResponse(BaseModel):
|
| 51 |
+
source: ProductDetail
|
| 52 |
+
results: list[SimilarProductItem]
|
| 53 |
+
total: int
|
backend/app/models/search.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, Optional
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
|
| 4 |
+
SortOption = Literal["relevance", "price_asc", "price_desc", "name_asc", "name_desc"]
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SearchRequest(BaseModel):
|
| 8 |
+
query: str = Field(..., min_length=1, max_length=500, description="Search query text")
|
| 9 |
+
top_n: int = Field(20, ge=1, le=100, description="Number of results to return")
|
| 10 |
+
sort_by: SortOption = Field("relevance", description="Sort order")
|
| 11 |
+
text_weight: float = Field(0.5, ge=0.0, le=1.0, description="Text vs image weight for multimodal queries")
|
| 12 |
+
image_b64: Optional[str] = Field(None, description="Base64-encoded image for multimodal search")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SearchResultItem(BaseModel):
|
| 16 |
+
sku: str
|
| 17 |
+
name: str
|
| 18 |
+
brand: str
|
| 19 |
+
price: float
|
| 20 |
+
color: str
|
| 21 |
+
color_family: str
|
| 22 |
+
category: str
|
| 23 |
+
gender: str
|
| 24 |
+
image_url: str
|
| 25 |
+
url: Optional[str] = None
|
| 26 |
+
score: float
|
| 27 |
+
style_tags: list[str] = []
|
| 28 |
+
in_stock: bool = True
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class QueryInfo(BaseModel):
|
| 32 |
+
original_query: str
|
| 33 |
+
processed_query: str
|
| 34 |
+
detected_language: str = "en"
|
| 35 |
+
was_translated: bool = False
|
| 36 |
+
was_spell_corrected: bool = False
|
| 37 |
+
spell_suggestion: Optional[str] = None
|
| 38 |
+
parsed_category: Optional[str] = None
|
| 39 |
+
parsed_color: Optional[str] = None
|
| 40 |
+
parsed_price_range: list[Optional[float]] = [None, None]
|
| 41 |
+
parsed_gender: Optional[str] = None
|
| 42 |
+
parsed_style_tags: list[str] = []
|
| 43 |
+
parsed_material: Optional[str] = None
|
| 44 |
+
parsed_size: Optional[str] = None
|
| 45 |
+
parsed_exclusions: list[str] = []
|
| 46 |
+
sort_by: str = "relevance"
|
| 47 |
+
available_sorts: list[str] = []
|
| 48 |
+
suggested_searches: list[str] = []
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SearchResponse(BaseModel):
|
| 52 |
+
results: list[SearchResultItem]
|
| 53 |
+
query_info: QueryInfo
|
| 54 |
+
total: int
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ImageSearchRequest(BaseModel):
|
| 58 |
+
top_n: int = Field(20, ge=1, le=100)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class EvaluateRequest(BaseModel):
|
| 62 |
+
test_queries: list[dict]
|
| 63 |
+
k_values: list[int] = [5, 10, 20]
|
backend/app/routers/__init__.py
ADDED
|
File without changes
|
backend/app/routers/health.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Request
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
router = APIRouter(tags=["health"])
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class HealthResponse(BaseModel):
|
| 8 |
+
status: str
|
| 9 |
+
products: int
|
| 10 |
+
engine_ready: bool
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.get("/health", response_model=HealthResponse)
|
| 14 |
+
def health_check(request: Request) -> HealthResponse:
|
| 15 |
+
engine = getattr(request.app.state, "engine", None)
|
| 16 |
+
if engine is None or not engine._is_ready:
|
| 17 |
+
return HealthResponse(status="loading", products=0, engine_ready=False)
|
| 18 |
+
return HealthResponse(
|
| 19 |
+
status="ok",
|
| 20 |
+
products=len(engine.metadata),
|
| 21 |
+
engine_ready=True,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@router.get("/audit")
|
| 26 |
+
def audit(request: Request) -> dict:
|
| 27 |
+
engine = getattr(request.app.state, "engine", None)
|
| 28 |
+
if engine is None:
|
| 29 |
+
return {"status": "engine_not_loaded"}
|
| 30 |
+
return engine.audit()
|
backend/app/routers/products.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, Query
|
| 2 |
+
|
| 3 |
+
from backend.app.dependencies import get_engine
|
| 4 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 5 |
+
from backend.app.models.product import ProductDetail, OutfitResponse
|
| 6 |
+
from backend.app.services import search_service
|
| 7 |
+
|
| 8 |
+
router = APIRouter(prefix="/products", tags=["products"])
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@router.get("/{sku}", response_model=ProductDetail)
|
| 12 |
+
def product_detail(
|
| 13 |
+
sku: str,
|
| 14 |
+
engine: ASOSSearchEngine = Depends(get_engine),
|
| 15 |
+
) -> ProductDetail:
|
| 16 |
+
"""Get full product details for a single SKU."""
|
| 17 |
+
return search_service.get_product_detail(engine, sku)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.get("/{sku}/outfit", response_model=OutfitResponse)
|
| 21 |
+
def complete_the_look(
|
| 22 |
+
sku: str,
|
| 23 |
+
n_per_category: int = Query(3, ge=1, le=10),
|
| 24 |
+
engine: ASOSSearchEngine = Depends(get_engine),
|
| 25 |
+
) -> OutfitResponse:
|
| 26 |
+
"""Get outfit recommendations for a product."""
|
| 27 |
+
return search_service.get_outfit(engine, sku, n_per_category=n_per_category)
|
backend/app/routers/search.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, File, Query, UploadFile
|
| 4 |
+
|
| 5 |
+
from backend.app.dependencies import get_engine
|
| 6 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 7 |
+
from backend.app.models.search import (
|
| 8 |
+
SearchRequest, SearchResponse, EvaluateRequest,
|
| 9 |
+
)
|
| 10 |
+
from backend.app.models.product import SimilarResponse
|
| 11 |
+
from backend.app.services import search_service
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger("asos_search")
|
| 14 |
+
|
| 15 |
+
router = APIRouter(prefix="/search", tags=["search"])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@router.post("", response_model=SearchResponse)
|
| 19 |
+
def text_search(
|
| 20 |
+
request: SearchRequest,
|
| 21 |
+
engine: ASOSSearchEngine = Depends(get_engine),
|
| 22 |
+
) -> SearchResponse:
|
| 23 |
+
"""Text search with optional base64 image for multimodal queries."""
|
| 24 |
+
return search_service.search(engine, request)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@router.post("/image", response_model=SearchResponse)
|
| 28 |
+
async def image_search(
|
| 29 |
+
file: UploadFile = File(...),
|
| 30 |
+
top_n: int = Query(20, ge=1, le=100),
|
| 31 |
+
engine: ASOSSearchEngine = Depends(get_engine),
|
| 32 |
+
) -> SearchResponse:
|
| 33 |
+
"""Image-only search via file upload."""
|
| 34 |
+
image_bytes = await file.read()
|
| 35 |
+
image = search_service.decode_image(image_bytes=image_bytes)
|
| 36 |
+
if image is None:
|
| 37 |
+
from backend.app.exceptions import InvalidQueryError
|
| 38 |
+
raise InvalidQueryError("Uploaded file could not be decoded as an image")
|
| 39 |
+
return search_service.search_by_image(engine, image, top_n=top_n)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.get("/similar/{sku}", response_model=SimilarResponse)
|
| 43 |
+
def similar_search(
|
| 44 |
+
sku: str,
|
| 45 |
+
top_n: int = Query(10, ge=1, le=100),
|
| 46 |
+
engine: ASOSSearchEngine = Depends(get_engine),
|
| 47 |
+
) -> SimilarResponse:
|
| 48 |
+
"""Find visually similar products to a given SKU."""
|
| 49 |
+
return search_service.get_similar(engine, sku, top_n=top_n)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@router.post("/evaluate")
|
| 53 |
+
def evaluate(
|
| 54 |
+
request: EvaluateRequest,
|
| 55 |
+
engine: ASOSSearchEngine = Depends(get_engine),
|
| 56 |
+
) -> dict:
|
| 57 |
+
"""Run evaluation suite against the engine."""
|
| 58 |
+
from backend.app.engine.evaluator import SearchEvaluator
|
| 59 |
+
evaluator = SearchEvaluator(engine)
|
| 60 |
+
return evaluator.evaluate(request.test_queries, k_values=request.k_values)
|
backend/app/services/__init__.py
ADDED
|
File without changes
|
backend/app/services/search_service.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from backend.app.engine.search_engine import ASOSSearchEngine
|
| 10 |
+
from backend.app.exceptions import SKUNotFoundError, InvalidQueryError
|
| 11 |
+
from backend.app.models.search import SearchRequest, SearchResponse, SearchResultItem, QueryInfo
|
| 12 |
+
from backend.app.models.product import ProductDetail, OutfitResponse, OutfitItem, SimilarProductItem, SimilarResponse
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def decode_image(
|
| 18 |
+
image_b64: Optional[str] = None,
|
| 19 |
+
image_bytes: Optional[bytes] = None,
|
| 20 |
+
) -> Optional[Image.Image]:
|
| 21 |
+
"""Decode a base64 string or raw bytes into a PIL Image.
|
| 22 |
+
|
| 23 |
+
Raises InvalidQueryError if decoding fails.
|
| 24 |
+
Returns None when neither argument is provided.
|
| 25 |
+
"""
|
| 26 |
+
if image_b64 is None and image_bytes is None:
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
if image_b64 is not None:
|
| 31 |
+
# Strip optional data-URI prefix (e.g. "data:image/jpeg;base64,")
|
| 32 |
+
if "," in image_b64:
|
| 33 |
+
image_b64 = image_b64.split(",", 1)[1]
|
| 34 |
+
raw = base64.b64decode(image_b64)
|
| 35 |
+
else:
|
| 36 |
+
raw = image_bytes # type: ignore[assignment]
|
| 37 |
+
|
| 38 |
+
return Image.open(io.BytesIO(raw)).convert("RGB")
|
| 39 |
+
except Exception as exc:
|
| 40 |
+
raise InvalidQueryError(f"Could not decode image: {exc}") from exc
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _row_to_search_item(row: pd.Series) -> SearchResultItem:
|
| 44 |
+
"""Convert a single DataFrame row to a SearchResultItem Pydantic model.
|
| 45 |
+
|
| 46 |
+
Handles numpy scalar types and gracefully falls back when optional
|
| 47 |
+
columns are absent.
|
| 48 |
+
"""
|
| 49 |
+
# Score: prefer hybrid_score → rrf_score → score → 0
|
| 50 |
+
for score_col in ("hybrid_score", "rrf_score", "score"):
|
| 51 |
+
raw_score = row.get(score_col)
|
| 52 |
+
if raw_score is not None:
|
| 53 |
+
try:
|
| 54 |
+
score = float(raw_score)
|
| 55 |
+
break
|
| 56 |
+
except (TypeError, ValueError):
|
| 57 |
+
continue
|
| 58 |
+
else:
|
| 59 |
+
score = 0.0
|
| 60 |
+
|
| 61 |
+
# style_tags: ensure it is always a plain Python list of strings
|
| 62 |
+
raw_tags = row.get("style_tags", [])
|
| 63 |
+
if isinstance(raw_tags, list):
|
| 64 |
+
style_tags = [str(t) for t in raw_tags]
|
| 65 |
+
elif isinstance(raw_tags, str) and raw_tags:
|
| 66 |
+
# Could be a JSON-serialised list stored as a string
|
| 67 |
+
try:
|
| 68 |
+
import ast
|
| 69 |
+
parsed = ast.literal_eval(raw_tags)
|
| 70 |
+
style_tags = [str(t) for t in parsed] if isinstance(parsed, list) else [raw_tags]
|
| 71 |
+
except Exception:
|
| 72 |
+
style_tags = [raw_tags]
|
| 73 |
+
else:
|
| 74 |
+
style_tags = []
|
| 75 |
+
|
| 76 |
+
return SearchResultItem(
|
| 77 |
+
sku=str(row["sku"]),
|
| 78 |
+
name=str(row["name"]),
|
| 79 |
+
brand=str(row["brand"]),
|
| 80 |
+
price=float(row["price"]),
|
| 81 |
+
color=str(row.get("color_clean", "")),
|
| 82 |
+
color_family=str(row.get("color_family", "")),
|
| 83 |
+
category=str(row.get("category", "")),
|
| 84 |
+
gender=str(row.get("gender", "")),
|
| 85 |
+
image_url=str(row.get("primary_image_url", "")),
|
| 86 |
+
url=str(row.get("url", "")),
|
| 87 |
+
score=score,
|
| 88 |
+
style_tags=style_tags,
|
| 89 |
+
in_stock=bool(row.get("any_in_stock", True)),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _row_to_product_detail(detail: dict) -> ProductDetail:
|
| 94 |
+
"""Convert a product detail dict (from engine.get_product_detail) to a
|
| 95 |
+
ProductDetail Pydantic model.
|
| 96 |
+
|
| 97 |
+
Coerces numpy/pandas scalars and handles image_urls being a list or not.
|
| 98 |
+
"""
|
| 99 |
+
# image_urls
|
| 100 |
+
raw_image_urls = detail.get("image_urls", [])
|
| 101 |
+
if isinstance(raw_image_urls, list):
|
| 102 |
+
image_urls = [str(u) for u in raw_image_urls]
|
| 103 |
+
elif isinstance(raw_image_urls, str) and raw_image_urls:
|
| 104 |
+
try:
|
| 105 |
+
import ast
|
| 106 |
+
parsed = ast.literal_eval(raw_image_urls)
|
| 107 |
+
image_urls = [str(u) for u in parsed] if isinstance(parsed, list) else [raw_image_urls]
|
| 108 |
+
except Exception:
|
| 109 |
+
image_urls = [raw_image_urls]
|
| 110 |
+
else:
|
| 111 |
+
image_urls = []
|
| 112 |
+
|
| 113 |
+
# sizes_available — always a list of strings
|
| 114 |
+
raw_sizes = detail.get("sizes_available", [])
|
| 115 |
+
if isinstance(raw_sizes, list):
|
| 116 |
+
sizes_available = [str(s) for s in raw_sizes]
|
| 117 |
+
elif isinstance(raw_sizes, str) and raw_sizes:
|
| 118 |
+
try:
|
| 119 |
+
import ast
|
| 120 |
+
parsed = ast.literal_eval(raw_sizes)
|
| 121 |
+
sizes_available = [str(s) for s in parsed] if isinstance(parsed, list) else [raw_sizes]
|
| 122 |
+
except Exception:
|
| 123 |
+
sizes_available = [raw_sizes]
|
| 124 |
+
else:
|
| 125 |
+
sizes_available = []
|
| 126 |
+
|
| 127 |
+
# style_tags
|
| 128 |
+
raw_tags = detail.get("style_tags", [])
|
| 129 |
+
if isinstance(raw_tags, list):
|
| 130 |
+
style_tags = [str(t) for t in raw_tags]
|
| 131 |
+
elif isinstance(raw_tags, str) and raw_tags:
|
| 132 |
+
try:
|
| 133 |
+
import ast
|
| 134 |
+
parsed = ast.literal_eval(raw_tags)
|
| 135 |
+
style_tags = [str(t) for t in parsed] if isinstance(parsed, list) else [raw_tags]
|
| 136 |
+
except Exception:
|
| 137 |
+
style_tags = [raw_tags]
|
| 138 |
+
else:
|
| 139 |
+
style_tags = []
|
| 140 |
+
|
| 141 |
+
# materials
|
| 142 |
+
raw_materials = detail.get("materials", [])
|
| 143 |
+
if isinstance(raw_materials, list):
|
| 144 |
+
materials = [str(m) for m in raw_materials]
|
| 145 |
+
elif isinstance(raw_materials, str) and raw_materials:
|
| 146 |
+
try:
|
| 147 |
+
import ast
|
| 148 |
+
parsed = ast.literal_eval(raw_materials)
|
| 149 |
+
materials = [str(m) for m in parsed] if isinstance(parsed, list) else [raw_materials]
|
| 150 |
+
except Exception:
|
| 151 |
+
materials = [raw_materials]
|
| 152 |
+
else:
|
| 153 |
+
materials = []
|
| 154 |
+
|
| 155 |
+
return ProductDetail(
|
| 156 |
+
sku=str(detail["sku"]),
|
| 157 |
+
name=str(detail.get("name", "")),
|
| 158 |
+
brand=str(detail.get("brand", "")),
|
| 159 |
+
price=float(detail.get("price", 0.0)),
|
| 160 |
+
color=str(detail.get("color_clean", "")),
|
| 161 |
+
color_family=str(detail.get("color_family", "")),
|
| 162 |
+
category=str(detail.get("category", "")),
|
| 163 |
+
gender=str(detail.get("gender", "")),
|
| 164 |
+
image_url=str(detail.get("primary_image_url", "")),
|
| 165 |
+
url=str(detail.get("url", "")),
|
| 166 |
+
image_urls=image_urls,
|
| 167 |
+
style_tags=style_tags,
|
| 168 |
+
materials=materials,
|
| 169 |
+
sizes_available=sizes_available,
|
| 170 |
+
product_details=str(detail.get("product_details", "")),
|
| 171 |
+
in_stock=bool(detail.get("any_in_stock", True)),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def search(engine: ASOSSearchEngine, request: SearchRequest) -> SearchResponse:
|
| 176 |
+
"""Execute a text (or multimodal) search and return a SearchResponse.
|
| 177 |
+
|
| 178 |
+
Decodes the optional base64 image, calls the engine, converts the
|
| 179 |
+
resulting DataFrame, and wraps query metadata into the response.
|
| 180 |
+
"""
|
| 181 |
+
query_image: Optional[Image.Image] = None
|
| 182 |
+
if request.image_b64:
|
| 183 |
+
query_image = decode_image(image_b64=request.image_b64)
|
| 184 |
+
|
| 185 |
+
results_df: pd.DataFrame = engine.search(
|
| 186 |
+
query=request.query,
|
| 187 |
+
query_image=query_image,
|
| 188 |
+
top_n=request.top_n,
|
| 189 |
+
text_weight=request.text_weight,
|
| 190 |
+
sort_by=request.sort_by,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
items = [_row_to_search_item(row) for _, row in results_df.iterrows()]
|
| 194 |
+
|
| 195 |
+
# Extract query_info dict attached by the engine
|
| 196 |
+
raw_qi: dict = results_df.attrs.get("query_info", {})
|
| 197 |
+
query_info = QueryInfo(
|
| 198 |
+
original_query=str(raw_qi.get("original_query", request.query)),
|
| 199 |
+
processed_query=str(raw_qi.get("processed_query", request.query)),
|
| 200 |
+
detected_language=str(raw_qi.get("detected_language", "en")),
|
| 201 |
+
was_translated=bool(raw_qi.get("was_translated", False)),
|
| 202 |
+
was_spell_corrected=bool(raw_qi.get("was_spell_corrected", False)),
|
| 203 |
+
spell_suggestion=raw_qi.get("spell_suggestion"),
|
| 204 |
+
parsed_category=raw_qi.get("parsed_category"),
|
| 205 |
+
parsed_color=raw_qi.get("parsed_color"),
|
| 206 |
+
parsed_price_range=list(raw_qi.get("parsed_price_range", [None, None])),
|
| 207 |
+
parsed_gender=raw_qi.get("parsed_gender"),
|
| 208 |
+
parsed_style_tags=list(raw_qi.get("parsed_style_tags", [])),
|
| 209 |
+
parsed_material=raw_qi.get("parsed_material"),
|
| 210 |
+
parsed_size=raw_qi.get("parsed_size"),
|
| 211 |
+
parsed_exclusions=list(raw_qi.get("parsed_exclusions", [])),
|
| 212 |
+
sort_by=str(raw_qi.get("sort_by", request.sort_by)),
|
| 213 |
+
available_sorts=list(raw_qi.get("available_sorts", [])),
|
| 214 |
+
suggested_searches=list(raw_qi.get("suggested_searches", [])),
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return SearchResponse(results=items, query_info=query_info, total=len(items))
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def search_by_image(
|
| 221 |
+
engine: ASOSSearchEngine,
|
| 222 |
+
image: Image.Image,
|
| 223 |
+
top_n: int = 20,
|
| 224 |
+
) -> SearchResponse:
|
| 225 |
+
"""Execute a pure image search and return a SearchResponse.
|
| 226 |
+
|
| 227 |
+
A minimal QueryInfo is generated because there is no text query to parse.
|
| 228 |
+
"""
|
| 229 |
+
results_df: pd.DataFrame = engine.search_by_image(image=image, top_n=top_n)
|
| 230 |
+
|
| 231 |
+
items = [_row_to_search_item(row) for _, row in results_df.iterrows()]
|
| 232 |
+
|
| 233 |
+
query_info = QueryInfo(
|
| 234 |
+
original_query="",
|
| 235 |
+
processed_query="",
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return SearchResponse(results=items, query_info=query_info, total=len(items))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def get_product_detail(engine: ASOSSearchEngine, sku: str) -> ProductDetail:
|
| 242 |
+
"""Fetch and return full product detail for a single SKU.
|
| 243 |
+
|
| 244 |
+
Raises SKUNotFoundError when the SKU does not exist in the catalogue.
|
| 245 |
+
"""
|
| 246 |
+
detail = engine.get_product_detail(sku)
|
| 247 |
+
if detail is None:
|
| 248 |
+
raise SKUNotFoundError(sku)
|
| 249 |
+
|
| 250 |
+
return _row_to_product_detail(detail)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def get_similar(
|
| 254 |
+
engine: ASOSSearchEngine,
|
| 255 |
+
sku: str,
|
| 256 |
+
top_n: int = 10,
|
| 257 |
+
) -> SimilarResponse:
|
| 258 |
+
"""Return similar products for a given SKU.
|
| 259 |
+
|
| 260 |
+
Raises SKUNotFoundError when the source SKU does not exist.
|
| 261 |
+
"""
|
| 262 |
+
source_detail = engine.get_product_detail(sku)
|
| 263 |
+
if source_detail is None:
|
| 264 |
+
raise SKUNotFoundError(sku)
|
| 265 |
+
|
| 266 |
+
source = _row_to_product_detail(source_detail)
|
| 267 |
+
|
| 268 |
+
results_df: pd.DataFrame = engine.search_similar(sku=sku, top_n=top_n)
|
| 269 |
+
|
| 270 |
+
similar_items = []
|
| 271 |
+
for _, row in results_df.iterrows():
|
| 272 |
+
similar_items.append(
|
| 273 |
+
SimilarProductItem(
|
| 274 |
+
sku=str(row["sku"]),
|
| 275 |
+
name=str(row.get("name", "")),
|
| 276 |
+
brand=str(row.get("brand", "")),
|
| 277 |
+
price=float(row.get("price", 0.0)),
|
| 278 |
+
color=str(row.get("color_clean", "")),
|
| 279 |
+
category=str(row.get("category", "")),
|
| 280 |
+
image_url=str(row.get("primary_image_url", "")),
|
| 281 |
+
similarity_score=float(row.get("similarity_score", 0.0)),
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return SimilarResponse(source=source, results=similar_items, total=len(similar_items))
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def get_outfit(
|
| 289 |
+
engine: ASOSSearchEngine,
|
| 290 |
+
sku: str,
|
| 291 |
+
n_per_category: int = 3,
|
| 292 |
+
) -> OutfitResponse:
|
| 293 |
+
"""Return an outfit recommendation for a given SKU.
|
| 294 |
+
|
| 295 |
+
Raises SKUNotFoundError when the source SKU does not exist.
|
| 296 |
+
"""
|
| 297 |
+
source_detail = engine.get_product_detail(sku)
|
| 298 |
+
if source_detail is None:
|
| 299 |
+
raise SKUNotFoundError(sku)
|
| 300 |
+
|
| 301 |
+
source = _row_to_product_detail(source_detail)
|
| 302 |
+
|
| 303 |
+
outfit_dict: dict[str, pd.DataFrame] = engine.complete_the_look(
|
| 304 |
+
sku=sku, n_per_category=n_per_category
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
outfit: dict[str, list[OutfitItem]] = {}
|
| 308 |
+
for category, df in outfit_dict.items():
|
| 309 |
+
category_items = []
|
| 310 |
+
for _, row in df.iterrows():
|
| 311 |
+
category_items.append(
|
| 312 |
+
OutfitItem(
|
| 313 |
+
sku=str(row["sku"]),
|
| 314 |
+
name=str(row.get("name", "")),
|
| 315 |
+
brand=str(row.get("brand", "")),
|
| 316 |
+
price=float(row.get("price", 0.0)),
|
| 317 |
+
color_family=str(row.get("color_family", "")),
|
| 318 |
+
category=str(row.get("category", "")),
|
| 319 |
+
image_url=str(row.get("primary_image_url", "")),
|
| 320 |
+
outfit_score=float(row.get("outfit_score", 0.0)),
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
outfit[category] = category_items
|
| 324 |
+
|
| 325 |
+
return OutfitResponse(source=source, outfit=outfit)
|
backend/pyproject.toml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "asos-search-backend"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "ASOS Multimodal Fashion Search Engine API"
|
| 5 |
+
requires-python = ">=3.10"
|
| 6 |
+
|
| 7 |
+
[tool.pytest.ini_options]
|
| 8 |
+
testpaths = ["tests"]
|
| 9 |
+
pythonpath = [".."]
|
backend/requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn[standard]>=0.27.0
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
pydantic-settings>=2.0
|
| 5 |
+
python-multipart>=0.0.6
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
torch>=2.0.0
|
| 9 |
+
transformers>=4.30.0
|
| 10 |
+
faiss-cpu>=1.7.4
|
| 11 |
+
Pillow>=10.0.0
|
| 12 |
+
tqdm>=4.65.0
|
| 13 |
+
pytest>=7.0.0
|
| 14 |
+
httpx>=0.24.0
|
backend/tests/__init__.py
ADDED
|
File without changes
|
backend/tests/conftest.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@pytest.fixture
|
| 5 |
+
def query_parser():
|
| 6 |
+
from backend.app.engine.query_parser import QueryParser
|
| 7 |
+
return QueryParser()
|
backend/tests/test_api_health.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
|
| 4 |
+
from backend.app.main import app
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _make_mock_engine(ready=True, n_products=100):
|
| 8 |
+
engine = MagicMock()
|
| 9 |
+
engine._is_ready = ready
|
| 10 |
+
engine.metadata = MagicMock()
|
| 11 |
+
engine.metadata.__len__ = MagicMock(return_value=n_products)
|
| 12 |
+
engine.audit.return_value = {"status": "ready", "products": n_products}
|
| 13 |
+
return engine
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestHealthEndpoints:
|
| 17 |
+
def test_health_ok(self):
|
| 18 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 19 |
+
app.state.engine = _make_mock_engine()
|
| 20 |
+
response = client.get("/api/v1/health")
|
| 21 |
+
assert response.status_code == 200
|
| 22 |
+
data = response.json()
|
| 23 |
+
assert data["status"] == "ok"
|
| 24 |
+
assert data["engine_ready"] is True
|
| 25 |
+
assert data["products"] == 100
|
| 26 |
+
|
| 27 |
+
def test_health_not_ready(self):
|
| 28 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 29 |
+
app.state.engine = _make_mock_engine(ready=False)
|
| 30 |
+
response = client.get("/api/v1/health")
|
| 31 |
+
data = response.json()
|
| 32 |
+
assert data["engine_ready"] is False
|
| 33 |
+
|
| 34 |
+
def test_health_no_engine(self):
|
| 35 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 36 |
+
app.state.engine = None
|
| 37 |
+
response = client.get("/api/v1/health")
|
| 38 |
+
data = response.json()
|
| 39 |
+
assert data["engine_ready"] is False
|
| 40 |
+
|
| 41 |
+
def test_audit(self):
|
| 42 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 43 |
+
app.state.engine = _make_mock_engine()
|
| 44 |
+
response = client.get("/api/v1/audit")
|
| 45 |
+
assert response.status_code == 200
|
| 46 |
+
data = response.json()
|
| 47 |
+
assert data["status"] == "ready"
|
backend/tests/test_api_products.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock
|
| 2 |
+
from fastapi.testclient import TestClient
|
| 3 |
+
|
| 4 |
+
from backend.app.main import app
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _make_product_detail():
|
| 8 |
+
return {
|
| 9 |
+
"sku": "12345", "name": "Test Dress", "brand": "ASOS",
|
| 10 |
+
"price": 29.99, "color_clean": "black", "color_family": "black",
|
| 11 |
+
"category": "Dresses", "gender": "Women",
|
| 12 |
+
"primary_image_url": "https://example.com/img.jpg",
|
| 13 |
+
"image_urls": ["https://example.com/img.jpg", "https://example.com/img2.jpg"],
|
| 14 |
+
"style_tags": ["casual", "summer"], "materials": ["cotton"],
|
| 15 |
+
"sizes_available": ["S", "M", "L"],
|
| 16 |
+
"product_details": "A nice black dress", "any_in_stock": True,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestProductEndpoints:
|
| 21 |
+
def test_product_detail(self):
|
| 22 |
+
engine = MagicMock()
|
| 23 |
+
engine._is_ready = True
|
| 24 |
+
engine.get_product_detail.return_value = _make_product_detail()
|
| 25 |
+
app.state.engine = engine
|
| 26 |
+
|
| 27 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 28 |
+
response = client.get("/api/v1/products/12345")
|
| 29 |
+
assert response.status_code == 200
|
| 30 |
+
data = response.json()
|
| 31 |
+
assert data["sku"] == "12345"
|
| 32 |
+
assert data["name"] == "Test Dress"
|
| 33 |
+
assert len(data["image_urls"]) == 2
|
| 34 |
+
assert data["materials"] == ["cotton"]
|
| 35 |
+
|
| 36 |
+
def test_product_not_found(self):
|
| 37 |
+
engine = MagicMock()
|
| 38 |
+
engine._is_ready = True
|
| 39 |
+
engine.get_product_detail.return_value = None
|
| 40 |
+
app.state.engine = engine
|
| 41 |
+
|
| 42 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 43 |
+
response = client.get("/api/v1/products/99999")
|
| 44 |
+
assert response.status_code == 404
|
| 45 |
+
|
| 46 |
+
def test_outfit(self):
|
| 47 |
+
engine = MagicMock()
|
| 48 |
+
engine._is_ready = True
|
| 49 |
+
engine.get_product_detail.return_value = _make_product_detail()
|
| 50 |
+
engine.complete_the_look.return_value = {}
|
| 51 |
+
app.state.engine = engine
|
| 52 |
+
|
| 53 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 54 |
+
response = client.get("/api/v1/products/12345/outfit")
|
| 55 |
+
assert response.status_code == 200
|
| 56 |
+
data = response.json()
|
| 57 |
+
assert "source" in data
|
| 58 |
+
assert "outfit" in data
|
| 59 |
+
assert data["source"]["sku"] == "12345"
|
| 60 |
+
|
| 61 |
+
def test_outfit_not_found(self):
|
| 62 |
+
engine = MagicMock()
|
| 63 |
+
engine._is_ready = True
|
| 64 |
+
engine.get_product_detail.return_value = None
|
| 65 |
+
app.state.engine = engine
|
| 66 |
+
|
| 67 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 68 |
+
response = client.get("/api/v1/products/99999/outfit")
|
| 69 |
+
assert response.status_code == 404
|
| 70 |
+
|
| 71 |
+
def test_outfit_with_categories(self):
|
| 72 |
+
import pandas as pd
|
| 73 |
+
engine = MagicMock()
|
| 74 |
+
engine._is_ready = True
|
| 75 |
+
engine.get_product_detail.return_value = _make_product_detail()
|
| 76 |
+
engine.complete_the_look.return_value = {
|
| 77 |
+
"Shoes": pd.DataFrame([{
|
| 78 |
+
"sku": "55555", "name": "Black Heels", "brand": "ASOS",
|
| 79 |
+
"price": 45.00, "color_family": "black", "category": "Shoes",
|
| 80 |
+
"primary_image_url": "https://example.com/shoes.jpg",
|
| 81 |
+
"outfit_score": 0.82,
|
| 82 |
+
}])
|
| 83 |
+
}
|
| 84 |
+
app.state.engine = engine
|
| 85 |
+
|
| 86 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 87 |
+
response = client.get("/api/v1/products/12345/outfit")
|
| 88 |
+
assert response.status_code == 200
|
| 89 |
+
data = response.json()
|
| 90 |
+
assert "Shoes" in data["outfit"]
|
| 91 |
+
assert len(data["outfit"]["Shoes"]) == 1
|
| 92 |
+
assert data["outfit"]["Shoes"][0]["outfit_score"] == 0.82
|
backend/tests/test_api_search.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import MagicMock
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from fastapi.testclient import TestClient
|
| 4 |
+
|
| 5 |
+
from backend.app.main import app
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _make_search_results():
|
| 9 |
+
df = pd.DataFrame([{
|
| 10 |
+
"sku": "12345",
|
| 11 |
+
"name": "Test Dress",
|
| 12 |
+
"brand": "ASOS",
|
| 13 |
+
"price": 29.99,
|
| 14 |
+
"color_clean": "black",
|
| 15 |
+
"color_family": "black",
|
| 16 |
+
"category": "Dresses",
|
| 17 |
+
"gender": "Women",
|
| 18 |
+
"primary_image_url": "https://example.com/img.jpg",
|
| 19 |
+
"hybrid_score": 0.95,
|
| 20 |
+
"style_tags": ["casual"],
|
| 21 |
+
"any_in_stock": True,
|
| 22 |
+
}])
|
| 23 |
+
df.attrs["query_info"] = {
|
| 24 |
+
"original_query": "black dress",
|
| 25 |
+
"processed_query": "black dress",
|
| 26 |
+
"detected_language": "en",
|
| 27 |
+
"was_translated": False,
|
| 28 |
+
"was_spell_corrected": False,
|
| 29 |
+
"spell_suggestion": None,
|
| 30 |
+
"parsed_category": "Dresses",
|
| 31 |
+
"parsed_color": "black",
|
| 32 |
+
"parsed_price_range": [None, None],
|
| 33 |
+
"parsed_gender": None,
|
| 34 |
+
"parsed_style_tags": [],
|
| 35 |
+
"parsed_material": None,
|
| 36 |
+
"parsed_size": None,
|
| 37 |
+
"parsed_exclusions": [],
|
| 38 |
+
"sort_by": "relevance",
|
| 39 |
+
"available_sorts": ["relevance", "price_asc", "price_desc"],
|
| 40 |
+
"suggested_searches": ["navy dresses"],
|
| 41 |
+
}
|
| 42 |
+
return df
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _make_mock_engine():
|
| 46 |
+
engine = MagicMock()
|
| 47 |
+
engine._is_ready = True
|
| 48 |
+
engine.search.return_value = _make_search_results()
|
| 49 |
+
return engine
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TestSearchEndpoints:
|
| 53 |
+
def test_text_search(self):
|
| 54 |
+
app.state.engine = _make_mock_engine()
|
| 55 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 56 |
+
response = client.post("/api/v1/search", json={"query": "black dress"})
|
| 57 |
+
assert response.status_code == 200
|
| 58 |
+
data = response.json()
|
| 59 |
+
assert data["total"] == 1
|
| 60 |
+
assert data["results"][0]["sku"] == "12345"
|
| 61 |
+
assert data["results"][0]["name"] == "Test Dress"
|
| 62 |
+
assert data["query_info"]["parsed_category"] == "Dresses"
|
| 63 |
+
|
| 64 |
+
def test_search_with_params(self):
|
| 65 |
+
app.state.engine = _make_mock_engine()
|
| 66 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 67 |
+
response = client.post("/api/v1/search", json={
|
| 68 |
+
"query": "red shoes",
|
| 69 |
+
"top_n": 5,
|
| 70 |
+
"sort_by": "price_asc",
|
| 71 |
+
})
|
| 72 |
+
assert response.status_code == 200
|
| 73 |
+
|
| 74 |
+
def test_empty_query_rejected(self):
|
| 75 |
+
app.state.engine = _make_mock_engine()
|
| 76 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 77 |
+
response = client.post("/api/v1/search", json={"query": ""})
|
| 78 |
+
assert response.status_code == 422
|
| 79 |
+
|
| 80 |
+
def test_engine_not_ready(self):
|
| 81 |
+
app.state.engine = None
|
| 82 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 83 |
+
response = client.post("/api/v1/search", json={"query": "dress"})
|
| 84 |
+
assert response.status_code == 503
|
| 85 |
+
|
| 86 |
+
def test_similar_search(self):
|
| 87 |
+
engine = _make_mock_engine()
|
| 88 |
+
engine.get_product_detail.return_value = {
|
| 89 |
+
"sku": "12345", "name": "Test Dress", "brand": "ASOS",
|
| 90 |
+
"price": 29.99, "color_clean": "black", "color_family": "black",
|
| 91 |
+
"category": "Dresses", "gender": "Women",
|
| 92 |
+
"primary_image_url": "https://example.com/img.jpg",
|
| 93 |
+
"image_urls": [], "style_tags": [], "materials": [],
|
| 94 |
+
"sizes_available": [], "product_details": "", "any_in_stock": True,
|
| 95 |
+
}
|
| 96 |
+
engine.search_similar.return_value = pd.DataFrame([{
|
| 97 |
+
"sku": "67890", "name": "Similar Dress", "brand": "ASOS",
|
| 98 |
+
"price": 35.00, "color_clean": "navy", "category": "Dresses",
|
| 99 |
+
"primary_image_url": "https://example.com/img2.jpg",
|
| 100 |
+
"similarity_score": 0.89,
|
| 101 |
+
}])
|
| 102 |
+
app.state.engine = engine
|
| 103 |
+
client = TestClient(app, raise_server_exceptions=False)
|
| 104 |
+
response = client.get("/api/v1/search/similar/12345")
|
| 105 |
+
assert response.status_code == 200
|
| 106 |
+
data = response.json()
|
| 107 |
+
assert data["total"] == 1
|
| 108 |
+
assert data["results"][0]["similarity_score"] == 0.89
|
backend/tests/test_bm25.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from backend.app.engine.bm25 import SimpleBM25
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TestSimpleBM25:
|
| 6 |
+
def test_fit_and_score(self):
|
| 7 |
+
bm25 = SimpleBM25()
|
| 8 |
+
docs = [
|
| 9 |
+
"black leather jacket mens",
|
| 10 |
+
"red floral dress womens",
|
| 11 |
+
"blue denim jeans casual",
|
| 12 |
+
]
|
| 13 |
+
bm25.fit(docs)
|
| 14 |
+
assert bm25.n_docs == 3
|
| 15 |
+
|
| 16 |
+
scores = bm25.score_candidates("black leather", [0, 1, 2])
|
| 17 |
+
assert scores[0] > scores[1]
|
| 18 |
+
assert scores[0] > scores[2]
|
| 19 |
+
|
| 20 |
+
def test_empty_query(self):
|
| 21 |
+
bm25 = SimpleBM25()
|
| 22 |
+
bm25.fit(["black dress", "red shoes"])
|
| 23 |
+
scores = bm25.score_candidates("", [0, 1])
|
| 24 |
+
assert np.all(scores == 0.0)
|
| 25 |
+
|
| 26 |
+
def test_unknown_terms(self):
|
| 27 |
+
bm25 = SimpleBM25()
|
| 28 |
+
bm25.fit(["black dress"])
|
| 29 |
+
scores = bm25.score_candidates("xyznotaword", [0])
|
| 30 |
+
assert scores[0] == 0.0
|
| 31 |
+
|
| 32 |
+
def test_out_of_range_index(self):
|
| 33 |
+
bm25 = SimpleBM25()
|
| 34 |
+
bm25.fit(["black dress"])
|
| 35 |
+
scores = bm25.score_candidates("black", [0, 999])
|
| 36 |
+
assert scores[0] > 0
|
| 37 |
+
assert scores[1] == 0.0
|
| 38 |
+
|
| 39 |
+
def test_exact_match_scores_higher(self):
|
| 40 |
+
bm25 = SimpleBM25()
|
| 41 |
+
bm25.fit(["black leather jacket", "red silk dress", "blue cotton shirt"])
|
| 42 |
+
scores = bm25.score_candidates("black leather jacket", [0, 1, 2])
|
| 43 |
+
assert scores[0] > scores[1]
|
| 44 |
+
assert scores[0] > scores[2]
|
| 45 |
+
|
| 46 |
+
def test_doc_frequency_computed(self):
|
| 47 |
+
bm25 = SimpleBM25()
|
| 48 |
+
bm25.fit(["black dress", "black shoes", "red dress"])
|
| 49 |
+
assert bm25.df["black"] == 2
|
| 50 |
+
assert bm25.df["dress"] == 2
|
| 51 |
+
assert bm25.df["red"] == 1
|
backend/tests/test_nlp.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from backend.app.engine.nlp import MultilingualHandler, SpellCorrector
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TestMultilingualHandler:
|
| 5 |
+
def test_english_detected(self):
|
| 6 |
+
lang = MultilingualHandler.detect_language("black leather jacket")
|
| 7 |
+
assert lang == "en"
|
| 8 |
+
|
| 9 |
+
def test_french_detected(self):
|
| 10 |
+
lang = MultilingualHandler.detect_language("robe noir pour la femme")
|
| 11 |
+
assert lang == "fr"
|
| 12 |
+
|
| 13 |
+
def test_german_detected(self):
|
| 14 |
+
lang = MultilingualHandler.detect_language("das kleid ist für die frau")
|
| 15 |
+
assert lang == "de"
|
| 16 |
+
|
| 17 |
+
def test_spanish_detected(self):
|
| 18 |
+
lang = MultilingualHandler.detect_language("el vestido rojo para la mujer")
|
| 19 |
+
assert lang == "es"
|
| 20 |
+
|
| 21 |
+
def test_translate_french(self):
|
| 22 |
+
translated, lang, was_translated = MultilingualHandler.translate_query("robe noir")
|
| 23 |
+
assert "dress" in translated
|
| 24 |
+
assert "black" in translated
|
| 25 |
+
assert was_translated is True
|
| 26 |
+
|
| 27 |
+
def test_english_passthrough(self):
|
| 28 |
+
translated, lang, was_translated = MultilingualHandler.translate_query("black dress")
|
| 29 |
+
assert translated == "black dress"
|
| 30 |
+
assert was_translated is False
|
| 31 |
+
assert lang == "en"
|
| 32 |
+
|
| 33 |
+
def test_spanish_translate(self):
|
| 34 |
+
translated, lang, was_translated = MultilingualHandler.translate_query("vestido rojo")
|
| 35 |
+
assert "dress" in translated
|
| 36 |
+
assert "red" in translated
|
| 37 |
+
|
| 38 |
+
def test_cjk_detected(self):
|
| 39 |
+
lang = MultilingualHandler.detect_language("黒いドレス")
|
| 40 |
+
assert lang in ("ja", "zh")
|
| 41 |
+
|
| 42 |
+
def test_non_latin_passthrough(self):
|
| 43 |
+
translated, lang, was_translated = MultilingualHandler.translate_query("黒いドレス")
|
| 44 |
+
assert was_translated is False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TestSpellCorrector:
|
| 48 |
+
def test_correct_known_word(self):
|
| 49 |
+
sc = SpellCorrector()
|
| 50 |
+
sc.fit(["black leather jacket dress shoes boots"])
|
| 51 |
+
corrected = sc.correct_word("blak")
|
| 52 |
+
assert corrected == "black"
|
| 53 |
+
|
| 54 |
+
def test_no_correction_needed(self):
|
| 55 |
+
sc = SpellCorrector()
|
| 56 |
+
sc.fit(["black leather jacket"])
|
| 57 |
+
corrected = sc.correct_word("black")
|
| 58 |
+
assert corrected == "black"
|
| 59 |
+
|
| 60 |
+
def test_query_correction(self):
|
| 61 |
+
sc = SpellCorrector()
|
| 62 |
+
sc.fit(["black leather jacket dress shoes boots trainers hoodie"])
|
| 63 |
+
result, was_corrected = sc.correct_query("blak lether jaket")
|
| 64 |
+
assert was_corrected is True
|
| 65 |
+
assert "black" in result
|
| 66 |
+
|
| 67 |
+
def test_short_words_skipped(self):
|
| 68 |
+
sc = SpellCorrector()
|
| 69 |
+
sc.fit(["black leather jacket"])
|
| 70 |
+
corrected = sc.correct_word("an")
|
| 71 |
+
assert corrected == "an"
|
| 72 |
+
|
| 73 |
+
def test_price_tokens_skipped(self):
|
| 74 |
+
sc = SpellCorrector()
|
| 75 |
+
sc.fit(["black dress"])
|
| 76 |
+
result, _ = sc.correct_query("dress £40")
|
| 77 |
+
assert "£40" in result
|
| 78 |
+
|
| 79 |
+
def test_not_ready_passthrough(self):
|
| 80 |
+
sc = SpellCorrector()
|
| 81 |
+
result, was_corrected = sc.correct_query("blak dress")
|
| 82 |
+
assert result == "blak dress"
|
| 83 |
+
assert was_corrected is False
|
backend/tests/test_query_parser.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from backend.app.engine.query_parser import QueryParser, ParsedQuery
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@pytest.fixture
|
| 6 |
+
def parser():
|
| 7 |
+
return QueryParser()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestCategoryParsing:
|
| 11 |
+
def test_dress(self, parser):
|
| 12 |
+
result = parser.parse("black midi dress")
|
| 13 |
+
assert result.category_filter == "Dresses"
|
| 14 |
+
|
| 15 |
+
def test_jacket(self, parser):
|
| 16 |
+
result = parser.parse("leather jacket")
|
| 17 |
+
assert result.category_filter == "Coats & Jackets"
|
| 18 |
+
|
| 19 |
+
def test_jeans(self, parser):
|
| 20 |
+
result = parser.parse("blue jeans")
|
| 21 |
+
assert result.category_filter == "Jeans"
|
| 22 |
+
|
| 23 |
+
def test_hoodie(self, parser):
|
| 24 |
+
result = parser.parse("oversized hoodie")
|
| 25 |
+
assert result.category_filter == "Hoodies & Sweatshirts"
|
| 26 |
+
|
| 27 |
+
def test_trainers(self, parser):
|
| 28 |
+
result = parser.parse("white trainers")
|
| 29 |
+
assert result.category_filter == "Shoes"
|
| 30 |
+
|
| 31 |
+
def test_bag(self, parser):
|
| 32 |
+
result = parser.parse("leather bag")
|
| 33 |
+
assert result.category_filter == "Bags"
|
| 34 |
+
|
| 35 |
+
def test_no_category(self, parser):
|
| 36 |
+
result = parser.parse("something nice")
|
| 37 |
+
assert result.category_filter is None
|
| 38 |
+
|
| 39 |
+
def test_multi_word_category(self, parser):
|
| 40 |
+
result = parser.parse("puffer jacket warm")
|
| 41 |
+
assert result.category_filter == "Coats & Jackets"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestColorParsing:
|
| 45 |
+
def test_basic_color(self, parser):
|
| 46 |
+
result = parser.parse("black dress")
|
| 47 |
+
assert result.color_filter == "black"
|
| 48 |
+
|
| 49 |
+
def test_synonym_color(self, parser):
|
| 50 |
+
result = parser.parse("scarlet top")
|
| 51 |
+
assert result.color_filter == "red"
|
| 52 |
+
|
| 53 |
+
def test_navy(self, parser):
|
| 54 |
+
result = parser.parse("navy blazer")
|
| 55 |
+
assert result.color_filter == "navy"
|
| 56 |
+
|
| 57 |
+
def test_multi_word_color(self, parser):
|
| 58 |
+
result = parser.parse("sky blue dress")
|
| 59 |
+
assert result.color_filter == "blue"
|
| 60 |
+
|
| 61 |
+
def test_no_color(self, parser):
|
| 62 |
+
result = parser.parse("casual hoodie")
|
| 63 |
+
assert result.color_filter is None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class TestPriceParsing:
|
| 67 |
+
def test_under(self, parser):
|
| 68 |
+
result = parser.parse("dress under £40")
|
| 69 |
+
assert result.price_max == 40.0
|
| 70 |
+
assert result.price_min is None
|
| 71 |
+
|
| 72 |
+
def test_over(self, parser):
|
| 73 |
+
result = parser.parse("jacket over £100")
|
| 74 |
+
assert result.price_min == 100.0
|
| 75 |
+
|
| 76 |
+
def test_range(self, parser):
|
| 77 |
+
result = parser.parse("shoes £20-£50")
|
| 78 |
+
assert result.price_min == 20.0
|
| 79 |
+
assert result.price_max == 50.0
|
| 80 |
+
|
| 81 |
+
def test_budget(self, parser):
|
| 82 |
+
result = parser.parse("budget dress")
|
| 83 |
+
assert result.price_max == 30.0
|
| 84 |
+
|
| 85 |
+
def test_luxury(self, parser):
|
| 86 |
+
result = parser.parse("luxury jacket")
|
| 87 |
+
assert result.price_min == 100.0
|
| 88 |
+
|
| 89 |
+
def test_no_price(self, parser):
|
| 90 |
+
result = parser.parse("blue dress")
|
| 91 |
+
assert result.price_min is None
|
| 92 |
+
assert result.price_max is None
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TestGenderParsing:
|
| 96 |
+
def test_mens(self, parser):
|
| 97 |
+
result = parser.parse("mens hoodie")
|
| 98 |
+
assert result.gender_filter == "Men"
|
| 99 |
+
|
| 100 |
+
def test_womens(self, parser):
|
| 101 |
+
result = parser.parse("for women dress")
|
| 102 |
+
assert result.gender_filter == "Women"
|
| 103 |
+
|
| 104 |
+
def test_ladies(self, parser):
|
| 105 |
+
result = parser.parse("ladies dress elegant")
|
| 106 |
+
assert result.gender_filter == "Women"
|
| 107 |
+
|
| 108 |
+
def test_no_gender(self, parser):
|
| 109 |
+
result = parser.parse("casual hoodie")
|
| 110 |
+
assert result.gender_filter is None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TestMaterialParsing:
|
| 114 |
+
def test_silk(self, parser):
|
| 115 |
+
result = parser.parse("silk midi dress")
|
| 116 |
+
assert result.material_filter == "silk"
|
| 117 |
+
|
| 118 |
+
def test_leather(self, parser):
|
| 119 |
+
result = parser.parse("leather jacket")
|
| 120 |
+
assert result.material_filter == "leather"
|
| 121 |
+
|
| 122 |
+
def test_denim(self, parser):
|
| 123 |
+
result = parser.parse("denim jacket")
|
| 124 |
+
assert result.material_filter == "denim"
|
| 125 |
+
|
| 126 |
+
def test_no_material(self, parser):
|
| 127 |
+
result = parser.parse("black dress")
|
| 128 |
+
assert result.material_filter is None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class TestSizeParsing:
|
| 132 |
+
def test_named_size(self, parser):
|
| 133 |
+
result = parser.parse("size small hoodie")
|
| 134 |
+
assert result.size_filter == "S"
|
| 135 |
+
|
| 136 |
+
def test_numeric_size(self, parser):
|
| 137 |
+
result = parser.parse("size 10 dress")
|
| 138 |
+
assert result.size_filter == "10"
|
| 139 |
+
|
| 140 |
+
def test_xl(self, parser):
|
| 141 |
+
result = parser.parse("XL casual shirt")
|
| 142 |
+
assert result.size_filter == "XL"
|
| 143 |
+
|
| 144 |
+
def test_no_size(self, parser):
|
| 145 |
+
result = parser.parse("black dress")
|
| 146 |
+
assert result.size_filter is None
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class TestExclusions:
|
| 150 |
+
def test_not(self, parser):
|
| 151 |
+
result = parser.parse("black dress not floral")
|
| 152 |
+
assert "floral" in result.exclusions
|
| 153 |
+
|
| 154 |
+
def test_without(self, parser):
|
| 155 |
+
result = parser.parse("jacket without leather")
|
| 156 |
+
assert "leather" in result.exclusions
|
| 157 |
+
|
| 158 |
+
def test_no_keyword(self, parser):
|
| 159 |
+
result = parser.parse("summer top no black")
|
| 160 |
+
assert "black" in result.exclusions
|
| 161 |
+
|
| 162 |
+
def test_material_exclusion_conflict(self, parser):
|
| 163 |
+
result = parser.parse("jacket not cotton")
|
| 164 |
+
assert result.material_filter is None
|
| 165 |
+
assert "cotton" in result.exclusions
|
| 166 |
+
|
| 167 |
+
def test_no_exclusions(self, parser):
|
| 168 |
+
result = parser.parse("black dress")
|
| 169 |
+
assert result.exclusions == []
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class TestStyleTags:
|
| 173 |
+
def test_single_tag(self, parser):
|
| 174 |
+
result = parser.parse("casual hoodie")
|
| 175 |
+
assert "casual" in result.style_tags
|
| 176 |
+
|
| 177 |
+
def test_multiple_tags(self, parser):
|
| 178 |
+
result = parser.parse("vintage boho dress")
|
| 179 |
+
assert "vintage" in result.style_tags
|
| 180 |
+
assert "boho" in result.style_tags
|
| 181 |
+
|
| 182 |
+
def test_no_tags(self, parser):
|
| 183 |
+
result = parser.parse("blue jeans")
|
| 184 |
+
assert "casual" not in result.style_tags
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class TestVibeText:
|
| 188 |
+
def test_preserves_raw_query(self, parser):
|
| 189 |
+
result = parser.parse("black dress under £40")
|
| 190 |
+
assert result.raw_query == "black dress under £40"
|
| 191 |
+
|
| 192 |
+
def test_vibe_not_empty(self, parser):
|
| 193 |
+
result = parser.parse("black dress")
|
| 194 |
+
assert len(result.vibe_text) > 0
|