Search-AI-Act / app.py
hoololi's picture
Upload 2 files
9aa44ca verified
import gradio as gr
import pandas as pd
import numpy as np
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from rank_bm25 import BM25Okapi
import re
from typing import List, Tuple, Dict
import threading
import time
# Configuration
DATASET_NAME = "hoololi/AI_Act_with_embeddings"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
TOP_K = 5
class AIActSearchEngine:
def __init__(self):
self.dataset = None
self.embedding_model = None
self.tfidf_vectorizer = None
self.tfidf_matrix = None
self.bm25_model = None
self.processed_docs = None
self.load_data()
self.setup_models()
def load_data(self):
"""Load dataset from Hugging Face"""
print("Loading dataset...")
dataset = load_dataset(DATASET_NAME, split="train")
self.dataset = dataset.to_pandas()
print(f"Dataset loaded: {len(self.dataset)} articles")
def setup_models(self):
"""Initialize models and vectorizers"""
print("Initializing models...")
# Embedding model
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)
# TF-IDF
self.tfidf_vectorizer = TfidfVectorizer(
max_features=10000,
stop_words='english',
lowercase=True,
ngram_range=(1, 2)
)
self.tfidf_matrix = self.tfidf_vectorizer.fit_transform(self.dataset['processed_content'])
# BM25 (using optimized library)
self.setup_bm25()
print("Models initialized successfully!")
def setup_bm25(self):
"""Setup BM25 using optimized library"""
print("Setting up BM25...")
# Tokenize documents for BM25
self.processed_docs = [doc.split() for doc in self.dataset['processed_content']]
# Create BM25 model (much faster than manual implementation)
self.bm25_model = BM25Okapi(self.processed_docs)
print("BM25 setup complete!")
def search_tfidf(self, query: str) -> List[Tuple[str, str, float]]:
"""TF-IDF search"""
query_vector = self.tfidf_vectorizer.transform([query])
similarities = cosine_similarity(query_vector, self.tfidf_matrix).flatten()
# Create a list of (score, index) for stable sorting
scored_docs = [(similarities[i], i) for i in range(len(similarities)) if similarities[i] > 0]
# Sort by descending score, then by ascending index for tie-breaking
scored_docs.sort(key=lambda x: (-x[0], x[1]))
# Take top K and deduplicate
seen_articles = set()
results = []
for score, idx in scored_docs:
article_num = self.dataset.iloc[idx]['article_number']
if article_num not in seen_articles and len(results) < TOP_K:
seen_articles.add(article_num)
results.append((
article_num,
self.dataset.iloc[idx]['article_content'],
score
))
return results
def search_bm25(self, query: str) -> List[Tuple[str, str, float]]:
"""BM25 search using optimized library"""
# Tokenize query
query_tokens = query.lower().split()
# Get BM25 scores (much faster!)
scores = self.bm25_model.get_scores(query_tokens)
# Create scored documents list
scored_docs = [(scores[i], i) for i in range(len(scores)) if scores[i] > 0]
# Sort by descending score, then by ascending index for tie-breaking
scored_docs.sort(key=lambda x: (-x[0], x[1]))
# Take top K and deduplicate
seen_articles = set()
results = []
for score, idx in scored_docs:
article_num = self.dataset.iloc[idx]['article_number']
if article_num not in seen_articles and len(results) < TOP_K:
seen_articles.add(article_num)
results.append((
article_num,
self.dataset.iloc[idx]['article_content'],
score
))
return results
def search_embeddings(self, query: str) -> List[Tuple[str, str, float]]:
"""Embedding similarity search"""
# Encode the query
query_embedding = self.embedding_model.encode([query])
# Get stored embeddings
stored_embeddings = np.array(self.dataset['embedding'].tolist())
# Calculate cosine similarity
similarities = cosine_similarity(query_embedding, stored_embeddings).flatten()
# Create a list of (score, index) for stable sorting
scored_docs = [(similarities[i], i) for i in range(len(similarities))]
# Sort by descending score, then by ascending index for tie-breaking
scored_docs.sort(key=lambda x: (-x[0], x[1]))
# Take top K and deduplicate
seen_articles = set()
results = []
for score, idx in scored_docs:
article_num = self.dataset.iloc[idx]['article_number']
if article_num not in seen_articles and len(results) < TOP_K:
seen_articles.add(article_num)
results.append((
article_num,
self.dataset.iloc[idx]['article_content'],
score
))
return results
def search_all(self, query: str) -> Dict[str, List[Tuple[str, str, float]]]:
"""Perform all searches"""
if not query.strip():
return {
'tfidf': [],
'bm25': [],
'embeddings': []
}
return {
'tfidf': self.search_tfidf(query),
'bm25': self.search_bm25(query),
'embeddings': self.search_embeddings(query)
}
def format_results_table(results: List[Tuple[str, str, float]],
search_type: str,
highlight_articles: set) -> str:
"""Format results as HTML table"""
if not results:
return f"""
<div style="text-align: center; padding: 20px;">
<h3>{search_type}</h3>
<p>No results found</p>
</div>
"""
html = f"""
<div style="margin: 10px;">
<h3 style="text-align: center; margin-bottom: 15px;">{search_type}</h3>
<table style="width: 100%; border-collapse: collapse; font-size: 12px;">
<thead>
<tr style="background-color: #f0f0f0;">
<th style="border: 1px solid #ddd; padding: 8px; width: 80px;">Score</th>
<th style="border: 1px solid #ddd; padding: 8px; width: 80px;">Article</th>
<th style="border: 1px solid #ddd; padding: 8px;">Content</th>
</tr>
</thead>
<tbody>
"""
for article_num, content, score in results:
# Highlight if article is in all 3 searches
bg_color = "background-color: #90EE90;" if article_num in highlight_articles else ""
# Limit content for display
truncated_content = content[:300] + "..." if len(content) > 300 else content
html += f"""
<tr>
<td style="border: 1px solid #ddd; padding: 8px; text-align: center;">{score:.3f}</td>
<td style="border: 1px solid #ddd; padding: 8px; text-align: center; {bg_color}">{article_num}</td>
<td style="border: 1px solid #ddd; padding: 8px; max-width: 400px; overflow: hidden;">
<details>
<summary style="cursor: pointer; font-weight: bold;">View content</summary>
<div style="margin-top: 10px; white-space: pre-wrap;">{content}</div>
</details>
</td>
</tr>
"""
html += """
</tbody>
</table>
</div>
"""
return html
def search_articles_progressive(query: str, search_engine: AIActSearchEngine):
"""Progressive search function with sequential updates"""
if not query.strip():
empty_table = """
<div style="text-align: center; padding: 20px;">
<p>Enter a query to start searching</p>
</div>
"""
return empty_table, empty_table, empty_table
# Initialize loading states
loading_table = """
<div style="text-align: center; padding: 20px;">
<div style="display: inline-block; width: 20px; height: 20px; border: 3px solid #f3f3f3; border-top: 3px solid #3498db; border-radius: 50%; animation: spin 1s linear infinite;"></div>
<p style="margin-top: 10px;">Searching...</p>
<style>
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
</style>
</div>
"""
# Results storage
results = {'tfidf': [], 'bm25': [], 'embeddings': []}
# 1. TF-IDF Search (fastest)
print("Starting TF-IDF search...")
start_time = time.time()
results['tfidf'] = search_engine.search_tfidf(query)
print(f"TF-IDF completed in {time.time() - start_time:.2f}s")
# 2. BM25 Search (fast with optimized library)
print("Starting BM25 search...")
start_time = time.time()
results['bm25'] = search_engine.search_bm25(query)
print(f"BM25 completed in {time.time() - start_time:.2f}s")
# 3. Embeddings Search (slowest)
print("Starting embeddings search...")
start_time = time.time()
results['embeddings'] = search_engine.search_embeddings(query)
print(f"Embeddings completed in {time.time() - start_time:.2f}s")
# Identify articles present in all 3 searches
tfidf_articles = {r[0] for r in results['tfidf']}
bm25_articles = {r[0] for r in results['bm25']}
embedding_articles = {r[0] for r in results['embeddings']}
# Articles present in all 3 methods
highlight_articles = tfidf_articles & bm25_articles & embedding_articles
# Format final results
tfidf_html = format_results_table(results['tfidf'], "TF-IDF", highlight_articles)
bm25_html = format_results_table(results['bm25'], "BM25", highlight_articles)
embeddings_html = format_results_table(results['embeddings'], "Embeddings", highlight_articles)
return tfidf_html, bm25_html, embeddings_html
def search_articles(query: str, search_engine: AIActSearchEngine):
"""Main search function (kept for compatibility)"""
return search_articles_progressive(query, search_engine)
def main():
"""Main function to launch the application"""
print("Initializing application...")
# Initialize search engine
search_engine = AIActSearchEngine()
# Create Gradio interface
with gr.Blocks(title="AI Act Search Tool", theme=gr.themes.Default()) as app:
gr.Markdown("# πŸ” AI Act Textual Search Tool")
gr.Markdown("Compare results from different textual search methods on AI Act articles")
with gr.Row():
with gr.Column(scale=3):
query_input = gr.Textbox(
label="Search Query",
placeholder="Enter your keywords...",
lines=2
)
with gr.Column(scale=1):
search_button = gr.Button("πŸ” Search", variant="primary")
gr.Markdown("---")
gr.Markdown("### Search Results")
gr.Markdown("Articles found by all 3 methods are highlighted in **light green**")
with gr.Row():
tfidf_output = gr.HTML(label="TF-IDF")
bm25_output = gr.HTML(label="BM25")
embeddings_output = gr.HTML(label="Embeddings")
# Define actions
search_button.click(
fn=lambda q: search_articles(q, search_engine),
inputs=query_input,
outputs=[tfidf_output, bm25_output, embeddings_output]
)
# Allow search with Enter
query_input.submit(
fn=lambda q: search_articles(q, search_engine),
inputs=query_input,
outputs=[tfidf_output, bm25_output, embeddings_output]
)
# Launch application
print("Launching Gradio application...")
app.launch(debug=True, share=False)
if __name__ == "__main__":
main()