han145 commited on
Commit
5643b00
·
verified ·
1 Parent(s): 174f98d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -105,9 +105,21 @@ def generate_completion(prompt, max_tokens=256, temperature=0.7):
105
  # 构建提示词 - 使用Qwen模型的对话格式
106
  text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
107
 
108
- # 编码输入
109
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
110
- inputs = inputs.to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  # 生成响应
113
  with torch.no_grad():
@@ -121,7 +133,7 @@ def generate_completion(prompt, max_tokens=256, temperature=0.7):
121
  )
122
 
123
  # 解码响应
124
- response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
125
 
126
  # 清理特殊标记
127
  response = response.replace("<|im_end|>", "").strip()
@@ -264,10 +276,22 @@ async def create_chat_completion(
264
  max_tokens = data.get("max_tokens", MAX_TOKENS)
265
  temperature = data.get("temperature", 0.7)
266
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  # 从消息中提取用户提示
268
  user_message = ""
269
  for msg in messages:
270
- if msg.get("role") == "user":
271
  user_message = msg.get("content", "")
272
  break
273
 
@@ -343,10 +367,19 @@ async def openclaw_chat_api(
343
  data = await request.json()
344
  messages = data.get("messages", [])
345
 
 
 
 
 
 
 
 
 
 
346
  # 提取用户消息
347
  user_message = ""
348
  for msg in messages:
349
- if msg.get("role") == "user":
350
  user_message = msg.get("content", "")
351
  break
352
 
 
105
  # 构建提示词 - 使用Qwen模型的对话格式
106
  text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
107
 
108
+ # 记录文本类型和长度
109
+ logger.info(f"输入文本类型: {type(text)}, 长度: {len(text)}")
110
+
111
+ # 编码输入 - 使用更安全的编码方式
112
+ encoding = tokenizer.encode_plus(
113
+ text,
114
+ return_tensors="pt",
115
+ truncation=True,
116
+ max_length=1024,
117
+ padding="max_length" if tokenizer.pad_token_id is not None else False
118
+ )
119
+ inputs = {
120
+ "input_ids": encoding["input_ids"].to(DEVICE),
121
+ "attention_mask": encoding["attention_mask"].to(DEVICE)
122
+ }
123
 
124
  # 生成响应
125
  with torch.no_grad():
 
133
  )
134
 
135
  # 解码响应
136
+ response = tokenizer.decode(outputs[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
137
 
138
  # 清理特殊标记
139
  response = response.replace("<|im_end|>", "").strip()
 
276
  max_tokens = data.get("max_tokens", MAX_TOKENS)
277
  temperature = data.get("temperature", 0.7)
278
 
279
+ # 检查消息格式
280
+ if not messages or not isinstance(messages, list):
281
+ return JSONResponse(
282
+ status_code=400,
283
+ content={
284
+ "error": {
285
+ "message": "无效的消息格式",
286
+ "type": "invalid_request_error"
287
+ }
288
+ }
289
+ )
290
+
291
  # 从消息中提取用户提示
292
  user_message = ""
293
  for msg in messages:
294
+ if isinstance(msg, dict) and msg.get("role") == "user":
295
  user_message = msg.get("content", "")
296
  break
297
 
 
367
  data = await request.json()
368
  messages = data.get("messages", [])
369
 
370
+ # 检查消息格式
371
+ if not messages or not isinstance(messages, list):
372
+ return JSONResponse(
373
+ status_code=400,
374
+ content={
375
+ "error": "无效的消息格式"
376
+ }
377
+ )
378
+
379
  # 提取用户消息
380
  user_message = ""
381
  for msg in messages:
382
+ if isinstance(msg, dict) and msg.get("role") == "user":
383
  user_message = msg.get("content", "")
384
  break
385