|
|
import gradio as gr |
|
|
import os |
|
|
from openai import OpenAI |
|
|
from retriever import ( |
|
|
load_collection, load_encoder, encode_query, retrieve_docs, |
|
|
query_rerank, expand_with_neighbors, dedup_by_chapter_event |
|
|
) |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
api_key = os.getenv("OPENROUTER_API_KEY") |
|
|
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key) |
|
|
|
|
|
collection = load_collection() |
|
|
encoder = load_encoder() |
|
|
reranker = CrossEncoder("BAAI/bge-reranker-large") |
|
|
|
|
|
|
|
|
def build_rag_prompt(query, context, system_message): |
|
|
prompt = f"""{system_message} |
|
|
|
|
|
已知资料如下: |
|
|
{context} |
|
|
|
|
|
用户提问:{query} |
|
|
|
|
|
规则: |
|
|
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。 |
|
|
2. 如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案. |
|
|
3. 如果有多个可能性, 可以自行判断是否将其全部列举和解释 |
|
|
4. 如果不能确定答案,请如实说明理由,不要凭空编造。 |
|
|
""" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def respond( |
|
|
message, |
|
|
history: list[dict[str, str]], |
|
|
system_message, |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p, |
|
|
): |
|
|
""" |
|
|
message: 当前输入内容 |
|
|
history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...] |
|
|
system_message: 自定义 System Prompt |
|
|
""" |
|
|
default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。" |
|
|
system_msg = (system_message or default_system_message).strip() |
|
|
chat_history = [{"role": "system", "content": system_msg}] |
|
|
|
|
|
query_vec = encode_query(encoder, message) |
|
|
results = retrieve_docs(collection, query_vec, top_k=30) |
|
|
reranked = query_rerank(reranker, message, results, top_n=10) |
|
|
deduped = dedup_by_chapter_event(reranked, max_per_group=1) |
|
|
expanded_results = expand_with_neighbors(deduped[:3], collection) |
|
|
context = expanded_results[0][0] if expanded_results else "" |
|
|
|
|
|
rag_prompt = build_rag_prompt(message, context, system_msg) |
|
|
messages = [ |
|
|
{"role": "system", "content": system_msg}, |
|
|
{"role": "user", "content": rag_prompt} |
|
|
] |
|
|
|
|
|
response = "" |
|
|
stream = client.chat.completions.create( |
|
|
model="qwen/qwen3-235b-a22b:free", |
|
|
messages=messages, |
|
|
temperature=temperature, |
|
|
max_tokens=max_tokens, |
|
|
top_p=top_p, |
|
|
stream=True |
|
|
) |
|
|
for chunk in stream: |
|
|
delta = getattr(chunk.choices[0].delta, "content", None) |
|
|
if delta: |
|
|
response += delta |
|
|
yield response |
|
|
|
|
|
print("\n=== Answer ===") |
|
|
print(response) |
|
|
print("\n=== retrieved documents ===") |
|
|
for idx, (context, score, meta) in enumerate(expanded_results, 1): |
|
|
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...") |
|
|
print(meta) |
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.ChatInterface( |
|
|
respond, |
|
|
type="messages", |
|
|
additional_inputs=[ |
|
|
gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"), |
|
|
gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"), |
|
|
gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.95, |
|
|
step=0.05, |
|
|
label="Top-p (nucleus sampling)", |
|
|
), |
|
|
], |
|
|
examples=[ |
|
|
["在水族馆里爱音和灯发生了什么?"], |
|
|
["RAS的目标是什么?"], |
|
|
], |
|
|
description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答", |
|
|
title="Dr-Bang RAG QA Chatbot" |
|
|
) |
|
|
|
|
|
with gr.Blocks(title="Dr-Bang RAG QA") as demo: |
|
|
with gr.Sidebar(): |
|
|
gr.Markdown( |
|
|
"## Dr-Bang QA\n\n" |
|
|
) |
|
|
gr.LoginButton() |
|
|
chatbot.render() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |