DunasAnastasiia
Initial commit (Xet)
7c2e31a
from __future__ import annotations
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gradio as gr
from rag.retrieve import Retriever
from rag.llm import answer_with_provider
def ensure_retriever(state):
if state is None:
state = Retriever()
return state
def defaults_for_provider(provider_name: str) -> tuple[str, str]:
"""
Returns (base_url, default_model) for a given provider.
"""
if provider_name.startswith("Groq"):
return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
if provider_name.startswith("OpenRouter"):
return "https://openrouter.ai/api/v1", "meta-llama/llama-3.1-8b-instruct:free"
# fallback
return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
def on_provider_change(provider_name: str):
base_url, model = defaults_for_provider(provider_name)
return base_url, model
def run_qa(
provider: str,
base_url: str,
api_key: str,
model: str,
question: str,
use_bm25: bool,
use_dense: bool,
use_rerank: bool,
state,
):
state = ensure_retriever(state)
if not question or not question.strip():
return "Write a question 🙂", "", state
# Retrieval toggles
chunks = state.retrieve(
question,
use_bm25=use_bm25,
use_dense=use_dense,
use_rerank=use_rerank,
)
# Show retrieved context
ctx = []
for i, c in enumerate(chunks, start=1):
ctx.append(
f"[{i}] ({c.why}, score={c.score:.4f}) source_id={c.source_id}, chunk_id={c.chunk_id}\n{c.text}"
)
ctx_text = "\n\n---\n\n".join(ctx) if ctx else "(nothing retrieved)"
# If both retrievers off => "no retrieval" mode
if not use_bm25 and not use_dense:
ctx_text = "(retrieval is OFF: the model will answer without any context)"
chunks_for_llm = []
else:
chunks_for_llm = [{"chunk_id": c.chunk_id, "source_id": c.source_id, "text": c.text} for c in chunks]
if not api_key or not api_key.strip():
return f"Paste your {provider} API key first.", ctx_text, state
# Provider call (OpenAI-compatible Chat Completions)
try:
ans = answer_with_provider(
api_key=api_key.strip(),
base_url=(base_url or "").strip(),
model=(model or "").strip(),
question=question,
chunks=chunks_for_llm,
)
except Exception as e:
return f"LLM error: {type(e).__name__}: {e}", ctx_text, state
return ans, ctx_text, state
with gr.Blocks(title="RAG QA (BM25 + Dense + OpenAI-compatible providers)") as demo:
gr.Markdown(
"# RAG QA (HF dataset + BM25 + Dense)\n"
"Use a **free-tier OpenAI-compatible provider** (Groq / OpenRouter).\n"
"1) Build index: `python -m rag.index`\n"
"2) Run UI: `python app.py`\n"
)
state = gr.State(None)
provider = gr.Dropdown(
["Groq (free tier)", "OpenRouter (free models)"],
value="Groq (free tier)",
label="Provider",
)
base_url = gr.Textbox(
label="Base URL",
value="https://api.groq.com/openai/v1",
placeholder="https://api.groq.com/openai/v1",
)
api_key = gr.Textbox(
label="API key",
type="password",
placeholder="paste provider key here",
)
model = gr.Textbox(
label="Model",
value="llama-3.1-8b-instant",
)
provider.change(
fn=on_provider_change,
inputs=[provider],
outputs=[base_url, model],
)
question = gr.Textbox(label="Question", placeholder="Ask something...", lines=2)
with gr.Row():
use_bm25 = gr.Checkbox(value=True, label="Use BM25")
use_dense = gr.Checkbox(value=True, label="Use Dense")
use_rerank = gr.Checkbox(value=False, label="Use Reranker (optional)")
btn = gr.Button("Answer")
answer = gr.Textbox(label="Answer", lines=8)
context = gr.Textbox(label="Retrieved chunks", lines=12)
btn.click(
fn=run_qa,
inputs=[provider, base_url, api_key, model, question, use_bm25, use_dense, use_rerank, state],
outputs=[answer, context, state],
)
if __name__ == "__main__":
import os
port = int(os.getenv("PORT", "7860"))
demo.launch(server_name="0.0.0.0", server_port=port, share=False)