from __future__ import annotations import os os.environ["TOKENIZERS_PARALLELISM"] = "false" import gradio as gr from rag.retrieve import Retriever from rag.llm import answer_with_provider def ensure_retriever(state): if state is None: state = Retriever() return state def defaults_for_provider(provider_name: str) -> tuple[str, str]: """ Returns (base_url, default_model) for a given provider. """ if provider_name.startswith("Groq"): return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant" if provider_name.startswith("OpenRouter"): return "https://openrouter.ai/api/v1", "meta-llama/llama-3.1-8b-instruct:free" # fallback return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant" def on_provider_change(provider_name: str): base_url, model = defaults_for_provider(provider_name) return base_url, model def run_qa( provider: str, base_url: str, api_key: str, model: str, question: str, use_bm25: bool, use_dense: bool, use_rerank: bool, state, ): state = ensure_retriever(state) if not question or not question.strip(): return "Write a question 🙂", "", state # Retrieval toggles chunks = state.retrieve( question, use_bm25=use_bm25, use_dense=use_dense, use_rerank=use_rerank, ) # Show retrieved context ctx = [] for i, c in enumerate(chunks, start=1): ctx.append( f"[{i}] ({c.why}, score={c.score:.4f}) source_id={c.source_id}, chunk_id={c.chunk_id}\n{c.text}" ) ctx_text = "\n\n---\n\n".join(ctx) if ctx else "(nothing retrieved)" # If both retrievers off => "no retrieval" mode if not use_bm25 and not use_dense: ctx_text = "(retrieval is OFF: the model will answer without any context)" chunks_for_llm = [] else: chunks_for_llm = [{"chunk_id": c.chunk_id, "source_id": c.source_id, "text": c.text} for c in chunks] if not api_key or not api_key.strip(): return f"Paste your {provider} API key first.", ctx_text, state # Provider call (OpenAI-compatible Chat Completions) try: ans = answer_with_provider( api_key=api_key.strip(), base_url=(base_url or "").strip(), model=(model or "").strip(), question=question, chunks=chunks_for_llm, ) except Exception as e: return f"LLM error: {type(e).__name__}: {e}", ctx_text, state return ans, ctx_text, state with gr.Blocks(title="RAG QA (BM25 + Dense + OpenAI-compatible providers)") as demo: gr.Markdown( "# RAG QA (HF dataset + BM25 + Dense)\n" "Use a **free-tier OpenAI-compatible provider** (Groq / OpenRouter).\n" "1) Build index: `python -m rag.index`\n" "2) Run UI: `python app.py`\n" ) state = gr.State(None) provider = gr.Dropdown( ["Groq (free tier)", "OpenRouter (free models)"], value="Groq (free tier)", label="Provider", ) base_url = gr.Textbox( label="Base URL", value="https://api.groq.com/openai/v1", placeholder="https://api.groq.com/openai/v1", ) api_key = gr.Textbox( label="API key", type="password", placeholder="paste provider key here", ) model = gr.Textbox( label="Model", value="llama-3.1-8b-instant", ) provider.change( fn=on_provider_change, inputs=[provider], outputs=[base_url, model], ) question = gr.Textbox(label="Question", placeholder="Ask something...", lines=2) with gr.Row(): use_bm25 = gr.Checkbox(value=True, label="Use BM25") use_dense = gr.Checkbox(value=True, label="Use Dense") use_rerank = gr.Checkbox(value=False, label="Use Reranker (optional)") btn = gr.Button("Answer") answer = gr.Textbox(label="Answer", lines=8) context = gr.Textbox(label="Retrieved chunks", lines=12) btn.click( fn=run_qa, inputs=[provider, base_url, api_key, model, question, use_bm25, use_dense, use_rerank, state], outputs=[answer, context, state], ) if __name__ == "__main__": import os port = int(os.getenv("PORT", "7860")) demo.launch(server_name="0.0.0.0", server_port=port, share=False)