File size: 4,352 Bytes
7c2e31a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from __future__ import annotations

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import gradio as gr

from rag.retrieve import Retriever
from rag.llm import answer_with_provider


def ensure_retriever(state):
    if state is None:
        state = Retriever()
    return state


def defaults_for_provider(provider_name: str) -> tuple[str, str]:
    """
    Returns (base_url, default_model) for a given provider.
    """
    if provider_name.startswith("Groq"):
        return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"
    if provider_name.startswith("OpenRouter"):
        return "https://openrouter.ai/api/v1", "meta-llama/llama-3.1-8b-instruct:free"
    # fallback
    return "https://api.groq.com/openai/v1", "llama-3.1-8b-instant"


def on_provider_change(provider_name: str):
    base_url, model = defaults_for_provider(provider_name)
    return base_url, model


def run_qa(
    provider: str,
    base_url: str,
    api_key: str,
    model: str,
    question: str,
    use_bm25: bool,
    use_dense: bool,
    use_rerank: bool,
    state,
):
    state = ensure_retriever(state)

    if not question or not question.strip():
        return "Write a question 🙂", "", state

    # Retrieval toggles
    chunks = state.retrieve(
        question,
        use_bm25=use_bm25,
        use_dense=use_dense,
        use_rerank=use_rerank,
    )

    # Show retrieved context
    ctx = []
    for i, c in enumerate(chunks, start=1):
        ctx.append(
            f"[{i}] ({c.why}, score={c.score:.4f}) source_id={c.source_id}, chunk_id={c.chunk_id}\n{c.text}"
        )
    ctx_text = "\n\n---\n\n".join(ctx) if ctx else "(nothing retrieved)"

    # If both retrievers off => "no retrieval" mode
    if not use_bm25 and not use_dense:
        ctx_text = "(retrieval is OFF: the model will answer without any context)"
        chunks_for_llm = []
    else:
        chunks_for_llm = [{"chunk_id": c.chunk_id, "source_id": c.source_id, "text": c.text} for c in chunks]

    if not api_key or not api_key.strip():
        return f"Paste your {provider} API key first.", ctx_text, state

    # Provider call (OpenAI-compatible Chat Completions)
    try:
        ans = answer_with_provider(
            api_key=api_key.strip(),
            base_url=(base_url or "").strip(),
            model=(model or "").strip(),
            question=question,
            chunks=chunks_for_llm,
        )
    except Exception as e:
        return f"LLM error: {type(e).__name__}: {e}", ctx_text, state

    return ans, ctx_text, state


with gr.Blocks(title="RAG QA (BM25 + Dense + OpenAI-compatible providers)") as demo:
    gr.Markdown(
        "# RAG QA (HF dataset + BM25 + Dense)\n"
        "Use a **free-tier OpenAI-compatible provider** (Groq / OpenRouter).\n"
        "1) Build index: `python -m rag.index`\n"
        "2) Run UI: `python app.py`\n"
    )

    state = gr.State(None)

    provider = gr.Dropdown(
        ["Groq (free tier)", "OpenRouter (free models)"],
        value="Groq (free tier)",
        label="Provider",
    )

    base_url = gr.Textbox(
        label="Base URL",
        value="https://api.groq.com/openai/v1",
        placeholder="https://api.groq.com/openai/v1",
    )

    api_key = gr.Textbox(
        label="API key",
        type="password",
        placeholder="paste provider key here",
    )

    model = gr.Textbox(
        label="Model",
        value="llama-3.1-8b-instant",
    )

    provider.change(
        fn=on_provider_change,
        inputs=[provider],
        outputs=[base_url, model],
    )

    question = gr.Textbox(label="Question", placeholder="Ask something...", lines=2)

    with gr.Row():
        use_bm25 = gr.Checkbox(value=True, label="Use BM25")
        use_dense = gr.Checkbox(value=True, label="Use Dense")
        use_rerank = gr.Checkbox(value=False, label="Use Reranker (optional)")

    btn = gr.Button("Answer")

    answer = gr.Textbox(label="Answer", lines=8)
    context = gr.Textbox(label="Retrieved chunks", lines=12)

    btn.click(
        fn=run_qa,
        inputs=[provider, base_url, api_key, model, question, use_bm25, use_dense, use_rerank, state],
        outputs=[answer, context, state],
    )


if __name__ == "__main__":
    import os
    port = int(os.getenv("PORT", "7860"))
    demo.launch(server_name="0.0.0.0", server_port=port, share=False)