Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -131,14 +131,17 @@ def call_llm(messages: List[dict], model: str, logs: List[str]) -> dict:
|
|
| 131 |
sys_txt = messages[0].get("content", "") if messages else ""
|
| 132 |
usr_txt = messages[1].get("content", "") if len(messages) > 1 else ""
|
| 133 |
extra_rules = "\n\n請務必只輸出單一 JSON 物件,不得包含任何 JSON 之外的文字或符號。"
|
| 134 |
-
|
| 135 |
chat = [
|
| 136 |
{"role": "system", "content": sys_txt},
|
| 137 |
{"role": "user", "content": usr_txt + extra_rules}
|
| 138 |
]
|
|
|
|
|
|
|
| 139 |
prompt = _hf_tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
| 140 |
|
| 141 |
inputs = _hf_tok(prompt, return_tensors="pt").to(_hf_model.device)
|
|
|
|
| 142 |
with torch.no_grad():
|
| 143 |
out_ids = _hf_model.generate(
|
| 144 |
**inputs,
|
|
@@ -148,10 +151,13 @@ def call_llm(messages: List[dict], model: str, logs: List[str]) -> dict:
|
|
| 148 |
eos_token_id=_hf_tok.eos_token_id,
|
| 149 |
pad_token_id=_hf_tok.eos_token_id
|
| 150 |
)
|
|
|
|
| 151 |
full = _hf_tok.decode(out_ids[0], skip_special_tokens=True)
|
| 152 |
gen = full[len(prompt):] if full.startswith(prompt) else full
|
|
|
|
| 153 |
logs.append(f"[LOCAL LLM] Gen chars={len(gen)}")
|
| 154 |
-
|
|
|
|
| 155 |
# 嘗試解析 JSON
|
| 156 |
try:
|
| 157 |
data = json.loads(gen)
|
|
|
|
| 131 |
sys_txt = messages[0].get("content", "") if messages else ""
|
| 132 |
usr_txt = messages[1].get("content", "") if len(messages) > 1 else ""
|
| 133 |
extra_rules = "\n\n請務必只輸出單一 JSON 物件,不得包含任何 JSON 之外的文字或符號。"
|
| 134 |
+
print('準備 chat prompt(加上 JSON 輸出約束)')
|
| 135 |
chat = [
|
| 136 |
{"role": "system", "content": sys_txt},
|
| 137 |
{"role": "user", "content": usr_txt + extra_rules}
|
| 138 |
]
|
| 139 |
+
print(f"user content:{usr_txt + extra_rules}")
|
| 140 |
+
|
| 141 |
prompt = _hf_tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
| 142 |
|
| 143 |
inputs = _hf_tok(prompt, return_tensors="pt").to(_hf_model.device)
|
| 144 |
+
print("inputs")
|
| 145 |
with torch.no_grad():
|
| 146 |
out_ids = _hf_model.generate(
|
| 147 |
**inputs,
|
|
|
|
| 151 |
eos_token_id=_hf_tok.eos_token_id,
|
| 152 |
pad_token_id=_hf_tok.eos_token_id
|
| 153 |
)
|
| 154 |
+
print("torch.no_grad")
|
| 155 |
full = _hf_tok.decode(out_ids[0], skip_special_tokens=True)
|
| 156 |
gen = full[len(prompt):] if full.startswith(prompt) else full
|
| 157 |
+
print("gen")
|
| 158 |
logs.append(f"[LOCAL LLM] Gen chars={len(gen)}")
|
| 159 |
+
print(gen)
|
| 160 |
+
logs.append(gen)
|
| 161 |
# 嘗試解析 JSON
|
| 162 |
try:
|
| 163 |
data = json.loads(gen)
|