DeathBlade020's picture
Update app.py
5a18a8e verified
import gradio as gr
import os
import hashlib
import json
import pickle
from datetime import datetime, timedelta
from pathlib import Path
from dotenv import load_dotenv
# Import your original RAG technique modules
from Hyde import get_answer_using_hyde
from QueryDecomposition import get_answer_using_query_decomposition
from QueryExpansion import get_answer_using_query_expansion
from RagFusion import get_answer_using_rag_fusion
from StepBackQuery import get_answer
# Import new advanced retrieval techniques
from AdvancedRag import (
get_answer_using_multi_query,
get_answer_using_parent_child,
get_answer_using_contextual_compression,
get_answer_using_cross_encoder,
get_answer_using_semantic_routing
)
load_dotenv()
# Cache configuration
CACHE_DIR = Path("rag_cache")
CACHE_DIR.mkdir(exist_ok=True)
CACHE_EXPIRY_HOURS = 24 # Cache expires after 24 hours
# Extended dictionary mapping technique names to their corresponding functions
RAG_TECHNIQUES = {
# Original Techniques
"HyDE (Hypothetical Document Embeddings)": get_answer_using_hyde,
"Query Decomposition": get_answer_using_query_decomposition,
"Query Expansion": get_answer_using_query_expansion,
"RAG Fusion": get_answer_using_rag_fusion,
"Step Back Query": get_answer,
# Advanced Retrieval Techniques
"Multi-Query Retrieval": get_answer_using_multi_query,
"Parent-Child Retrieval": get_answer_using_parent_child,
"Contextual Compression": get_answer_using_contextual_compression,
"Cross-Encoder Reranking": get_answer_using_cross_encoder,
"Semantic Routing": get_answer_using_semantic_routing,
}
def generate_cache_key(link, technique):
"""
Generate a unique cache key based on link and technique
"""
cache_string = f"{link}_{technique}"
return hashlib.md5(cache_string.encode()).hexdigest()
def get_cache_file_path(cache_key):
"""
Get the full path for a cache file
"""
return CACHE_DIR / f"{cache_key}.pkl"
def is_cache_valid(cache_file_path):
"""
Check if cache file exists and is not expired
"""
if not cache_file_path.exists():
return False
# Check if cache is expired
file_time = datetime.fromtimestamp(cache_file_path.stat().st_mtime)
expiry_time = datetime.now() - timedelta(hours=CACHE_EXPIRY_HOURS)
return file_time > expiry_time
def save_to_cache(cache_key, data):
"""
Save data to cache file
"""
try:
cache_file_path = get_cache_file_path(cache_key)
cache_data = {
'data': data,
'timestamp': datetime.now().isoformat(),
'cache_key': cache_key
}
with open(cache_file_path, 'wb') as f:
pickle.dump(cache_data, f)
print(f"βœ… Cached result for key: {cache_key}")
return True
except Exception as e:
print(f"❌ Failed to save cache: {e}")
return False
def load_from_cache(cache_key):
"""
Load data from cache file
"""
try:
cache_file_path = get_cache_file_path(cache_key)
if not is_cache_valid(cache_file_path):
return None
with open(cache_file_path, 'rb') as f:
cache_data = pickle.load(f)
print(f"🎯 Cache hit for key: {cache_key}")
return cache_data['data']
except Exception as e:
print(f"❌ Failed to load cache: {e}")
return None
def clear_expired_cache():
"""
Automatically clear expired cache files
"""
try:
cache_files = list(CACHE_DIR.glob("*.pkl"))
expired_count = 0
for cache_file in cache_files:
if not is_cache_valid(cache_file):
cache_file.unlink()
expired_count += 1
if expired_count > 0:
print(f"🧹 Auto-cleared {expired_count} expired cache files")
except Exception as e:
print(f"❌ Failed to auto-clear expired cache: {e}")
def process_rag_query(link, question, technique):
"""
Process the RAG query using the selected technique with caching
"""
try:
if not link or not question:
return "Please provide both a link and a question."
if not link.startswith(('http://', 'https://')):
return "Please provide a valid URL starting with http:// or https://"
# Auto-clear expired cache files
clear_expired_cache()
# Generate cache key based on link and technique
cache_key = generate_cache_key(link, technique)
# Try to load from cache first
cached_result = load_from_cache(cache_key)
if cached_result is not None:
# Check if we have this specific question cached
if isinstance(cached_result, dict) and question in cached_result:
return cached_result[question]
# Get the corresponding function for the selected technique
rag_function = RAG_TECHNIQUES.get(technique)
if not rag_function:
return "Invalid technique selected."
print(f"πŸ”„ Processing new query: {technique} for {link}")
# Call the appropriate RAG function
answer = rag_function(link, question)
# Save to cache
if cached_result is None:
cached_result = {}
elif not isinstance(cached_result, dict):
cached_result = {}
cached_result[question] = answer
save_to_cache(cache_key, cached_result)
return answer
except Exception as e:
return f"Error processing query: {str(e)}\n\nNote: Advanced techniques require additional dependencies. Make sure you have installed: sentence-transformers, scikit-learn"
def create_webpage_preview(link):
"""
Create an HTML iframe to preview the webpage
"""
if not link:
return ""
if not link.startswith(('http://', 'https://')):
return "<p style='color: red;'>Please provide a valid URL starting with http:// or https://</p>"
# Create an iframe to display the webpage
iframe_html = f"""
<div style="width: 100%; height: 500px; border: 1px solid #ccc; border-radius: 5px;">
<iframe src="{link}" width="100%" height="100%" frameborder="0"
style="border-radius: 5px;">
<p>Your browser does not support iframes.
<a href="{link}" target="_blank">Click here to open the link</a></p>
</iframe>
</div>
"""
return iframe_html
# Create the Gradio interface
def create_interface():
with gr.Blocks(title="Advanced RAG Techniques", theme=gr.themes.Soft()) as demo: # type: ignore
gr.Markdown("""
# πŸš€ Advanced RAG Techniques Comparison Tool
""")
# This tool now includes **5 advanced retrieval techniques** alongside the original methods:
# **πŸ”₯ New Advanced Techniques:**
# - **Multi-Query Retrieval** - Generate diverse queries for comprehensive results
# - **Parent-Child Retrieval** - Search with small chunks, return large context
# - **Contextual Compression** - AI-powered relevance filtering
# - **Cross-Encoder Reranking** - Superior relevance scoring
# - **Semantic Routing** - Smart query classification and routing
# **Instructions:**
# 1. Enter a valid URL in the link box
# 2. Preview the webpage content
# 3. Enter your question about the content
# 4. Select a RAG technique from the dropdown (try the new advanced ones!)
# 5. Click Submit to get your answer
# """)
with gr.Row():
with gr.Column(scale=1):
# Input section
gr.Markdown("## πŸ“ Input Section")
link_input = gr.Textbox(
label="Website URL",
placeholder="https://example.com/article",
info="Enter the URL of the webpage you want to analyze"
)
question_input = gr.Textbox(
label="Your Question",
placeholder="What is the main topic discussed in this article?",
info="Ask any question about the content of the webpage"
)
technique_dropdown = gr.Dropdown(
choices=list(RAG_TECHNIQUES.keys()),
label="RAG Technique",
value="Multi-Query Retrieval",
info="Choose the RAG technique - try the new advanced techniques!"
)
submit_btn = gr.Button("πŸš€ Submit Query", variant="primary", size="lg")
# Output section
gr.Markdown("## πŸ’‘ Answer")
answer_output = gr.Textbox(
label="Generated Answer",
lines=10,
interactive=False,
placeholder="Your answer will appear here..."
)
with gr.Column(scale=1):
# Webpage preview section
gr.Markdown("## 🌐 Webpage Preview")
webpage_preview = gr.HTML(
label="Webpage Content",
value="<p style='text-align: center; color: #666; padding: 50px;'>Enter a URL to preview the webpage</p>"
)
# Event handlers
link_input.change(
fn=create_webpage_preview,
inputs=[link_input],
outputs=[webpage_preview]
)
submit_btn.click(
fn=process_rag_query,
inputs=[link_input, question_input, technique_dropdown],
outputs=[answer_output]
)
# Add some example links and questions
# gr.Markdown("""
# ## πŸ“š Example Usage & Technique Comparison
# **Sample URLs to try:**
# - `https://lilianweng.github.io/posts/2023-06-23-agent/` (AI Agents blog post)
# - `https://docs.python.org/3/tutorial/` (Python Tutorial)
# - `https://en.wikipedia.org/wiki/Machine_learning` (Machine Learning Wikipedia)
# **Sample Questions:**
# - "What is task decomposition for LLM agents?"
# - "What are the main components of an AI agent?"
# - "How does retrieval-augmented generation work?"
# **πŸ’‘ Pro Tip:** Try the same question with different techniques to see how results vary!
# """)
# # Add advanced technique descriptions
# with gr.Accordion("πŸ”§ Advanced RAG Techniques Explained", open=False):
# gr.Markdown("""
# ## Original Techniques:
# **HyDE:** Generates a hypothetical answer first, then uses it to retrieve relevant documents.
# **Query Decomposition:** Breaks down complex questions into simpler sub-questions that are answered sequentially.
# **Query Expansion:** Generates multiple variations of the original query to improve retrieval coverage.
# **RAG Fusion:** Creates multiple related queries and uses reciprocal rank fusion to combine results.
# **Step Back Query:** Transforms specific questions into more general ones to retrieve broader context.
# ## πŸš€ Advanced Techniques:
# **Multi-Query Retrieval:** Generates 4+ diverse query perspectives and merges results for comprehensive coverage.
# **Parent-Child Retrieval:** Uses small chunks for precise matching but returns larger parent chunks for better context.
# **Contextual Compression:** Uses LLM to compress retrieved chunks, keeping only information relevant to your question.
# **Cross-Encoder Reranking:** Uses specialized neural models to score and rerank documents for superior relevance.
# **Semantic Routing:** Automatically classifies your query type (factual, conceptual, comparative, analytical) and routes to the best retrieval strategy.
# """)
# # Installation requirements
# with gr.Accordion("πŸ“¦ Additional Dependencies for Advanced Techniques", open=False):
# gr.Markdown("""
# To use the advanced retrieval techniques, install these additional packages:
# ```bash
# pip install sentence-transformers scikit-learn
# ```
# If you encounter errors with advanced techniques, make sure these dependencies are installed.
# """)
return demo
# Launch the application
if __name__ == "__main__":
# Check if required environment variables are set
if not os.getenv("OPENAI_API_KEY"):
print("Warning: OPENAI_API_KEY not found in environment variables.")
print("Please make sure to set your OpenAI API key in your .env file.")
# Create and launch the interface
demo = create_interface()
demo.launch(
share=True, # Set to True if you want a public link
)