File size: 5,683 Bytes
f552653
05ddcd3
 
 
 
 
 
 
f552653
e20c58d
b16f34e
cfe26d9
7405997
1b601c8
7405997
e20c58d
 
7405997
 
 
e20c58d
f552653
3939c99
05ddcd3
 
 
f552653
7405997
e20c58d
 
 
 
 
 
 
 
 
 
c881ef3
 
 
e20c58d
c881ef3
05ddcd3
f552653
05ddcd3
e595bf1
 
 
445d9bd
 
e595bf1
 
05ddcd3
f552653
 
c881ef3
 
 
 
 
 
 
 
 
 
 
 
 
ffefccd
ebd4559
 
c881ef3
e20c58d
 
 
 
 
 
 
c881ef3
e20c58d
c881ef3
e20c58d
 
 
 
 
 
 
 
489cfae
8d407ac
 
 
f552653
a49d6b5
c881ef3
a49d6b5
c881ef3
 
 
 
 
7405997
c881ef3
 
 
 
 
05ddcd3
c881ef3
 
 
 
 
f552653
1369018
 
 
 
 
 
 
f552653
c881ef3
 
 
 
 
8c14604
6c4c86b
c881ef3
 
 
 
 
 
 
 
 
f7d6f25
 
 
 
445d9bd
c881ef3
 
f552653
05ddcd3
c881ef3
 
 
489cfae
05ddcd3
c881ef3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import os
from openai import OpenAI
from retriever import (
    load_collection, load_encoder, encode_query, retrieve_docs,
    query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder

QWEN_MODEL="qwen/qwen3-235b-a22b:free"
DEEPSEEK_MODEL="deepseek/deepseek-chat-v3.1:free"
GPT_OSS_MODEL="openai/gpt-oss-20b:free"
api_key = os.getenv("OPENROUTER_API_KEY")
#deepseek_key = os.getenv("DEEPSEEK_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)

# open ai method
#api_key = os.getenv("OPENAI_API_KEY")
#client = OpenAI(api_key=api_key)
#GPT_MODEL="gpt-4o"



collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")

def reformulate_query(user_question, model_name=QWEN_MODEL):
    prompt = f"""你是一个BangDream知识检索助手。请把用户的问题扩写或转写为适合知识库语义检索的检索语句,涵盖所有可能的提问方式或同义关键词。
    用户问题:{user_question}
    """
    resp = client.chat.completions.create(
        model=model_name,
        messages=[{"role": "user", "content": prompt}],
        temperature=0.1,
        max_tokens=4096,
    )
    return resp.choices[0].message.content.strip()

def build_rag_prompt(query, context, system_message):
    prompt = f"""{system_message}
你将获得多个独立的资料片段,请充分查阅每一条资料.
已知资料如下:
{context}

用户提问:{query}

规则:
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。
2. 如果有多个相关答案或不同观点,可以考虑是否全部分点列出
3. 如果只能在部分资料里找到答案,也请说明是参考哪些资料内容
4. 如果不能确定答案,请如实说明理由,不要凭空编造。
"""
    return prompt


def respond(
        message,
        history: list[dict[str, str]],
        system_message,
        max_tokens,
        temperature,
        top_p,
):
    """
    message: 当前输入内容
    history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
    system_message: 自定义 System Prompt
    """
    default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
    system_msg = (system_message or default_system_message).strip()
    chat_history = [{"role": "system", "content": system_msg}]

    # reformulate query
    print("Reformulating...")
    reformulated_query_text = reformulate_query(message)
    print(f"[DEBUG] reformulated query: {reformulated_query_text}")

    print("Thinking...\n...")
    # rerank original query
    query_vec = encode_query(encoder, message)
    results = retrieve_docs(collection, query_vec, top_k=20)
    reranked = query_rerank(reranker, message, results, top_n=10)

    # rerank reformulated query
    reformulated_query_vec = encode_query(encoder, reformulated_query_text)
    reformulated_results = retrieve_docs(collection, reformulated_query_vec, top_k=20)
    reformulated_reranked = query_rerank(reranker, reformulated_query_text, reformulated_results, top_n=10)

    total_reranked = reranked + reformulated_reranked
    deduped = dedup_by_chapter_event(total_reranked, max_per_group=1)
    expanded_results = expand_with_neighbors(deduped[:5], collection)
    context = []
    for idx, text in enumerate(expanded_results):
        context.append(text[0]) if text else ""

    rag_prompt = build_rag_prompt(message, context, system_msg)
    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user", "content": rag_prompt}
    ]

    response = ""
    stream = client.chat.completions.create(
        model=QWEN_MODEL,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
        top_p=top_p,
        stream=True
    )
    for chunk in stream:
        delta = getattr(chunk.choices[0].delta, "content", None)
        if delta:
            response += delta
            yield response

    print("\n=== Answer ===")
    print(response)
    print("\n=== retrieved documents ===")
    for idx, (context, score, meta) in enumerate(expanded_results, 1):
        print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
        print(meta)


# ========== Gradio ChatInterface with extra sidebar inputs ==========
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
        gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    #examples=[
    #    ["在水族馆里爱音和灯发生了什么?"],
    #    ["RAS的目标是什么?"],
    #],
    description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答\n\nGitHub项目地址: [GitHub repo](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)",
    title="Dr-Bang RAG QA Chatbot"
)

with gr.Blocks(title="Dr-Bang RAG QA") as demo:
    with gr.Sidebar():
        gr.Markdown(
            "## Dr-Bang QA\n\n"
            "[GitHub project](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)\n\n"
        )
        gr.LoginButton()
    chatbot.render()

if __name__ == "__main__":
    demo.launch()