nacho commited on
Commit
3be6e48
·
1 Parent(s): 9fba9a0

fix: support array content blocks from AstrBot (multi-modal message format)

Browse files
Files changed (1) hide show
  1. main.py +17 -4
main.py CHANGED
@@ -8,7 +8,7 @@ import re
8
  import time
9
  import uuid
10
  from pathlib import Path
11
- from typing import Optional
12
 
13
  LOG_FORMAT = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
14
  LOG_BUFFER_SIZE = 500
@@ -75,9 +75,14 @@ manager = AccountManager(
75
  )
76
 
77
 
 
 
 
 
 
78
  class Message(BaseModel):
79
  role: str
80
- content: str
81
 
82
 
83
  class ChatCompletionRequest(BaseModel):
@@ -89,6 +94,14 @@ class ChatCompletionRequest(BaseModel):
89
  tools: Optional[list[dict]] = None
90
 
91
 
 
 
 
 
 
 
 
 
92
  def verify_api_key(authorization: Optional[str] = Header(None)) -> str:
93
  if not authorization:
94
  raise HTTPException(status_code=401, detail="Missing API key")
@@ -225,7 +238,7 @@ async def chat_completions(
225
  if not request.messages:
226
  raise HTTPException(status_code=400, detail="No messages provided")
227
 
228
- prompt = request.messages[-1].content
229
 
230
  if request.tools:
231
  tool_desc = json.dumps(request.tools, ensure_ascii=False)
@@ -505,7 +518,7 @@ async def admin_chat(request: Request, admin_key: str = Header(...)):
505
  if not req.messages:
506
  raise HTTPException(status_code=400, detail="No messages provided")
507
 
508
- prompt = req.messages[-1].content
509
 
510
  if req.tools:
511
  tool_desc = json.dumps(req.tools, ensure_ascii=False)
 
8
  import time
9
  import uuid
10
  from pathlib import Path
11
+ from typing import Optional, Union
12
 
13
  LOG_FORMAT = "%(asctime)s [%(levelname)s] %(name)s: %(message)s"
14
  LOG_BUFFER_SIZE = 500
 
75
  )
76
 
77
 
78
+ class ContentBlock(BaseModel):
79
+ type: str = "text"
80
+ text: str = ""
81
+
82
+
83
  class Message(BaseModel):
84
  role: str
85
+ content: Union[str, list[ContentBlock]]
86
 
87
 
88
  class ChatCompletionRequest(BaseModel):
 
94
  tools: Optional[list[dict]] = None
95
 
96
 
97
+ def _get_message_text(msg: Message) -> str:
98
+ """Extract plain text from a message, handling both string and array content."""
99
+ c = msg.content
100
+ if isinstance(c, str):
101
+ return c
102
+ return "".join(b.text for b in c if b.text and b.type == "text")
103
+
104
+
105
  def verify_api_key(authorization: Optional[str] = Header(None)) -> str:
106
  if not authorization:
107
  raise HTTPException(status_code=401, detail="Missing API key")
 
238
  if not request.messages:
239
  raise HTTPException(status_code=400, detail="No messages provided")
240
 
241
+ prompt = _get_message_text(request.messages[-1])
242
 
243
  if request.tools:
244
  tool_desc = json.dumps(request.tools, ensure_ascii=False)
 
518
  if not req.messages:
519
  raise HTTPException(status_code=400, detail="No messages provided")
520
 
521
+ prompt = _get_message_text(req.messages[-1])
522
 
523
  if req.tools:
524
  tool_desc = json.dumps(req.tools, ensure_ascii=False)