Dr-Bang / app.py
Cudd1es's picture
fixed app.py
6c4c86b
raw
history blame
3.98 kB
import gradio as gr
import os
from openai import OpenAI
from retriever import (
load_collection, load_encoder, encode_query, retrieve_docs,
query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder
api_key = os.getenv("OPENROUTER_API_KEY")
client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")
def build_rag_prompt(query, context, system_message):
prompt = f"""{system_message}
已知资料如下:
{context}
用户提问:{query}
规则:
1. 请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。
2. 如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.
3. 如果有多个可能性, 可以自行判断是否将其全部列举和解释
4. 如果不能确定答案,请如实说明理由,不要凭空编造。
"""
return prompt
def respond(
message,
history: list[dict[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
"""
message: 当前输入内容
history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
system_message: 自定义 System Prompt
"""
default_system_message = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
system_msg = (system_message or default_system_message).strip()
chat_history = [{"role": "system", "content": system_msg}]
query_vec = encode_query(encoder, message)
results = retrieve_docs(collection, query_vec, top_k=30)
reranked = query_rerank(reranker, message, results, top_n=10)
deduped = dedup_by_chapter_event(reranked, max_per_group=1)
expanded_results = expand_with_neighbors(deduped[:3], collection)
context = expanded_results[0][0] if expanded_results else ""
rag_prompt = build_rag_prompt(message, context, system_msg)
messages = [
{"role": "system", "content": system_msg},
{"role": "user", "content": rag_prompt}
]
response = ""
stream = client.chat.completions.create(
model="qwen/qwen3-235b-a22b:free",
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stream=True
)
for chunk in stream:
delta = getattr(chunk.choices[0].delta, "content", None)
if delta:
response += delta
yield response
print("\n=== Answer ===")
print(response)
print("\n=== retrieved documents ===")
for idx, (context, score, meta) in enumerate(expanded_results, 1):
print(f"\n--- document {idx} (Score={score:.4f}) ---\n{context[:200]}...")
print(meta)
# ========== Gradio ChatInterface with extra sidebar inputs ==========
chatbot = gr.ChatInterface(
respond,
type="messages",
additional_inputs=[
gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
gr.Slider(minimum=64, maximum=8192, value=1536, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["在水族馆里爱音和灯发生了什么?"],
["RAS的目标是什么?"],
],
description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答",
title="Dr-Bang RAG QA Chatbot"
)
with gr.Blocks(title="Dr-Bang RAG QA") as demo:
with gr.Sidebar():
gr.Markdown(
"## Dr-Bang QA\n\n"
)
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()