Spaces:
Sleeping
Sleeping
File size: 6,219 Bytes
f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 16b9e90 f69e608 16b9e90 f69e608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | """
Hybrid retrieval engine for RetailMind.
Combines dense semantic search (SentenceTransformers) with structured
metadata filtering (price range, category, tags) so that queries like
"eco-friendly bag under $30" actually return relevant, correctly-priced items.
"""
from __future__ import annotations
import logging
import re
from typing import Any
import numpy as np
from modules.shared import get_embedding_model
logger = logging.getLogger(__name__)
class HybridRetriever:
"""Two-stage retriever: metadata pre-filter β semantic re-rank."""
def __init__(self, catalog: list[dict]) -> None:
self.catalog = catalog
self.model = get_embedding_model()
# Build rich embedding texts that capture all searchable facets
texts = [
(
f"{p['title']}. {p['desc']} "
f"Category: {p['category']}. "
f"Materials: {p.get('materials', 'N/A')}. "
f"Tags: {', '.join(p.get('tags', []))}."
)
for p in catalog
]
logger.info("Encoding %d productsβ¦", len(catalog))
self.embeddings = self.model.encode(texts, show_progress_bar=False)
self._norms = np.linalg.norm(self.embeddings, axis=1)
logger.info("Catalog indexed successfully.")
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def search(
self,
query: str,
top_k: int = 4,
category_filter: str | None = None,
query_emb=None,
) -> list[dict[str, Any]]:
"""
Retrieve top-k products for *query*.
Pipeline:
1. Extract price ceiling from natural language (e.g. "under $50").
2. Pre-filter catalog by price / category if applicable.
3. Rank remaining items by cosine similarity.
4. Return top-k with scores.
Pass *query_emb* to skip re-encoding when the caller already has it.
"""
price_cap = self._extract_price_cap(query)
cat_hint = category_filter or self._extract_category_hint(query)
# Stage 1 β metadata pre-filter
candidate_indices = self._prefilter(price_cap, cat_hint)
# Stage 2 β semantic ranking over candidates
if query_emb is None:
query_emb = self.model.encode([query], show_progress_bar=False)[0]
query_norm = np.linalg.norm(query_emb)
if len(candidate_indices) == 0:
# Fallback: rank entire catalog if filters yield nothing
candidate_indices = list(range(len(self.catalog)))
cand_embs = self.embeddings[candidate_indices]
cand_norms = self._norms[candidate_indices]
scores = np.dot(cand_embs, query_emb) / (cand_norms * query_norm + 1e-10)
top_local = np.argsort(scores)[::-1][:top_k]
results = []
for li in top_local:
global_idx = candidate_indices[li]
results.append({
"product": self.catalog[global_idx],
"score": float(scores[li]),
})
logger.debug(
"Query: %r | price_cap=%s | cat=%s | candidates=%d | top=%d",
query, price_cap, cat_hint, len(candidate_indices), len(results),
)
return results
# ββ Private helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββ
@staticmethod
def _extract_price_cap(query: str) -> float | None:
"""Parse 'under $50', 'below 30', 'less than $25', 'budget' etc."""
patterns = [
r"under\s*\$?\s*(\d+(?:\.\d+)?)",
r"below\s*\$?\s*(\d+(?:\.\d+)?)",
r"less\s+than\s*\$?\s*(\d+(?:\.\d+)?)",
r"cheaper\s+than\s*\$?\s*(\d+(?:\.\d+)?)",
r"max(?:imum)?\s*\$?\s*(\d+(?:\.\d+)?)",
r"\$(\d+(?:\.\d+)?)\s*(?:or\s+less|max|budget)",
r"only\s+have\s*\$?\s*(\d+)",
r"(?:spend|budget)\s*(?:of|is)?\s*\$?\s*(\d+)",
]
for pat in patterns:
m = re.search(pat, query, re.IGNORECASE)
if m:
return float(m.group(1))
# Heuristic: very budget-oriented queries
budget_keywords = {"cheapest", "budget", "affordable", "inexpensive", "bargain"}
if any(kw in query.lower() for kw in budget_keywords):
return 50.0 # Reasonable default budget ceiling
return None
def _extract_category_hint(self, query: str) -> str | None:
"""Map common query terms to catalog categories."""
category_keywords: dict[str, list[str]] = {
"winter": ["winter", "cold", "snow", "warm", "insulated", "thermal"],
"summer": ["summer", "beach", "hot", "heat", "sun", "warm weather"],
"eco-friendly": ["eco", "sustainable", "organic", "recycled", "green", "environment", "plant-based"],
"sports": ["sport", "fitness", "running", "gym", "training", "workout", "athletic"],
"electronics": ["tech", "electronic", "gadget", "headphone", "speaker", "charger", "smart"],
"premium": ["luxury", "premium", "high-end", "designer", "artisan"],
"home": ["home", "kitchen", "desk", "candle", "bath", "decor"],
"health": ["health", "beauty", "sunscreen", "lipstick", "serum", "balm", "skincare", "makeup"],
}
for cat, keywords in category_keywords.items():
pattern = r'\b(?:' + '|'.join(keywords) + r')\b'
if re.search(pattern, query, re.IGNORECASE):
return cat
return None
def _prefilter(
self, price_cap: float | None, category: str | None
) -> list[int]:
"""Return indices of products matching hard constraints."""
indices = []
for i, p in enumerate(self.catalog):
if price_cap is not None and p["price"] > price_cap:
continue
if category is not None and p["category"] != category:
continue
indices.append(i)
return indices
|