Spaces:
Sleeping
Sleeping
File size: 4,352 Bytes
7c2e31a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | from __future__ import annotations
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gradio as gr
from rag.retrieve import Retriever
from rag.llm import answer_with_provider
def ensure_retriever(state):
if state is None:
state = Retriever()
return state
def defaults_for_provider(provider_name: str) -> tuple[str, str]:
"""
Returns (base_url, default_model) for a given provider.
"""
if provider_name.startswith("Groq"):
return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
if provider_name.startswith("OpenRouter"):
return "https://openrouter.ai/api/v1", "meta-llama/llama-3.1-8b-instruct:free"
# fallback
return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
def on_provider_change(provider_name: str):
base_url, model = defaults_for_provider(provider_name)
return base_url, model
def run_qa(
provider: str,
base_url: str,
api_key: str,
model: str,
question: str,
use_bm25: bool,
use_dense: bool,
use_rerank: bool,
state,
):
state = ensure_retriever(state)
if not question or not question.strip():
return "Write a question 🙂", "", state
# Retrieval toggles
chunks = state.retrieve(
question,
use_bm25=use_bm25,
use_dense=use_dense,
use_rerank=use_rerank,
)
# Show retrieved context
ctx = []
for i, c in enumerate(chunks, start=1):
ctx.append(
f"[{i}] ({c.why}, score={c.score:.4f}) source_id={c.source_id}, chunk_id={c.chunk_id}\n{c.text}"
)
ctx_text = "\n\n---\n\n".join(ctx) if ctx else "(nothing retrieved)"
# If both retrievers off => "no retrieval" mode
if not use_bm25 and not use_dense:
ctx_text = "(retrieval is OFF: the model will answer without any context)"
chunks_for_llm = []
else:
chunks_for_llm = [{"chunk_id": c.chunk_id, "source_id": c.source_id, "text": c.text} for c in chunks]
if not api_key or not api_key.strip():
return f"Paste your {provider} API key first.", ctx_text, state
# Provider call (OpenAI-compatible Chat Completions)
try:
ans = answer_with_provider(
api_key=api_key.strip(),
base_url=(base_url or "").strip(),
model=(model or "").strip(),
question=question,
chunks=chunks_for_llm,
)
except Exception as e:
return f"LLM error: {type(e).__name__}: {e}", ctx_text, state
return ans, ctx_text, state
with gr.Blocks(title="RAG QA (BM25 + Dense + OpenAI-compatible providers)") as demo:
gr.Markdown(
"# RAG QA (HF dataset + BM25 + Dense)\n"
"Use a **free-tier OpenAI-compatible provider** (Groq / OpenRouter).\n"
"1) Build index: `python -m rag.index`\n"
"2) Run UI: `python app.py`\n"
)
state = gr.State(None)
provider = gr.Dropdown(
["Groq (free tier)", "OpenRouter (free models)"],
value="Groq (free tier)",
label="Provider",
)
base_url = gr.Textbox(
label="Base URL",
value="https://api.groq.com/openai/v1",
placeholder="https://api.groq.com/openai/v1",
)
api_key = gr.Textbox(
label="API key",
type="password",
placeholder="paste provider key here",
)
model = gr.Textbox(
label="Model",
value="llama-3.1-8b-instant",
)
provider.change(
fn=on_provider_change,
inputs=[provider],
outputs=[base_url, model],
)
question = gr.Textbox(label="Question", placeholder="Ask something...", lines=2)
with gr.Row():
use_bm25 = gr.Checkbox(value=True, label="Use BM25")
use_dense = gr.Checkbox(value=True, label="Use Dense")
use_rerank = gr.Checkbox(value=False, label="Use Reranker (optional)")
btn = gr.Button("Answer")
answer = gr.Textbox(label="Answer", lines=8)
context = gr.Textbox(label="Retrieved chunks", lines=12)
btn.click(
fn=run_qa,
inputs=[provider, base_url, api_key, model, question, use_bm25, use_dense, use_rerank, state],
outputs=[answer, context, state],
)
if __name__ == "__main__":
import os
port = int(os.getenv("PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=port, share=False)
|