Update app.py
Browse files
app.py
CHANGED
|
@@ -105,9 +105,21 @@ def generate_completion(prompt, max_tokens=256, temperature=0.7):
|
|
| 105 |
# 构建提示词 - 使用Qwen模型的对话格式
|
| 106 |
text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
| 107 |
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
# 生成响应
|
| 113 |
with torch.no_grad():
|
|
@@ -121,7 +133,7 @@ def generate_completion(prompt, max_tokens=256, temperature=0.7):
|
|
| 121 |
)
|
| 122 |
|
| 123 |
# 解码响应
|
| 124 |
-
response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
|
| 125 |
|
| 126 |
# 清理特殊标记
|
| 127 |
response = response.replace("<|im_end|>", "").strip()
|
|
@@ -264,10 +276,22 @@ async def create_chat_completion(
|
|
| 264 |
max_tokens = data.get("max_tokens", MAX_TOKENS)
|
| 265 |
temperature = data.get("temperature", 0.7)
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# 从消息中提取用户提示
|
| 268 |
user_message = ""
|
| 269 |
for msg in messages:
|
| 270 |
-
if msg.get("role") == "user":
|
| 271 |
user_message = msg.get("content", "")
|
| 272 |
break
|
| 273 |
|
|
@@ -343,10 +367,19 @@ async def openclaw_chat_api(
|
|
| 343 |
data = await request.json()
|
| 344 |
messages = data.get("messages", [])
|
| 345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
# 提取用户消息
|
| 347 |
user_message = ""
|
| 348 |
for msg in messages:
|
| 349 |
-
if msg.get("role") == "user":
|
| 350 |
user_message = msg.get("content", "")
|
| 351 |
break
|
| 352 |
|
|
|
|
| 105 |
# 构建提示词 - 使用Qwen模型的对话格式
|
| 106 |
text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
| 107 |
|
| 108 |
+
# 记录文本类型和长度
|
| 109 |
+
logger.info(f"输入文本类型: {type(text)}, 长度: {len(text)}")
|
| 110 |
+
|
| 111 |
+
# 编码输入 - 使用更安全的编码方式
|
| 112 |
+
encoding = tokenizer.encode_plus(
|
| 113 |
+
text,
|
| 114 |
+
return_tensors="pt",
|
| 115 |
+
truncation=True,
|
| 116 |
+
max_length=1024,
|
| 117 |
+
padding="max_length" if tokenizer.pad_token_id is not None else False
|
| 118 |
+
)
|
| 119 |
+
inputs = {
|
| 120 |
+
"input_ids": encoding["input_ids"].to(DEVICE),
|
| 121 |
+
"attention_mask": encoding["attention_mask"].to(DEVICE)
|
| 122 |
+
}
|
| 123 |
|
| 124 |
# 生成响应
|
| 125 |
with torch.no_grad():
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
# 解码响应
|
| 136 |
+
response = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
|
| 137 |
|
| 138 |
# 清理特殊标记
|
| 139 |
response = response.replace("<|im_end|>", "").strip()
|
|
|
|
| 276 |
max_tokens = data.get("max_tokens", MAX_TOKENS)
|
| 277 |
temperature = data.get("temperature", 0.7)
|
| 278 |
|
| 279 |
+
# 检查消息格式
|
| 280 |
+
if not messages or not isinstance(messages, list):
|
| 281 |
+
return JSONResponse(
|
| 282 |
+
status_code=400,
|
| 283 |
+
content={
|
| 284 |
+
"error": {
|
| 285 |
+
"message": "无效的消息格式",
|
| 286 |
+
"type": "invalid_request_error"
|
| 287 |
+
}
|
| 288 |
+
}
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
# 从消息中提取用户提示
|
| 292 |
user_message = ""
|
| 293 |
for msg in messages:
|
| 294 |
+
if isinstance(msg, dict) and msg.get("role") == "user":
|
| 295 |
user_message = msg.get("content", "")
|
| 296 |
break
|
| 297 |
|
|
|
|
| 367 |
data = await request.json()
|
| 368 |
messages = data.get("messages", [])
|
| 369 |
|
| 370 |
+
# 检查消息格式
|
| 371 |
+
if not messages or not isinstance(messages, list):
|
| 372 |
+
return JSONResponse(
|
| 373 |
+
status_code=400,
|
| 374 |
+
content={
|
| 375 |
+
"error": "无效的消息格式"
|
| 376 |
+
}
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
# 提取用户消息
|
| 380 |
user_message = ""
|
| 381 |
for msg in messages:
|
| 382 |
+
if isinstance(msg, dict) and msg.get("role") == "user":
|
| 383 |
user_message = msg.get("content", "")
|
| 384 |
break
|
| 385 |
|