Cyantist8208 commited on
Commit
c0c0f5a
·
1 Parent(s): 2102b2f
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -116,19 +116,12 @@ def build_llm_prompt(system: str, context: list[str], user_question: str) -> str
116
  conversation.append({"role": "user", "content": user_question.strip()})
117
 
118
  # 套用 LLaMA-style prompt 格式
119
- input_token = tokenizer.apply_chat_template(
120
  conversation,
121
- add_generation_prompt=True,
122
- return_tensors="pt"
123
  )
124
 
125
- terminators = [
126
- tokenizer.eos_token_id,
127
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
128
- ]
129
-
130
- return input_token, terminators
131
-
132
  # ---------- 4. Gradio playground (same UI as before) --------------------------
133
  def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
134
  try:
@@ -171,13 +164,25 @@ def answer(system: str, context: str, question: str,
171
  context_list += store["texts"]
172
 
173
  # 2. Build a Qwen-chat prompt (helper defined earlier)
174
- input_ids, terminators = build_llm_prompt(system, context_list, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # --- generate ------------------------------------------------------
177
  output = chat_model.generate(
178
- input_ids,
179
  max_new_tokens=512,
180
- eos_token_id=terminators,
181
  max_length=MAX_PROMPT_TOKENS + 512,
182
  do_sample=True,
183
  temperature=temperature,
@@ -185,7 +190,7 @@ def answer(system: str, context: str, question: str,
185
  top_k=top_k_tok
186
  )
187
  full = tokenizer.decode(output[0], skip_special_tokens=True)
188
- reply = full.split("<|im_start|>assistant")[-1].strip()
189
  return reply
190
 
191
  except Exception as e:
 
116
  conversation.append({"role": "user", "content": user_question.strip()})
117
 
118
  # 套用 LLaMA-style prompt 格式
119
+ return tokenizer.apply_chat_template(
120
  conversation,
121
+ tokenize=False,
122
+ add_generation_prompt=False
123
  )
124
 
 
 
 
 
 
 
 
125
  # ---------- 4. Gradio playground (same UI as before) --------------------------
126
  def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
127
  try:
 
164
  context_list += store["texts"]
165
 
166
  # 2. Build a Qwen-chat prompt (helper defined earlier)
167
+ prompt = build_llm_prompt(system, context_list, question)
168
+
169
+ # 3. Tokenise & cap
170
+ load_chat()
171
+ tokens = tokenizer(
172
+ prompt,
173
+ return_tensors="pt",
174
+ add_special_tokens=False, # we built the chat template ourselves
175
+ )
176
+
177
+ if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
178
+ tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
179
+
180
+ tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
181
 
182
  # --- generate ------------------------------------------------------
183
  output = chat_model.generate(
184
+ **tokens,
185
  max_new_tokens=512,
 
186
  max_length=MAX_PROMPT_TOKENS + 512,
187
  do_sample=True,
188
  temperature=temperature,
 
190
  top_k=top_k_tok
191
  )
192
  full = tokenizer.decode(output[0], skip_special_tokens=True)
193
+ reply = full.split("<|im_start|>assistant")[-1].strip() + tokenizer.chat_template
194
  return reply
195
 
196
  except Exception as e: