Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import pickle | |
| import faiss | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from sentence_transformers import CrossEncoder, SentenceTransformer | |
| class SearchEngine: | |
| def __init__(self): | |
| self.loaded = False | |
| self.bm25 = None | |
| self.documents = None | |
| self.faiss_index = None | |
| self.reranker = None | |
| self.encoder = None | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def load_models(self): | |
| try: | |
| # Download files | |
| faiss_path = hf_hub_download("entropy25/product-vectorstore", "index.faiss", repo_type="dataset") | |
| bm25_path = hf_hub_download("entropy25/product-vectorstore", "bm25_model.pkl", repo_type="dataset") | |
| docs_path = hf_hub_download("entropy25/product-vectorstore", "documents.pkl", repo_type="dataset") | |
| # Load models | |
| self.faiss_index = faiss.read_index(faiss_path) | |
| self.encoder = SentenceTransformer('sentence-transformers/all-mpnet-base-v2') | |
| self.reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2', device=self.device) | |
| with open(bm25_path, 'rb') as f: | |
| self.bm25 = pickle.load(f) | |
| with open(docs_path, 'rb') as f: | |
| self.documents = pickle.load(f) | |
| self.loaded = True | |
| return f"β Loaded {self.faiss_index.ntotal:,} vectors, {len(self.documents):,} docs" | |
| except Exception as e: | |
| return f"β Error: {str(e)}" | |
| def hybrid_search(self, query, k=10): | |
| vector_results = self._vector_search(query, k*2) | |
| bm25_results = self._bm25_search(query, k*2) | |
| doc_scores = {} | |
| # Process vector results | |
| if vector_results: | |
| vector_scores = [1.0 - r['score'] for r in vector_results] # Convert to similarity | |
| min_score, max_score = min(vector_scores), max(vector_scores) | |
| score_range = max_score - min_score or 1.0 | |
| for result, score in zip(vector_results, vector_scores): | |
| normalized = (score - min_score) / score_range | |
| doc_id = result['metadata'].get('product_id', hash(result['text'])) | |
| doc_scores[doc_id] = { | |
| 'metadata': result['metadata'], | |
| 'text': result['text'], | |
| 'score': 0.7 * normalized | |
| } | |
| # Process BM25 results | |
| if bm25_results: | |
| bm25_scores = [r['score'] for r in bm25_results] | |
| min_score, max_score = min(bm25_scores), max(bm25_scores) | |
| score_range = max_score - min_score or 1.0 | |
| for result in bm25_results: | |
| normalized = (result['score'] - min_score) / score_range | |
| doc_id = result['metadata'].get('product_id', hash(result['text'])) | |
| if doc_id in doc_scores: | |
| doc_scores[doc_id]['score'] += 0.3 * normalized | |
| else: | |
| doc_scores[doc_id] = { | |
| 'metadata': result['metadata'], | |
| 'text': result['text'], | |
| 'score': 0.3 * normalized | |
| } | |
| return sorted(doc_scores.values(), key=lambda x: x['score'], reverse=True)[:k] | |
| def rerank(self, query, results, k=10): | |
| if not results: | |
| return [] | |
| pairs = [[query, doc['text']] for doc in results] | |
| with torch.no_grad(): | |
| scores = self.reranker.predict(pairs) | |
| for doc, score in zip(results, scores): | |
| doc['rerank_score'] = float(score) | |
| return sorted(results, key=lambda x: x['rerank_score'], reverse=True)[:k] | |
| def search(self, query, k=10, use_rerank=True, method="hybrid"): | |
| if not self.loaded: | |
| return [] | |
| if method == "vector": | |
| results = self._vector_search(query, k=k*3 if use_rerank else k) | |
| elif method == "bm25": | |
| results = self._bm25_search(query, k=k*3 if use_rerank else k) | |
| else: # hybrid | |
| results = self.hybrid_search(query, k=k*3 if use_rerank else k) | |
| if use_rerank and results: | |
| results = self.rerank(query, results, k=k) | |
| for result in results: | |
| result['score'] = result['rerank_score'] | |
| return results[:k] | |
| def _vector_search(self, query, k): | |
| query_vec = self.encoder.encode([query]) | |
| scores, indices = self.faiss_index.search(query_vec, k) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx < len(self.documents): | |
| doc = self.documents[idx] | |
| # Handle different document formats | |
| if isinstance(doc, dict): | |
| text = doc.get('content', str(doc)) | |
| metadata = doc.get('metadata', {}) | |
| else: | |
| text = str(doc) | |
| metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else {} | |
| results.append({ | |
| 'text': text, | |
| 'score': float(1 - score), | |
| 'metadata': metadata | |
| }) | |
| return results | |
| def _bm25_search(self, query, k): | |
| tokens = query.lower().split() | |
| scores = self.bm25.get_scores(tokens) | |
| top_indices = np.argsort(scores)[::-1][:k] | |
| results = [] | |
| for idx in top_indices: | |
| if idx < len(self.documents) and scores[idx] > 0: | |
| doc = self.documents[idx] | |
| # Handle different document formats | |
| if isinstance(doc, dict): | |
| text = doc.get('content', str(doc)) | |
| metadata = doc.get('metadata', {}) | |
| else: | |
| text = str(doc) | |
| metadata = getattr(doc, 'metadata', {}) if hasattr(doc, 'metadata') else {} | |
| results.append({ | |
| 'text': text, | |
| 'score': float(scores[idx]), | |
| 'metadata': metadata | |
| }) | |
| return results | |
| engine = SearchEngine() | |
| def init_system(): | |
| return engine.load_models() | |
| def search_products(query, method="hybrid", use_rerank=True): | |
| if not query.strip(): | |
| return create_sample_cards() | |
| results = engine.search(query, k=8, use_rerank=use_rerank, method=method) | |
| if not results: | |
| return create_sample_cards() | |
| return create_product_cards(results, query) | |
| def create_product_cards(results, query): | |
| cards_html = f"<h3>Search Results for '{query}'</h3><div style='display: flex; flex-wrap: wrap; gap: 20px; justify-content: center;'>" | |
| for i, result in enumerate(results): | |
| metadata = result.get('metadata', {}) | |
| # Extract product information from text and metadata | |
| raw_text = result['text'] | |
| # Handle different document formats | |
| if raw_text.startswith("{'content': '") and raw_text.endswith("'}"): | |
| # Extract content from dictionary format | |
| try: | |
| import ast | |
| doc_dict = ast.literal_eval(raw_text) | |
| text = doc_dict.get('content', raw_text) | |
| except: | |
| text = raw_text | |
| else: | |
| text = raw_text | |
| title = metadata.get('title', metadata.get('product_title', '')) | |
| brand = metadata.get('brand', metadata.get('product_brand', '')) | |
| # Extract title from structured content | |
| if not title: | |
| if text.startswith('Title: '): | |
| title = text[7:].split('\n')[0] # Get title after "Title: " | |
| title = title[:80] + ('...' if len(title) > 80 else '') | |
| else: | |
| title = text.split('\n')[0][:60] + ('...' if len(text.split('\n')[0]) > 60 else '') | |
| # Extract brand from title or text | |
| if not brand: | |
| # Get the first word from title as potential brand | |
| if title and not title.startswith('{'): | |
| first_word = title.split()[0] if title.split() else '' | |
| if first_word and len(first_word) > 1 and first_word.upper() != first_word: | |
| brand = first_word | |
| # Fallback: try to find brand in text | |
| if not brand or brand.lower() in ['title:', 'product']: | |
| words = text.lower().split() | |
| common_brands = ['apple', 'samsung', 'sony', 'lg', 'dell', 'hp', 'lenovo', 'asus', 'acer', 'microsoft', 'google', 'amazon', 'nike', 'adidas', 'jarlif', 'sketchers', 'puma', 'reebok'] | |
| for brand_name in common_brands: | |
| if brand_name in words: | |
| brand = brand_name.capitalize() | |
| break | |
| if not brand or brand.lower() in ['title:', 'product']: | |
| brand = "Generic" | |
| score = result.get('score', 0) | |
| # Create description from text, removing title part if present | |
| if text.startswith('Title: '): | |
| description_text = text[text.find('\n')+1:] if '\n' in text else text[7:] | |
| else: | |
| description_text = text | |
| # Remove brand name from description if it appears at the beginning | |
| description_lines = description_text.split('\n') | |
| cleaned_lines = [] | |
| for line in description_lines: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| # Skip lines that start with "Brand: [brand_name]" | |
| if line.lower().startswith('brand:') and brand.lower() in line.lower(): | |
| continue | |
| # Remove brand name if it appears at the beginning of a line | |
| if line.lower().startswith(brand.lower()): | |
| line = line[len(brand):].strip() | |
| if line.startswith(':') or line.startswith('-'): | |
| line = line[1:].strip() | |
| if line: # Only add non-empty lines | |
| cleaned_lines.append(line) | |
| description_text = ' '.join(cleaned_lines) | |
| description = description_text[:150] + "..." if len(description_text) > 150 else description_text | |
| cards_html += f""" | |
| <div style="border: 1px solid #ddd; border-radius: 12px; padding: 20px; width: 280px; background: white; box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> | |
| <div style="background: #f0f0f0; height: 120px; border-radius: 8px; display: flex; align-items: center; justify-content: center; margin-bottom: 15px; color: #666;"> | |
| π± Product Image | |
| </div> | |
| <h4 style="margin: 0 0 8px 0; color: #333; font-size: 16px; line-height: 1.3;">{title}</h4> | |
| <p style="margin: 0 0 8px 0; color: #666; font-size: 14px;">Brand: {brand}</p> | |
| <p style="margin: 0 0 12px 0; color: #888; font-size: 13px; line-height: 1.4;">{description}</p> | |
| <div style="display: flex; justify-content: space-between; align-items: center;"> | |
| <span style="background: #007bff; color: white; padding: 4px 8px; border-radius: 4px; font-size: 12px;">Score: {score:.3f}</span> | |
| <button style="background: #28a745; color: white; border: none; padding: 6px 12px; border-radius: 4px; cursor: pointer; font-size: 12px;">View</button> | |
| </div> | |
| </div> | |
| """ | |
| cards_html += "</div>" | |
| return cards_html | |
| def create_sample_cards(): | |
| sample_products = [ | |
| {"title": "Wireless Bluetooth Headphones", "brand": "TechSound", "desc": "High-quality wireless headphones with noise cancellation", "score": 0.95}, | |
| {"title": "Gaming Mechanical Keyboard", "brand": "GamePro", "desc": "RGB backlit mechanical keyboard for gaming", "score": 0.92}, | |
| {"title": "Smartphone Case", "brand": "ProtectMax", "desc": "Durable protection case for latest smartphones", "score": 0.89}, | |
| {"title": "Running Shoes", "brand": "SportFit", "desc": "Comfortable running shoes for daily exercise", "score": 0.87}, | |
| {"title": "Coffee Machine", "brand": "BrewMaster", "desc": "Automatic coffee machine with multiple settings", "score": 0.84}, | |
| {"title": "Laptop Stand", "brand": "DeskPro", "desc": "Adjustable aluminum laptop stand for ergonomic work", "score": 0.82}, | |
| {"title": "Wireless Mouse", "brand": "ClickPro", "desc": "Ergonomic wireless mouse with precision tracking", "score": 0.80}, | |
| {"title": "USB-C Hub", "brand": "ConnectMax", "desc": "Multi-port USB-C hub with 4K HDMI output", "score": 0.78} | |
| ] | |
| cards_html = "<h3>Recommended Products</h3><div style='display: flex; flex-wrap: wrap; gap: 20px; justify-content: center;'>" | |
| for product in sample_products: | |
| cards_html += f""" | |
| <div style="border: 1px solid #ddd; border-radius: 12px; padding: 20px; width: 280px; background: white; box-shadow: 0 2px 8px rgba(0,0,0,0.1);"> | |
| <div style="background: #f0f0f0; height: 120px; border-radius: 8px; display: flex; align-items: center; justify-content: center; margin-bottom: 15px; color: #666;"> | |
| π± Product Image | |
| </div> | |
| <h4 style="margin: 0 0 8px 0; color: #333; font-size: 16px; line-height: 1.3;">{product['title']}</h4> | |
| <p style="margin: 0 0 8px 0; color: #666; font-size: 14px;">Brand: {product['brand']}</p> | |
| <p style="margin: 0 0 12px 0; color: #888; font-size: 13px; line-height: 1.4;">{product['desc']}</p> | |
| <div style="display: flex; justify-content: space-between; align-items: center;"> | |
| <span style="background: #007bff; color: white; padding: 4px 8px; border-radius: 4px; font-size: 12px;">Score: {product['score']:.2f}</span> | |
| <button style="background: #28a745; color: white; border: none; padding: 6px 12px; border-radius: 4px; cursor: pointer; font-size: 12px;">View</button> | |
| </div> | |
| </div> | |
| """ | |
| cards_html += "</div>" | |
| return cards_html | |
| # Quick search function | |
| def quick_search(search_term): | |
| return search_products(search_term) | |
| # UI | |
| with gr.Blocks(title="Product Search", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; margin-bottom: 20px; border-radius: 10px;'> | |
| <h1 style='margin: 0; font-size: 2.5em;'>ποΈ Product Search</h1> | |
| <p style='margin: 10px 0 0 0; opacity: 0.9;'>Find products with AI-powered search</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| init_btn = gr.Button("π Initialize System", variant="primary", size="lg") | |
| status = gr.Textbox(label="System Status", lines=2, interactive=False, value="Click to initialize...") | |
| with gr.Column(scale=3): | |
| search_input = gr.Textbox( | |
| label="Search Products", | |
| placeholder="Search for headphones, keyboards, cases...", | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| method_dropdown = gr.Dropdown( | |
| choices=["hybrid", "vector", "bm25"], | |
| value="hybrid", | |
| label="Search Method", | |
| scale=1 | |
| ) | |
| rerank_checkbox = gr.Checkbox( | |
| value=True, | |
| label="Use Reranking", | |
| scale=1 | |
| ) | |
| search_btn = gr.Button("π Search", variant="primary", size="lg") | |
| gr.Markdown("### Quick Search") | |
| with gr.Row(): | |
| quick_searches = ["bluetooth headphones", "gaming keyboard", "phone case", "running shoes"] | |
| quick_btn1 = gr.Button(quick_searches[0], size="sm") | |
| quick_btn2 = gr.Button(quick_searches[1], size="sm") | |
| quick_btn3 = gr.Button(quick_searches[2], size="sm") | |
| quick_btn4 = gr.Button(quick_searches[3], size="sm") | |
| # Results display | |
| results_display = gr.HTML(value=create_sample_cards()) | |
| # Events | |
| init_btn.click(init_system, outputs=status) | |
| search_btn.click(search_products, inputs=[search_input, method_dropdown, rerank_checkbox], outputs=results_display) | |
| search_input.submit(search_products, inputs=[search_input, method_dropdown, rerank_checkbox], outputs=results_display) | |
| # Quick search button events - Fixed | |
| quick_btn1.click(lambda: search_products(quick_searches[0], "hybrid", True), outputs=results_display) | |
| quick_btn2.click(lambda: search_products(quick_searches[1], "hybrid", True), outputs=results_display) | |
| quick_btn3.click(lambda: search_products(quick_searches[2], "hybrid", True), outputs=results_display) | |
| quick_btn4.click(lambda: search_products(quick_searches[3], "hybrid", True), outputs=results_display) | |
| demo.launch() |