Spaces:
Sleeping
Sleeping
| """ | |
| Gradio UI – two pages: Index & Retrieve. | |
| """ | |
| import gradio as gr | |
| from rag_engine import RagEngine | |
| engine = RagEngine() | |
| # -------------------- INDEX TAB -------------------- | |
| def index_fn(url, chunk_size, overlap): | |
| try: | |
| text = engine.scrape_bbc_article(url) | |
| except Exception as e: | |
| return f"Scrape error: {e}" | |
| chunks = engine.chunk_text(text, chunk_size, overlap) | |
| engine.build_index(chunks) | |
| return f"Indexed {len(chunks)} chunks." | |
| # -------------------- RETRIEVE TAB -------------------- | |
| def retrieve_fn(query, k, rerank, llm_choice, num_answers, temperature): | |
| if engine.faiss_idx is None: | |
| return "Please index articles first.", "", "" | |
| retrieved = engine.retrieve(query, k, rerank) | |
| retrieval_md = "\n\n".join([f"**Chunk {i+1}** (cross-score={h['cross_score']:.3f})\n{h['chunk'][:300]}..." | |
| for i, h in enumerate(retrieved)]) | |
| answers, scores = [], [] | |
| for _ in range(num_answers): | |
| ans = engine.generate(query, retrieved, llm_choice, temperature=temperature) | |
| # use first retrieved chunk as pseudo-reference for metric demo | |
| metrics = engine.compute_metrics(retrieved[0]["chunk"], ans) | |
| answers.append(ans) | |
| scores.append(metrics) | |
| answers_md = "\n\n".join([f"**Answer {i+1}** \nBLEU={s['bleu']:.2f} ROUGE-L={s['rouge-l']:.2f} \n{a}" | |
| for i, (a, s) in enumerate(zip(answers, scores))]) | |
| return retrieval_md, answers_md, "" | |
| # -------------------- UI -------------------- | |
| with gr.Blocks(title="BBC RAG Demo") as demo: | |
| gr.Markdown("## BBC Culture / Travel RAG Demo") | |
| with gr.Tab("Index"): | |
| url_in = gr.Textbox(label="BBC article URL", placeholder="https://www.bbc.com/culture/article/20231013-xxx") | |
| with gr.Row(): | |
| chunk_size = gr.Slider(100, 800, 300, step=50, label="Chunk size (words)") | |
| overlap = gr.Slider(0, 100, 20, step=10, label="Overlap (words)") | |
| index_btn = gr.Button("Index article") | |
| index_out = gr.Textbox(label="Status") | |
| index_btn.click(index_fn, inputs=[url_in, chunk_size, overlap], outputs=index_out) | |
| with gr.Tab("Retrieve"): | |
| query_in = gr.Textbox(label="Question") | |
| with gr.Row(): | |
| k = gr.Slider(1, 30, 5, step=1, label="Retrieve k") | |
| rerank = gr.Slider(1, 20, 3, step=1, label="Re-rank top") | |
| num_ans = gr.Slider(1, 10, 2, step=1, label="Number of answers") | |
| temperature = gr.Slider(0.0, 1.0, 0.0, step=0.05, label="Temperature") | |
| llm_choice = gr.Dropdown(choices=["gpt2", "distilgpt2", "flan-t5-small", "flan-t5-base"], | |
| value="gpt2", label="LLM") | |
| go_btn = gr.Button("Answer") | |
| retrieval_out = gr.Markdown(label="Retrieved chunks") | |
| answers_out = gr.Markdown(label="Generated answers") | |
| go_btn.click(retrieve_fn, | |
| inputs=[query_in, k, rerank, llm_choice, num_ans, temperature], | |
| outputs=[retrieval_out, answers_out]) | |
| demo.launch() |