File size: 5,683 Bytes
f552653 05ddcd3 f552653 e20c58d b16f34e cfe26d9 7405997 1b601c8 7405997 e20c58d 7405997 e20c58d f552653 3939c99 05ddcd3 f552653 7405997 e20c58d c881ef3 e20c58d c881ef3 05ddcd3 f552653 05ddcd3 e595bf1 445d9bd e595bf1 05ddcd3 f552653 c881ef3 ffefccd ebd4559 c881ef3 e20c58d c881ef3 e20c58d c881ef3 e20c58d 489cfae 8d407ac f552653 a49d6b5 c881ef3 a49d6b5 c881ef3 7405997 c881ef3 05ddcd3 c881ef3 f552653 1369018 f552653 c881ef3 8c14604 6c4c86b c881ef3 f7d6f25 445d9bd c881ef3 f552653 05ddcd3 c881ef3 489cfae 05ddcd3 c881ef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import os
from openai import OpenAI
from retriever import (
load_collection, load_encoder, encode_query, retrieve_docs,
query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder
QWEN_MODEL="qwen/qwen3-235b-a22b:free"
DEEPSEEK_MODEL="deepseek/deepseek-chat-v3.1:free"
GPT_OSS_MODEL="openai/gpt-oss-20b:free"
api_key = os.getenv("OPENROUTER_API_KEY")
#deepseek_key = os.getenv("DEEPSEEK_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
# open ai method
#api_key = os.getenv("OPENAI_API_KEY")
#client = OpenAI(api_key=api_key)
#GPT_MODEL="gpt-4o"
collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")
def reformulate_query(user_question, model_name=QWEN_MODEL):
prompt = f"""你是一个BangDream知识检索助手。请把用户的问题扩写或转写为适合知识库语义检索的检索语句,涵盖所有可能的提问方式或同义关键词。
用户问题:{user_question}
"""
resp = client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=0.1,
max_tokens=4096,
)
return resp.choices[0].message.content.strip()
def build_rag_prompt(query, context, system_message):
prompt = f"""{system_message}
你将获得多个独立的资料片段,请充分查阅每一条资料.
已知资料如下:
{context}
用户提问:{query}
规则:
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。
2. 如果有多个相关答案或不同观点,可以考虑是否全部分点列出
3. 如果只能在部分资料里找到答案,也请说明是参考哪些资料内容
4. 如果不能确定答案,请如实说明理由,不要凭空编造。
"""
return prompt
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
message: 当前输入内容
history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
system_message: 自定义 System Prompt
"""
default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
system_msg = (system_message or default_system_message).strip()
chat_history = [{"role": "system", "content": system_msg}]
# reformulate query
print("Reformulating...")
reformulated_query_text = reformulate_query(message)
print(f"[DEBUG] reformulated query: {reformulated_query_text}")
print("Thinking...\n...")
# rerank original query
query_vec = encode_query(encoder, message)
results = retrieve_docs(collection, query_vec, top_k=20)
reranked = query_rerank(reranker, message, results, top_n=10)
# rerank reformulated query
reformulated_query_vec = encode_query(encoder, reformulated_query_text)
reformulated_results = retrieve_docs(collection, reformulated_query_vec, top_k=20)
reformulated_reranked = query_rerank(reranker, reformulated_query_text, reformulated_results, top_n=10)
total_reranked = reranked + reformulated_reranked
deduped = dedup_by_chapter_event(total_reranked, max_per_group=1)
expanded_results = expand_with_neighbors(deduped[:5], collection)
context = []
for idx, text in enumerate(expanded_results):
context.append(text[0]) if text else ""
rag_prompt = build_rag_prompt(message, context, system_msg)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": rag_prompt}
]
response = ""
stream = client.chat.completions.create(
model=QWEN_MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
for chunk in stream:
delta = getattr(chunk.choices[0].delta, "content", None)
if delta:
response += delta
yield response
print("\n=== Answer ===")
print(response)
print("\n=== retrieved documents ===")
for idx, (context, score, meta) in enumerate(expanded_results, 1):
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
print(meta)
# ========== Gradio ChatInterface with extra sidebar inputs ==========
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
#examples=[
# ["在水族馆里爱音和灯发生了什么?"],
# ["RAS的目标是什么?"],
#],
description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答\n\nGitHub项目地址: [GitHub repo](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)",
title="Dr-Bang RAG QA Chatbot"
)
with gr.Blocks(title="Dr-Bang RAG QA") as demo:
with gr.Sidebar():
gr.Markdown(
"## Dr-Bang QA\n\n"
"[GitHub project](https://github.com/Cudd1es/dr-bang?tab=readme-ov-file)\n\n"
)
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch() |