Cyantist8208 commited on
Commit
acbd024
·
1 Parent(s): e07811d

parse context

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -76,17 +76,22 @@ def add_docs(user_id: str, docs: list[str]) -> int:
76
  )
77
  return len(docs)
78
  # ----- Qwen-chat prompt helper ---------------------------------------------
79
- def build_qwen_prompt(context: str, user_question: str) -> str:
80
- """Return a string that follows Qwen-Chat’s template."""
81
- load_chat() # make sure tokenizer is ready
82
- if context != "":
83
- context = f"Context:\n{context}\n\n"
84
  conversation = [
85
- {"role": "system",
86
- "content": "You are an email assistant. Use ONLY the context provided."},
87
- {"role": "user",
88
- "content": f"{context}{user_question}"}
89
  ]
 
 
 
 
 
 
 
 
 
90
  return tokenizer.apply_chat_template(
91
  conversation, tokenize=False, add_generation_prompt=True
92
  )
@@ -103,7 +108,7 @@ def store_doc(doc_text: str, user_id="demo"):
103
  except Exception as e:
104
  return f"Error during storing: {e}"
105
 
106
- def answer(question: str, user_id="demo", history="None"):
107
  """UI callback: retrieve, build prompt with Qwen tags, generate answer."""
108
  try:
109
  if not question.strip():
@@ -111,7 +116,6 @@ def answer(question: str, user_id="demo", history="None"):
111
  if history != "None" and not kb[user_id]["texts"]:
112
  return "No reference passage yet. Add one first."
113
 
114
- context = ""
115
  # 1. Retrieve top-k similar passages
116
  if history == "Some":
117
  q_vec = embed(question)
@@ -119,13 +123,13 @@ def answer(question: str, user_id="demo", history="None"):
119
  sims = torch.matmul(store["vecs"], q_vec) # [N]
120
  k = min(4, sims.numel())
121
  idxs = torch.topk(sims, k=k).indices.tolist()
122
- context = "\n".join(store["texts"][i] for i in idxs)
123
  elif history == "All":
124
  store = kb[user_id]
125
- context = "\n".join(store["texts"])
126
 
127
  # 2. Build a Qwen-chat prompt (helper defined earlier)
128
- prompt = build_qwen_prompt(history + context, question)
129
 
130
  # 3. Generate and strip everything before the assistant tag
131
  load_chat()
@@ -168,16 +172,21 @@ with gr.Blocks() as demo:
168
 
169
  # ---- Q & A ----
170
  question_box = gr.Textbox(lines=2, label="Ask a question")
171
- history_cb = gr.Textbox(value="None", label="Use chat history")
 
 
172
 
173
  answer_btn = gr.Button("Answer")
174
  answer_box = gr.Textbox(lines=6, label="Assistant reply")
175
 
176
  answer_btn.click(
177
- fn=answer,
178
- inputs=[question_box, user_id_box, history_cb],
179
- outputs=answer_box
180
- )
 
 
 
181
 
182
  # ---------- 3. FastAPI layer --------------------------------------------------
183
  class IngestReq(BaseModel):
 
76
  )
77
  return len(docs)
78
  # ----- Qwen-chat prompt helper ---------------------------------------------
79
+ def build_qwen_prompt(system: str, context: list[str], user_question: str) -> str:
80
+ """Return a Qwen-style prompt with multiple context items."""
81
+ load_chat() # 確保 tokenizer 載入
82
+
 
83
  conversation = [
84
+ {"role": "system", "content": system}
 
 
 
85
  ]
86
+
87
+ # 將每段 context 當作 user 提供的提示
88
+ for ctx in context:
89
+ if ctx.strip(): # 跳過空內容
90
+ conversation.append({"role": "user", "content": ctx})
91
+
92
+ # 加入最終問題
93
+ conversation.append({"role": "user", "content": user_question})
94
+
95
  return tokenizer.apply_chat_template(
96
  conversation, tokenize=False, add_generation_prompt=True
97
  )
 
108
  except Exception as e:
109
  return f"Error during storing: {e}"
110
 
111
+ def answer(system: str, context: list[str], question: str, user_id="demo", history="None"):
112
  """UI callback: retrieve, build prompt with Qwen tags, generate answer."""
113
  try:
114
  if not question.strip():
 
116
  if history != "None" and not kb[user_id]["texts"]:
117
  return "No reference passage yet. Add one first."
118
 
 
119
  # 1. Retrieve top-k similar passages
120
  if history == "Some":
121
  q_vec = embed(question)
 
123
  sims = torch.matmul(store["vecs"], q_vec) # [N]
124
  k = min(4, sims.numel())
125
  idxs = torch.topk(sims, k=k).indices.tolist()
126
+ context += [store["texts"][i] for i in idxs]
127
  elif history == "All":
128
  store = kb[user_id]
129
+ context += store["texts"]
130
 
131
  # 2. Build a Qwen-chat prompt (helper defined earlier)
132
+ prompt = build_qwen_prompt(system, context, question)
133
 
134
  # 3. Generate and strip everything before the assistant tag
135
  load_chat()
 
172
 
173
  # ---- Q & A ----
174
  question_box = gr.Textbox(lines=2, label="Ask a question")
175
+ history_cb = gr.Textbox(value="None", label="Use chat history")
176
+ system_box = gr.Textbox(lines=2, label="System prompt")
177
+ context_box = gr.Textbox(lines=6, label="Context passages (each line one passage)")
178
 
179
  answer_btn = gr.Button("Answer")
180
  answer_box = gr.Textbox(lines=6, label="Assistant reply")
181
 
182
  answer_btn.click(
183
+ fn=lambda sys, ctx, q, uid, h: answer(sys,
184
+ [line.strip() for line in ctx.splitlines() if line.strip()],
185
+ q, uid, h
186
+ ),
187
+ inputs=[system_box, context_box, question_box, user_id_box, history_cb],
188
+ outputs=answer_box
189
+ )
190
 
191
  # ---------- 3. FastAPI layer --------------------------------------------------
192
  class IngestReq(BaseModel):