Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import pickle | |
| import numpy as np | |
| import faiss | |
| import time | |
| import requests | |
| import matplotlib.pyplot as plt | |
| from concurrent.futures import ThreadPoolExecutor | |
| from huggingface_hub import snapshot_download | |
| from optimum.onnxruntime import ORTModelForFeatureExtraction | |
| from transformers import AutoTokenizer | |
| # --------------------------------------------------- | |
| # 1. SETUP & SELECTIVE DOWNLOAD | |
| # --------------------------------------------------- | |
| REPO_ID = "Pandeymp29/Amazon-Fashion-Semantic-Search" | |
| LOCAL_DIR = "artifacts" | |
| # Optimization: Only download the NEW files. Ignore the old junk. | |
| ALLOW_PATTERNS = [ | |
| "OptModel/*", # The New ONNX Model folder | |
| "amazon_IndexHNSWFlat.faiss", # The New Graph Index | |
| "product_lookup_optimised.pkl", # The New Data Dictionary | |
| "item_ids.npy" # The Product IDs | |
| ] | |
| if not os.path.exists(LOCAL_DIR): | |
| snapshot_download( | |
| repo_id=REPO_ID, | |
| local_dir=LOCAL_DIR, | |
| local_dir_use_symlinks=False, | |
| allow_patterns=ALLOW_PATTERNS | |
| ) | |
| # --------------------------------------------------- | |
| # 2. DEFINE PATHS | |
| # --------------------------------------------------- | |
| MODEL_PATH = os.path.join(LOCAL_DIR, "OptModel") | |
| INDEX_PATH = os.path.join(LOCAL_DIR, "amazon_IndexHNSWFlat.faiss") | |
| LOOKUP_PATH = os.path.join(LOCAL_DIR, "product_lookup_optimised.pkl") | |
| IDS_PATH = os.path.join(LOCAL_DIR, "item_ids.npy") | |
| # --------------------------------------------------- | |
| # 3. LOAD OPTIMIZED MODEL (ONNX) | |
| # --------------------------------------------------- | |
| print("Load ONNX Model") | |
| model = ORTModelForFeatureExtraction.from_pretrained(MODEL_PATH) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| def encode_onnx(text): | |
| """Helper to encode text using ONNX Runtime""" | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| outputs = model(**inputs) | |
| embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
| return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) | |
| # --------------------------------------------------- | |
| # 4. LOAD DATA | |
| # --------------------------------------------------- | |
| index = faiss.read_index(INDEX_PATH) | |
| item_ids = np.load(IDS_PATH) | |
| with open(LOOKUP_PATH, "rb") as f: | |
| product_lookup = pickle.load(f) | |
| # --------------------------------------------------- | |
| # 5. SEARCH LOGIC (Main App) | |
| # --------------------------------------------------- | |
| def search_products(query): | |
| if not query.strip(): | |
| return [] | |
| # 1. Encode (Fast ONNX) | |
| query_vector = encode_onnx(query) | |
| # 2. Search (HNSW Graph) | |
| k = 30 | |
| D, I = index.search(query_vector, k) | |
| gallery_data = [] | |
| seen_titles = set() | |
| for i in range(k): | |
| idx = I[0][i] | |
| if idx == -1: continue | |
| product_id = item_ids[idx] | |
| details = product_lookup.get(product_id, {}) | |
| # Deduplication | |
| title_raw = details.get('title', 'Unknown') | |
| if title_raw in seen_titles: continue | |
| seen_titles.add(title_raw) | |
| # Formatting | |
| title_display = title_raw[:60] + "..." if len(title_raw) > 60 else title_raw | |
| if details.get('videos'): | |
| title_display = title_display | |
| # Image Logic | |
| img_url = "https://via.placeholder.com/200?text=No+Image" | |
| images = details.get('images', []) | |
| if images and len(images) > 0: | |
| if isinstance(images[0], dict): | |
| img_url = images[0].get('large', images[0].get('thumb', img_url)) | |
| elif isinstance(images[0], str): | |
| img_url = images[0] | |
| gallery_data.append((img_url, title_display)) | |
| return gallery_data | |
| # --------------------------------------------------- | |
| # 6. BENCHMARK LOGIC (Added Feature) | |
| # --------------------------------------------------- | |
| # URL of the Old System (for comparison) | |
| OLD_API_URL = "https://pandeymp29-amazon-fashion-recommedation-system.hf.space/api/predict" | |
| def raw_api_call(url, query): | |
| """Measures latency of the Old System via API""" | |
| start = time.time() | |
| try: | |
| # Timeout set to 10s to avoid hanging forever | |
| requests.post(url, json={"data": [query]}, timeout=10) | |
| return (time.time() - start) * 1000 # ms | |
| except: | |
| return 0 | |
| def run_benchmark(query): | |
| # 1. Measure Old System (Remote API) | |
| lat_old = [] | |
| for _ in range(3): # Run 3 times | |
| t = raw_api_call(OLD_API_URL, query) | |
| if t > 0: lat_old.append(t) | |
| avg_old = np.mean(lat_old) if lat_old else 0 | |
| # 2. Measure New System (Local Internal Function) | |
| lat_new = [] | |
| for _ in range(3): # Run 3 times | |
| st = time.time() | |
| search_products(query) # Calls the actual search function above | |
| lat_new.append((time.time() - st) * 1000) | |
| avg_new = np.mean(lat_new) | |
| # 3. Generate Plot | |
| fig = plt.figure(figsize=(10, 5)) | |
| plt.bar(['Old System', 'New System'], [avg_old, avg_new], color=['#ff9999', '#66b3ff']) | |
| plt.title(f"Latency Comparison: '{query}'") | |
| plt.ylabel("Time (ms)") | |
| # Add text labels | |
| for i, v in enumerate([avg_old, avg_new]): | |
| plt.text(i, v, f"{int(v)} ms", ha='center', va='bottom', fontweight='bold') | |
| speedup = avg_old / avg_new if avg_new > 0 else 0 | |
| return fig, f" Result: {speedup:.1f}x Faster" | |
| # --------------------------------------------------- | |
| # 7. GRADIO UI (Tabs) | |
| # --------------------------------------------------- | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Amazon Fashion Semantic Search") as demo: | |
| gr.Markdown("# Amazon Fashion Semantic Search (Optimized)") | |
| with gr.Tabs(): | |
| # --- TAB 1: MAIN APP --- | |
| with gr.TabItem("Optimised RecSys"): | |
| gr.Markdown("Powered by **ONNX Runtime** & **HNSW Index**.") | |
| with gr.Row(): | |
| inp = gr.Textbox(placeholder="Try 'red floral summer dress'...", label="Search Query", scale=4) | |
| btn = gr.Button("Search", variant="primary", scale=1) | |
| gallery = gr.Gallery( | |
| label="Recommendations", | |
| columns=[4], | |
| rows=[2], | |
| height="auto", | |
| object_fit="contain" | |
| ) | |
| btn.click(fn=search_products, inputs=inp, outputs=gallery) | |
| inp.submit(fn=search_products, inputs=inp, outputs=gallery) | |
| # --- TAB 2: BENCHMARK --- | |
| with gr.TabItem("Benchmark"): | |
| gr.Markdown("Compare this **New ONNX Space** against the **Old PyTorch Space** in real-time.") | |
| with gr.Row(): | |
| bench_inp = gr.Textbox(value="Red dress for wedding", label="Test Query") | |
| bench_btn = gr.Button("Run Comparison", variant="stop") | |
| with gr.Row(): | |
| bench_plot = gr.Plot(label="Latency Comparison") | |
| bench_txt = gr.Textbox(label="Speedup Score", interactive=False) | |
| bench_btn.click(run_benchmark, bench_inp, [bench_plot, bench_txt]) | |
| demo.launch() |