import os import gradio as gr import pickle import numpy as np import faiss import time import requests import matplotlib.pyplot as plt from concurrent.futures import ThreadPoolExecutor from huggingface_hub import snapshot_download from optimum.onnxruntime import ORTModelForFeatureExtraction from transformers import AutoTokenizer # --------------------------------------------------- # 1. SETUP & SELECTIVE DOWNLOAD # --------------------------------------------------- REPO_ID = "Pandeymp29/Amazon-Fashion-Semantic-Search" LOCAL_DIR = "artifacts" # Optimization: Only download the NEW files. Ignore the old junk. ALLOW_PATTERNS = [ "OptModel/*", # The New ONNX Model folder "amazon_IndexHNSWFlat.faiss", # The New Graph Index "product_lookup_optimised.pkl", # The New Data Dictionary "item_ids.npy" # The Product IDs ] if not os.path.exists(LOCAL_DIR): snapshot_download( repo_id=REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=False, allow_patterns=ALLOW_PATTERNS ) # --------------------------------------------------- # 2. DEFINE PATHS # --------------------------------------------------- MODEL_PATH = os.path.join(LOCAL_DIR, "OptModel") INDEX_PATH = os.path.join(LOCAL_DIR, "amazon_IndexHNSWFlat.faiss") LOOKUP_PATH = os.path.join(LOCAL_DIR, "product_lookup_optimised.pkl") IDS_PATH = os.path.join(LOCAL_DIR, "item_ids.npy") # --------------------------------------------------- # 3. LOAD OPTIMIZED MODEL (ONNX) # --------------------------------------------------- print("Load ONNX Model") model = ORTModelForFeatureExtraction.from_pretrained(MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) def encode_onnx(text): """Helper to encode text using ONNX Runtime""" inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs) embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy() return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # --------------------------------------------------- # 4. LOAD DATA # --------------------------------------------------- index = faiss.read_index(INDEX_PATH) item_ids = np.load(IDS_PATH) with open(LOOKUP_PATH, "rb") as f: product_lookup = pickle.load(f) # --------------------------------------------------- # 5. SEARCH LOGIC (Main App) # --------------------------------------------------- def search_products(query): if not query.strip(): return [] # 1. Encode (Fast ONNX) query_vector = encode_onnx(query) # 2. Search (HNSW Graph) k = 30 D, I = index.search(query_vector, k) gallery_data = [] seen_titles = set() for i in range(k): idx = I[0][i] if idx == -1: continue product_id = item_ids[idx] details = product_lookup.get(product_id, {}) # Deduplication title_raw = details.get('title', 'Unknown') if title_raw in seen_titles: continue seen_titles.add(title_raw) # Formatting title_display = title_raw[:60] + "..." if len(title_raw) > 60 else title_raw if details.get('videos'): title_display = title_display # Image Logic img_url = "https://via.placeholder.com/200?text=No+Image" images = details.get('images', []) if images and len(images) > 0: if isinstance(images[0], dict): img_url = images[0].get('large', images[0].get('thumb', img_url)) elif isinstance(images[0], str): img_url = images[0] gallery_data.append((img_url, title_display)) return gallery_data # --------------------------------------------------- # 6. BENCHMARK LOGIC (Added Feature) # --------------------------------------------------- # URL of the Old System (for comparison) OLD_API_URL = "https://pandeymp29-amazon-fashion-recommedation-system.hf.space/api/predict" def raw_api_call(url, query): """Measures latency of the Old System via API""" start = time.time() try: # Timeout set to 10s to avoid hanging forever requests.post(url, json={"data": [query]}, timeout=10) return (time.time() - start) * 1000 # ms except: return 0 def run_benchmark(query): # 1. Measure Old System (Remote API) lat_old = [] for _ in range(3): # Run 3 times t = raw_api_call(OLD_API_URL, query) if t > 0: lat_old.append(t) avg_old = np.mean(lat_old) if lat_old else 0 # 2. Measure New System (Local Internal Function) lat_new = [] for _ in range(3): # Run 3 times st = time.time() search_products(query) # Calls the actual search function above lat_new.append((time.time() - st) * 1000) avg_new = np.mean(lat_new) # 3. Generate Plot fig = plt.figure(figsize=(10, 5)) plt.bar(['Old System', 'New System'], [avg_old, avg_new], color=['#ff9999', '#66b3ff']) plt.title(f"Latency Comparison: '{query}'") plt.ylabel("Time (ms)") # Add text labels for i, v in enumerate([avg_old, avg_new]): plt.text(i, v, f"{int(v)} ms", ha='center', va='bottom', fontweight='bold') speedup = avg_old / avg_new if avg_new > 0 else 0 return fig, f" Result: {speedup:.1f}x Faster" # --------------------------------------------------- # 7. GRADIO UI (Tabs) # --------------------------------------------------- with gr.Blocks(theme=gr.themes.Soft(), title="Amazon Fashion Semantic Search") as demo: gr.Markdown("# Amazon Fashion Semantic Search (Optimized)") with gr.Tabs(): # --- TAB 1: MAIN APP --- with gr.TabItem("Optimised RecSys"): gr.Markdown("Powered by **ONNX Runtime** & **HNSW Index**.") with gr.Row(): inp = gr.Textbox(placeholder="Try 'red floral summer dress'...", label="Search Query", scale=4) btn = gr.Button("Search", variant="primary", scale=1) gallery = gr.Gallery( label="Recommendations", columns=[4], rows=[2], height="auto", object_fit="contain" ) btn.click(fn=search_products, inputs=inp, outputs=gallery) inp.submit(fn=search_products, inputs=inp, outputs=gallery) # --- TAB 2: BENCHMARK --- with gr.TabItem("Benchmark"): gr.Markdown("Compare this **New ONNX Space** against the **Old PyTorch Space** in real-time.") with gr.Row(): bench_inp = gr.Textbox(value="Red dress for wedding", label="Test Query") bench_btn = gr.Button("Run Comparison", variant="stop") with gr.Row(): bench_plot = gr.Plot(label="Latency Comparison") bench_txt = gr.Textbox(label="Speedup Score", interactive=False) bench_btn.click(run_benchmark, bench_inp, [bench_plot, bench_txt]) demo.launch()