entropy25's picture
Update app.py
8aa94a5 verified
import gradio as gr
import numpy as np
import pickle
import faiss
import torch
from huggingface_hub import hf_hub_download
from sentence_transformers import CrossEncoder, SentenceTransformer
class SearchEngine:
def __init__(self):
self.loaded = False
self.bm25 = None
self.documents = None
self.faiss_index = None
self.reranker = None
self.encoder = None
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_models(self):
try:
# Download files
faiss_path = hf_hub_download("entropy25/product-vectorstore", "index.faiss", repo_type="dataset")
bm25_path = hf_hub_download("entropy25/product-vectorstore", "bm25_model.pkl", repo_type="dataset")
docs_path = hf_hub_download("entropy25/product-vectorstore", "documents.pkl", repo_type="dataset")
# Load models
self.faiss_index = faiss.read_index(faiss_path)
self.encoder = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', device=self.device)
with open(bm25_path, 'rb') as f:
self.bm25 = pickle.load(f)
with open(docs_path, 'rb') as f:
self.documents = pickle.load(f)
self.loaded = True
return f"βœ… Loaded {self.faiss_index.ntotal:,} vectors, {len(self.documents):,} docs"
except Exception as e:
return f"❌ Error: {str(e)}"
def hybrid_search(self, query, k=10):
vector_results = self._vector_search(query, k*2)
bm25_results = self._bm25_search(query, k*2)
doc_scores = {}
# Process vector results
if vector_results:
vector_scores = [1.0 - r['score'] for r in vector_results] # Convert to similarity
min_score, max_score = min(vector_scores), max(vector_scores)
score_range = max_score - min_score or 1.0
for result, score in zip(vector_results, vector_scores):
normalized = (score - min_score) / score_range
doc_id = result['metadata'].get('product_id', hash(result['text']))
doc_scores[doc_id] = {
'metadata': result['metadata'],
'text': result['text'],
'score': 0.7 * normalized
}
# Process BM25 results
if bm25_results:
bm25_scores = [r['score'] for r in bm25_results]
min_score, max_score = min(bm25_scores), max(bm25_scores)
score_range = max_score - min_score or 1.0
for result in bm25_results:
normalized = (result['score'] - min_score) / score_range
doc_id = result['metadata'].get('product_id', hash(result['text']))
if doc_id in doc_scores:
doc_scores[doc_id]['score'] += 0.3 * normalized
else:
doc_scores[doc_id] = {
'metadata': result['metadata'],
'text': result['text'],
'score': 0.3 * normalized
}
return sorted(doc_scores.values(), key=lambda x: x['score'], reverse=True)[:k]
def rerank(self, query, results, k=10):
if not results:
return []
pairs = [[query, doc['text']] for doc in results]
with torch.no_grad():
scores = self.reranker.predict(pairs)
for doc, score in zip(results, scores):
doc['rerank_score'] = float(score)
return sorted(results, key=lambda x: x['rerank_score'], reverse=True)[:k]
def search(self, query, k=10, use_rerank=True, method="hybrid"):
if not self.loaded:
return []
if method == "vector":
results = self._vector_search(query, k=k*3 if use_rerank else k)
elif method == "bm25":
results = self._bm25_search(query, k=k*3 if use_rerank else k)
else: # hybrid
results = self.hybrid_search(query, k=k*3 if use_rerank else k)
if use_rerank and results:
results = self.rerank(query, results, k=k)
for result in results:
result['score'] = result['rerank_score']
return results[:k]
def _vector_search(self, query, k):
query_vec = self.encoder.encode([query])
scores, indices = self.faiss_index.search(query_vec, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.documents):
doc = self.documents[idx]
# Handle different document formats
if isinstance(doc, dict):
text = doc.get('content', str(doc))
metadata = doc.get('metadata', {})
else:
text = str(doc)
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else {}
results.append({
'text': text,
'score': float(1 - score),
'metadata': metadata
})
return results
def _bm25_search(self, query, k):
tokens = query.lower().split()
scores = self.bm25.get_scores(tokens)
top_indices = np.argsort(scores)[::-1][:k]
results = []
for idx in top_indices:
if idx < len(self.documents) and scores[idx] > 0:
doc = self.documents[idx]
# Handle different document formats
if isinstance(doc, dict):
text = doc.get('content', str(doc))
metadata = doc.get('metadata', {})
else:
text = str(doc)
metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else {}
results.append({
'text': text,
'score': float(scores[idx]),
'metadata': metadata
})
return results
engine = SearchEngine()
def init_system():
return engine.load_models()
def search_products(query, method="hybrid", use_rerank=True):
if not query.strip():
return create_sample_cards()
results = engine.search(query, k=8, use_rerank=use_rerank, method=method)
if not results:
return create_sample_cards()
return create_product_cards(results, query)
def create_product_cards(results, query):
cards_html = f"<h3>Search Results for '{query}'</h3><div style='display: flex; flex-wrap: wrap; gap: 20px; justify-content: center;'>"
for i, result in enumerate(results):
metadata = result.get('metadata', {})
# Extract product information from text and metadata
raw_text = result['text']
# Handle different document formats
if raw_text.startswith("{'content': '") and raw_text.endswith("'}"):
# Extract content from dictionary format
try:
import ast
doc_dict = ast.literal_eval(raw_text)
text = doc_dict.get('content', raw_text)
except:
text = raw_text
else:
text = raw_text
title = metadata.get('title', metadata.get('product_title', ''))
brand = metadata.get('brand', metadata.get('product_brand', ''))
# Extract title from structured content
if not title:
if text.startswith('Title: '):
title = text[7:].split('\n')[0] # Get title after "Title: "
title = title[:80] + ('...' if len(title) > 80 else '')
else:
title = text.split('\n')[0][:60] + ('...' if len(text.split('\n')[0]) > 60 else '')
# Extract brand from title or text
if not brand:
# Get the first word from title as potential brand
if title and not title.startswith('{'):
first_word = title.split()[0] if title.split() else ''
if first_word and len(first_word) > 1 and first_word.upper() != first_word:
brand = first_word
# Fallback: try to find brand in text
if not brand or brand.lower() in ['title:', 'product']:
words = text.lower().split()
common_brands = ['apple', 'samsung', 'sony', 'lg', 'dell', 'hp', 'lenovo', 'asus', 'acer', 'microsoft', 'google', 'amazon', 'nike', 'adidas', 'jarlif', 'sketchers', 'puma', 'reebok']
for brand_name in common_brands:
if brand_name in words:
brand = brand_name.capitalize()
break
if not brand or brand.lower() in ['title:', 'product']:
brand = "Generic"
score = result.get('score', 0)
# Create description from text, removing title part if present
if text.startswith('Title: '):
description_text = text[text.find('\n')+1:] if '\n' in text else text[7:]
else:
description_text = text
# Remove brand name from description if it appears at the beginning
description_lines = description_text.split('\n')
cleaned_lines = []
for line in description_lines:
line = line.strip()
if not line:
continue
# Skip lines that start with "Brand: [brand_name]"
if line.lower().startswith('brand:') and brand.lower() in line.lower():
continue
# Remove brand name if it appears at the beginning of a line
if line.lower().startswith(brand.lower()):
line = line[len(brand):].strip()
if line.startswith(':') or line.startswith('-'):
line = line[1:].strip()
if line: # Only add non-empty lines
cleaned_lines.append(line)
description_text = ' '.join(cleaned_lines)
description = description_text[:150] + "..." if len(description_text) > 150 else description_text
cards_html += f"""
<div style="border: 1px solid #ddd; border-radius: 12px; padding: 20px; width: 280px; background: white; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<div style="background: #f0f0f0; height: 120px; border-radius: 8px; display: flex; align-items: center; justify-content: center; margin-bottom: 15px; color: #666;">
πŸ“± Product Image
</div>
<h4 style="margin: 0 0 8px 0; color: #333; font-size: 16px; line-height: 1.3;">{title}</h4>
<p style="margin: 0 0 8px 0; color: #666; font-size: 14px;">Brand: {brand}</p>
<p style="margin: 0 0 12px 0; color: #888; font-size: 13px; line-height: 1.4;">{description}</p>
<div style="display: flex; justify-content: space-between; align-items: center;">
<span style="background: #007bff; color: white; padding: 4px 8px; border-radius: 4px; font-size: 12px;">Score: {score:.3f}</span>
<button style="background: #28a745; color: white; border: none; padding: 6px 12px; border-radius: 4px; cursor: pointer; font-size: 12px;">View</button>
</div>
</div>
"""
cards_html += "</div>"
return cards_html
def create_sample_cards():
sample_products = [
{"title": "Wireless Bluetooth Headphones", "brand": "TechSound", "desc": "High-quality wireless headphones with noise cancellation", "score": 0.95},
{"title": "Gaming Mechanical Keyboard", "brand": "GamePro", "desc": "RGB backlit mechanical keyboard for gaming", "score": 0.92},
{"title": "Smartphone Case", "brand": "ProtectMax", "desc": "Durable protection case for latest smartphones", "score": 0.89},
{"title": "Running Shoes", "brand": "SportFit", "desc": "Comfortable running shoes for daily exercise", "score": 0.87},
{"title": "Coffee Machine", "brand": "BrewMaster", "desc": "Automatic coffee machine with multiple settings", "score": 0.84},
{"title": "Laptop Stand", "brand": "DeskPro", "desc": "Adjustable aluminum laptop stand for ergonomic work", "score": 0.82},
{"title": "Wireless Mouse", "brand": "ClickPro", "desc": "Ergonomic wireless mouse with precision tracking", "score": 0.80},
{"title": "USB-C Hub", "brand": "ConnectMax", "desc": "Multi-port USB-C hub with 4K HDMI output", "score": 0.78}
]
cards_html = "<h3>Recommended Products</h3><div style='display: flex; flex-wrap: wrap; gap: 20px; justify-content: center;'>"
for product in sample_products:
cards_html += f"""
<div style="border: 1px solid #ddd; border-radius: 12px; padding: 20px; width: 280px; background: white; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<div style="background: #f0f0f0; height: 120px; border-radius: 8px; display: flex; align-items: center; justify-content: center; margin-bottom: 15px; color: #666;">
πŸ“± Product Image
</div>
<h4 style="margin: 0 0 8px 0; color: #333; font-size: 16px; line-height: 1.3;">{product['title']}</h4>
<p style="margin: 0 0 8px 0; color: #666; font-size: 14px;">Brand: {product['brand']}</p>
<p style="margin: 0 0 12px 0; color: #888; font-size: 13px; line-height: 1.4;">{product['desc']}</p>
<div style="display: flex; justify-content: space-between; align-items: center;">
<span style="background: #007bff; color: white; padding: 4px 8px; border-radius: 4px; font-size: 12px;">Score: {product['score']:.2f}</span>
<button style="background: #28a745; color: white; border: none; padding: 6px 12px; border-radius: 4px; cursor: pointer; font-size: 12px;">View</button>
</div>
</div>
"""
cards_html += "</div>"
return cards_html
# Quick search function
def quick_search(search_term):
return search_products(search_term)
# UI
with gr.Blocks(title="Product Search", theme=gr.themes.Soft()) as demo:
gr.HTML("""
<div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; margin-bottom: 20px; border-radius: 10px;'>
<h1 style='margin: 0; font-size: 2.5em;'>πŸ›οΈ Product Search</h1>
<p style='margin: 10px 0 0 0; opacity: 0.9;'>Find products with AI-powered search</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
init_btn = gr.Button("πŸš€ Initialize System", variant="primary", size="lg")
status = gr.Textbox(label="System Status", lines=2, interactive=False, value="Click to initialize...")
with gr.Column(scale=3):
search_input = gr.Textbox(
label="Search Products",
placeholder="Search for headphones, keyboards, cases...",
lines=1
)
with gr.Row():
method_dropdown = gr.Dropdown(
choices=["hybrid", "vector", "bm25"],
value="hybrid",
label="Search Method",
scale=1
)
rerank_checkbox = gr.Checkbox(
value=True,
label="Use Reranking",
scale=1
)
search_btn = gr.Button("πŸ” Search", variant="primary", size="lg")
gr.Markdown("### Quick Search")
with gr.Row():
quick_searches = ["bluetooth headphones", "gaming keyboard", "phone case", "running shoes"]
quick_btn1 = gr.Button(quick_searches[0], size="sm")
quick_btn2 = gr.Button(quick_searches[1], size="sm")
quick_btn3 = gr.Button(quick_searches[2], size="sm")
quick_btn4 = gr.Button(quick_searches[3], size="sm")
# Results display
results_display = gr.HTML(value=create_sample_cards())
# Events
init_btn.click(init_system, outputs=status)
search_btn.click(search_products, inputs=[search_input, method_dropdown, rerank_checkbox], outputs=results_display)
search_input.submit(search_products, inputs=[search_input, method_dropdown, rerank_checkbox], outputs=results_display)
# Quick search button events - Fixed
quick_btn1.click(lambda: search_products(quick_searches[0], "hybrid", True), outputs=results_display)
quick_btn2.click(lambda: search_products(quick_searches[1], "hybrid", True), outputs=results_display)
quick_btn3.click(lambda: search_products(quick_searches[2], "hybrid", True), outputs=results_display)
quick_btn4.click(lambda: search_products(quick_searches[3], "hybrid", True), outputs=results_display)
demo.launch()