ray-lei commited on
Commit
95e44f0
·
verified ·
1 Parent(s): 9cc89b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -4
app.py CHANGED
@@ -302,14 +302,20 @@ async def health_check():
302
  "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
303
  }
304
 
305
- # Anthropic Claude 格式的请求体
 
 
 
 
 
306
  class MessagesRequest(BaseModel):
307
  model: str
308
- messages: List[Message]
309
  max_tokens: Optional[int] = 2048
310
  temperature: Optional[float] = 0.7
311
  top_p: Optional[float] = 0.9
312
 
 
313
  @app.post("/v1/messages")
314
  async def messages_endpoint(request: MessagesRequest):
315
  """
@@ -322,8 +328,25 @@ async def messages_endpoint(request: MessagesRequest):
322
  except:
323
  raise HTTPException(status_code=503, detail="Model not loaded")
324
 
325
- # 使用已有的格式化和生逻辑
326
- prompt = format_messages(request.messages)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  response_text = generate_response(
328
  prompt,
329
  request.temperature,
 
302
  "device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
303
  }
304
 
305
+ from typing import Any
306
+
307
+ class AnthropicMessage(BaseModel):
308
+ role: str
309
+ content: Any # 注意,这里不设为 str,因为 Claude CLI 传的是 list
310
+
311
  class MessagesRequest(BaseModel):
312
  model: str
313
+ messages: List[AnthropicMessage]
314
  max_tokens: Optional[int] = 2048
315
  temperature: Optional[float] = 0.7
316
  top_p: Optional[float] = 0.9
317
 
318
+
319
  @app.post("/v1/messages")
320
  async def messages_endpoint(request: MessagesRequest):
321
  """
 
328
  except:
329
  raise HTTPException(status_code=503, detail="Model not loaded")
330
 
331
+ # 把 content 数组拼接纯文本
332
+ converted_messages = []
333
+ for msg in request.messages:
334
+ if isinstance(msg.content, list):
335
+ # 把每个 {"type":"text","text":"..."} 拼接
336
+ texts = []
337
+ for block in msg.content:
338
+ if isinstance(block, dict) and "text" in block:
339
+ texts.append(block["text"])
340
+ merged = "\n".join(texts)
341
+ elif isinstance(msg.content, str):
342
+ merged = msg.content
343
+ else:
344
+ merged = str(msg.content)
345
+
346
+ converted_messages.append(Message(role=msg.role, content=merged))
347
+
348
+ # 使用原本的格式化和生成逻辑
349
+ prompt = format_messages(converted_messages)
350
  response_text = generate_response(
351
  prompt,
352
  request.temperature,