myraggradio / app.py
Soha85's picture
Upload 3 files
3f86984 verified
"""
Gradio UI – two pages: Index & Retrieve.
"""
import gradio as gr
from rag_engine import RagEngine
engine = RagEngine()
# -------------------- INDEX TAB --------------------
def index_fn(url, chunk_size, overlap):
try:
text = engine.scrape_bbc_article(url)
except Exception as e:
return f"Scrape error: {e}"
chunks = engine.chunk_text(text, chunk_size, overlap)
engine.build_index(chunks)
return f"Indexed {len(chunks)} chunks."
# -------------------- RETRIEVE TAB --------------------
def retrieve_fn(query, k, rerank, llm_choice, num_answers, temperature):
if engine.faiss_idx is None:
return "Please index articles first.", "", ""
retrieved = engine.retrieve(query, k, rerank)
retrieval_md = "\n\n".join([f"**Chunk {i+1}** (cross-score={h['cross_score']:.3f})\n{h['chunk'][:300]}..."
for i, h in enumerate(retrieved)])
answers, scores = [], []
for _ in range(num_answers):
ans = engine.generate(query, retrieved, llm_choice, temperature=temperature)
# use first retrieved chunk as pseudo-reference for metric demo
metrics = engine.compute_metrics(retrieved[0]["chunk"], ans)
answers.append(ans)
scores.append(metrics)
answers_md = "\n\n".join([f"**Answer {i+1}** \nBLEU={s['bleu']:.2f} ROUGE-L={s['rouge-l']:.2f} \n{a}"
for i, (a, s) in enumerate(zip(answers, scores))])
return retrieval_md, answers_md, ""
# -------------------- UI --------------------
with gr.Blocks(title="BBC RAG Demo") as demo:
gr.Markdown("## BBC Culture / Travel RAG Demo")
with gr.Tab("Index"):
url_in = gr.Textbox(label="BBC article URL", placeholder="https://www.bbc.com/culture/article/20231013-xxx")
with gr.Row():
chunk_size = gr.Slider(100, 800, 300, step=50, label="Chunk size (words)")
overlap = gr.Slider(0, 100, 20, step=10, label="Overlap (words)")
index_btn = gr.Button("Index article")
index_out = gr.Textbox(label="Status")
index_btn.click(index_fn, inputs=[url_in, chunk_size, overlap], outputs=index_out)
with gr.Tab("Retrieve"):
query_in = gr.Textbox(label="Question")
with gr.Row():
k = gr.Slider(1, 30, 5, step=1, label="Retrieve k")
rerank = gr.Slider(1, 20, 3, step=1, label="Re-rank top")
num_ans = gr.Slider(1, 10, 2, step=1, label="Number of answers")
temperature = gr.Slider(0.0, 1.0, 0.0, step=0.05, label="Temperature")
llm_choice = gr.Dropdown(choices=["gpt2", "distilgpt2", "flan-t5-small", "flan-t5-base"],
value="gpt2", label="LLM")
go_btn = gr.Button("Answer")
retrieval_out = gr.Markdown(label="Retrieved chunks")
answers_out = gr.Markdown(label="Generated answers")
go_btn.click(retrieve_fn,
inputs=[query_in, k, rerank, llm_choice, num_ans, temperature],
outputs=[retrieval_out, answers_out])
demo.launch()