File size: 6,219 Bytes
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d624b44
 
f69e608
 
 
 
 
 
 
 
 
d624b44
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d624b44
f69e608
 
 
 
 
 
 
 
 
d624b44
 
f69e608
 
 
 
 
 
 
 
d624b44
 
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16b9e90
f69e608
 
16b9e90
 
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Hybrid retrieval engine for RetailMind.

Combines dense semantic search (SentenceTransformers) with structured
metadata filtering (price range, category, tags) so that queries like
"eco-friendly bag under $30" actually return relevant, correctly-priced items.
"""

from __future__ import annotations

import logging
import re
from typing import Any

import numpy as np

from modules.shared import get_embedding_model

logger = logging.getLogger(__name__)


class HybridRetriever:
    """Two-stage retriever: metadata pre-filter β†’ semantic re-rank."""

    def __init__(self, catalog: list[dict]) -> None:
        self.catalog = catalog
        self.model = get_embedding_model()

        # Build rich embedding texts that capture all searchable facets
        texts = [
            (
                f"{p['title']}. {p['desc']} "
                f"Category: {p['category']}. "
                f"Materials: {p.get('materials', 'N/A')}. "
                f"Tags: {', '.join(p.get('tags', []))}."
            )
            for p in catalog
        ]
        logger.info("Encoding %d products…", len(catalog))
        self.embeddings = self.model.encode(texts, show_progress_bar=False)
        self._norms = np.linalg.norm(self.embeddings, axis=1)
        logger.info("Catalog indexed successfully.")

    # ── Public API ──────────────────────────────────────────────────────────

    def search(
        self,
        query: str,
        top_k: int = 4,
        category_filter: str | None = None,
        query_emb=None,
    ) -> list[dict[str, Any]]:
        """
        Retrieve top-k products for *query*.

        Pipeline:
        1. Extract price ceiling from natural language (e.g. "under $50").
        2. Pre-filter catalog by price / category if applicable.
        3. Rank remaining items by cosine similarity.
        4. Return top-k with scores.

        Pass *query_emb* to skip re-encoding when the caller already has it.
        """
        price_cap = self._extract_price_cap(query)
        cat_hint = category_filter or self._extract_category_hint(query)

        # Stage 1 β€” metadata pre-filter
        candidate_indices = self._prefilter(price_cap, cat_hint)

        # Stage 2 β€” semantic ranking over candidates
        if query_emb is None:
            query_emb = self.model.encode([query], show_progress_bar=False)[0]
        query_norm = np.linalg.norm(query_emb)

        if len(candidate_indices) == 0:
            # Fallback: rank entire catalog if filters yield nothing
            candidate_indices = list(range(len(self.catalog)))

        cand_embs = self.embeddings[candidate_indices]
        cand_norms = self._norms[candidate_indices]

        scores = np.dot(cand_embs, query_emb) / (cand_norms * query_norm + 1e-10)
        top_local = np.argsort(scores)[::-1][:top_k]

        results = []
        for li in top_local:
            global_idx = candidate_indices[li]
            results.append({
                "product": self.catalog[global_idx],
                "score": float(scores[li]),
            })

        logger.debug(
            "Query: %r | price_cap=%s | cat=%s | candidates=%d | top=%d",
            query, price_cap, cat_hint, len(candidate_indices), len(results),
        )
        return results

    # ── Private helpers ─────────────────────────────────────────────────────

    @staticmethod
    def _extract_price_cap(query: str) -> float | None:
        """Parse 'under $50', 'below 30', 'less than $25', 'budget' etc."""
        patterns = [
            r"under\s*\$?\s*(\d+(?:\.\d+)?)",
            r"below\s*\$?\s*(\d+(?:\.\d+)?)",
            r"less\s+than\s*\$?\s*(\d+(?:\.\d+)?)",
            r"cheaper\s+than\s*\$?\s*(\d+(?:\.\d+)?)",
            r"max(?:imum)?\s*\$?\s*(\d+(?:\.\d+)?)",
            r"\$(\d+(?:\.\d+)?)\s*(?:or\s+less|max|budget)",
            r"only\s+have\s*\$?\s*(\d+)",
            r"(?:spend|budget)\s*(?:of|is)?\s*\$?\s*(\d+)",
        ]
        for pat in patterns:
            m = re.search(pat, query, re.IGNORECASE)
            if m:
                return float(m.group(1))

        # Heuristic: very budget-oriented queries
        budget_keywords = {"cheapest", "budget", "affordable", "inexpensive", "bargain"}
        if any(kw in query.lower() for kw in budget_keywords):
            return 50.0  # Reasonable default budget ceiling

        return None

    def _extract_category_hint(self, query: str) -> str | None:
        """Map common query terms to catalog categories."""
        category_keywords: dict[str, list[str]] = {
            "winter": ["winter", "cold", "snow", "warm", "insulated", "thermal"],
            "summer": ["summer", "beach", "hot", "heat", "sun", "warm weather"],
            "eco-friendly": ["eco", "sustainable", "organic", "recycled", "green", "environment", "plant-based"],
            "sports": ["sport", "fitness", "running", "gym", "training", "workout", "athletic"],
            "electronics": ["tech", "electronic", "gadget", "headphone", "speaker", "charger", "smart"],
            "premium": ["luxury", "premium", "high-end", "designer", "artisan"],
            "home": ["home", "kitchen", "desk", "candle", "bath", "decor"],
            "health": ["health", "beauty", "sunscreen", "lipstick", "serum", "balm", "skincare", "makeup"],
        }
        for cat, keywords in category_keywords.items():
            pattern = r'\b(?:' + '|'.join(keywords) + r')\b'
            if re.search(pattern, query, re.IGNORECASE):
                return cat
        return None

    def _prefilter(
        self, price_cap: float | None, category: str | None
    ) -> list[int]:
        """Return indices of products matching hard constraints."""
        indices = []
        for i, p in enumerate(self.catalog):
            if price_cap is not None and p["price"] > price_cap:
                continue
            if category is not None and p["category"] != category:
                continue
            indices.append(i)
        return indices