L瓢u Quang V农 Nativu5 commited on
Commit
77830e0
unverified
1 Parent(s): fe5ae62

:sparkles: Add tool call support and fix image input and output (#47)

Browse files

* Add tool_calls support

* Add tool_calls support

* Add tool_calls support

* Add tool_calls support

* Add tool_calls support

* Add endpoint generate and edit images

* Add endpoint generate and edit images

* Add support structured output

* Add support structured output

* Incorrect logging

* Force LLM to follow tool_call format

* Format the code using Black.

* Fixes "Error handling message: No image returned"

* Return image dimensions

* Fixes Pylance warning

* Format by ruff

* uv run directly

* Fixes XML_WRAP_HINT leaked

* Instruct an LLM to return code snippets enclosed within Markdown fenced code blocks.

* Adjust the streaming response logic to ensure important sections remain intact and are not fragmented.

* Change chunk_size to 64

* ruff check

* Fixes for im_start/im_end hints leaking from responses.

* Ensure all endpoints are fully compliant with OpenAI compatibility standards.

* :memo: Fix doc

---------

Co-authored-by: Nativu5 <44155313+Nativu5@users.noreply.github.com>

app/models/models.py CHANGED
@@ -1,5 +1,7 @@
 
 
1
  from datetime import datetime
2
- from typing import Dict, List, Literal, Optional, Union
3
 
4
  from pydantic import BaseModel, Field
5
 
@@ -17,8 +19,9 @@ class Message(BaseModel):
17
  """Message model"""
18
 
19
  role: str
20
- content: Union[str, List[ContentItem]]
21
  name: Optional[str] = None
 
22
 
23
 
24
  class Choice(BaseModel):
@@ -29,6 +32,49 @@ class Choice(BaseModel):
29
  finish_reason: str
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class Usage(BaseModel):
33
  """Usage statistics model"""
34
 
@@ -51,14 +97,16 @@ class ChatCompletionRequest(BaseModel):
51
 
52
  model: str
53
  messages: List[Message]
 
 
54
  temperature: Optional[float] = 0.7
55
  top_p: Optional[float] = 1.0
56
- n: Optional[int] = 1
57
- stream: Optional[bool] = False
58
  max_tokens: Optional[int] = None
59
- presence_penalty: Optional[float] = 0
60
- frequency_penalty: Optional[float] = 0
61
- user: Optional[str] = None
 
 
62
 
63
 
64
  class ChatCompletionResponse(BaseModel):
@@ -101,3 +149,130 @@ class ConversationInStore(BaseModel):
101
  ..., description="Metadata for Gemini API to locate the conversation"
102
  )
103
  messages: list[Message] = Field(..., description="Message contents in the conversation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
  from datetime import datetime
4
+ from typing import Any, Dict, List, Literal, Optional, Union
5
 
6
  from pydantic import BaseModel, Field
7
 
 
19
  """Message model"""
20
 
21
  role: str
22
+ content: Union[str, List[ContentItem], None] = None
23
  name: Optional[str] = None
24
+ tool_calls: Optional[List["ToolCall"]] = None
25
 
26
 
27
  class Choice(BaseModel):
 
32
  finish_reason: str
33
 
34
 
35
+ class FunctionCall(BaseModel):
36
+ """Function call payload"""
37
+
38
+ name: str
39
+ arguments: str
40
+
41
+
42
+ class ToolCall(BaseModel):
43
+ """Tool call item"""
44
+
45
+ id: str
46
+ type: Literal["function"]
47
+ function: FunctionCall
48
+
49
+
50
+ class ToolFunctionDefinition(BaseModel):
51
+ """Function definition for tool."""
52
+
53
+ name: str
54
+ description: Optional[str] = None
55
+ parameters: Optional[Dict[str, Any]] = None
56
+
57
+
58
+ class Tool(BaseModel):
59
+ """Tool specification."""
60
+
61
+ type: Literal["function"]
62
+ function: ToolFunctionDefinition
63
+
64
+
65
+ class ToolChoiceFunctionDetail(BaseModel):
66
+ """Detail of a tool choice function."""
67
+
68
+ name: str
69
+
70
+
71
+ class ToolChoiceFunction(BaseModel):
72
+ """Tool choice forcing a specific function."""
73
+
74
+ type: Literal["function"]
75
+ function: ToolChoiceFunctionDetail
76
+
77
+
78
  class Usage(BaseModel):
79
  """Usage statistics model"""
80
 
 
97
 
98
  model: str
99
  messages: List[Message]
100
+ stream: Optional[bool] = False
101
+ user: Optional[str] = None
102
  temperature: Optional[float] = 0.7
103
  top_p: Optional[float] = 1.0
 
 
104
  max_tokens: Optional[int] = None
105
+ tools: Optional[List["Tool"]] = None
106
+ tool_choice: Optional[
107
+ Union[Literal["none"], Literal["auto"], Literal["required"], "ToolChoiceFunction"]
108
+ ] = None
109
+ response_format: Optional[Dict[str, Any]] = None
110
 
111
 
112
  class ChatCompletionResponse(BaseModel):
 
149
  ..., description="Metadata for Gemini API to locate the conversation"
150
  )
151
  messages: list[Message] = Field(..., description="Message contents in the conversation")
152
+
153
+
154
+ class ResponseInputContent(BaseModel):
155
+ """Content item for Responses API input."""
156
+
157
+ type: Literal["input_text", "input_image"]
158
+ text: Optional[str] = None
159
+ image_url: Optional[str] = None
160
+ image_base64: Optional[str] = None
161
+ mime_type: Optional[str] = None
162
+
163
+
164
+ class ResponseInputItem(BaseModel):
165
+ """Single input item for Responses API."""
166
+
167
+ type: Optional[Literal["message"]] = "message"
168
+ role: Literal["user", "assistant", "system", "developer"]
169
+ content: Union[str, List[ResponseInputContent]]
170
+
171
+
172
+ class ResponseToolChoice(BaseModel):
173
+ """Tool choice enforcing a specific tool in Responses API."""
174
+
175
+ type: Literal["image_generation"]
176
+
177
+
178
+ class ResponseImageTool(BaseModel):
179
+ """Image generation tool specification for Responses API."""
180
+
181
+ type: Literal["image_generation"]
182
+ model: Optional[str] = None
183
+ output_format: Optional[str] = None
184
+
185
+
186
+ class ResponseCreateRequest(BaseModel):
187
+ """Responses API request payload."""
188
+
189
+ model: str
190
+ input: Union[str, List[ResponseInputItem]]
191
+ instructions: Optional[Union[str, List[ResponseInputItem]]] = None
192
+ temperature: Optional[float] = 0.7
193
+ top_p: Optional[float] = 1.0
194
+ max_output_tokens: Optional[int] = None
195
+ stream: Optional[bool] = False
196
+ tool_choice: Optional[ResponseToolChoice] = None
197
+ tools: Optional[List[ResponseImageTool]] = None
198
+ store: Optional[bool] = None
199
+ user: Optional[str] = None
200
+ response_format: Optional[Dict[str, Any]] = None
201
+ metadata: Optional[Dict[str, Any]] = None
202
+
203
+
204
+ class ResponseUsage(BaseModel):
205
+ """Usage statistics for Responses API."""
206
+
207
+ input_tokens: int
208
+ output_tokens: int
209
+ total_tokens: int
210
+
211
+
212
+ class ResponseOutputContent(BaseModel):
213
+ """Content item for Responses API output."""
214
+
215
+ type: Literal["output_text", "output_image"]
216
+ text: Optional[str] = None
217
+ image_base64: Optional[str] = None
218
+ mime_type: Optional[str] = None
219
+ width: Optional[int] = None
220
+ height: Optional[int] = None
221
+
222
+
223
+ class ResponseOutputMessage(BaseModel):
224
+ """Assistant message returned by Responses API."""
225
+
226
+ id: str
227
+ type: Literal["message"]
228
+ role: Literal["assistant"]
229
+ content: List[ResponseOutputContent]
230
+
231
+
232
+ class ResponseImageGenerationCall(BaseModel):
233
+ """Image generation call record emitted in Responses API."""
234
+
235
+ id: str
236
+ type: Literal["image_generation_call"] = "image_generation_call"
237
+ status: Literal["completed", "in_progress", "generating", "failed"] = "completed"
238
+ result: Optional[str] = None
239
+ output_format: Optional[str] = None
240
+ size: Optional[str] = None
241
+ revised_prompt: Optional[str] = None
242
+
243
+
244
+ class ResponseToolCall(BaseModel):
245
+ """Tool call record emitted in Responses API."""
246
+
247
+ id: str
248
+ type: Literal["tool_call"] = "tool_call"
249
+ status: Literal["in_progress", "completed", "failed", "requires_action"] = "completed"
250
+ function: FunctionCall
251
+
252
+
253
+ class ResponseCreateResponse(BaseModel):
254
+ """Responses API response payload."""
255
+
256
+ id: str
257
+ object: Literal["response"] = "response"
258
+ created: int
259
+ model: str
260
+ output: List[Union[ResponseOutputMessage, ResponseImageGenerationCall, ResponseToolCall]]
261
+ output_text: Optional[str] = None
262
+ status: Literal[
263
+ "in_progress",
264
+ "completed",
265
+ "failed",
266
+ "incomplete",
267
+ "requires_action",
268
+ ] = "completed"
269
+ usage: ResponseUsage
270
+ metadata: Optional[Dict[str, Any]] = None
271
+ system_fingerprint: Optional[str] = None
272
+ input: Optional[Union[str, List[ResponseInputItem]]] = None
273
+
274
+
275
+ # Rebuild models with forward references
276
+ Message.model_rebuild()
277
+ ToolCall.model_rebuild()
278
+ ChatCompletionRequest.model_rebuild()
app/server/chat.py CHANGED
@@ -1,26 +1,44 @@
 
 
 
 
1
  import uuid
 
