howard9963 commited on
Commit
22db8a4
·
verified ·
1 Parent(s): ee74030

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -131,14 +131,17 @@ def call_llm(messages: List[dict], model: str, logs: List[str]) -> dict:
131
  sys_txt = messages[0].get("content", "") if messages else ""
132
  usr_txt = messages[1].get("content", "") if len(messages) > 1 else ""
133
  extra_rules = "\n\n請務必只輸出單一 JSON 物件,不得包含任何 JSON 之外的文字或符號。"
134
-
135
  chat = [
136
  {"role": "system", "content": sys_txt},
137
  {"role": "user", "content": usr_txt + extra_rules}
138
  ]
 
 
139
  prompt = _hf_tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
140
 
141
  inputs = _hf_tok(prompt, return_tensors="pt").to(_hf_model.device)
 
142
  with torch.no_grad():
143
  out_ids = _hf_model.generate(
144
  **inputs,
@@ -148,10 +151,13 @@ def call_llm(messages: List[dict], model: str, logs: List[str]) -> dict:
148
  eos_token_id=_hf_tok.eos_token_id,
149
  pad_token_id=_hf_tok.eos_token_id
150
  )
 
151
  full = _hf_tok.decode(out_ids[0], skip_special_tokens=True)
152
  gen = full[len(prompt):] if full.startswith(prompt) else full
 
153
  logs.append(f"[LOCAL LLM] Gen chars={len(gen)}")
154
-
 
155
  # 嘗試解析 JSON
156
  try:
157
  data = json.loads(gen)
 
131
  sys_txt = messages[0].get("content", "") if messages else ""
132
  usr_txt = messages[1].get("content", "") if len(messages) > 1 else ""
133
  extra_rules = "\n\n請務必只輸出單一 JSON 物件,不得包含任何 JSON 之外的文字或符號。"
134
+ print('準備 chat prompt(加上 JSON 輸出約束)')
135
  chat = [
136
  {"role": "system", "content": sys_txt},
137
  {"role": "user", "content": usr_txt + extra_rules}
138
  ]
139
+ print(f"user content:{usr_txt + extra_rules}")
140
+
141
  prompt = _hf_tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
142
 
143
  inputs = _hf_tok(prompt, return_tensors="pt").to(_hf_model.device)
144
+ print("inputs")
145
  with torch.no_grad():
146
  out_ids = _hf_model.generate(
147
  **inputs,
 
151
  eos_token_id=_hf_tok.eos_token_id,
152
  pad_token_id=_hf_tok.eos_token_id
153
  )
154
+ print("torch.no_grad")
155
  full = _hf_tok.decode(out_ids[0], skip_special_tokens=True)
156
  gen = full[len(prompt):] if full.startswith(prompt) else full
157
+ print("gen")
158
  logs.append(f"[LOCAL LLM] Gen chars={len(gen)}")
159
+ print(gen)
160
+ logs.append(gen)
161
  # 嘗試解析 JSON
162
  try:
163
  data = json.loads(gen)