Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import numpy as np | |
| import gradio as gr | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| # ----------------------- | |
| # Settings | |
| # ----------------------- | |
| DATASET_NAME = "LukeSajkowski/products_ecommerce_embeddings" | |
| SPLIT = "train" | |
| # A strong, common text-embedding model (768-d) | |
| MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" | |
| # Cache files (saved inside the Space container) | |
| CACHE_DIR = "cache" | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| EMB_PATH = os.path.join(CACHE_DIR, "item_embeddings.npy") | |
| TXT_PATH = os.path.join(CACHE_DIR, "item_texts.json") | |
| def _l2_normalize(x: np.ndarray, axis: int = 1, eps: float = 1e-12) -> np.ndarray: | |
| norm = np.linalg.norm(x, axis=axis, keepdims=True) | |
| return x / np.maximum(norm, eps) | |
| def _pick_text_column(df_cols): | |
| candidates = [ | |
| "title", "product_title", "name", "product_name", "text", "description", | |
| "caption", "item_name" | |
| ] | |
| for c in candidates: | |
| if c in df_cols: | |
| return c | |
| return None | |
| # ----------------------- | |
| # Load dataset and build embeddings (once) | |
| # ----------------------- | |
| ds = load_dataset(DATASET_NAME, split=SPLIT) | |
| df = ds.to_pandas() | |
| text_col = _pick_text_column(df.columns) | |
| if text_col is None: | |
| raise ValueError( | |
| f"Could not find a suitable text column in dataset. Columns are: {list(df.columns)}. " | |
| f"Rename/choose a column like 'title' or 'name'." | |
| ) | |
| titles = df[text_col].fillna("").astype(str).tolist() | |
| model = SentenceTransformer(MODEL_NAME) | |
| if os.path.exists(EMB_PATH) and os.path.exists(TXT_PATH): | |
| item_embeddings = np.load(EMB_PATH) | |
| with open(TXT_PATH, "r", encoding="utf-8") as f: | |
| cached_titles = json.load(f) | |
| if len(cached_titles) != len(titles) or item_embeddings.shape[0] != len(titles): | |
| # rebuild cleanly | |
| for p in (EMB_PATH, TXT_PATH): | |
| try: | |
| os.remove(p) | |
| except FileNotFoundError: | |
| pass | |
| raise RuntimeError("Cache mismatch detected. Please restart Space to rebuild embeddings.") | |
| titles = cached_titles | |
| else: | |
| embs = model.encode( | |
| titles, | |
| batch_size=64, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| normalize_embeddings=False, | |
| ).astype(np.float32) | |
| item_embeddings = _l2_normalize(embs, axis=1).astype(np.float32) | |
| np.save(EMB_PATH, item_embeddings) | |
| with open(TXT_PATH, "w", encoding="utf-8") as f: | |
| json.dump(titles, f, ensure_ascii=False) | |
| EMB_DIM = int(item_embeddings.shape[1]) | |
| def search(query: str, top_k: int): | |
| query = (query or "").strip() | |
| if not query: | |
| return [] | |
| q_emb = model.encode([query], convert_to_numpy=True).astype(np.float32) | |
| q_emb = _l2_normalize(q_emb, axis=1)[0] # shape (dim,) | |
| sims = item_embeddings @ q_emb # cosine similarity (because normalized) | |
| k = int(top_k) | |
| k = max(1, min(k, 20)) | |
| idx = np.argpartition(-sims, kth=k - 1)[:k] | |
| idx = idx[np.argsort(-sims[idx])] | |
| rows = [] | |
| for i in idx: | |
| rows.append([float(sims[i]), titles[int(i)][:200]]) | |
| return rows | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| f"# Ecommerce Text Search (computed embeddings)\n" | |
| f"Dataset: `{DATASET_NAME}` | Split: `{SPLIT}` | Text col: `{text_col}` | " | |
| f"Model: `{MODEL_NAME}` | Dim: `{EMB_DIM}`" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| query_in = gr.Textbox( | |
| label="Search query", | |
| placeholder="e.g., xerox transfer roller, epson document feeder" | |
| ) | |
| topk_in = gr.Slider(1, 20, value=5, step=1, label="Top K") | |
| btn = gr.Button("Submit") | |
| with gr.Column(): | |
| out = gr.Dataframe( | |
| headers=["similarity", "title"], | |
| datatype=["number", "str"], | |
| interactive=False, | |
| ) | |
| # video at the bottom (inside same Blocks) | |
| gr.Markdown("### Video presentation") | |
| gr.HTML(""" | |
| <div style="display:flex; justify-content:center; margin: 10px 0 25px 0;"> | |
| <iframe width="720" height="405" | |
| src="https://www.youtube.com/embed/oGHEj-QKCCI" | |
| title="Project presentation video" | |
| frameborder="0" | |
| allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" | |
| allowfullscreen> | |
| </iframe> | |
| </div> | |
| """) | |
| btn.click(fn=search, inputs=[query_in, topk_in], outputs=out) | |
| demo.launch() | |