File size: 3,977 Bytes
f552653
05ddcd3
 
 
 
 
 
 
f552653
e595bf1
 
f552653
05ddcd3
 
 
f552653
c881ef3
 
 
 
 
05ddcd3
f552653
05ddcd3
e595bf1
 
 
 
 
 
 
05ddcd3
f552653
 
c881ef3
 
 
 
 
 
 
 
 
 
 
 
 
ffefccd
ebd4559
 
c881ef3
 
05ddcd3
c881ef3
05ddcd3
 
 
f552653
a49d6b5
c881ef3
a49d6b5
c881ef3
 
 
 
 
e595bf1
c881ef3
 
 
 
 
05ddcd3
c881ef3
 
 
 
 
f552653
1369018
 
 
 
 
 
 
f552653
c881ef3
 
 
 
 
 
6c4c86b
c881ef3
 
 
 
 
 
 
 
 
 
1369018
 
c881ef3
 
 
 
f552653
05ddcd3
c881ef3
 
 
05ddcd3
c881ef3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import gradio as gr
import os
from openai import OpenAI
from retriever import (
    load_collection, load_encoder, encode_query, retrieve_docs,
    query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder

api_key = os.getenv("OPENROUTER_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)

collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")


def build_rag_prompt(query, context, system_message):
    prompt = f"""{system_message}

已知资料如下:
{context}

用户提问:{query}

规则:
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。
2. 如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.
3. 如果有多个可能性, 可以自行判断是否将其全部列举和解释
4. 如果不能确定答案,请如实说明理由,不要凭空编造。
"""
    return prompt


def respond(
        message,
        history: list[dict[str, str]],
        system_message,
        max_tokens,
        temperature,
        top_p,
):
    """
    message: 当前输入内容
    history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
    system_message: 自定义 System Prompt
    """
    default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
    system_msg = (system_message or default_system_message).strip()
    chat_history = [{"role": "system", "content": system_msg}]

    query_vec = encode_query(encoder, message)
    results = retrieve_docs(collection, query_vec, top_k=30)
    reranked = query_rerank(reranker, message, results, top_n=10)
    deduped = dedup_by_chapter_event(reranked, max_per_group=1)
    expanded_results = expand_with_neighbors(deduped[:3], collection)
    context = expanded_results[0][0] if expanded_results else ""

    rag_prompt = build_rag_prompt(message, context, system_msg)
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": rag_prompt}
    ]

    response = ""
    stream = client.chat.completions.create(
        model="qwen/qwen3-235b-a22b:free",
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=top_p,
        stream=True
    )
    for chunk in stream:
        delta = getattr(chunk.choices[0].delta, "content", None)
        if delta:
            response += delta
            yield response

    print("\n=== Answer ===")
    print(response)
    print("\n=== retrieved documents ===")
    for idx, (context, score, meta) in enumerate(expanded_results, 1):
        print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
        print(meta)


# ========== Gradio ChatInterface with extra sidebar inputs ==========
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
        gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["在水族馆里爱音和灯发生了什么?"],
        ["RAS的目标是什么?"],
    ],
    description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答",
    title="Dr-Bang RAG QA Chatbot"
)

with gr.Blocks(title="Dr-Bang RAG QA") as demo:
    with gr.Sidebar():
        gr.Markdown(
            "## Dr-Bang QA\n\n"
        )
        gr.LoginButton()
    chatbot.render()

if __name__ == "__main__":
    demo.launch()