Cyantist8208 commited on
Commit
2102b2f
·
1 Parent(s): ce2f7c2
Files changed (1) hide show
  1. app.py +12 -26
app.py CHANGED
@@ -15,16 +15,9 @@ from transformers import (
15
  import torch.nn.functional as F
16
  from collections import defaultdict
17
  HF_TOKEN = os.getenv("HF_token")
18
- CHAT_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
19
  EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
20
  MAX_PROMPT_TOKENS = 8192
21
- import transformers
22
- pipeline = transformers.pipeline(
23
- "text-generation",
24
- model=CHAT_MODEL_ID,
25
- model_kwargs={"torch_dtype": torch.bfloat16},
26
- device_map="auto",
27
- )
28
 
29
  # ---------- new defaults & helper ------------------
30
  DEFAULT_TEMP = 0.7
@@ -125,11 +118,16 @@ def build_llm_prompt(system: str, context: list[str], user_question: str) -> str
125
  # 套用 LLaMA-style prompt 格式
126
  input_token = tokenizer.apply_chat_template(
127
  conversation,
128
- tokenize=False,
129
- add_generation_prompt=False
130
  )
131
 
132
- return tokenizer.decode(input_token)
 
 
 
 
 
133
 
134
  # ---------- 4. Gradio playground (same UI as before) --------------------------
135
  def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
@@ -173,25 +171,13 @@ def answer(system: str, context: str, question: str,
173
  context_list += store["texts"]
174
 
175
  # 2. Build a Qwen-chat prompt (helper defined earlier)
176
- prompt = build_llm_prompt(system, context_list, question)
177
-
178
- # 3. Tokenise & cap
179
- load_chat()
180
- tokens = tokenizer(
181
- prompt,
182
- return_tensors="pt",
183
- add_special_tokens=False, # we built the chat template ourselves
184
- )
185
-
186
- if tokens["input_ids"].size(1) > MAX_PROMPT_TOKENS:
187
- tokens = {k: v[:, -MAX_PROMPT_TOKENS:] for k, v in tokens.items()}
188
-
189
- tokens = {k: v.to(chat_model.device) for k, v in tokens.items()}
190
 
191
  # --- generate ------------------------------------------------------
192
  output = chat_model.generate(
193
- **tokens,
194
  max_new_tokens=512,
 
195
  max_length=MAX_PROMPT_TOKENS + 512,
196
  do_sample=True,
197
  temperature=temperature,
 
15
  import torch.nn.functional as F
16
  from collections import defaultdict
17
  HF_TOKEN = os.getenv("HF_token")
18
+ CHAT_MODEL_ID = "NousResearch/Meta-Llama-3-8B-Instruct"
19
  EMB_MODEL_ID = "mixedbread-ai/mxbai-embed-large-v1"
20
  MAX_PROMPT_TOKENS = 8192
 
 
 
 
 
 
 
21
 
22
  # ---------- new defaults & helper ------------------
23
  DEFAULT_TEMP = 0.7
 
118
  # 套用 LLaMA-style prompt 格式
119
  input_token = tokenizer.apply_chat_template(
120
  conversation,
121
+ add_generation_prompt=True,
122
+ return_tensors="pt"
123
  )
124
 
125
+ terminators = [
126
+ tokenizer.eos_token_id,
127
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
128
+ ]
129
+
130
+ return input_token, terminators
131
 
132
  # ---------- 4. Gradio playground (same UI as before) --------------------------
133
  def store_doc(doc_text: str,user_id="demo",chunk_size=DEFAULT_CHUNK_SIZE,chunk_overlap=DEFAULT_CHUNK_OVERLAP):
 
171
  context_list += store["texts"]
172
 
173
  # 2. Build a Qwen-chat prompt (helper defined earlier)
174
+ input_ids, terminators = build_llm_prompt(system, context_list, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # --- generate ------------------------------------------------------
177
  output = chat_model.generate(
178
+ input_ids,
179
  max_new_tokens=512,
180
+ eos_token_id=terminators,
181
  max_length=MAX_PROMPT_TOKENS + 512,
182
  do_sample=True,
183
  temperature=temperature,