File size: 3,977 Bytes
f552653 05ddcd3 f552653 e595bf1 f552653 05ddcd3 f552653 c881ef3 05ddcd3 f552653 05ddcd3 e595bf1 05ddcd3 f552653 c881ef3 ffefccd ebd4559 c881ef3 05ddcd3 c881ef3 05ddcd3 f552653 a49d6b5 c881ef3 a49d6b5 c881ef3 e595bf1 c881ef3 05ddcd3 c881ef3 f552653 1369018 f552653 c881ef3 6c4c86b c881ef3 1369018 c881ef3 f552653 05ddcd3 c881ef3 05ddcd3 c881ef3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import os
from openai import OpenAI
from retriever import (
load_collection, load_encoder, encode_query, retrieve_docs,
query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder
api_key = os.getenv("OPENROUTER_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")
def build_rag_prompt(query, context, system_message):
prompt = f"""{system_message}
已知资料如下:
{context}
用户提问:{query}
规则:
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。
2. 如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.
3. 如果有多个可能性, 可以自行判断是否将其全部列举和解释
4. 如果不能确定答案,请如实说明理由,不要凭空编造。
"""
return prompt
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
message: 当前输入内容
history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
system_message: 自定义 System Prompt
"""
default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
system_msg = (system_message or default_system_message).strip()
chat_history = [{"role": "system", "content": system_msg}]
query_vec = encode_query(encoder, message)
results = retrieve_docs(collection, query_vec, top_k=30)
reranked = query_rerank(reranker, message, results, top_n=10)
deduped = dedup_by_chapter_event(reranked, max_per_group=1)
expanded_results = expand_with_neighbors(deduped[:3], collection)
context = expanded_results[0][0] if expanded_results else ""
rag_prompt = build_rag_prompt(message, context, system_msg)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": rag_prompt}
]
response = ""
stream = client.chat.completions.create(
model="qwen/qwen3-235b-a22b:free",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
for chunk in stream:
delta = getattr(chunk.choices[0].delta, "content", None)
if delta:
response += delta
yield response
print("\n=== Answer ===")
print(response)
print("\n=== retrieved documents ===")
for idx, (context, score, meta) in enumerate(expanded_results, 1):
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
print(meta)
# ========== Gradio ChatInterface with extra sidebar inputs ==========
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["在水族馆里爱音和灯发生了什么?"],
["RAS的目标是什么?"],
],
description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答",
title="Dr-Bang RAG QA Chatbot"
)
with gr.Blocks(title="Dr-Bang RAG QA") as demo:
with gr.Sidebar():
gr.Markdown(
"## Dr-Bang QA\n\n"
)
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch() |