Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from rank_bm25 import BM25Okapi | |
| import re | |
| from typing import List, Tuple, Dict | |
| import threading | |
| import time | |
| # Configuration | |
| DATASET_NAME = "hoololi/AI_Act_with_embeddings" | |
| EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" | |
| TOP_K = 5 | |
| class AIActSearchEngine: | |
| def __init__(self): | |
| self.dataset = None | |
| self.embedding_model = None | |
| self.tfidf_vectorizer = None | |
| self.tfidf_matrix = None | |
| self.bm25_model = None | |
| self.processed_docs = None | |
| self.load_data() | |
| self.setup_models() | |
| def load_data(self): | |
| """Load dataset from Hugging Face""" | |
| print("Loading dataset...") | |
| dataset = load_dataset(DATASET_NAME, split="train") | |
| self.dataset = dataset.to_pandas() | |
| print(f"Dataset loaded: {len(self.dataset)} articles") | |
| def setup_models(self): | |
| """Initialize models and vectorizers""" | |
| print("Initializing models...") | |
| # Embedding model | |
| self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| # TF-IDF | |
| self.tfidf_vectorizer = TfidfVectorizer( | |
| max_features=10000, | |
| stop_words='english', | |
| lowercase=True, | |
| ngram_range=(1, 2) | |
| ) | |
| self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(self.dataset['processed_content']) | |
| # BM25 (using optimized library) | |
| self.setup_bm25() | |
| print("Models initialized successfully!") | |
| def setup_bm25(self): | |
| """Setup BM25 using optimized library""" | |
| print("Setting up BM25...") | |
| # Tokenize documents for BM25 | |
| self.processed_docs = [doc.split() for doc in self.dataset['processed_content']] | |
| # Create BM25 model (much faster than manual implementation) | |
| self.bm25_model = BM25Okapi(self.processed_docs) | |
| print("BM25 setup complete!") | |
| def search_tfidf(self, query: str) -> List[Tuple[str, str, float]]: | |
| """TF-IDF search""" | |
| query_vector = self.tfidf_vectorizer.transform([query]) | |
| similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten() | |
| # Create a list of (score, index) for stable sorting | |
| scored_docs = [(similarities[i], i) for i in range(len(similarities)) if similarities[i] > 0] | |
| # Sort by descending score, then by ascending index for tie-breaking | |
| scored_docs.sort(key=lambda x: (-x[0], x[1])) | |
| # Take top K and deduplicate | |
| seen_articles = set() | |
| results = [] | |
| for score, idx in scored_docs: | |
| article_num = self.dataset.iloc[idx]['article_number'] | |
| if article_num not in seen_articles and len(results) < TOP_K: | |
| seen_articles.add(article_num) | |
| results.append(( | |
| article_num, | |
| self.dataset.iloc[idx]['article_content'], | |
| score | |
| )) | |
| return results | |
| def search_bm25(self, query: str) -> List[Tuple[str, str, float]]: | |
| """BM25 search using optimized library""" | |
| # Tokenize query | |
| query_tokens = query.lower().split() | |
| # Get BM25 scores (much faster!) | |
| scores = self.bm25_model.get_scores(query_tokens) | |
| # Create scored documents list | |
| scored_docs = [(scores[i], i) for i in range(len(scores)) if scores[i] > 0] | |
| # Sort by descending score, then by ascending index for tie-breaking | |
| scored_docs.sort(key=lambda x: (-x[0], x[1])) | |
| # Take top K and deduplicate | |
| seen_articles = set() | |
| results = [] | |
| for score, idx in scored_docs: | |
| article_num = self.dataset.iloc[idx]['article_number'] | |
| if article_num not in seen_articles and len(results) < TOP_K: | |
| seen_articles.add(article_num) | |
| results.append(( | |
| article_num, | |
| self.dataset.iloc[idx]['article_content'], | |
| score | |
| )) | |
| return results | |
| def search_embeddings(self, query: str) -> List[Tuple[str, str, float]]: | |
| """Embedding similarity search""" | |
| # Encode the query | |
| query_embedding = self.embedding_model.encode([query]) | |
| # Get stored embeddings | |
| stored_embeddings = np.array(self.dataset['embedding'].tolist()) | |
| # Calculate cosine similarity | |
| similarities = cosine_similarity(query_embedding, stored_embeddings).flatten() | |
| # Create a list of (score, index) for stable sorting | |
| scored_docs = [(similarities[i], i) for i in range(len(similarities))] | |
| # Sort by descending score, then by ascending index for tie-breaking | |
| scored_docs.sort(key=lambda x: (-x[0], x[1])) | |
| # Take top K and deduplicate | |
| seen_articles = set() | |
| results = [] | |
| for score, idx in scored_docs: | |
| article_num = self.dataset.iloc[idx]['article_number'] | |
| if article_num not in seen_articles and len(results) < TOP_K: | |
| seen_articles.add(article_num) | |
| results.append(( | |
| article_num, | |
| self.dataset.iloc[idx]['article_content'], | |
| score | |
| )) | |
| return results | |
| def search_all(self, query: str) -> Dict[str, List[Tuple[str, str, float]]]: | |
| """Perform all searches""" | |
| if not query.strip(): | |
| return { | |
| 'tfidf': [], | |
| 'bm25': [], | |
| 'embeddings': [] | |
| } | |
| return { | |
| 'tfidf': self.search_tfidf(query), | |
| 'bm25': self.search_bm25(query), | |
| 'embeddings': self.search_embeddings(query) | |
| } | |
| def format_results_table(results: List[Tuple[str, str, float]], | |
| search_type: str, | |
| highlight_articles: set) -> str: | |
| """Format results as HTML table""" | |
| if not results: | |
| return f""" | |
| <div style="text-align: center; padding: 20px;"> | |
| <h3>{search_type}</h3> | |
| <p>No results found</p> | |
| </div> | |
| """ | |
| html = f""" | |
| <div style="margin: 10px;"> | |
| <h3 style="text-align: center; margin-bottom: 15px;">{search_type}</h3> | |
| <table style="width: 100%; border-collapse: collapse; font-size: 12px;"> | |
| <thead> | |
| <tr style="background-color: #f0f0f0;"> | |
| <th style="border: 1px solid #ddd; padding: 8px; width: 80px;">Score</th> | |
| <th style="border: 1px solid #ddd; padding: 8px; width: 80px;">Article</th> | |
| <th style="border: 1px solid #ddd; padding: 8px;">Content</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| """ | |
| for article_num, content, score in results: | |
| # Highlight if article is in all 3 searches | |
| bg_color = "background-color: #90EE90;" if article_num in highlight_articles else "" | |
| # Limit content for display | |
| truncated_content = content[:300] + "..." if len(content) > 300 else content | |
| html += f""" | |
| <tr> | |
| <td style="border: 1px solid #ddd; padding: 8px; text-align: center;">{score:.3f}</td> | |
| <td style="border: 1px solid #ddd; padding: 8px; text-align: center; {bg_color}">{article_num}</td> | |
| <td style="border: 1px solid #ddd; padding: 8px; max-width: 400px; overflow: hidden;"> | |
| <details> | |
| <summary style="cursor: pointer; font-weight: bold;">View content</summary> | |
| <div style="margin-top: 10px; white-space: pre-wrap;">{content}</div> | |
| </details> | |
| </td> | |
| </tr> | |
| """ | |
| html += """ | |
| </tbody> | |
| </table> | |
| </div> | |
| """ | |
| return html | |
| def search_articles_progressive(query: str, search_engine: AIActSearchEngine): | |
| """Progressive search function with sequential updates""" | |
| if not query.strip(): | |
| empty_table = """ | |
| <div style="text-align: center; padding: 20px;"> | |
| <p>Enter a query to start searching</p> | |
| </div> | |
| """ | |
| return empty_table, empty_table, empty_table | |
| # Initialize loading states | |
| loading_table = """ | |
| <div style="text-align: center; padding: 20px;"> | |
| <div style="display: inline-block; width: 20px; height: 20px; border: 3px solid #f3f3f3; border-top: 3px solid #3498db; border-radius: 50%; animation: spin 1s linear infinite;"></div> | |
| <p style="margin-top: 10px;">Searching...</p> | |
| <style> | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| </style> | |
| </div> | |
| """ | |
| # Results storage | |
| results = {'tfidf': [], 'bm25': [], 'embeddings': []} | |
| # 1. TF-IDF Search (fastest) | |
| print("Starting TF-IDF search...") | |
| start_time = time.time() | |
| results['tfidf'] = search_engine.search_tfidf(query) | |
| print(f"TF-IDF completed in {time.time() - start_time:.2f}s") | |
| # 2. BM25 Search (fast with optimized library) | |
| print("Starting BM25 search...") | |
| start_time = time.time() | |
| results['bm25'] = search_engine.search_bm25(query) | |
| print(f"BM25 completed in {time.time() - start_time:.2f}s") | |
| # 3. Embeddings Search (slowest) | |
| print("Starting embeddings search...") | |
| start_time = time.time() | |
| results['embeddings'] = search_engine.search_embeddings(query) | |
| print(f"Embeddings completed in {time.time() - start_time:.2f}s") | |
| # Identify articles present in all 3 searches | |
| tfidf_articles = {r[0] for r in results['tfidf']} | |
| bm25_articles = {r[0] for r in results['bm25']} | |
| embedding_articles = {r[0] for r in results['embeddings']} | |
| # Articles present in all 3 methods | |
| highlight_articles = tfidf_articles & bm25_articles & embedding_articles | |
| # Format final results | |
| tfidf_html = format_results_table(results['tfidf'], "TF-IDF", highlight_articles) | |
| bm25_html = format_results_table(results['bm25'], "BM25", highlight_articles) | |
| embeddings_html = format_results_table(results['embeddings'], "Embeddings", highlight_articles) | |
| return tfidf_html, bm25_html, embeddings_html | |
| def search_articles(query: str, search_engine: AIActSearchEngine): | |
| """Main search function (kept for compatibility)""" | |
| return search_articles_progressive(query, search_engine) | |
| def main(): | |
| """Main function to launch the application""" | |
| print("Initializing application...") | |
| # Initialize search engine | |
| search_engine = AIActSearchEngine() | |
| # Create Gradio interface | |
| with gr.Blocks(title="AI Act Search Tool", theme=gr.themes.Default()) as app: | |
| gr.Markdown("# π AI Act Textual Search Tool") | |
| gr.Markdown("Compare results from different textual search methods on AI Act articles") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query_input = gr.Textbox( | |
| label="Search Query", | |
| placeholder="Enter your keywords...", | |
| lines=2 | |
| ) | |
| with gr.Column(scale=1): | |
| search_button = gr.Button("π Search", variant="primary") | |
| gr.Markdown("---") | |
| gr.Markdown("### Search Results") | |
| gr.Markdown("Articles found by all 3 methods are highlighted in **light green**") | |
| with gr.Row(): | |
| tfidf_output = gr.HTML(label="TF-IDF") | |
| bm25_output = gr.HTML(label="BM25") | |
| embeddings_output = gr.HTML(label="Embeddings") | |
| # Define actions | |
| search_button.click( | |
| fn=lambda q: search_articles(q, search_engine), | |
| inputs=query_input, | |
| outputs=[tfidf_output, bm25_output, embeddings_output] | |
| ) | |
| # Allow search with Enter | |
| query_input.submit( | |
| fn=lambda q: search_articles(q, search_engine), | |
| inputs=query_input, | |
| outputs=[tfidf_output, bm25_output, embeddings_output] | |
| ) | |
| # Launch application | |
| print("Launching Gradio application...") | |
| app.launch(debug=True, share=False) | |
| if __name__ == "__main__": | |
| main() |