Cudd1es commited on
Commit
e20c58d
·
1 Parent(s): f7d6f25

fixed app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -6
app.py CHANGED
@@ -7,17 +7,35 @@ from retriever import (
7
  )
8
  from sentence_transformers import CrossEncoder
9
 
10
- api_key = os.getenv("OPENROUTER_API_KEY")
11
- client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
 
 
 
 
 
 
 
12
 
13
  collection = load_collection()
14
  encoder = load_encoder()
15
  reranker = CrossEncoder("BAAI/bge-reranker-large")
16
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def build_rag_prompt(query, context, system_message):
19
  prompt = f"""{system_message}
20
-
21
  已知资料如下:
22
  {context}
23
 
@@ -49,10 +67,24 @@ def respond(
49
  system_msg = (system_message or default_system_message).strip()
50
  chat_history = [{"role": "system", "content": system_msg}]
51
 
 
 
 
 
 
 
 
52
  query_vec = encode_query(encoder, message)
53
- results = retrieve_docs(collection, query_vec, top_k=30)
54
  reranked = query_rerank(reranker, message, results, top_n=10)
55
- deduped = dedup_by_chapter_event(reranked, max_per_group=1)
 
 
 
 
 
 
 
56
  expanded_results = expand_with_neighbors(deduped[:3], collection)
57
  context = expanded_results[0][0] if expanded_results else ""
58
 
@@ -64,7 +96,7 @@ def respond(
64
 
65
  response = ""
66
  stream = client.chat.completions.create(
67
- model="qwen/qwen3-235b-a22b:free",
68
  messages=messages,
69
  temperature=temperature,
70
  max_tokens=max_tokens,
 
7
  )
8
  from sentence_transformers import CrossEncoder
9
 
10
+ QWEN_MODEL="qwen/qwen3-235b-a22b:free"
11
+ #api_key = os.getenv("OPENROUTER_API_KEY")
12
+ #client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
13
+
14
+ # open ai method
15
+ api_key = os.getenv("OPENAI_API_KEY")
16
+ client = OpenAI(api_key=api_key)
17
+ GPT_MODEL="gpt-4o"
18
+
19
 
20
  collection = load_collection()
21
  encoder = load_encoder()
22
  reranker = CrossEncoder("BAAI/bge-reranker-large")
23
 
24
+ def reformulate_query(user_question, model_name=GPT_MODEL):
25
+ prompt = f"""你是一个BangDream知识检索助手。请把用户的问题扩写或转写为适合知识库语义检索的检索语句,涵盖所有可能的提问方式或同义关键词。
26
+ 用户问题:{user_question}
27
+ """
28
+ resp = client.chat.completions.create(
29
+ model=model_name,
30
+ messages=[{"role": "user", "content": prompt}],
31
+ temperature=0.1,
32
+ max_tokens=4096,
33
+ )
34
+ return resp.choices[0].message.content.strip()
35
 
36
  def build_rag_prompt(query, context, system_message):
37
  prompt = f"""{system_message}
38
+ 你将获得多个独立的资料片段,请充分查阅每一条资料.
39
  已知资料如下:
40
  {context}
41
 
 
67
  system_msg = (system_message or default_system_message).strip()
68
  chat_history = [{"role": "system", "content": system_msg}]
69
 
70
+ # reformulate query
71
+ print("Reformulating...")
72
+ reformulated_query_text = reformulate_query(message)
73
+ print(f"[DEBUG] reformulated query: {reformulated_query_text}")
74
+
75
+ print("Thinking...\n...")
76
+ # rerank original query
77
  query_vec = encode_query(encoder, message)
78
+ results = retrieve_docs(collection, query_vec, top_k=20)
79
  reranked = query_rerank(reranker, message, results, top_n=10)
80
+
81
+ # rerank reformulated query
82
+ reformulated_query_vec = encode_query(encoder, reformulated_query_text)
83
+ reformulated_results = retrieve_docs(collection, reformulated_query_vec, top_k=20)
84
+ reformulated_reranked = query_rerank(reranker, reformulated_query_text, reformulated_results, top_n=10)
85
+
86
+ total_reranked = reranked + reformulated_reranked
87
+ deduped = dedup_by_chapter_event(total_reranked, max_per_group=1)
88
  expanded_results = expand_with_neighbors(deduped[:3], collection)
89
  context = expanded_results[0][0] if expanded_results else ""
90
 
 
96
 
97
  response = ""
98
  stream = client.chat.completions.create(
99
+ model=GPT_MODEL,
100
  messages=messages,
101
  temperature=temperature,
102
  max_tokens=max_tokens,