Dr-Bang / app.py
Cudd1es's picture
app.py updated
05ddcd3
raw
history blame
3.03 kB
import gradio as gr
import os
from openai import OpenAI
from retriever import (
load_collection, load_encoder, encode_query, retrieve_docs,
query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)
collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")
def build_rag_prompt(query, context):
prompt = f"""已知资料如下:
{context}
用户提问:{query}
请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.如果不能确定答案,请如实说明理由,不要凭空编造。"""
return prompt
def answer_fn(query, history=None):
query_vec = encode_query(encoder, query)
results = retrieve_docs(collection, query_vec, top_k=30)
reranked = query_rerank(reranker, query, results, top_n=10)
deduped = dedup_by_chapter_event(reranked, max_per_group=1)
expanded_results = expand_with_neighbors(deduped[:3], collection)
context = expanded_results[0][0] if expanded_results else ""
rag_prompt = build_rag_prompt(query, context)
system_prompt = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
response = client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": rag_prompt}
],
temperature=0.2,
max_tokens=512,
)
answer = response.choices[0].message.content.strip()
references = ""
for idx, (doc, score, meta) in enumerate(expanded_results, 1):
chapter = meta.get("chapterTitle", "UnknownChapter")
event = meta.get("eventName", "UnknownEvent")
references += f"\n--- reference: {idx} (chapter: {chapter}, event: {event}, score={score:.4f}) ---\n"
references += doc[:300] + "...\n"
return answer, references
# Gradio UI
with gr.Blocks(title="Dr-Bang RAG QA") as demo:
gr.Markdown("# Dr-Bang RAG QA\n\n基于BangDream知识库的RAG问答系统。")
with gr.Row():
chatbot = gr.ChatInterface(
fn=answer_fn,
title="Dr-Bang RAG Chat",
description="输入你有关BangDream的问题,邦学家会基于资料库为你检索并作答。",
additional_inputs=[],
retry_btn=None,
undo_btn=None,
clear_btn="clear",
examples=[
["乐奈为什么喜欢吉他?"],
["LOCK和CHU²第一次见面是什么情节?"],
["谁是RAS的初代成员?"],
],
outputs=[
gr.Textbox(label="Answer", lines=6, interactive=False),
gr.Textbox(label="Reference", lines=8, interactive=False)
]
)
demo.launch()