|
|
import gradio as gr |
|
|
import os |
|
|
from openai import OpenAI |
|
|
from retriever import ( |
|
|
load_collection, load_encoder, encode_query, retrieve_docs, |
|
|
query_rerank, expand_with_neighbors, dedup_by_chapter_event |
|
|
) |
|
|
from sentence_transformers import CrossEncoder |
|
|
|
|
|
api_key = os.getenv("OPENAI_API_KEY") |
|
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
collection = load_collection() |
|
|
encoder = load_encoder() |
|
|
reranker = CrossEncoder("BAAI/bge-reranker-large") |
|
|
|
|
|
def build_rag_prompt(query, context): |
|
|
prompt = f"""已知资料如下: |
|
|
{context} |
|
|
|
|
|
用户提问:{query} |
|
|
请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.如果不能确定答案,请如实说明理由,不要凭空编造。""" |
|
|
return prompt |
|
|
|
|
|
|
|
|
def answer_fn(query, history=None): |
|
|
query_vec = encode_query(encoder, query) |
|
|
results = retrieve_docs(collection, query_vec, top_k=30) |
|
|
reranked = query_rerank(reranker, query, results, top_n=10) |
|
|
deduped = dedup_by_chapter_event(reranked, max_per_group=1) |
|
|
expanded_results = expand_with_neighbors(deduped[:3], collection) |
|
|
|
|
|
context = expanded_results[0][0] if expanded_results else "" |
|
|
rag_prompt = build_rag_prompt(query, context) |
|
|
system_prompt = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。" |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": rag_prompt} |
|
|
], |
|
|
temperature=0.2, |
|
|
max_tokens=512, |
|
|
) |
|
|
answer = response.choices[0].message.content.strip() |
|
|
|
|
|
references = "" |
|
|
for idx, (doc, score, meta) in enumerate(expanded_results, 1): |
|
|
chapter = meta.get("chapterTitle", "UnknownChapter") |
|
|
event = meta.get("eventName", "UnknownEvent") |
|
|
references += f"\n--- reference: {idx} (chapter: {chapter}, event: {event}, score={score:.4f}) ---\n" |
|
|
references += doc[:300] + "...\n" |
|
|
|
|
|
return answer, references |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Dr-Bang RAG QA") as demo: |
|
|
gr.Markdown("# Dr-Bang RAG QA\n\n基于BangDream知识库的RAG问答系统。") |
|
|
with gr.Row(): |
|
|
chatbot = gr.ChatInterface( |
|
|
fn=answer_fn, |
|
|
title="Dr-Bang RAG Chat", |
|
|
description="输入你有关BangDream的问题,邦学家会基于资料库为你检索并作答。", |
|
|
additional_inputs=[], |
|
|
retry_btn=None, |
|
|
undo_btn=None, |
|
|
clear_btn="clear", |
|
|
examples=[ |
|
|
["乐奈为什么喜欢吉他?"], |
|
|
["LOCK和CHU²第一次见面是什么情节?"], |
|
|
["谁是RAS的初代成员?"], |
|
|
], |
|
|
outputs=[ |
|
|
gr.Textbox(label="Answer", lines=6, interactive=False), |
|
|
gr.Textbox(label="Reference", lines=8, interactive=False) |
|
|
] |
|
|
) |
|
|
demo.launch() |