discover_rag / app.py
joelg's picture
initial attempt
8a18ce0
raw
history blame
8.32 kB
import gradio as gr
import spaces
from rag_system import RAGSystem
from i18n import get_text
# Initialize RAG system
rag = RAGSystem()
# Language state
language = "en"
def switch_language(lang):
global language
language = lang
return update_interface()
def update_interface():
t = lambda key: get_text(key, language)
return {
# Update all interface elements with new language
}
@spaces.GPU
def process_pdf(pdf_file, chunk_size, chunk_overlap):
"""Process uploaded PDF and create embeddings"""
t = lambda key: get_text(key, language)
try:
if pdf_file is None:
# Load default corpus
status = rag.load_default_corpus(chunk_size, chunk_overlap)
else:
status = rag.process_document(pdf_file.name, chunk_size, chunk_overlap)
return status
except Exception as e:
return f"{t('error')}: {str(e)}"
@spaces.GPU
def perform_query(
query,
embedding_model,
top_k,
similarity_threshold,
llm_model,
temperature,
max_tokens
):
"""Perform RAG query and return results"""
t = lambda key: get_text(key, language)
if not rag.is_ready():
return t("no_corpus"), "", "", ""
try:
# Set models and parameters
rag.set_embedding_model(embedding_model)
rag.set_llm_model(llm_model)
# Retrieve relevant chunks
results = rag.retrieve(query, top_k, similarity_threshold)
# Format retrieved chunks display
chunks_display = format_chunks(results, t)
# Generate answer
answer, prompt = rag.generate(
query,
results,
temperature,
max_tokens
)
return answer, chunks_display, prompt, ""
except Exception as e:
return "", "", "", f"{t('error')}: {str(e)}"
def format_chunks(results, t):
"""Format retrieved chunks with scores for display"""
output = f"### {t('retrieved_chunks')}\n\n"
for i, (chunk, score) in enumerate(results, 1):
output += f"**Chunk {i}** - {t('similarity_score')}: {score:.4f}\n"
output += f"```\n{chunk}\n```\n\n"
return output
def create_interface():
t = lambda key: get_text(key, language)
with gr.Blocks(title="RAG Pedagogical Demo", theme=gr.themes.Soft()) as demo:
# Header with language selector
with gr.Row():
gr.Markdown("# 🎓 RAG Pedagogical Demo / Démo Pédagogique RAG")
lang_radio = gr.Radio(
choices=["en", "fr"],
value="en",
label="Language / Langue"
)
with gr.Tabs() as tabs:
# Tab 1: Corpus Management
with gr.Tab(label="📚 Corpus"):
gr.Markdown(f"## {t('corpus_management')}")
gr.Markdown(t('corpus_description'))
pdf_upload = gr.File(
label=t('upload_pdf'),
file_types=[".pdf"]
)
with gr.Row():
chunk_size = gr.Slider(
minimum=100,
maximum=1000,
value=500,
step=50,
label=t('chunk_size')
)
chunk_overlap = gr.Slider(
minimum=0,
maximum=200,
value=50,
step=10,
label=t('chunk_overlap')
)
process_btn = gr.Button(t('process_corpus'), variant="primary")
corpus_status = gr.Textbox(label=t('status'), interactive=False)
process_btn.click(
fn=process_pdf,
inputs=[pdf_upload, chunk_size, chunk_overlap],
outputs=corpus_status
)
# Tab 2: Retrieval Configuration
with gr.Tab(label="🔍 Retrieval"):
gr.Markdown(f"## {t('retrieval_config')}")
embedding_model = gr.Dropdown(
choices=[
"sentence-transformers/all-MiniLM-L6-v2",
"sentence-transformers/all-mpnet-base-v2",
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
],
value="sentence-transformers/all-MiniLM-L6-v2",
label=t('embedding_model')
)
with gr.Row():
top_k = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label=t('top_k')
)
similarity_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label=t('similarity_threshold')
)
# Tab 3: Generation Configuration
with gr.Tab(label="🤖 Generation"):
gr.Markdown(f"## {t('generation_config')}")
llm_model = gr.Dropdown(
choices=[
"HuggingFaceH4/zephyr-7b-beta",
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Llama-2-7b-chat-hf",
],
value="HuggingFaceH4/zephyr-7b-beta",
label=t('llm_model')
)
with gr.Row():
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label=t('temperature')
)
max_tokens = gr.Slider(
minimum=50,
maximum=1000,
value=300,
step=50,
label=t('max_tokens')
)
# Tab 4: Query & Results
with gr.Tab(label="💬 Query"):
gr.Markdown(f"## {t('ask_question')}")
query_input = gr.Textbox(
label=t('your_question'),
placeholder=t('question_placeholder'),
lines=3
)
examples = gr.Examples(
examples=[
["What is Retrieval Augmented Generation?"],
["How does RAG improve language models?"],
["What are the main components of a RAG system?"],
],
inputs=query_input,
label=t('example_questions')
)
query_btn = gr.Button(t('submit_query'), variant="primary")
gr.Markdown(f"### {t('answer')}")
answer_output = gr.Markdown()
with gr.Accordion(t('retrieved_chunks'), open=True):
chunks_output = gr.Markdown()
with gr.Accordion(t('prompt_sent'), open=False):
prompt_output = gr.Code(language="text")
error_output = gr.Textbox(label=t('errors'), visible=False)
query_btn.click(
fn=perform_query,
inputs=[
query_input,
embedding_model,
top_k,
similarity_threshold,
llm_model,
temperature,
max_tokens
],
outputs=[answer_output, chunks_output, prompt_output, error_output]
)
# Footer
gr.Markdown("""
---
**Note**: This is a pedagogical demonstration of RAG systems.
Models run on HuggingFace ZeroGPU infrastructure.
**Note** : Ceci est une démonstration pédagogique des systèmes RAG.
Les modèles tournent sur l'infrastructure HuggingFace ZeroGPU.
""")
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch()