Cudd1es commited on
Commit
c881ef3
·
1 Parent(s): d6922e3

fixed app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -40
app.py CHANGED
@@ -14,8 +14,11 @@ collection = load_collection()
14
  encoder = load_encoder()
15
  reranker = CrossEncoder("BAAI/bge-reranker-large")
16
 
17
- def build_rag_prompt(query, context):
18
- prompt = f"""已知资料如下:
 
 
 
19
  {context}
20
 
21
  用户提问:{query}
@@ -23,53 +26,86 @@ def build_rag_prompt(query, context):
23
  return prompt
24
 
25
 
26
- def answer_fn(query, history=None):
27
- query_vec = encode_query(encoder, query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  results = retrieve_docs(collection, query_vec, top_k=30)
29
- reranked = query_rerank(reranker, query, results, top_n=10)
30
  deduped = dedup_by_chapter_event(reranked, max_per_group=1)
31
  expanded_results = expand_with_neighbors(deduped[:3], collection)
32
-
33
  context = expanded_results[0][0] if expanded_results else ""
34
- rag_prompt = build_rag_prompt(query, context)
35
- system_prompt = "你是BangDream知识问答助手, 也就是邦学家. 只能基于提供的资料内容作答。"
36
 
37
- response = client.chat.completions.create(
 
 
 
 
 
 
 
38
  model="gpt-4o",
39
- messages=[
40
- {"role": "system", "content": system_prompt},
41
- {"role": "user", "content": rag_prompt}
42
- ],
43
- temperature=0.2,
44
- max_tokens=512,
45
  )
46
- answer = response.choices[0].message.content.strip()
 
 
 
 
47
 
48
- references = ""
49
- for idx, (doc, score, meta) in enumerate(expanded_results, 1):
50
- chapter = meta.get("chapterTitle", "UnknownChapter")
51
- event = meta.get("eventName", "UnknownEvent")
52
- references += f"\n--- reference: {idx} (chapter: {chapter}, event: {event}, score={score:.4f}) ---\n"
53
- references += doc[:300] + "...\n"
54
 
55
- return answer, references
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # Gradio UI
58
  with gr.Blocks(title="Dr-Bang RAG QA") as demo:
59
- gr.Markdown("# Dr-Bang RAG QA\n\n基于BangDream知识库的RAG问答系统。")
60
- with gr.Row():
61
- chatbot = gr.ChatInterface(
62
- fn=answer_fn,
63
- title="Dr-Bang RAG Chat",
64
- description="输入你的BangDream问题,AI助手会基于资料库为你检索并作答。",
65
- examples=[
66
- ["乐奈为什么喜欢吉他?"],
67
- ["LOCK和CHU²第一次见面是什么情节?"],
68
- ["谁是RAS的初代成员?"],
69
- ],
70
- outputs=[
71
- gr.Textbox(label="Answer", lines=6, interactive=False),
72
- gr.Textbox(label="Reference", lines=8, interactive=False)
73
- ]
74
  )
75
- demo.launch()
 
 
 
 
 
14
  encoder = load_encoder()
15
  reranker = CrossEncoder("BAAI/bge-reranker-large")
16
 
17
+
18
+ def build_rag_prompt(query, context, system_message):
19
+ prompt = f"""{system_message}
20
+
21
+ 已知资料如下:
22
  {context}
23
 
24
  用户提问:{query}
 
26
  return prompt
27
 
28
 
29
+ def respond(
30
+ message,
31
+ history: list[dict[str, str]],
32
+ system_message,
33
+ max_tokens,
34
+ temperature,
35
+ top_p,
36
+ ):
37
+ """
38
+ message: 当前输入内容
39
+ history: [{"role": "user", "content": ...}, {"role": "assistant", "content": ...}, ...]
40
+ system_message: 自定义 System Prompt
41
+ """
42
+ chat_history = [
43
+ {"role": "system", "content": system_message.strip() or "你是BangDream知识问答助手, 只能基于提供资料作答。"}
44
+ ]
45
+ chat_history.extend(history)
46
+ chat_history.append({"role": "user", "content": message})
47
+
48
+ query_vec = encode_query(encoder, message)
49
  results = retrieve_docs(collection, query_vec, top_k=30)
50
+ reranked = query_rerank(reranker, message, results, top_n=10)
51
  deduped = dedup_by_chapter_event(reranked, max_per_group=1)
52
  expanded_results = expand_with_neighbors(deduped[:3], collection)
 
53
  context = expanded_results[0][0] if expanded_results else ""
 
 
54
 
55
+ rag_prompt = build_rag_prompt(message, context, system_message)
56
+ messages = [
57
+ {"role": "system", "content": system_message.strip() or "你是BangDream知识问答助手, 只能基于提供资料作答。"},
58
+ {"role": "user", "content": rag_prompt}
59
+ ]
60
+
61
+ response = ""
62
+ stream = client.chat.completions.create(
63
  model="gpt-4o",
64
+ messages=messages,
65
+ temperature=temperature,
66
+ max_tokens=max_tokens,
67
+ top_p=top_p,
68
+ stream=True
 
69
  )
70
+ for chunk in stream:
71
+ delta = getattr(chunk.choices[0].delta, "content", None)
72
+ if delta:
73
+ response += delta
74
+ yield response
75
 
 
 
 
 
 
 
76
 
77
+ # ========== Gradio ChatInterface with extra sidebar inputs ==========
78
+ chatbot = gr.ChatInterface(
79
+ respond,
80
+ type="messages",
81
+ additional_inputs=[
82
+ gr.Textbox(value="你是BangDream知识问答助手, 只能基于提供资料内容作答。", label="System message"),
83
+ gr.Slider(minimum=64, maximum=1024, value=512, step=1, label="Max new tokens"),
84
+ gr.Slider(minimum=0.1, maximum=2.0, value=0.2, step=0.05, label="Temperature"),
85
+ gr.Slider(
86
+ minimum=0.1,
87
+ maximum=1.0,
88
+ value=0.95,
89
+ step=0.05,
90
+ label="Top-p (nucleus sampling)",
91
+ ),
92
+ ],
93
+ examples=[
94
+ ["乐奈为什么喜欢吉他?"],
95
+ ["LOCK和CHU²第一次见面是什么情节?"],
96
+ ["谁是RAS的初代成员?"],
97
+ ],
98
+ description="输入你关于BangDream的问题,邦学家会基于资料库为你检索并作答",
99
+ title="Dr-Bang RAG QA Chatbot"
100
+ )
101
 
 
102
  with gr.Blocks(title="Dr-Bang RAG QA") as demo:
103
+ with gr.Sidebar():
104
+ gr.Markdown(
105
+ "## Dr-Bang QA\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
106
  )
107
+ gr.LoginButton()
108
+ chatbot.render()
109
+
110
+ if __name__ == "__main__":
111
+ demo.launch()