maorsoul's picture
Update app.py
e405ca3 verified
import os
import json
import numpy as np
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
# -----------------------
# Settings
# -----------------------
DATASET_NAME = "LukeSajkowski/products_ecommerce_embeddings"
SPLIT = "train"
# A strong, common text-embedding model (768-d)
MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
# Cache files (saved inside the Space container)
CACHE_DIR = "cache"
os.makedirs(CACHE_DIR, exist_ok=True)
EMB_PATH = os.path.join(CACHE_DIR, "item_embeddings.npy")
TXT_PATH = os.path.join(CACHE_DIR, "item_texts.json")
def _l2_normalize(x: np.ndarray, axis: int = 1, eps: float = 1e-12) -> np.ndarray:
norm = np.linalg.norm(x, axis=axis, keepdims=True)
return x / np.maximum(norm, eps)
def _pick_text_column(df_cols):
candidates = [
"title", "product_title", "name", "product_name", "text", "description",
"caption", "item_name"
]
for c in candidates:
if c in df_cols:
return c
return None
# -----------------------
# Load dataset and build embeddings (once)
# -----------------------
ds = load_dataset(DATASET_NAME, split=SPLIT)
df = ds.to_pandas()
text_col = _pick_text_column(df.columns)
if text_col is None:
raise ValueError(
f"Could not find a suitable text column in dataset. Columns are: {list(df.columns)}. "
f"Rename/choose a column like 'title' or 'name'."
)
titles = df[text_col].fillna("").astype(str).tolist()
model = SentenceTransformer(MODEL_NAME)
if os.path.exists(EMB_PATH) and os.path.exists(TXT_PATH):
item_embeddings = np.load(EMB_PATH)
with open(TXT_PATH, "r", encoding="utf-8") as f:
cached_titles = json.load(f)
if len(cached_titles) != len(titles) or item_embeddings.shape[0] != len(titles):
# rebuild cleanly
for p in (EMB_PATH, TXT_PATH):
try:
os.remove(p)
except FileNotFoundError:
pass
raise RuntimeError("Cache mismatch detected. Please restart Space to rebuild embeddings.")
titles = cached_titles
else:
embs = model.encode(
titles,
batch_size=64,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=False,
).astype(np.float32)
item_embeddings = _l2_normalize(embs, axis=1).astype(np.float32)
np.save(EMB_PATH, item_embeddings)
with open(TXT_PATH, "w", encoding="utf-8") as f:
json.dump(titles, f, ensure_ascii=False)
EMB_DIM = int(item_embeddings.shape[1])
def search(query: str, top_k: int):
query = (query or "").strip()
if not query:
return []
q_emb = model.encode([query], convert_to_numpy=True).astype(np.float32)
q_emb = _l2_normalize(q_emb, axis=1)[0] # shape (dim,)
sims = item_embeddings @ q_emb # cosine similarity (because normalized)
k = int(top_k)
k = max(1, min(k, 20))
idx = np.argpartition(-sims, kth=k - 1)[:k]
idx = idx[np.argsort(-sims[idx])]
rows = []
for i in idx:
rows.append([float(sims[i]), titles[int(i)][:200]])
return rows
with gr.Blocks() as demo:
gr.Markdown(
f"# Ecommerce Text Search (computed embeddings)\n"
f"Dataset: `{DATASET_NAME}` | Split: `{SPLIT}` | Text col: `{text_col}` | "
f"Model: `{MODEL_NAME}` | Dim: `{EMB_DIM}`"
)
with gr.Row():
with gr.Column():
query_in = gr.Textbox(
label="Search query",
placeholder="e.g., xerox transfer roller, epson document feeder"
)
topk_in = gr.Slider(1, 20, value=5, step=1, label="Top K")
btn = gr.Button("Submit")
with gr.Column():
out = gr.Dataframe(
headers=["similarity", "title"],
datatype=["number", "str"],
interactive=False,
)
# video at the bottom (inside same Blocks)
gr.Markdown("### Video presentation")
gr.HTML("""
<div style="display:flex; justify-content:center; margin: 10px 0 25px 0;">
<iframe width="720" height="405"
src="https://www.youtube.com/embed/oGHEj-QKCCI"
title="Project presentation video"
frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen>
</iframe>
</div>
""")
btn.click(fn=search, inputs=[query_in, topk_in], outputs=out)
demo.launch()