File size: 3,027 Bytes
f552653
05ddcd3
 
 
 
 
 
 
f552653
05ddcd3
 
f552653
05ddcd3
 
 
f552653
05ddcd3
 
 
f552653
05ddcd3
 
 
f552653
 
05ddcd3
 
 
 
 
 
f552653
05ddcd3
 
 
f552653
05ddcd3
 
 
 
 
 
 
 
 
 
f552653
05ddcd3
 
 
 
 
 
f552653
05ddcd3
f552653
05ddcd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import os
from openai import OpenAI
from retriever import (
    load_collection, load_encoder, encode_query, retrieve_docs,
    query_rerank, expand_with_neighbors, dedup_by_chapter_event
)
from sentence_transformers import CrossEncoder

api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

collection = load_collection()
encoder = load_encoder()
reranker = CrossEncoder("BAAI/bge-reranker-large")

def build_rag_prompt(query, context):
    prompt = f"""已知资料如下:
{context}

用户提问:{query}
请参考所有已知资料, 并结合资料内容,简明、准确地回答问题。如果有多个符合的答案, 可以根据你是否确定而决定是否分别陈述这些答案.如果不能确定答案,请如实说明理由,不要凭空编造。"""
    return prompt


def answer_fn(query, history=None):
    query_vec = encode_query(encoder, query)
    results = retrieve_docs(collection, query_vec, top_k=30)
    reranked = query_rerank(reranker, query, results, top_n=10)
    deduped = dedup_by_chapter_event(reranked, max_per_group=1)
    expanded_results = expand_with_neighbors(deduped[:3], collection)

    context = expanded_results[0][0] if expanded_results else ""
    rag_prompt = build_rag_prompt(query, context)
    system_prompt = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": rag_prompt}
        ],
        temperature=0.2,
        max_tokens=512,
    )
    answer = response.choices[0].message.content.strip()

    references = ""
    for idx, (doc, score, meta) in enumerate(expanded_results, 1):
        chapter = meta.get("chapterTitle", "UnknownChapter")
        event = meta.get("eventName", "UnknownEvent")
        references += f"\n--- reference: {idx} (chapter: {chapter}, event: {event}, score={score:.4f}) ---\n"
        references += doc[:300] + "...\n"

    return answer, references

# Gradio UI
with gr.Blocks(title="Dr-Bang RAG QA") as demo:
    gr.Markdown("# Dr-Bang RAG QA\n\n基于BangDream知识库的RAG问答系统。")
    with gr.Row():
        chatbot = gr.ChatInterface(
            fn=answer_fn,
            title="Dr-Bang RAG Chat",
            description="输入你有关BangDream的问题,邦学家会基于资料库为你检索并作答。",
            additional_inputs=[],
            retry_btn=None,
            undo_btn=None,
            clear_btn="clear",
            examples=[
                ["乐奈为什么喜欢吉他?"],
                ["LOCK和CHU²第一次见面是什么情节?"],
                ["谁是RAS的初代成员?"],
            ],
            outputs=[
                gr.Textbox(label="Answer", lines=6, interactive=False),
                gr.Textbox(label="Reference", lines=8, interactive=False)
            ]
        )
demo.launch()