rishadaz commited on
Commit
e06d5a0
·
verified ·
1 Parent(s): 270ea62

Create utils/semantic.py

Browse files
Files changed (1) hide show
  1. utils/semantic.py +295 -0
utils/semantic.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ semantic_search.py
3
+ ------------------
4
+ Semantic search over an Amazon product catalogue using FAISS + HuggingFace embeddings.
5
+
6
+ Expected inputs
7
+ ---------------
8
+ - metadata_dataset : datasets.Dataset — one row per product (raw_metadata["full"])
9
+ - reviews_dataset : datasets.Dataset — passed to get_best_reviews(reviews, asin, k)
10
+
11
+ Typical usage
12
+ -------------
13
+ docs = build_documents(raw_metadata["full"], raw_reviews, n=100)
14
+ store = build_vector_store(docs)
15
+ results = semantic_search("noise cancelling headphones", store, k=5)
16
+ """
17
+
18
+ import logging
19
+ from typing import Any
20
+ import torch
21
+ import json, os, sys
22
+ from pathlib import Path
23
+
24
+ import faiss
25
+ from datasets import Dataset
26
+ from langchain_community.docstore.in_memory import InMemoryDocstore
27
+ from langchain_community.vectorstores import FAISS
28
+ from langchain_core.documents import Document
29
+ from langchain_huggingface import HuggingFaceEmbeddings
30
+ ROOT_FOLDER = Path(__file__).resolve().parent.parent
31
+
32
+ sys.path.append(str(ROOT_FOLDER))
33
+ from utils.eda_helpers import get_best_reviews
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Constants
39
+ # ---------------------------------------------------------------------------
40
+
41
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
42
+ DEFAULT_TOP_REVIEWS = 5
43
+ DEFAULT_TOP_K = 5
44
+
45
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
46
+ EMBEDDINGS = HuggingFaceEmbeddings(
47
+ model_name=DEFAULT_EMBEDDING_MODEL,
48
+ model_kwargs={
49
+ "device": DEVICE,
50
+ "model_kwargs": {"torch_dtype": torch.float16},
51
+ },
52
+ encode_kwargs={
53
+ "batch_size": 128 if DEVICE == 'cpu' else 512,
54
+ "normalize_embeddings": True,
55
+ },
56
+ )
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Document construction
60
+ # ---------------------------------------------------------------------------
61
+
62
+ def _format_review(review) -> str:
63
+ """Return a concise single-line string for one review."""
64
+ rating = review.get("rating", "?")
65
+ title = (review.get("title") or "").strip()
66
+ text = (review.get("text") or "").strip()
67
+ return f"[{rating}★] {title} — {text}"
68
+
69
+
70
+ def _build_reviews_block(
71
+ reviews: Dataset,
72
+ parent_asin: str,
73
+ k: int = DEFAULT_TOP_REVIEWS,
74
+ ) -> str:
75
+ """
76
+ Fetch top-k reviews for *parent_asin* and return a formatted text block.
77
+ Returns an empty string when no reviews are found.
78
+ """
79
+ total, product_reviews = get_best_reviews(reviews, parent_asin, k)
80
+ if not product_reviews:
81
+ return 0, ""
82
+ lines = "\n ".join(_format_review(r) for r in product_reviews)
83
+ return total, f"{lines}"
84
+
85
+
86
+ def _build_page_content(product, review_block: str) -> str:
87
+ """Assemble the text that will be embedded. Empty sections are omitted."""
88
+ title = (product.get("title") or "").strip()
89
+ main_category = (product.get("main_category") or "").strip()
90
+ categories = main_category +" >> " + " > ".join(product.get("categories") or [])
91
+ features = "\n ".join(product.get("features") or [])
92
+ description = " ".join(product.get("description") or [])
93
+ details = (product.get("details") or "").strip()
94
+
95
+ parts = [f"Product: {title}"]
96
+ if categories:
97
+ parts.append(f"Category Path: {categories}")
98
+ if features:
99
+ parts.append(f"Features:\n {features}")
100
+ if description:
101
+ parts.append(f"Description:\n {description}")
102
+ if review_block:
103
+ parts.append(f"Top Reviews:\n {review_block}")
104
+ if details:
105
+ parts.append(f"Details:\n {details}")
106
+
107
+ return "\n".join(parts)
108
+
109
+
110
+ def create_document(product, reviews: Dataset) -> Document | None:
111
+ """
112
+ Build a :class:`~langchain_core.documents.Document` from one product row.
113
+
114
+ Args:
115
+ product: A single row from a HuggingFace metadata Dataset (dict-like).
116
+ reviews: The full reviews Dataset, forwarded to ``get_best_reviews``.
117
+
118
+ Returns:
119
+ A Document, or ``None`` if the row has no ``parent_asin``.
120
+
121
+ Notes:
122
+ *page_content* contains only the text that influences embeddings.
123
+ *metadata* stores structured scalars used for filtering and display
124
+ after retrieval — values are kept flat and JSON-serialisable so FAISS
125
+ filter expressions work correctly.
126
+ """
127
+ parent_asin = product.get("parent_asin")
128
+ if not parent_asin:
129
+ logger.warning("Skipping product with missing parent_asin: %s", product.get("title"))
130
+ return None
131
+
132
+ tot, review_block = _build_reviews_block(reviews, parent_asin)
133
+ page_content = _build_page_content(product, review_block)
134
+
135
+ metadata = {
136
+ # --- identifiers ---
137
+ "parent_asin": parent_asin,
138
+ # --- numeric (filterable / rankable) ---
139
+ "price": product.get("price"),
140
+ "average_rating": product.get("average_rating"),
141
+ "rating_number": product.get("rating_number"),
142
+ # --- categorical (filterable) ---
143
+ "main_category": product.get("main_category", ""),
144
+ "categories": product.get("categories") or [],
145
+ # --- free-form (display only; coerce to str for FAISS compatibility) ---
146
+ "details": str(product.get("details") or ""),
147
+ "total_reviews": tot
148
+ }
149
+
150
+ return Document(page_content=page_content, metadata=metadata)
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Vector store
155
+ # ---------------------------------------------------------------------------
156
+
157
+ # Case when we want to create embeddings at once
158
+ def build_vector_store(
159
+ docs: list[Document],
160
+ existing_store: FAISS | None = None,
161
+ ) -> FAISS:
162
+ """
163
+ Embed *docs* and return (or update) a FAISS vector store.
164
+
165
+ If ``existing_store`` is provided, documents are added to it.
166
+ Otherwise, a new FAISS store is created.
167
+
168
+ Document IDs are set to ``parent_asin``.
169
+ """
170
+ if not docs:
171
+ raise ValueError("Cannot build a vector store from an empty document list.")
172
+
173
+ logger.info("Embedding on %s", DEVICE)
174
+
175
+ # --- Create new store if needed ---
176
+ if existing_store is None:
177
+ dim = len(EMBEDDINGS.embed_query("probe"))
178
+ index = faiss.IndexFlatL2(dim)
179
+
180
+ vector_store = FAISS(
181
+ embedding_function=EMBEDDINGS,
182
+ index=index,
183
+ docstore=InMemoryDocstore(),
184
+ index_to_docstore_id={},
185
+ )
186
+ else:
187
+ vector_store = existing_store
188
+
189
+ # --- Add documents ---
190
+ uuids = [doc.metadata["parent_asin"] for doc in docs]
191
+ vector_store.add_documents(documents=docs, ids=uuids)
192
+
193
+ logger.info("Indexed %d documents into FAISS.", len(docs))
194
+ return vector_store
195
+
196
+ # Running the above function in batches and saving
197
+ def build_and_save_vector_store(
198
+ metadata_dataset: Dataset,
199
+ reviews: Dataset,
200
+ save_path: str,
201
+ batch_size: int = 500,
202
+ ) -> FAISS:
203
+
204
+ # --- Resume / initialize ---
205
+ if os.path.exists(os.path.join(save_path, "index.faiss")):
206
+ vector_store = FAISS.load_local(
207
+ save_path, EMBEDDINGS, allow_dangerous_deserialization=True
208
+ )
209
+ already_indexed = set(vector_store.index_to_docstore_id.values())
210
+ print(f"Resuming — {len(already_indexed)} docs already indexed.")
211
+ else:
212
+ os.makedirs(save_path, exist_ok=True)
213
+ vector_store = None # let helper create it
214
+ already_indexed = set()
215
+
216
+ progress_file = os.path.join(save_path, "progress.json")
217
+
218
+ # --- Resume progress ---
219
+ if os.path.exists(progress_file):
220
+ with open(progress_file) as f:
221
+ resume_start = json.load(f).get("next_start", 0)
222
+ print(f"Resuming from row {resume_start}.")
223
+ else:
224
+ resume_start = 0
225
+
226
+ total = len(metadata_dataset)
227
+
228
+ for start in range(resume_start, total, batch_size):
229
+ batch = metadata_dataset.select(range(start, min(start + batch_size, total)))
230
+
231
+ docs = []
232
+ for row in batch:
233
+ doc = create_document(row, reviews)
234
+ if doc is not None and doc.metadata["parent_asin"] not in already_indexed:
235
+ docs.append(doc)
236
+
237
+ if docs:
238
+ vector_store = build_vector_store(
239
+ docs=docs,
240
+ existing_store=vector_store,
241
+ )
242
+ already_indexed.update(doc.metadata["parent_asin"] for doc in docs)
243
+
244
+ # --- Save after each batch ---
245
+ vector_store.save_local(save_path)
246
+ with open(progress_file, "w") as f:
247
+ json.dump({"next_start": min(start + batch_size, total)}, f)
248
+
249
+ print(f"Indexed {min(start + batch_size, total)} / {total} rows")
250
+
251
+ if os.path.exists(progress_file):
252
+ os.remove(progress_file)
253
+
254
+ return vector_store
255
+
256
+ # ---------------------------------------------------------------------------
257
+ # Search
258
+ # ---------------------------------------------------------------------------
259
+
260
+ def semantic_search(
261
+ query: str,
262
+ vector_store: FAISS,
263
+ k: int = DEFAULT_TOP_K,
264
+ filter = None,
265
+ ) -> list[Document]:
266
+ """
267
+ Run a semantic similarity search against a pre-built *vector_store*.
268
+
269
+ Args:
270
+ query: Natural-language search query.
271
+ vector_store: A FAISS store built with :func:`build_vector_store`.
272
+ k: Number of results to return.
273
+ filter: Optional metadata filter dict, e.g.
274
+ ``{"main_category": "Electronics"}``.
275
+
276
+ Returns:
277
+ Ordered list of the *k* most relevant Documents.
278
+ """
279
+ results = vector_store.similarity_search_with_score(query, k=k, filter=filter)
280
+ logger.info("'%s' -> %d results", query, len(results))
281
+ return results
282
+
283
+ # ---------------------------------------------------------------------------
284
+ # Read existing vector store
285
+ # ---------------------------------------------------------------------------
286
+
287
+ def load_vector_store(
288
+ load_path: str,
289
+ ) -> FAISS:
290
+
291
+ return FAISS.load_local(
292
+ load_path,
293
+ embeddings=EMBEDDINGS,
294
+ allow_dangerous_deserialization=True,
295
+ )