Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import gradio as gr | |
| from rag.retrieve import Retriever | |
| from rag.llm import answer_with_provider | |
| def ensure_retriever(state): | |
| if state is None: | |
| state = Retriever() | |
| return state | |
| def defaults_for_provider(provider_name: str) -> tuple[str, str]: | |
| """ | |
| Returns (base_url, default_model) for a given provider. | |
| """ | |
| if provider_name.startswith("Groq"): | |
| return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant" | |
| if provider_name.startswith("OpenRouter"): | |
| return "https://openrouter.ai/api/v1", "meta-llama/llama-3.1-8b-instruct:free" | |
| # fallback | |
| return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant" | |
| def on_provider_change(provider_name: str): | |
| base_url, model = defaults_for_provider(provider_name) | |
| return base_url, model | |
| def run_qa( | |
| provider: str, | |
| base_url: str, | |
| api_key: str, | |
| model: str, | |
| question: str, | |
| use_bm25: bool, | |
| use_dense: bool, | |
| use_rerank: bool, | |
| state, | |
| ): | |
| state = ensure_retriever(state) | |
| if not question or not question.strip(): | |
| return "Write a question 🙂", "", state | |
| # Retrieval toggles | |
| chunks = state.retrieve( | |
| question, | |
| use_bm25=use_bm25, | |
| use_dense=use_dense, | |
| use_rerank=use_rerank, | |
| ) | |
| # Show retrieved context | |
| ctx = [] | |
| for i, c in enumerate(chunks, start=1): | |
| ctx.append( | |
| f"[{i}] ({c.why}, score={c.score:.4f}) source_id={c.source_id}, chunk_id={c.chunk_id}\n{c.text}" | |
| ) | |
| ctx_text = "\n\n---\n\n".join(ctx) if ctx else "(nothing retrieved)" | |
| # If both retrievers off => "no retrieval" mode | |
| if not use_bm25 and not use_dense: | |
| ctx_text = "(retrieval is OFF: the model will answer without any context)" | |
| chunks_for_llm = [] | |
| else: | |
| chunks_for_llm = [{"chunk_id": c.chunk_id, "source_id": c.source_id, "text": c.text} for c in chunks] | |
| if not api_key or not api_key.strip(): | |
| return f"Paste your {provider} API key first.", ctx_text, state | |
| # Provider call (OpenAI-compatible Chat Completions) | |
| try: | |
| ans = answer_with_provider( | |
| api_key=api_key.strip(), | |
| base_url=(base_url or "").strip(), | |
| model=(model or "").strip(), | |
| question=question, | |
| chunks=chunks_for_llm, | |
| ) | |
| except Exception as e: | |
| return f"LLM error: {type(e).__name__}: {e}", ctx_text, state | |
| return ans, ctx_text, state | |
| with gr.Blocks(title="RAG QA (BM25 + Dense + OpenAI-compatible providers)") as demo: | |
| gr.Markdown( | |
| "# RAG QA (HF dataset + BM25 + Dense)\n" | |
| "Use a **free-tier OpenAI-compatible provider** (Groq / OpenRouter).\n" | |
| "1) Build index: `python -m rag.index`\n" | |
| "2) Run UI: `python app.py`\n" | |
| ) | |
| state = gr.State(None) | |
| provider = gr.Dropdown( | |
| ["Groq (free tier)", "OpenRouter (free models)"], | |
| value="Groq (free tier)", | |
| label="Provider", | |
| ) | |
| base_url = gr.Textbox( | |
| label="Base URL", | |
| value="https://api.groq.com/openai/v1", | |
| placeholder="https://api.groq.com/openai/v1", | |
| ) | |
| api_key = gr.Textbox( | |
| label="API key", | |
| type="password", | |
| placeholder="paste provider key here", | |
| ) | |
| model = gr.Textbox( | |
| label="Model", | |
| value="llama-3.1-8b-instant", | |
| ) | |
| provider.change( | |
| fn=on_provider_change, | |
| inputs=[provider], | |
| outputs=[base_url, model], | |
| ) | |
| question = gr.Textbox(label="Question", placeholder="Ask something...", lines=2) | |
| with gr.Row(): | |
| use_bm25 = gr.Checkbox(value=True, label="Use BM25") | |
| use_dense = gr.Checkbox(value=True, label="Use Dense") | |
| use_rerank = gr.Checkbox(value=False, label="Use Reranker (optional)") | |
| btn = gr.Button("Answer") | |
| answer = gr.Textbox(label="Answer", lines=8) | |
| context = gr.Textbox(label="Retrieved chunks", lines=12) | |
| btn.click( | |
| fn=run_qa, | |
| inputs=[provider, base_url, api_key, model, question, use_bm25, use_dense, use_rerank, state], | |
| outputs=[answer, context, state], | |
| ) | |
| if __name__ == "__main__": | |
| import os | |
| port = int(os.getenv("PORT", "7860")) | |
| demo.launch(server_name="0.0.0.0", server_port=port, share=False) | |