2
  from datetime import datetime, timezone
3
  from pathlib import Path
 
4
 
5
  import orjson
6
  from fastapi import APIRouter, Depends, HTTPException, status
7
  from fastapi.responses import StreamingResponse
8
  from gemini_webapi.client import ChatSession
9
  from gemini_webapi.constants import Model
 
10
  from loguru import logger
11
 
12
  from ..models import (
13
  ChatCompletionRequest,
 
14
  ConversationInStore,
 
15
  Message,
16
  ModelData,
17
  ModelListResponse,
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
- from ..services import (
20
- GeminiClientPool,
21
- GeminiClientWrapper,
22
- LMDBConversationStore,
23
- )
24
  from ..utils import g_config
25
  from ..utils.helper import estimate_tokens
26
  from .middleware import get_temp_dir, verify_api_key
@@ -30,10 +48,396 @@ MAX_CHARS_PER_REQUEST = int(g_config.gemini.max_chars_per_request * 0.9)
30
 
31
  CONTINUATION_HINT = "\n(More messages to come, please reply with just 'ok.')"
32
 
 
 
 
 
 
 
 
 
33
 
34
  router = APIRouter()
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @router.get("/v1/models", response_model=ModelListResponse)
38
  async def list_models(api_key: str = Depends(verify_api_key)):
39
  now = int(datetime.now(tz=timezone.utc).timestamp())
@@ -71,29 +475,51 @@ async def create_chat_completion(
71
  detail="At least one message is required in the conversation.",
72
  )
73
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Check if conversation is reusable
75
  session, client, remaining_messages = _find_reusable_session(db, pool, model, request.messages)
76
 
77
  if session:
78
- # Prepare the model input depending on how many turns are missing.
79
- if len(remaining_messages) == 1:
 
 
 
 
 
 
 
80
  model_input, files = await GeminiClientWrapper.process_message(
81
- remaining_messages[0], tmp_dir, tagged=False
82
  )
83
  else:
84
  model_input, files = await GeminiClientWrapper.process_conversation(
85
- remaining_messages, tmp_dir
86
  )
87
  logger.debug(
88
- f"Reused session {session.metadata} - sending {len(remaining_messages)} new messages."
89
  )
90
  else:
91
  # Start a new session and concat messages into a single string
92
  try:
93
  client = pool.acquire()
94
  session = client.start_chat(model=model)
 
 
 
95
  model_input, files = await GeminiClientWrapper.process_conversation(
96
- request.messages, tmp_dir
97
  )
98
  except ValueError as e:
99
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@@ -114,12 +540,46 @@ async def create_chat_completion(
114
  raise
115
 
116
  # Format the response from API
117
- model_output = GeminiClientWrapper.extract_output(response, include_thoughts=True)
118
- stored_output = GeminiClientWrapper.extract_output(response, include_thoughts=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # After formatting, persist the conversation to LMDB
121
  try:
122
- last_message = Message(role="assistant", content=stored_output)
 
 
 
 
123
  cleaned_history = db.sanitize_assistant_messages(request.messages)
124
  conv = ConversationInStore(
125
  model=model.model_name,
@@ -138,7 +598,8 @@ async def create_chat_completion(
138
  timestamp = int(datetime.now(tz=timezone.utc).timestamp())
139
  if request.stream:
140
  return _create_streaming_response(
141
- model_output,
 
142
  completion_id,
143
  timestamp,
144
  request.model,
@@ -146,17 +607,277 @@ async def create_chat_completion(
146
  )
147
  else:
148
  return _create_standard_response(
149
- model_output, completion_id, timestamp, request.model, request.messages
 
 
 
 
 
150
  )
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def _text_from_message(message: Message) -> str:
154
  """Return text content from a message for token estimation."""
 
155
  if isinstance(message.content, str):
156
- return message.content
157
- return "\n".join(
158
- item.text or "" for item in message.content if getattr(item, "type", "") == "text"
159
- )
 
 
 
 
 
 
 
 
 
160
 
161
 
162
  def _find_reusable_session(
@@ -172,7 +893,7 @@ def _find_reusable_session(
172
  ---------
173
  When a reply was generated by *another* server instance, the local LMDB may
174
  only contain an older part of the conversation. However, as long as we can
175
- line-up **any** earlier assistant/system response, we can restore the
176
  corresponding Gemini session and replay the *remaining* turns locally
177
  (including that missing assistant reply and the subsequent user prompts).
178
 
@@ -248,8 +969,50 @@ async def _send_with_split(session: ChatSession, text: str, files: list[Path | s
248
  return await session.send_message(chunks[-1], files=files)
249
 
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  def _create_streaming_response(
252
  model_output: str,
 
253
  completion_id: str,
254
  created_time: int,
255
  model: str,
@@ -259,8 +1022,10 @@ def _create_streaming_response(
259
 
260
  # Calculate token usage
261
  prompt_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
262
- completion_tokens = estimate_tokens(model_output)
 
263
  total_tokens = prompt_tokens + completion_tokens
 
264
 
265
  async def generate_stream():
266
  # Send start event
@@ -274,9 +1039,7 @@ def _create_streaming_response(
274
  yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
275
 
276
  # Stream output text in chunks for efficiency
277
- chunk_size = 32
278
- for i in range(0, len(model_output), chunk_size):
279
- chunk = model_output[i : i + chunk_size]
280
  data = {
281
  "id": completion_id,
282
  "object": "chat.completion.chunk",
@@ -286,13 +1049,30 @@ def _create_streaming_response(
286
  }
287
  yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  # Send end event
290
  data = {
291
  "id": completion_id,
292
  "object": "chat.completion.chunk",
293
  "created": created_time,
294
  "model": model,
295
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
296
  "usage": {
297
  "prompt_tokens": prompt_tokens,
298
  "completion_tokens": completion_tokens,
@@ -305,8 +1085,89 @@ def _create_streaming_response(
305
  return StreamingResponse(generate_stream(), media_type="text/event-stream")
306
 
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  def _create_standard_response(
309
  model_output: str,
 
310
  completion_id: str,
311
  created_time: int,
312
  model: str,
@@ -315,8 +1176,14 @@ def _create_standard_response(
315
  """Create standard response"""
316
  # Calculate token usage
317
  prompt_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
318
- completion_tokens = estimate_tokens(model_output)
 
319
  total_tokens = prompt_tokens + completion_tokens
 
 
 
 
 
320
 
321
  result = {
322
  "id": completion_id,
@@ -326,8 +1193,8 @@ def _create_standard_response(
326
  "choices": [
327
  {
328
  "index": 0,
329
- "message": {"role": "assistant", "content": model_output},
330
- "finish_reason": "stop",
331
  }
332
  ],
333
  "usage": {
@@ -339,3 +1206,82 @@ def _create_standard_response(
339
 
340
  logger.debug(f"Response created with {total_tokens} total tokens")
341
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import re
4
+ import struct
5
  import uuid
6
+ from dataclasses import dataclass
7
  from datetime import datetime, timezone
8
  from pathlib import Path
9
+ from typing import Any, Iterator
10
 
11
  import orjson
12
  from fastapi import APIRouter, Depends, HTTPException, status
13
  from fastapi.responses import StreamingResponse
14
  from gemini_webapi.client import ChatSession
15
  from gemini_webapi.constants import Model
16
+ from gemini_webapi.types.image import GeneratedImage, Image
17
  from loguru import logger
18
 
19
  from ..models import (
20
  ChatCompletionRequest,
21
+ ContentItem,
22
  ConversationInStore,
23
+ FunctionCall,
24
  Message,
25
  ModelData,
26
  ModelListResponse,
27
+ ResponseCreateRequest,
28
+ ResponseCreateResponse,
29
+ ResponseImageGenerationCall,
30
+ ResponseInputContent,
31
+ ResponseInputItem,
32
+ ResponseOutputContent,
33
+ ResponseOutputMessage,
34
+ ResponseToolCall,
35
+ ResponseUsage,
36
+ Tool,
37
+ ToolCall,
38
+ ToolChoiceFunction,
39
  )
40
+ from ..services import GeminiClientPool, GeminiClientWrapper, LMDBConversationStore
41
+ from ..services.client import CODE_BLOCK_HINT, XML_WRAP_HINT
 
 
 
42
  from ..utils import g_config
43
  from ..utils.helper import estimate_tokens
44
  from .middleware import get_temp_dir, verify_api_key
 
48
 
49
  CONTINUATION_HINT = "\n(More messages to come, please reply with just 'ok.')"
50
 
51
+ TOOL_BLOCK_RE = re.compile(r"```xml\s*(.*?)```", re.DOTALL | re.IGNORECASE)
52
+ TOOL_CALL_RE = re.compile(
53
+ r"<tool_call\s+name=\"([^\"]+)\">(.*?)</tool_call>", re.DOTALL | re.IGNORECASE
54
+ )
55
+ JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*(.*?)\s*```$", re.DOTALL | re.IGNORECASE)
56
+ CONTROL_TOKEN_RE = re.compile(r"<\|im_(?:start|end)\|>")
57
+ XML_HINT_STRIPPED = XML_WRAP_HINT.strip()
58
+ CODE_HINT_STRIPPED = CODE_BLOCK_HINT.strip()
59
 
60
  router = APIRouter()
61
 
62
 
63
+ @dataclass
64
+ class StructuredOutputRequirement:
65
+ """Represents a structured response request from the client."""
66
+
67
+ schema_name: str
68
+ schema: dict[str, Any]
69
+ instruction: str
70
+ raw_format: dict[str, Any]
71
+
72
+
73
+ def _build_structured_requirement(
74
+ response_format: dict[str, Any] | None,
75
+ ) -> StructuredOutputRequirement | None:
76
+ """Translate OpenAI-style response_format into internal instructions."""
77
+ if not response_format or not isinstance(response_format, dict):
78
+ return None
79
+
80
+ if response_format.get("type") != "json_schema":
81
+ logger.warning(f"Unsupported response_format type requested: {response_format}")
82
+ return None
83
+
84
+ json_schema = response_format.get("json_schema")
85
+ if not isinstance(json_schema, dict):
86
+ logger.warning(f"Invalid json_schema payload in response_format: {response_format}")
87
+ return None
88
+
89
+ schema = json_schema.get("schema")
90
+ if not isinstance(schema, dict):
91
+ logger.warning(f"Missing `schema` object in response_format payload: {response_format}")
92
+ return None
93
+
94
+ schema_name = json_schema.get("name") or "response"
95
+ strict = json_schema.get("strict", True)
96
+
97
+ pretty_schema = json.dumps(schema, ensure_ascii=False, indent=2, sort_keys=True)
98
+ instruction_parts = [
99
+ "You must respond with a single valid JSON document that conforms to the schema shown below.",
100
+ "Do not include explanations, comments, or any text before or after the JSON.",
101
+ f'Schema name: "{schema_name}"',
102
+ "JSON Schema:",
103
+ pretty_schema,
104
+ ]
105
+ if not strict:
106
+ instruction_parts.insert(
107
+ 1,
108
+ "The schema allows unspecified fields, but include only what is necessary to satisfy the user's request.",
109
+ )
110
+
111
+ instruction = "\n\n".join(instruction_parts)
112
+ return StructuredOutputRequirement(
113
+ schema_name=schema_name,
114
+ schema=schema,
115
+ instruction=instruction,
116
+ raw_format=response_format,
117
+ )
118
+
119
+
120
+ def _strip_code_fence(text: str) -> str:
121
+ """Remove surrounding ```json fences if present."""
122
+ match = JSON_FENCE_RE.match(text.strip())
123
+ if match:
124
+ return match.group(1).strip()
125
+ return text.strip()
126
+
127
+
128
+ def _build_tool_prompt(
129
+ tools: list[Tool],
130
+ tool_choice: str | ToolChoiceFunction | None,
131
+ ) -> str:
132
+ """Generate a system prompt chunk describing available tools."""
133
+ if not tools:
134
+ return ""
135
+
136
+ lines: list[str] = [
137
+ "You can invoke the following developer tools. Call a tool only when it is required and follow the JSON schema exactly when providing arguments."
138
+ ]
139
+
140
+ for tool in tools:
141
+ function = tool.function
142
+ description = function.description or "No description provided."
143
+ lines.append(f"Tool `{function.name}`: {description}")
144
+ if function.parameters:
145
+ schema_text = json.dumps(function.parameters, ensure_ascii=False, indent=2)
146
+ lines.append("Arguments JSON schema:")
147
+ lines.append(schema_text)
148
+ else:
149
+ lines.append("Arguments JSON schema: {}")
150
+
151
+ if tool_choice == "none":
152
+ lines.append(
153
+ "For this request you must not call any tool. Provide the best possible natural language answer."
154
+ )
155
+ elif tool_choice == "required":
156
+ lines.append(
157
+ "You must call at least one tool before responding to the user. Do not provide a final user-facing answer until a tool call has been issued."
158
+ )
159
+ elif isinstance(tool_choice, ToolChoiceFunction):
160
+ target = tool_choice.function.name
161
+ lines.append(
162
+ f"You are required to call the tool named `{target}`. Do not call any other tool."
163
+ )
164
+ # `auto` or None fall back to default instructions.
165
+
166
+ lines.append(
167
+ "When you decide to call a tool you MUST respond with nothing except a single fenced block exactly like the template below."
168
+ )
169
+ lines.append(
170
+ "The fenced block MUST use ```xml as the opening fence and ``` as the closing fence. Do not add text before or after it."
171
+ )
172
+ lines.append("```xml")
173
+ lines.append('<tool_call name="tool_name">{"argument": "value"}</tool_call>')
174
+ lines.append("```")
175
+ lines.append(
176
+ "Use double quotes for JSON keys and values. If you omit the fenced block or include any extra text, the system will assume you are NOT calling a tool and your request will fail."
177
+ )
178
+ lines.append(
179
+ "If multiple tool calls are required, include multiple <tool_call> entries inside the same fenced block. Without a tool call, reply normally and do NOT emit any ```xml fence."
180
+ )
181
+
182
+ return "\n".join(lines)
183
+
184
+
185
+ def _append_xml_hint_to_last_user_message(messages: list[Message]) -> None:
186
+ """Ensure the last user message carries the XML wrap hint."""
187
+ for msg in reversed(messages):
188
+ if msg.role != "user" or msg.content is None:
189
+ continue
190
+
191
+ if isinstance(msg.content, str):
192
+ if XML_HINT_STRIPPED not in msg.content:
193
+ msg.content = f"{msg.content}{XML_WRAP_HINT}"
194
+ return
195
+
196
+ if isinstance(msg.content, list):
197
+ for part in reversed(msg.content):
198
+ if getattr(part, "type", None) != "text":
199
+ continue
200
+ text_value = part.text or ""
201
+ if XML_HINT_STRIPPED in text_value:
202
+ return
203
+ part.text = f"{text_value}{XML_WRAP_HINT}"
204
+ return
205
+
206
+ messages_text = XML_WRAP_HINT.strip()
207
+ msg.content.append(ContentItem(type="text", text=messages_text))
208
+ return
209
+
210
+ # No user message to annotate; nothing to do.
211
+
212
+
213
+ def _conversation_has_code_hint(messages: list[Message]) -> bool:
214
+ """Return True if any system message already includes the code block hint."""
215
+ for msg in messages:
216
+ if msg.role != "system" or msg.content is None:
217
+ continue
218
+
219
+ if isinstance(msg.content, str):
220
+ if CODE_HINT_STRIPPED in msg.content:
221
+ return True
222
+ continue
223
+
224
+ if isinstance(msg.content, list):
225
+ for part in msg.content:
226
+ if getattr(part, "type", None) != "text":
227
+ continue
228
+ if part.text and CODE_HINT_STRIPPED in part.text:
229
+ return True
230
+
231
+ return False
232
+
233
+
234
+ def _prepare_messages_for_model(
235
+ source_messages: list[Message],
236
+ tools: list[Tool] | None,
237
+ tool_choice: str | ToolChoiceFunction | None,
238
+ extra_instructions: list[str] | None = None,
239
+ ) -> list[Message]:
240
+ """Return a copy of messages enriched with tool instructions when needed."""
241
+ prepared = [msg.model_copy(deep=True) for msg in source_messages]
242
+
243
+ instructions: list[str] = []
244
+ if tools:
245
+ tool_prompt = _build_tool_prompt(tools, tool_choice)
246
+ if tool_prompt:
247
+ instructions.append(tool_prompt)
248
+
249
+ if extra_instructions:
250
+ instructions.extend(instr for instr in extra_instructions if instr)
251
+ logger.debug(
252
+ f"Applied {len(extra_instructions)} extra instructions for tool/structured output."
253
+ )
254
+
255
+ if not _conversation_has_code_hint(prepared):
256
+ instructions.append(CODE_BLOCK_HINT)
257
+ logger.debug("Injected default code block hint for Gemini conversation.")
258
+
259
+ if not instructions:
260
+ return prepared
261
+
262
+ combined_instructions = "\n\n".join(instructions)
263
+
264
+ if prepared and prepared[0].role == "system" and isinstance(prepared[0].content, str):
265
+ existing = prepared[0].content or ""
266
+ separator = "\n\n" if existing else ""
267
+ prepared[0].content = f"{existing}{separator}{combined_instructions}"
268
+ else:
269
+ prepared.insert(0, Message(role="system", content=combined_instructions))
270
+
271
+ if tools and tool_choice != "none":
272
+ _append_xml_hint_to_last_user_message(prepared)
273
+
274
+ return prepared
275
+
276
+
277
+ def _strip_system_hints(text: str) -> str:
278
+ """Remove system-level hint text from a given string."""
279
+ if not text:
280
+ return text
281
+ cleaned = text.replace(XML_WRAP_HINT, "").replace(XML_HINT_STRIPPED, "")
282
+ cleaned = cleaned.replace(CODE_BLOCK_HINT, "").replace(CODE_HINT_STRIPPED, "")
283
+ cleaned = CONTROL_TOKEN_RE.sub("", cleaned)
284
+ return cleaned.strip()
285
+
286
+
287
+ def _ensure_data_url(part: ResponseInputContent) -> str | None:
288
+ image_url = part.image_url
289
+ if not image_url and part.image_base64:
290
+ mime_type = part.mime_type or "image/png"
291
+ image_url = f"data:{mime_type};base64,{part.image_base64}"
292
+ return image_url
293
+
294
+
295
+ def _response_items_to_messages(
296
+ items: str | list[ResponseInputItem],
297
+ ) -> tuple[list[Message], str | list[ResponseInputItem]]:
298
+ """Convert Responses API input items into internal Message objects and normalized input."""
299
+ messages: list[Message] = []
300
+
301
+ if isinstance(items, str):
302
+ messages.append(Message(role="user", content=items))
303
+ logger.debug("Normalized Responses input: single string message.")
304
+ return messages, items
305
+
306
+ normalized_input: list[ResponseInputItem] = []
307
+ for item in items:
308
+ role = item.role
309
+ if role == "developer":
310
+ role = "system"
311
+
312
+ content = item.content
313
+ normalized_contents: list[ResponseInputContent] = []
314
+ if isinstance(content, str):
315
+ normalized_contents.append(ResponseInputContent(type="input_text", text=content))
316
+ messages.append(Message(role=role, content=content))
317
+ else:
318
+ converted: list[ContentItem] = []
319
+ for part in content:
320
+ if part.type == "input_text":
321
+ text_value = part.text or ""
322
+ normalized_contents.append(
323
+ ResponseInputContent(type="input_text", text=text_value)
324
+ )
325
+ if text_value:
326
+ converted.append(ContentItem(type="text", text=text_value))
327
+ elif part.type == "input_image":
328
+ image_url = _ensure_data_url(part)
329
+ if image_url:
330
+ normalized_contents.append(
331
+ ResponseInputContent(type="input_image", image_url=image_url)
332
+ )
333
+ converted.append(
334
+ ContentItem(type="image_url", image_url={"url": image_url})
335
+ )
336
+ messages.append(Message(role=role, content=converted or None))
337
+
338
+ normalized_input.append(
339
+ ResponseInputItem(type="message", role=item.role, content=normalized_contents or [])
340
+ )
341
+
342
+ logger.debug(
343
+ f"Normalized Responses input: {len(normalized_input)} message items (developer roles mapped to system)."
344
+ )
345
+ return messages, normalized_input
346
+
347
+
348
+ def _instructions_to_messages(
349
+ instructions: str | list[ResponseInputItem] | None,
350
+ ) -> list[Message]:
351
+ """Normalize instructions payload into Message objects."""
352
+ if not instructions:
353
+ return []
354
+
355
+ if isinstance(instructions, str):
356
+ return [Message(role="system", content=instructions)]
357
+
358
+ instruction_messages: list[Message] = []
359
+ for item in instructions:
360
+ if item.type and item.type != "message":
361
+ continue
362
+
363
+ role = item.role
364
+ if role == "developer":
365
+ role = "system"
366
+
367
+ content = item.content
368
+ if isinstance(content, str):
369
+ instruction_messages.append(Message(role=role, content=content))
370
+ else:
371
+ converted: list[ContentItem] = []
372
+ for part in content:
373
+ if part.type == "input_text":
374
+ text_value = part.text or ""
375
+ if text_value:
376
+ converted.append(ContentItem(type="text", text=text_value))
377
+ elif part.type == "input_image":
378
+ image_url = _ensure_data_url(part)
379
+ if image_url:
380
+ converted.append(
381
+ ContentItem(type="image_url", image_url={"url": image_url})
382
+ )
383
+ instruction_messages.append(Message(role=role, content=converted or None))
384
+
385
+ return instruction_messages
386
+
387
+
388
+ def _remove_tool_call_blocks(text: str) -> str:
389
+ """Strip tool call code blocks from text."""
390
+ if not text:
391
+ return text
392
+ cleaned = TOOL_BLOCK_RE.sub("", text)
393
+ return _strip_system_hints(cleaned)
394
+
395
+
396
+ def _extract_tool_calls(text: str) -> tuple[str, list[ToolCall]]:
397
+ """Extract tool call definitions and return cleaned text."""
398
+ if not text:
399
+ return text, []
400
+
401
+ tool_calls: list[ToolCall] = []
402
+
403
+ def _replace(match: re.Match[str]) -> str:
404
+ block_content = match.group(1)
405
+ if not block_content:
406
+ return ""
407
+
408
+ for call_match in TOOL_CALL_RE.finditer(block_content):
409
+ name = (call_match.group(1) or "").strip()
410
+ raw_args = (call_match.group(2) or "").strip()
411
+ if not name:
412
+ logger.warning(
413
+ f"Encountered tool_call block without a function name: {block_content}"
414
+ )
415
+ continue
416
+
417
+ arguments = raw_args
418
+ try:
419
+ parsed_args = json.loads(raw_args)
420
+ arguments = json.dumps(parsed_args, ensure_ascii=False)
421
+ except json.JSONDecodeError:
422
+ logger.warning(
423
+ f"Failed to parse tool call arguments for '{name}'. Passing raw string."
424
+ )
425
+
426
+ tool_calls.append(
427
+ ToolCall(
428
+ id=f"call_{uuid.uuid4().hex}",
429
+ type="function",
430
+ function=FunctionCall(name=name, arguments=arguments),
431
+ )
432
+ )
433
+
434
+ return ""
435
+
436
+ cleaned = TOOL_BLOCK_RE.sub(_replace, text)
437
+ cleaned = _strip_system_hints(cleaned)
438
+ return cleaned, tool_calls
439
+
440
+
441
  @router.get("/v1/models", response_model=ModelListResponse)
442
  async def list_models(api_key: str = Depends(verify_api_key)):
443
  now = int(datetime.now(tz=timezone.utc).timestamp())
 
475
  detail="At least one message is required in the conversation.",
476
  )
477
 
478
+ structured_requirement = _build_structured_requirement(request.response_format)
479
+ if structured_requirement and request.stream:
480
+ logger.debug(
481
+ "Structured response requested with streaming enabled; will stream canonical JSON once ready."
482
+ )
483
+ if structured_requirement:
484
+ logger.debug(
485
+ f"Structured response requested for /v1/chat/completions (schema={structured_requirement.schema_name})."
486
+ )
487
+
488
+ extra_instructions = [structured_requirement.instruction] if structured_requirement else None
489
+
490
  # Check if conversation is reusable
491
  session, client, remaining_messages = _find_reusable_session(db, pool, model, request.messages)
492
 
493
  if session:
494
+ messages_to_send = _prepare_messages_for_model(
495
+ remaining_messages, request.tools, request.tool_choice, extra_instructions
496
+ )
497
+ if not messages_to_send:
498
+ raise HTTPException(
499
+ status_code=status.HTTP_400_BAD_REQUEST,
500
+ detail="No new messages to send for the existing session.",
501
+ )
502
+ if len(messages_to_send) == 1:
503
  model_input, files = await GeminiClientWrapper.process_message(
504
+ messages_to_send[0], tmp_dir, tagged=False
505
  )
506
  else:
507
  model_input, files = await GeminiClientWrapper.process_conversation(
508
+ messages_to_send, tmp_dir
509
  )
510
  logger.debug(
511
+ f"Reused session {session.metadata} - sending {len(messages_to_send)} prepared messages."
512
  )
513
  else:
514
  # Start a new session and concat messages into a single string
515
  try:
516
  client = pool.acquire()
517
  session = client.start_chat(model=model)
518
+ messages_to_send = _prepare_messages_for_model(
519
+ request.messages, request.tools, request.tool_choice, extra_instructions
520
+ )
521
  model_input, files = await GeminiClientWrapper.process_conversation(
522
+ messages_to_send, tmp_dir
523
  )
524
  except ValueError as e:
525
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
 
540
  raise
541
 
542
  # Format the response from API
543
+ raw_output_with_think = GeminiClientWrapper.extract_output(response, include_thoughts=True)
544
+ raw_output_clean = GeminiClientWrapper.extract_output(response, include_thoughts=False)
545
+
546
+ visible_output, tool_calls = _extract_tool_calls(raw_output_with_think)
547
+ storage_output = _remove_tool_call_blocks(raw_output_clean).strip()
548
+ tool_calls_payload = [call.model_dump(mode="json") for call in tool_calls]
549
+
550
+ if structured_requirement:
551
+ cleaned_visible = _strip_code_fence(visible_output or "")
552
+ if not cleaned_visible:
553
+ raise HTTPException(
554
+ status_code=status.HTTP_502_BAD_GATEWAY,
555
+ detail="LLM returned an empty response while JSON schema output was requested.",
556
+ )
557
+ try:
558
+ structured_payload = json.loads(cleaned_visible)
559
+ except json.JSONDecodeError as exc:
560
+ logger.warning(
561
+ f"Failed to decode JSON for structured response (schema={structured_requirement.schema_name}): "
562
+ f"{cleaned_visible}"
563
+ )
564
+ raise HTTPException(
565
+ status_code=status.HTTP_502_BAD_GATEWAY,
566
+ detail="LLM returned invalid JSON for the requested response_format.",
567
+ ) from exc
568
+
569
+ canonical_output = json.dumps(structured_payload, ensure_ascii=False)
570
+ visible_output = canonical_output
571
+ storage_output = canonical_output
572
+
573
+ if tool_calls_payload:
574
+ logger.debug(f"Detected tool calls: {tool_calls_payload}")
575
 
576
  # After formatting, persist the conversation to LMDB
577
  try:
578
+ last_message = Message(
579
+ role="assistant",
580
+ content=storage_output or None,
581
+ tool_calls=tool_calls or None,
582
+ )
583
  cleaned_history = db.sanitize_assistant_messages(request.messages)
584
  conv = ConversationInStore(
585
  model=model.model_name,
 
598
  timestamp = int(datetime.now(tz=timezone.utc).timestamp())
599
  if request.stream:
600
  return _create_streaming_response(
601
+ visible_output,
602
+ tool_calls_payload,
603
  completion_id,
604
  timestamp,
605
  request.model,
 
607
  )
608
  else:
609
  return _create_standard_response(
610
+ visible_output,
611
+ tool_calls_payload,
612
+ completion_id,
613
+ timestamp,
614
+ request.model,
615
+ request.messages,
616
  )
617
 
618
 
619
+ @router.post("/v1/responses")
620
+ async def create_response(
621
+ request: ResponseCreateRequest,
622
+ api_key: str = Depends(verify_api_key),
623
+ tmp_dir: Path = Depends(get_temp_dir),
624
+ ):
625
+ messages, normalized_input = _response_items_to_messages(request.input)
626
+ if not messages:
627
+ raise HTTPException(
628
+ status_code=status.HTTP_400_BAD_REQUEST, detail="No message input provided."
629
+ )
630
+
631
+ structured_requirement = _build_structured_requirement(request.response_format)
632
+ if structured_requirement and request.stream:
633
+ logger.debug(
634
+ "Structured response requested with streaming enabled; streaming not supported for Responses."
635
+ )
636
+
637
+ preface_messages = _instructions_to_messages(request.instructions)
638
+ if structured_requirement:
639
+ preface_messages.insert(
640
+ 0, Message(role="system", content=structured_requirement.instruction)
641
+ )
642
+ logger.debug(
643
+ f"Structured response requested for /v1/responses (schema={structured_requirement.schema_name})."
644
+ )
645
+ if preface_messages:
646
+ messages = [*preface_messages, *messages]
647
+ logger.debug(
648
+ f"Injected {len(preface_messages)} instruction messages before sending to Gemini."
649
+ )
650
+
651
+ pool = GeminiClientPool()
652
+ db = LMDBConversationStore()
653
+
654
+ try:
655
+ model = Model.from_name(request.model)
656
+ except ValueError as exc:
657
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
658
+
659
+ session, client, remaining_messages = _find_reusable_session(db, pool, model, messages)
660
+
661
+ if session:
662
+ messages_to_send = remaining_messages
663
+ if not messages_to_send:
664
+ raise HTTPException(
665
+ status_code=status.HTTP_400_BAD_REQUEST,
666
+ detail="No new messages to send for the existing session.",
667
+ )
668
+ if len(messages_to_send) == 1:
669
+ model_input, files = await GeminiClientWrapper.process_message(
670
+ messages_to_send[0], tmp_dir, tagged=False
671
+ )
672
+ else:
673
+ model_input, files = await GeminiClientWrapper.process_conversation(
674
+ messages_to_send, tmp_dir
675
+ )
676
+ logger.debug(
677
+ f"Reused session {session.metadata} - sending {len(messages_to_send)} prepared messages."
678
+ )
679
+ else:
680
+ try:
681
+ client = pool.acquire()
682
+ session = client.start_chat(model=model)
683
+ model_input, files = await GeminiClientWrapper.process_conversation(messages, tmp_dir)
684
+ except ValueError as e:
685
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
686
+ except Exception as e:
687
+ logger.exception(f"Error in preparing conversation for responses API: {e}")
688
+ raise
689
+ logger.debug("New session started for /v1/responses request.")
690
+
691
+ try:
692
+ assert session and client, "Session and client not available"
693
+ logger.debug(
694
+ f"Client ID: {client.id}, Input length: {len(model_input)}, files count: {len(files)}"
695
+ )
696
+ model_output = await _send_with_split(session, model_input, files=files)
697
+ except Exception as e:
698
+ logger.exception(f"Error generating content from Gemini API for responses: {e}")
699
+ raise
700
+
701
+ text_with_think = GeminiClientWrapper.extract_output(model_output, include_thoughts=True)
702
+ text_without_think = GeminiClientWrapper.extract_output(model_output, include_thoughts=False)
703
+
704
+ visible_text, detected_tool_calls = _extract_tool_calls(text_with_think)
705
+ storage_output = _remove_tool_call_blocks(text_without_think).strip()
706
+ assistant_text = LMDBConversationStore.remove_think_tags(visible_text.strip())
707
+
708
+ if structured_requirement:
709
+ cleaned_visible = _strip_code_fence(assistant_text or "")
710
+ if not cleaned_visible:
711
+ raise HTTPException(
712
+ status_code=status.HTTP_502_BAD_GATEWAY,
713
+ detail="LLM returned an empty response while JSON schema output was requested.",
714
+ )
715
+ try:
716
+ structured_payload = json.loads(cleaned_visible)
717
+ except json.JSONDecodeError as exc:
718
+ logger.warning(
719
+ f"Failed to decode JSON for structured response (schema={structured_requirement.schema_name}): "
720
+ f"{cleaned_visible}"
721
+ )
722
+ raise HTTPException(
723
+ status_code=status.HTTP_502_BAD_GATEWAY,
724
+ detail="LLM returned invalid JSON for the requested response_format.",
725
+ ) from exc
726
+
727
+ canonical_output = json.dumps(structured_payload, ensure_ascii=False)
728
+ assistant_text = canonical_output
729
+ storage_output = canonical_output
730
+ logger.debug(
731
+ f"Structured response fulfilled for /v1/responses (schema={structured_requirement.schema_name})."
732
+ )
733
+
734
+ expects_image = (
735
+ request.tool_choice is not None and request.tool_choice.type == "image_generation"
736
+ )
737
+ if expects_image and not model_output.images:
738
+ summary = assistant_text.strip() if assistant_text else ""
739
+ if summary:
740
+ summary = re.sub(r"\s+", " ", summary)
741
+ if len(summary) > 200:
742
+ summary = f"{summary[:197]}..."
743
+ logger.warning(
744
+ "Image generation was requested via tool_choice but Gemini returned no images."
745
+ )
746
+ detail = "LLM returned no images for the requested image_generation tool."
747
+ if summary:
748
+ detail = f"{detail} Assistant response: {summary}"
749
+ raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=detail)
750
+
751
+ image_contents: list[ResponseOutputContent] = []
752
+ image_call_items: list[ResponseImageGenerationCall] = []
753
+ for image in model_output.images:
754
+ try:
755
+ image_base64, width, height = await _image_to_base64(image, tmp_dir)
756
+ except Exception as exc:
757
+ logger.warning(f"Failed to download generated image: {exc}")
758
+ continue
759
+ mime_type = "image/png" if isinstance(image, GeneratedImage) else "image/jpeg"
760
+ image_contents.append(
761
+ ResponseOutputContent(
762
+ type="output_image",
763
+ image_base64=image_base64,
764
+ mime_type=mime_type,
765
+ width=width,
766
+ height=height,
767
+ )
768
+ )
769
+ image_call_items.append(
770
+ ResponseImageGenerationCall(
771
+ id=f"img_{uuid.uuid4().hex}",
772
+ status="completed",
773
+ result=image_base64,
774
+ output_format="png" if isinstance(image, GeneratedImage) else "jpeg",
775
+ size=f"{width}x{height}" if width and height else None,
776
+ )
777
+ )
778
+
779
+ tool_call_items: list[ResponseToolCall] = []
780
+ if detected_tool_calls:
781
+ tool_call_items = [
782
+ ResponseToolCall(
783
+ id=call.id,
784
+ status="completed",
785
+ function=call.function,
786
+ )
787
+ for call in detected_tool_calls
788
+ ]
789
+
790
+ response_contents: list[ResponseOutputContent] = []
791
+ if assistant_text:
792
+ response_contents.append(ResponseOutputContent(type="output_text", text=assistant_text))
793
+ response_contents.extend(image_contents)
794
+
795
+ if not response_contents:
796
+ response_contents.append(ResponseOutputContent(type="output_text", text=""))
797
+
798
+ created_time = int(datetime.now(tz=timezone.utc).timestamp())
799
+ response_id = f"resp_{uuid.uuid4().hex}"
800
+ message_id = f"msg_{uuid.uuid4().hex}"
801
+
802
+ input_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
803
+ tool_arg_text = "".join(call.function.arguments or "" for call in detected_tool_calls)
804
+ completion_basis = assistant_text or ""
805
+ if tool_arg_text:
806
+ completion_basis = (
807
+ f"{completion_basis}\n{tool_arg_text}" if completion_basis else tool_arg_text
808
+ )
809
+ output_tokens = estimate_tokens(completion_basis)
810
+ usage = ResponseUsage(
811
+ input_tokens=input_tokens,
812
+ output_tokens=output_tokens,
813
+ total_tokens=input_tokens + output_tokens,
814
+ )
815
+
816
+ response_payload = ResponseCreateResponse(
817
+ id=response_id,
818
+ created=created_time,
819
+ model=request.model,
820
+ output=[
821
+ ResponseOutputMessage(
822
+ id=message_id,
823
+ type="message",
824
+ role="assistant",
825
+ content=response_contents,
826
+ ),
827
+ *tool_call_items,
828
+ *image_call_items,
829
+ ],
830
+ output_text=assistant_text or None,
831
+ status="completed",
832
+ usage=usage,
833
+ input=normalized_input or None,
834
+ metadata=request.metadata or None,
835
+ )
836
+
837
+ try:
838
+ last_message = Message(
839
+ role="assistant",
840
+ content=storage_output or None,
841
+ tool_calls=detected_tool_calls or None,
842
+ )
843
+ cleaned_history = db.sanitize_assistant_messages(messages)
844
+ conv = ConversationInStore(
845
+ model=model.model_name,
846
+ client_id=client.id,
847
+ metadata=session.metadata,
848
+ messages=[*cleaned_history, last_message],
849
+ )
850
+ key = db.store(conv)
851
+ logger.debug(f"Conversation saved to LMDB with key: {key}")
852
+ except Exception as exc:
853
+ logger.warning(f"Failed to save Responses conversation to LMDB: {exc}")
854
+
855
+ if request.stream:
856
+ logger.debug(
857
+ f"Streaming Responses API payload (response_id={response_payload.id}, text_chunks={bool(assistant_text)})."
858
+ )
859
+ return _create_responses_streaming_response(response_payload, assistant_text or "")
860
+
861
+ return response_payload
862
+
863
+
864
  def _text_from_message(message: Message) -> str:
865
  """Return text content from a message for token estimation."""
866
+ base_text = ""
867
  if isinstance(message.content, str):
868
+ base_text = message.content
869
+ elif isinstance(message.content, list):
870
+ base_text = "\n".join(
871
+ item.text or "" for item in message.content if getattr(item, "type", "") == "text"
872
+ )
873
+ elif message.content is None:
874
+ base_text = ""
875
+
876
+ if message.tool_calls:
877
+ tool_arg_text = "".join(call.function.arguments or "" for call in message.tool_calls)
878
+ base_text = f"{base_text}\n{tool_arg_text}" if base_text else tool_arg_text
879
+
880
+ return base_text
881
 
882
 
883
  def _find_reusable_session(
 
893
  ---------
894
  When a reply was generated by *another* server instance, the local LMDB may
895
  only contain an older part of the conversation. However, as long as we can
896
+ line up **any** earlier assistant/system response, we can restore the
897
  corresponding Gemini session and replay the *remaining* turns locally
898
  (including that missing assistant reply and the subsequent user prompts).
899
 
 
969
  return await session.send_message(chunks[-1], files=files)
970
 
971
 
972
+ def _iter_stream_segments(model_output: str, chunk_size: int = 64):
973
+ """Yield stream segments while keeping <think> markers and words intact."""
974
+ if not model_output:
975
+ return
976
+
977
+ token_pattern = re.compile(r"\s+|\S+\s*")
978
+ pending = ""
979
+
980
+ def _flush_pending() -> Iterator[str]:
981
+ nonlocal pending
982
+ if pending:
983
+ yield pending
984
+ pending = ""
985
+
986
+ # Split on <think> boundaries so the markers are never fragmented.
987
+ parts = re.split(r"(</?think>)", model_output)
988
+ for part in parts:
989
+ if not part:
990
+ continue
991
+ if part in {"<think>", "</think>"}:
992
+ yield from _flush_pending()
993
+ yield part
994
+ continue
995
+
996
+ for match in token_pattern.finditer(part):
997
+ token = match.group(0)
998
+
999
+ if len(token) > chunk_size:
1000
+ yield from _flush_pending()
1001
+ for idx in range(0, len(token), chunk_size):
1002
+ yield token[idx : idx + chunk_size]
1003
+ continue
1004
+
1005
+ if pending and len(pending) + len(token) > chunk_size:
1006
+ yield from _flush_pending()
1007
+
1008
+ pending += token
1009
+
1010
+ yield from _flush_pending()
1011
+
1012
+
1013
  def _create_streaming_response(
1014
  model_output: str,
1015
+ tool_calls: list[dict],
1016
  completion_id: str,
1017
  created_time: int,
1018
  model: str,
 
1022
 
1023
  # Calculate token usage
1024
  prompt_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
1025
+ tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
1026
+ completion_tokens = estimate_tokens(model_output + tool_args)
1027
  total_tokens = prompt_tokens + completion_tokens
1028
+ finish_reason = "tool_calls" if tool_calls else "stop"
1029
 
1030
  async def generate_stream():
1031
  # Send start event
 
1039
  yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
1040
 
1041
  # Stream output text in chunks for efficiency
1042
+ for chunk in _iter_stream_segments(model_output):
 
 
1043
  data = {
1044
  "id": completion_id,
1045
  "object": "chat.completion.chunk",
 
1049
  }
1050
  yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
1051
 
1052
+ if tool_calls:
1053
+ tool_calls_delta = [{**call, "index": idx} for idx, call in enumerate(tool_calls)]
1054
+ data = {
1055
+ "id": completion_id,
1056
+ "object": "chat.completion.chunk",
1057
+ "created": created_time,
1058
+ "model": model,
1059
+ "choices": [
1060
+ {
1061
+ "index": 0,
1062
+ "delta": {"tool_calls": tool_calls_delta},
1063
+ "finish_reason": None,
1064
+ }
1065
+ ],
1066
+ }
1067
+ yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
1068
+
1069
  # Send end event
1070
  data = {
1071
  "id": completion_id,
1072
  "object": "chat.completion.chunk",
1073
  "created": created_time,
1074
  "model": model,
1075
+ "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}],
1076
  "usage": {
1077
  "prompt_tokens": prompt_tokens,
1078
  "completion_tokens": completion_tokens,
 
1085
  return StreamingResponse(generate_stream(), media_type="text/event-stream")
1086
 
1087
 
1088
+ def _create_responses_streaming_response(
1089
+ response_payload: ResponseCreateResponse,
1090
+ assistant_text: str | None,
1091
+ ) -> StreamingResponse:
1092
+ """Create streaming response for Responses API using event types defined by OpenAI."""
1093
+
1094
+ response_dict = response_payload.model_dump(mode="json")
1095
+ response_id = response_payload.id
1096
+ created_time = response_payload.created
1097
+ model = response_payload.model
1098
+
1099
+ logger.debug(
1100
+ f"Preparing streaming envelope for /v1/responses (response_id={response_id}, model={model})."
1101
+ )
1102
+
1103
+ base_event = {
1104
+ "id": response_id,
1105
+ "object": "response",
1106
+ "created": created_time,
1107
+ "model": model,
1108
+ }
1109
+
1110
+ created_snapshot: dict[str, Any] = {
1111
+ "id": response_id,
1112
+ "object": "response",
1113
+ "created": created_time,
1114
+ "model": model,
1115
+ "status": "in_progress",
1116
+ }
1117
+ if response_dict.get("metadata") is not None:
1118
+ created_snapshot["metadata"] = response_dict["metadata"]
1119
+ if response_dict.get("input") is not None:
1120
+ created_snapshot["input"] = response_dict["input"]
1121
+
1122
+ async def generate_stream():
1123
+ # Emit creation event
1124
+ data = {
1125
+ **base_event,
1126
+ "type": "response.created",
1127
+ "response": created_snapshot,
1128
+ }
1129
+ yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
1130
+
1131
+ # Stream textual content, if any
1132
+ if assistant_text:
1133
+ for chunk in _iter_stream_segments(assistant_text):
1134
+ delta_event = {
1135
+ **base_event,
1136
+ "type": "response.output_text.delta",
1137
+ "output_index": 0,
1138
+ "delta": chunk,
1139
+ }
1140
+ yield f"data: {orjson.dumps(delta_event).decode('utf-8')}\n\n"
1141
+
1142
+ done_event = {
1143
+ **base_event,
1144
+ "type": "response.output_text.done",
1145
+ "output_index": 0,
1146
+ }
1147
+ yield f"data: {orjson.dumps(done_event).decode('utf-8')}\n\n"
1148
+ else:
1149
+ done_event = {
1150
+ **base_event,
1151
+ "type": "response.output_text.done",
1152
+ "output_index": 0,
1153
+ }
1154
+ yield f"data: {orjson.dumps(done_event).decode('utf-8')}\n\n"
1155
+
1156
+ # Emit completed event with full payload
1157
+ completed_event = {
1158
+ **base_event,
1159
+ "type": "response.completed",
1160
+ "response": response_dict,
1161
+ }
1162
+ yield f"data: {orjson.dumps(completed_event).decode('utf-8')}\n\n"
1163
+ yield "data: [DONE]\n\n"
1164
+
1165
+ return StreamingResponse(generate_stream(), media_type="text/event-stream")
1166
+
1167
+
1168
  def _create_standard_response(
1169
  model_output: str,
1170
+ tool_calls: list[dict],
1171
  completion_id: str,
1172
  created_time: int,
1173
  model: str,
 
1176
  """Create standard response"""
1177
  # Calculate token usage
1178
  prompt_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
1179
+ tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
1180
+ completion_tokens = estimate_tokens(model_output + tool_args)
1181
  total_tokens = prompt_tokens + completion_tokens
1182
+ finish_reason = "tool_calls" if tool_calls else "stop"
1183
+
1184
+ message_payload: dict = {"role": "assistant", "content": model_output or None}
1185
+ if tool_calls:
1186
+ message_payload["tool_calls"] = tool_calls
1187
 
1188
  result = {
1189
  "id": completion_id,
 
1193
  "choices": [
1194
  {
1195
  "index": 0,
1196
+ "message": message_payload,
1197
+ "finish_reason": finish_reason,
1198
  }
1199
  ],
1200
  "usage": {
 
1206
 
1207
  logger.debug(f"Response created with {total_tokens} total tokens")
1208
  return result
1209
+
1210
+
1211
+ def _extract_image_dimensions(data: bytes) -> tuple[int | None, int | None]:
1212
+ """Return image dimensions (width, height) if PNG or JPEG headers are present."""
1213
+ # PNG: dimensions stored in bytes 16..24 of the IHDR chunk
1214
+ if len(data) >= 24 and data.startswith(b"\x89PNG\r\n\x1a\n"):
1215
+ try:
1216
+ width, height = struct.unpack(">II", data[16:24])
1217
+ return int(width), int(height)
1218
+ except struct.error:
1219
+ return None, None
1220
+
1221
+ # JPEG: dimensions stored in SOF segment; iterate through markers to locate it
1222
+ if len(data) >= 4 and data[0:2] == b"\xff\xd8":
1223
+ idx = 2
1224
+ length = len(data)
1225
+ sof_markers = {
1226
+ 0xC0,
1227
+ 0xC1,
1228
+ 0xC2,
1229
+ 0xC3,
1230
+ 0xC5,
1231
+ 0xC6,
1232
+ 0xC7,
1233
+ 0xC9,
1234
+ 0xCA,
1235
+ 0xCB,
1236
+ 0xCD,
1237
+ 0xCE,
1238
+ 0xCF,
1239
+ }
1240
+ while idx < length:
1241
+ # Find marker alignment (markers are prefixed with 0xFF bytes)
1242
+ if data[idx] != 0xFF:
1243
+ idx += 1
1244
+ continue
1245
+ while idx < length and data[idx] == 0xFF:
1246
+ idx += 1
1247
+ if idx >= length:
1248
+ break
1249
+ marker = data[idx]
1250
+ idx += 1
1251
+
1252
+ if marker in (0xD8, 0xD9, 0x01) or 0xD0 <= marker <= 0xD7:
1253
+ continue
1254
+
1255
+ if idx + 1 >= length:
1256
+ break
1257
+ segment_length = (data[idx] << 8) + data[idx + 1]
1258
+ idx += 2
1259
+ if segment_length < 2:
1260
+ break
1261
+
1262
+ if marker in sof_markers:
1263
+ if idx + 4 < length:
1264
+ # Skip precision byte at idx, then read height/width (big-endian)
1265
+ height = (data[idx + 1] << 8) + data[idx + 2]
1266
+ width = (data[idx + 3] << 8) + data[idx + 4]
1267
+ return int(width), int(height)
1268
+ break
1269
+
1270
+ idx += segment_length - 2
1271
+
1272
+ return None, None
1273
+
1274
+
1275
+ async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | None, int | None]:
1276
+ """Persist an image provided by gemini_webapi and return base64 plus dimensions."""
1277
+ if isinstance(image, GeneratedImage):
1278
+ saved_path = await image.save(path=str(temp_dir), full_size=True)
1279
+ else:
1280
+ saved_path = await image.save(path=str(temp_dir))
1281
+
1282
+ if not saved_path:
1283
+ raise ValueError("Failed to save generated image")
1284
+
1285
+ data = Path(saved_path).read_bytes()
1286
+ width, height = _extract_image_dimensions(data)
1287
+ return base64.b64encode(data).decode("utf-8"), width, height
app/server/middleware.py CHANGED
@@ -17,7 +17,8 @@ def global_exception_handler(request: Request, exc: Exception):
17
  )
18
 
19
  return ORJSONResponse(
20
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"error": {"message": str(exc)}}
 
21
  )
22
 
23
 
 
17
  )
18
 
19
  return ORJSONResponse(
20
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
21
+ content={"error": {"message": str(exc)}},
22
  )
23
 
24
 
app/services/client.py CHANGED
@@ -1,6 +1,9 @@
1
  import asyncio
 
 
2
  import re
3
  from pathlib import Path
 
4
 
5
  from gemini_webapi import GeminiClient, ModelOutput
6
  from gemini_webapi.client import ChatSession
@@ -12,7 +15,28 @@ from ..models import Message
12
  from ..utils import g_config
13
  from ..utils.helper import add_tag, save_file_to_tempfile, save_url_to_tempfile
14
 
15
- XML_WRAP_HINT = "\nFor any xml block, e.g. tool call, always wrap it with: \n`````xml\n...\n`````\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
  class GeminiClientWrapper(GeminiClient):
@@ -22,16 +46,32 @@ class GeminiClientWrapper(GeminiClient):
22
  super().__init__(**kwargs)
23
  self.id = client_id
24
 
25
- async def init(self, **kwargs):
 
 
 
 
 
 
 
 
26
  """
27
  Inject default configuration values.
28
  """
29
- kwargs.setdefault("timeout", g_config.gemini.timeout)
30
- kwargs.setdefault("auto_refresh", g_config.gemini.auto_refresh)
31
- kwargs.setdefault("verbose", g_config.gemini.verbose)
32
- kwargs.setdefault("refresh_interval", g_config.gemini.refresh_interval)
 
33
 
34
- await super().init(**kwargs)
 
 
 
 
 
 
 
35
 
36
  async def generate_content(
37
  self,
@@ -41,22 +81,23 @@ class GeminiClientWrapper(GeminiClient):
41
  gem: Gem | str | None = None,
42
  chat: ChatSession | None = None,
43
  **kwargs,
44
- ):
45
  cnt = 2 # Try 2 times before giving up
46
- last_exception = None
47
  while cnt:
48
  cnt -= 1
49
  try:
50
  return await super().generate_content(prompt, files, model, gem, chat, **kwargs)
51
  except ModelInvalid as e:
52
- # This is not always caused by model selection. Instead it can be solved by retrying.
53
  # So we catch it and retry as a workaround.
54
  await asyncio.sleep(1)
55
  last_exception = e
56
 
57
  # If retrying failed, re-raise ModelInvalid
58
- if last_exception:
59
  raise last_exception
 
60
 
61
  @staticmethod
62
  async def process_message(
@@ -65,22 +106,21 @@ class GeminiClientWrapper(GeminiClient):
65
  """
66
  Process a single message and return model input.
67
  """
68
- model_input = ""
69
  files: list[Path | str] = []
 
 
70
  if isinstance(message.content, str):
71
  # Pure text content
72
- model_input = message.content
73
- else:
 
74
  # Mixed content
75
  # TODO: Use Pydantic to enforce the value checking
76
  for item in message.content:
77
  if item.type == "text":
78
  # Append multiple text fragments
79
  if item.text:
80
- if model_input:
81
- model_input += "\n" + item.text
82
- else:
83
- model_input = item.text
84
 
85
  elif item.type == "image_url":
86
  if not item.image_url:
@@ -98,20 +138,33 @@ class GeminiClientWrapper(GeminiClient):
98
  files.append(await save_file_to_tempfile(file_data, filename, tempdir))
99
  else:
100
  raise ValueError("File must contain 'file_data' key")
 
 
101
 
102
- # This is a workaround for Gemini Web's displaying issues with XML blocks.
103
- # Add this for tool calling
104
- if re.search(r"<\s*[^>]+>", model_input):
105
- hint = XML_WRAP_HINT
106
- else:
107
- hint = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Add role tag if needed
110
  if model_input:
111
  if tagged:
112
- model_input = add_tag(message.role, model_input + hint)
113
- else:
114
- model_input += hint
115
 
116
  return model_input, files
117
 
@@ -161,7 +214,36 @@ class GeminiClientWrapper(GeminiClient):
161
  text += str(response)
162
 
163
  # Fix some escaped characters
164
- text = text.replace("&lt;", "<").replace("\\<", "<").replace("\\_", "_").replace("\\>", ">")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  def simplify_link_target(text_content: str) -> str:
167
  match_colon_num = re.match(r"([^:]+:\d+)", text_content)
@@ -181,7 +263,7 @@ class GeminiClientWrapper(GeminiClient):
181
  else:
182
  return new_link_segment
183
 
184
- # Replace Google search links with simplified markdown links
185
  pattern = r"(\()?\[`([^`]+?)`\]\((https://www.google.com/search\?q=)(.*?)(?<!\\)\)\)*(\))?"
186
  text = re.sub(pattern, replacer, text)
187
 
 
1
  import asyncio
2
+ import html
3
+ import json
4
  import re
5
  from pathlib import Path
6
+ from typing import Any, cast
7
 
8
  from gemini_webapi import GeminiClient, ModelOutput
9
  from gemini_webapi.client import ChatSession
 
15
  from ..utils import g_config
16
  from ..utils.helper import add_tag, save_file_to_tempfile, save_url_to_tempfile
17
 
18
+ XML_WRAP_HINT = (
19
+ "\nYou MUST wrap every tool call response inside a single fenced block exactly like:\n"
20
+ '```xml\n<tool_call name="tool_name">{"arg": "value"}</tool_call>\n```\n'
21
+ "Do not surround the fence with any other text or whitespace; otherwise the call will be ignored.\n"
22
+ )
23
+
24
+ CODE_BLOCK_HINT = (
25
+ "\nWhenever you include code, markup, or shell snippets, wrap each snippet in a Markdown fenced "
26
+ "block and supply the correct language label (for example, ```python ... ``` or ```html ... ```).\n"
27
+ "Fence ONLY the actual code/markup; keep all narrative or explanatory text outside the fences.\n"
28
+ )
29
+
30
+ HTML_ESCAPE_RE = re.compile(r"&(?:lt|gt|amp|quot|apos|#[0-9]+|#x[0-9a-fA-F]+);")
31
+ MARKDOWN_ESCAPE_RE = re.compile(r"\\(?=\s*[-\\`*_{}\[\]()#+.!<>])")
32
+ CODE_FENCE_RE = re.compile(r"(```.*?```|`[^`]*`)", re.DOTALL)
33
+
34
+
35
+ _UNSET = object()
36
+
37
+
38
+ def _resolve(value: Any, fallback: Any):
39
+ return fallback if value is _UNSET else value
40
 
41
 
42
  class GeminiClientWrapper(GeminiClient):
 
46
  super().__init__(**kwargs)
47
  self.id = client_id
48
 
49
+ async def init(
50
+ self,
51
+ timeout: float = cast(float, _UNSET),
52
+ auto_close: bool = False,
53
+ close_delay: float = 300,
54
+ auto_refresh: bool = cast(bool, _UNSET),
55
+ refresh_interval: float = cast(float, _UNSET),
56
+ verbose: bool = cast(bool, _UNSET),
57
+ ) -> None:
58
  """
59
  Inject default configuration values.
60
  """
61
+ config = g_config.gemini
62
+ timeout = cast(float, _resolve(timeout, config.timeout))
63
+ auto_refresh = cast(bool, _resolve(auto_refresh, config.auto_refresh))
64
+ refresh_interval = cast(float, _resolve(refresh_interval, config.refresh_interval))
65
+ verbose = cast(bool, _resolve(verbose, config.verbose))
66
 
67
+ await super().init(
68
+ timeout=timeout,
69
+ auto_close=auto_close,
70
+ close_delay=close_delay,
71
+ auto_refresh=auto_refresh,
72
+ refresh_interval=refresh_interval,
73
+ verbose=verbose,
74
+ )
75
 
76
  async def generate_content(
77
  self,
 
81
  gem: Gem | str | None = None,
82
  chat: ChatSession | None = None,
83
  **kwargs,
84
+ ) -> ModelOutput:
85
  cnt = 2 # Try 2 times before giving up
86
+ last_exception: ModelInvalid | None = None
87
  while cnt:
88
  cnt -= 1
89
  try:
90
  return await super().generate_content(prompt, files, model, gem, chat, **kwargs)
91
  except ModelInvalid as e:
92
+ # This is not always caused by model selection. Instead, it can be solved by retrying.
93
  # So we catch it and retry as a workaround.
94
  await asyncio.sleep(1)
95
  last_exception = e
96
 
97
  # If retrying failed, re-raise ModelInvalid
98
+ if last_exception is not None:
99
  raise last_exception
100
+ raise RuntimeError("generate_content failed without receiving a ModelInvalid error.")
101
 
102
  @staticmethod
103
  async def process_message(
 
106
  """
107
  Process a single message and return model input.
108
  """
 
109
  files: list[Path | str] = []
110
+ text_fragments: list[str] = []
111
+
112
  if isinstance(message.content, str):
113
  # Pure text content
114
+ if message.content:
115
+ text_fragments.append(message.content)
116
+ elif isinstance(message.content, list):
117
  # Mixed content
118
  # TODO: Use Pydantic to enforce the value checking
119
  for item in message.content:
120
  if item.type == "text":
121
  # Append multiple text fragments
122
  if item.text:
123
+ text_fragments.append(item.text)
 
 
 
124
 
125
  elif item.type == "image_url":
126
  if not item.image_url:
 
138
  files.append(await save_file_to_tempfile(file_data, filename, tempdir))
139
  else:
140
  raise ValueError("File must contain 'file_data' key")
141
+ elif message.content is not None:
142
+ raise ValueError("Unsupported message content type.")
143
 
144
+ if message.tool_calls:
145
+ tool_blocks: list[str] = []
146
+ for call in message.tool_calls:
147
+ args_text = call.function.arguments.strip()
148
+ try:
149
+ parsed_args = json.loads(args_text)
150
+ args_text = json.dumps(parsed_args, ensure_ascii=False)
151
+ except (json.JSONDecodeError, TypeError):
152
+ # Leave args_text as is if it is not valid JSON
153
+ pass
154
+ tool_blocks.append(
155
+ f'<tool_call name="{call.function.name}">{args_text}</tool_call>'
156
+ )
157
+
158
+ if tool_blocks:
159
+ tool_section = "```xml\n" + "\n".join(tool_blocks) + "\n```"
160
+ text_fragments.append(tool_section)
161
+
162
+ model_input = "\n".join(fragment for fragment in text_fragments if fragment)
163
 
164
  # Add role tag if needed
165
  if model_input:
166
  if tagged:
167
+ model_input = add_tag(message.role, model_input)
 
 
168
 
169
  return model_input, files
170
 
 
214
  text += str(response)
215
 
216
  # Fix some escaped characters
217
+ def _unescape_html(text_content: str) -> str:
218
+ parts: list[str] = []
219
+ last_index = 0
220
+ for match in CODE_FENCE_RE.finditer(text_content):
221
+ non_code = text_content[last_index : match.start()]
222
+ if non_code:
223
+ parts.append(HTML_ESCAPE_RE.sub(lambda m: html.unescape(m.group(0)), non_code))
224
+ parts.append(match.group(0))
225
+ last_index = match.end()
226
+ tail = text_content[last_index:]
227
+ if tail:
228
+ parts.append(HTML_ESCAPE_RE.sub(lambda m: html.unescape(m.group(0)), tail))
229
+ return "".join(parts)
230
+
231
+ def _unescape_markdown(text_content: str) -> str:
232
+ parts: list[str] = []
233
+ last_index = 0
234
+ for match in CODE_FENCE_RE.finditer(text_content):
235
+ non_code = text_content[last_index : match.start()]
236
+ if non_code:
237
+ parts.append(MARKDOWN_ESCAPE_RE.sub("", non_code))
238
+ parts.append(match.group(0))
239
+ last_index = match.end()
240
+ tail = text_content[last_index:]
241
+ if tail:
242
+ parts.append(MARKDOWN_ESCAPE_RE.sub("", tail))
243
+ return "".join(parts)
244
+
245
+ text = _unescape_html(text)
246
+ text = _unescape_markdown(text)
247
 
248
  def simplify_link_target(text_content: str) -> str:
249
  match_colon_num = re.match(r"([^:]+:\d+)", text_content)
 
263
  else:
264
  return new_link_segment
265
 
266
+ # Replace Google search links with simplified Markdown links
267
  pattern = r"(\()?\[`([^`]+?)`\]\((https://www.google.com/search\?q=)(.*?)(?<!\\)\)\)*(\))?"
268
  text = re.sub(pattern, replacer, text)
269
 
app/utils/config.py CHANGED
@@ -174,7 +174,8 @@ def extract_gemini_clients_env() -> dict[int, dict[str, str]]:
174
 
175
 
176
  def _merge_clients_with_env(
177
- base_clients: list[GeminiClientSettings] | None, env_overrides: dict[int, dict[str, str]]
 
178
  ):
179
  """Override base_clients with env_overrides, return the new clients list."""
180
  if not env_overrides:
 
174
 
175
 
176
  def _merge_clients_with_env(
177
+ base_clients: list[GeminiClientSettings] | None,
178
+ env_overrides: dict[int, dict[str, str]],
179
  ):
180
  """Override base_clients with env_overrides, return the new clients list."""
181
  if not env_overrides:
app/utils/helper.py CHANGED
@@ -5,10 +5,12 @@ from pathlib import Path
5
  import httpx
6
  from loguru import logger
7
 
 
 
8
 
9
  def add_tag(role: str, content: str, unclose: bool = False) -> str:
10
  """Surround content with role tags"""
11
- if role not in ["user", "assistant", "system"]:
12
  logger.warning(f"Unknown role: {role}, returning content without tags")
13
  return content
14
 
@@ -34,6 +36,8 @@ async def save_file_to_tempfile(
34
 
35
 
36
  async def save_url_to_tempfile(url: str, tempdir: Path | None = None):
 
 
37
  if url.startswith("data:image/"):
38
  # Base64 encoded image
39
  base64_data = url.split(",")[1]
 
5
  import httpx
6
  from loguru import logger
7
 
8
+ VALID_TAG_ROLES = {"user", "assistant", "system", "tool"}
9
+
10
 
11
  def add_tag(role: str, content: str, unclose: bool = False) -> str:
12
  """Surround content with role tags"""
13
+ if role not in VALID_TAG_ROLES:
14
  logger.warning(f"Unknown role: {role}, returning content without tags")
15
  return content
16
 
 
36
 
37
 
38
  async def save_url_to_tempfile(url: str, tempdir: Path | None = None):
39
+ data: bytes | None = None
40
+ suffix: str | None = None
41
  if url.startswith("data:image/"):
42
  # Base64 encoded image
43
  base64_data = url.split(",")[1]
run.py CHANGED
@@ -20,7 +20,9 @@ if __name__ == "__main__":
20
 
21
  # Check if the certificate files exist
22
  if not os.path.exists(key_path) or not os.path.exists(cert_path):
23
- logger.critical(f"HTTPS enabled but SSL certificate files not found: {key_path}, {cert_path}")
 
 
24
  sys.exit(1)
25
 
26
  logger.info(f"Starting server at https://{g_config.server.host}:{g_config.server.port} ...")
 
20
 
21
  # Check if the certificate files exist
22
  if not os.path.exists(key_path) or not os.path.exists(cert_path):
23
+ logger.critical(
24
+ f"HTTPS enabled but SSL certificate files not found: {key_path}, {cert_path}"
25
+ )
26
  sys.exit(1)
27
 
28
  logger.info(f"Starting server at https://{g_config.server.host}:{g_config.server.port} ...")