Pandeymp29's picture
integrated tab for latency check
e7ff65e verified
import os
import gradio as gr
import pickle
import numpy as np
import faiss
import time
import requests
import matplotlib.pyplot as plt
from concurrent.futures import ThreadPoolExecutor
from huggingface_hub import snapshot_download
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer
# ---------------------------------------------------
# 1. SETUP & SELECTIVE DOWNLOAD
# ---------------------------------------------------
REPO_ID = "Pandeymp29/Amazon-Fashion-Semantic-Search"
LOCAL_DIR = "artifacts"
# Optimization: Only download the NEW files. Ignore the old junk.
ALLOW_PATTERNS = [
"OptModel/*", # The New ONNX Model folder
"amazon_IndexHNSWFlat.faiss", # The New Graph Index
"product_lookup_optimised.pkl", # The New Data Dictionary
"item_ids.npy" # The Product IDs
]
if not os.path.exists(LOCAL_DIR):
snapshot_download(
repo_id=REPO_ID,
local_dir=LOCAL_DIR,
local_dir_use_symlinks=False,
allow_patterns=ALLOW_PATTERNS
)
# ---------------------------------------------------
# 2. DEFINE PATHS
# ---------------------------------------------------
MODEL_PATH = os.path.join(LOCAL_DIR, "OptModel")
INDEX_PATH = os.path.join(LOCAL_DIR, "amazon_IndexHNSWFlat.faiss")
LOOKUP_PATH = os.path.join(LOCAL_DIR, "product_lookup_optimised.pkl")
IDS_PATH = os.path.join(LOCAL_DIR, "item_ids.npy")
# ---------------------------------------------------
# 3. LOAD OPTIMIZED MODEL (ONNX)
# ---------------------------------------------------
print("Load ONNX Model")
model = ORTModelForFeatureExtraction.from_pretrained(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
def encode_onnx(text):
"""Helper to encode text using ONNX Runtime"""
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1).detach().numpy()
return embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
# ---------------------------------------------------
# 4. LOAD DATA
# ---------------------------------------------------
index = faiss.read_index(INDEX_PATH)
item_ids = np.load(IDS_PATH)
with open(LOOKUP_PATH, "rb") as f:
product_lookup = pickle.load(f)
# ---------------------------------------------------
# 5. SEARCH LOGIC (Main App)
# ---------------------------------------------------
def search_products(query):
if not query.strip():
return []
# 1. Encode (Fast ONNX)
query_vector = encode_onnx(query)
# 2. Search (HNSW Graph)
k = 30
D, I = index.search(query_vector, k)
gallery_data = []
seen_titles = set()
for i in range(k):
idx = I[0][i]
if idx == -1: continue
product_id = item_ids[idx]
details = product_lookup.get(product_id, {})
# Deduplication
title_raw = details.get('title', 'Unknown')
if title_raw in seen_titles: continue
seen_titles.add(title_raw)
# Formatting
title_display = title_raw[:60] + "..." if len(title_raw) > 60 else title_raw
if details.get('videos'):
title_display = title_display
# Image Logic
img_url = "https://via.placeholder.com/200?text=No+Image"
images = details.get('images', [])
if images and len(images) > 0:
if isinstance(images[0], dict):
img_url = images[0].get('large', images[0].get('thumb', img_url))
elif isinstance(images[0], str):
img_url = images[0]
gallery_data.append((img_url, title_display))
return gallery_data
# ---------------------------------------------------
# 6. BENCHMARK LOGIC (Added Feature)
# ---------------------------------------------------
# URL of the Old System (for comparison)
OLD_API_URL = "https://pandeymp29-amazon-fashion-recommedation-system.hf.space/api/predict"
def raw_api_call(url, query):
"""Measures latency of the Old System via API"""
start = time.time()
try:
# Timeout set to 10s to avoid hanging forever
requests.post(url, json={"data": [query]}, timeout=10)
return (time.time() - start) * 1000 # ms
except:
return 0
def run_benchmark(query):
# 1. Measure Old System (Remote API)
lat_old = []
for _ in range(3): # Run 3 times
t = raw_api_call(OLD_API_URL, query)
if t > 0: lat_old.append(t)
avg_old = np.mean(lat_old) if lat_old else 0
# 2. Measure New System (Local Internal Function)
lat_new = []
for _ in range(3): # Run 3 times
st = time.time()
search_products(query) # Calls the actual search function above
lat_new.append((time.time() - st) * 1000)
avg_new = np.mean(lat_new)
# 3. Generate Plot
fig = plt.figure(figsize=(10, 5))
plt.bar(['Old System', 'New System'], [avg_old, avg_new], color=['#ff9999', '#66b3ff'])
plt.title(f"Latency Comparison: '{query}'")
plt.ylabel("Time (ms)")
# Add text labels
for i, v in enumerate([avg_old, avg_new]):
plt.text(i, v, f"{int(v)} ms", ha='center', va='bottom', fontweight='bold')
speedup = avg_old / avg_new if avg_new > 0 else 0
return fig, f" Result: {speedup:.1f}x Faster"
# ---------------------------------------------------
# 7. GRADIO UI (Tabs)
# ---------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft(), title="Amazon Fashion Semantic Search") as demo:
gr.Markdown("# Amazon Fashion Semantic Search (Optimized)")
with gr.Tabs():
# --- TAB 1: MAIN APP ---
with gr.TabItem("Optimised RecSys"):
gr.Markdown("Powered by **ONNX Runtime** & **HNSW Index**.")
with gr.Row():
inp = gr.Textbox(placeholder="Try 'red floral summer dress'...", label="Search Query", scale=4)
btn = gr.Button("Search", variant="primary", scale=1)
gallery = gr.Gallery(
label="Recommendations",
columns=[4],
rows=[2],
height="auto",
object_fit="contain"
)
btn.click(fn=search_products, inputs=inp, outputs=gallery)
inp.submit(fn=search_products, inputs=inp, outputs=gallery)
# --- TAB 2: BENCHMARK ---
with gr.TabItem("Benchmark"):
gr.Markdown("Compare this **New ONNX Space** against the **Old PyTorch Space** in real-time.")
with gr.Row():
bench_inp = gr.Textbox(value="Red dress for wedding", label="Test Query")
bench_btn = gr.Button("Run Comparison", variant="stop")
with gr.Row():
bench_plot = gr.Plot(label="Latency Comparison")
bench_txt = gr.Textbox(label="Speedup Score", interactive=False)
bench_btn.click(run_benchmark, bench_inp, [bench_plot, bench_txt])
demo.launch()