Lưu Quang Vũ commited on
Commit
a583ded
·
unverified ·
1 Parent(s): 79b9136

feat: Add support for custom Gemini models and model loading strategies (#86)

Browse files

* feat: Add support for custom Gemini models and model loading strategies

- Introduced `model_strategy` configuration for "append" (default + custom models) or "overwrite" (custom models only).
- Enhanced `/v1/models` endpoint to return models based on the configured strategy.
- Improved model loading with environment variable overrides and validation.
- Refactored model handling logic for improved modularity and error handling.

* feat: Improve Gemini model environment variable parsing and nested field support

- Enhanced `extract_gemini_models_env` to handle nested fields within environment variables.
- Updated type hints for more flexibility in model overrides.
- Improved `_merge_models_with_env` to better support field-level updates and appending new models.

* refactor: Consolidate utility functions and clean up unused code

- Moved utility functions like `strip_code_fence`, `extract_tool_calls`, and `iter_stream_segments` to a centralized helper module.
- Removed unused and redundant private methods from `chat.py`, including `_strip_code_fence`, `_strip_tagged_blocks`, and `_strip_system_hints`.
- Updated imports and references across modules for consistency.
- Simplified tool call and streaming logic by replacing inline implementations with shared helper functions.

* fix: Handle None input in `estimate_tokens` and return 0 for empty text

* refactor: Simplify model configuration and add JSON parsing validators

- Replaced unused model placeholder in `config.yaml` with an empty list.
- Added JSON parsing validators for `model_header` and `models` to enhance flexibility and error handling.
- Improved validation to filter out incomplete model configurations.

* refactor: Simplify Gemini model environment variable parsing with JSON support

- Replaced prefix-based parsing with a root key approach.
- Added JSON parsing to handle list-based model configurations.
- Improved handling of errors and cleanup of environment variables.

* fix: Enhance Gemini model environment variable parsing with fallback to Python literals

- Added `ast.literal_eval` as a fallback for parsing environment variables when JSON decoding fails.
- Improved error handling and logging for invalid configurations.
- Ensured proper cleanup of environment variables post-parsing.

* fix: Improve regex patterns in helper module

- Adjusted `TOOL_CALL_RE` regex pattern for better accuracy.

* docs: Update README files to include custom model configuration and environment variable setup

* fix: Remove unused headers from HTTP client in helper module

* fix: Update README and README.zh to clarify model configuration via environment variables; enhance error logging in config validation

* Update README and README.zh to clarify model configuration via JSON string or list structure for enhanced flexibility in automated environments

README.md CHANGED
@@ -118,7 +118,7 @@ services:
118
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
119
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
120
  - GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
121
- restart: on-failure:3 # Avoid retrying too many times
122
  ```
123
 
124
  Then run:
@@ -187,6 +187,30 @@ To use Gemini-FastAPI, you need to extract your Gemini session cookies:
187
 
188
  Each client entry can be configured with a different proxy to work around rate limits. Omit the `proxy` field or set it to `null` or an empty string to keep a direct connection.
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  ## Acknowledgments
191
 
192
  - [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - The underlying Gemini web API client
 
118
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
119
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
120
  - GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
121
+ restart: on-failure:3 # Avoid retrying too many times
122
  ```
123
 
124
  Then run:
 
187
 
188
  Each client entry can be configured with a different proxy to work around rate limits. Omit the `proxy` field or set it to `null` or an empty string to keep a direct connection.
189
 
190
+ ### Custom Models
191
+
192
+ You can define custom models in `config/config.yaml` or via environment variables.
193
+
194
+ #### YAML Configuration
195
+
196
+ ```yaml
197
+ gemini:
198
+ model_strategy: "append" # "append" (default + custom) or "overwrite" (custom only)
199
+ models:
200
+ - model_name: "gemini-3.0-pro"
201
+ model_header:
202
+ x-goog-ext-525001261-jspb: '[1,null,null,null,"9d8ca3786ebdfbea",null,null,0,[4],null,null,1]'
203
+ ```
204
+
205
+ #### Environment Variables
206
+
207
+ You can supply models as a JSON string or list structure via `CONFIG_GEMINI__MODELS`. This provides a flexible way to override settings via the shell or in automated environments (e.g. Docker) without modifying the configuration file.
208
+
209
+ ```bash
210
+ export CONFIG_GEMINI__MODEL_STRATEGY="overwrite"
211
+ export CONFIG_GEMINI__MODELS='[{"model_name": "gemini-3.0-pro", "model_header": {"x-goog-ext-525001261-jspb": "[1,null,null,null,\"9d8ca3786ebdfbea\",null,null,0,[4],null,null,1]"}}]'
212
+ ```
213
+
214
  ## Acknowledgments
215
 
216
  - [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - The underlying Gemini web API client
README.zh.md CHANGED
@@ -4,7 +4,6 @@
4
  [![FastAPI](https://img.shields.io/badge/FastAPI-0.115+-green.svg)](https://fastapi.tiangolo.com/)
5
  [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
6
 
7
-
8
  [ [English](README.md) | 中文 ]
9
 
10
  将 Gemini 网页端模型封装为兼容 OpenAI API 的 API Server。基于 [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) 实现。
@@ -50,6 +49,7 @@ pip install -e .
50
  ### 配置
51
 
52
  编辑 `config/config.yaml` 并提供至少一组凭证:
 
53
  ```yaml
54
  gemini:
55
  clients:
@@ -118,7 +118,7 @@ services:
118
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
119
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
120
  - GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
121
- restart: on-failure:3 # Avoid retrying too many times
122
  ```
123
 
124
  然后运行:
@@ -186,6 +186,30 @@ export CONFIG_STORAGE__MAX_SIZE=268435456 # 256 MB
186
 
187
  每个客户端条目可以配置不同的代理,从而规避速率限制。省略 `proxy` 字段或将其设置为 `null` 或空字符串以保持直连。
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  ## 鸣谢
190
 
191
  - [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - 底层 Gemini Web API 客户端
@@ -193,4 +217,4 @@ export CONFIG_STORAGE__MAX_SIZE=268435456 # 256 MB
193
 
194
  ## 免责声明
195
 
196
- 本项目与 Google 或 OpenAI 无关,仅供学习和研究使用。本项目使用了逆向工程 API,可能不符合 Google 服务条款。使用风险自负。
 
4
  [![FastAPI](https://img.shields.io/badge/FastAPI-0.115+-green.svg)](https://fastapi.tiangolo.com/)
5
  [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
6
 
 
7
  [ [English](README.md) | 中文 ]
8
 
9
  将 Gemini 网页端模型封装为兼容 OpenAI API 的 API Server。基于 [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) 实现。
 
49
  ### 配置
50
 
51
  编辑 `config/config.yaml` 并提供至少一组凭证:
52
+
53
  ```yaml
54
  gemini:
55
  clients:
 
118
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
119
  - CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
120
  - GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
121
+ restart: on-failure:3 # Avoid retrying too many times
122
  ```
123
 
124
  然后运行:
 
186
 
187
  每个客户端条目可以配置不同的代理,从而规避速率限制。省略 `proxy` 字段或将其设置为 `null` 或空字符串以保持直连。
188
 
189
+ ### 自定义模型
190
+
191
+ 你可以在 `config/config.yaml` 中或通过环境变量定义自定义模型。
192
+
193
+ #### YAML 配置
194
+
195
+ ```yaml
196
+ gemini:
197
+ model_strategy: "append" # "append" (默认 + 自定义) 或 "overwrite" (仅限自定义)
198
+ models:
199
+ - model_name: "gemini-3.0-pro"
200
+ model_header:
201
+ x-goog-ext-525001261-jspb: '[1,null,null,null,"9d8ca3786ebdfbea",null,null,0,[4],null,null,1]'
202
+ ```
203
+
204
+ #### 环境变量
205
+
206
+ 你可以通过 `CONFIG_GEMINI__MODELS` 以 JSON 字符串或列表结构的形式提供模型。这为通过 shell 或在自动化环境(例如 Docker)中覆盖设置提供了一种灵活的方式,而无需修改配置文件。
207
+
208
+ ```bash
209
+ export CONFIG_GEMINI__MODEL_STRATEGY="overwrite"
210
+ export CONFIG_GEMINI__MODELS='[{"model_name": "gemini-3.0-pro", "model_header": {"x-goog-ext-525001261-jspb": "[1,null,null,null,\"9d8ca3786ebdfbea\",null,null,0,[4],null,null,1]"}}]'
211
+ ```
212
+
213
  ## 鸣谢
214
 
215
  - [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - 底层 Gemini Web API 客户端
 
217
 
218
  ## 免责声明
219
 
220
+ 本项目与 Google 或 OpenAI 无关,仅供学习和研究使用。本项目使用了逆向工程 API,可能不符合 Google 服务条款。使用风险自负。
app/server/chat.py CHANGED
@@ -1,12 +1,11 @@
1
  import base64
2
  import json
3
  import re
4
- import struct
5
  import uuid
6
  from dataclasses import dataclass
7
  from datetime import datetime, timezone
8
  from pathlib import Path
9
- from typing import Any, Iterator
10
 
11
  import orjson
12
  from fastapi import APIRouter, Depends, HTTPException, Request, status
@@ -21,7 +20,6 @@ from ..models import (
21
  ChatCompletionRequest,
22
  ContentItem,
23
  ConversationInStore,
24
- FunctionCall,
25
  Message,
26
  ModelData,
27
  ModelListResponse,
@@ -37,26 +35,28 @@ from ..models import (
37
  ResponseToolChoice,
38
  ResponseUsage,
39
  Tool,
40
- ToolCall,
41
  ToolChoiceFunction,
42
  )
43
  from ..services import GeminiClientPool, GeminiClientWrapper, LMDBConversationStore
44
- from ..services.client import CODE_BLOCK_HINT, XML_WRAP_HINT
45
  from ..utils import g_config
46
- from ..utils.helper import estimate_tokens
 
 
 
 
 
 
 
 
 
 
 
 
47
  from .middleware import get_image_store_dir, get_image_token, get_temp_dir, verify_api_key
48
 
49
  # Maximum characters Gemini Web can accept in a single request (configurable)
50
  MAX_CHARS_PER_REQUEST = int(g_config.gemini.max_chars_per_request * 0.9)
51
  CONTINUATION_HINT = "\n(More messages to come, please reply with just 'ok.')"
52
- TOOL_BLOCK_RE = re.compile(r"```xml\s*(.*?)```", re.DOTALL | re.IGNORECASE)
53
- TOOL_CALL_RE = re.compile(
54
- r"<tool_call\s+name=\"([^\"]+)\">(.*?)</tool_call>", re.DOTALL | re.IGNORECASE
55
- )
56
- JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*(.*?)\s*```$", re.DOTALL | re.IGNORECASE)
57
- CONTROL_TOKEN_RE = re.compile(r"<\|im_(?:start|end)\|>")
58
- XML_HINT_STRIPPED = XML_WRAP_HINT.strip()
59
- CODE_HINT_STRIPPED = CODE_BLOCK_HINT.strip()
60
 
61
  router = APIRouter()
62
 
@@ -118,14 +118,6 @@ def _build_structured_requirement(
118
  )
119
 
120
 
121
- def _strip_code_fence(text: str) -> str:
122
- """Remove surrounding ```json fences if present."""
123
- match = JSON_FENCE_RE.match(text.strip())
124
- if match:
125
- return match.group(1).strip()
126
- return text.strip()
127
-
128
-
129
  def _build_tool_prompt(
130
  tools: list[Tool],
131
  tool_choice: str | ToolChoiceFunction | None,
@@ -312,75 +304,6 @@ def _prepare_messages_for_model(
312
  return prepared
313
 
314
 
315
- def _strip_system_hints(text: str) -> str:
316
- """Remove system-level hint text from a given string."""
317
- if not text:
318
- return text
319
- cleaned = _strip_tagged_blocks(text)
320
- cleaned = cleaned.replace(XML_WRAP_HINT, "").replace(XML_HINT_STRIPPED, "")
321
- cleaned = cleaned.replace(CODE_BLOCK_HINT, "").replace(CODE_HINT_STRIPPED, "")
322
- cleaned = CONTROL_TOKEN_RE.sub("", cleaned)
323
- return cleaned.strip()
324
-
325
-
326
- def _strip_tagged_blocks(text: str) -> str:
327
- """Remove <|im_start|>role ... <|im_end|> sections, dropping tool blocks entirely.
328
- - tool blocks are removed entirely (if missing end marker, drop to EOF).
329
- - other roles: remove markers and role, keep inner content (if missing end marker, keep to EOF).
330
- """
331
- if not text:
332
- return text
333
-
334
- result: list[str] = []
335
- idx = 0
336
- length = len(text)
337
- start_marker = "<|im_start|>"
338
- end_marker = "<|im_end|>"
339
-
340
- while idx < length:
341
- start = text.find(start_marker, idx)
342
- if start == -1:
343
- result.append(text[idx:])
344
- break
345
-
346
- # append any content before this block
347
- result.append(text[idx:start])
348
-
349
- role_start = start + len(start_marker)
350
- newline = text.find("\n", role_start)
351
- if newline == -1:
352
- # malformed block; keep remainder as-is (safe behavior)
353
- result.append(text[start:])
354
- break
355
-
356
- role = text[role_start:newline].strip().lower()
357
-
358
- end = text.find(end_marker, newline + 1)
359
- if end == -1:
360
- # missing end marker
361
- if role == "tool":
362
- # drop from start marker to EOF (skip remainder)
363
- break
364
- else:
365
- # keep inner content from after the role newline to EOF
366
- result.append(text[newline + 1 :])
367
- break
368
-
369
- block_end = end + len(end_marker)
370
-
371
- if role == "tool":
372
- # drop whole block
373
- idx = block_end
374
- continue
375
-
376
- # keep the content without role markers
377
- content = text[newline + 1 : end]
378
- result.append(content)
379
- idx = block_end
380
-
381
- return "".join(result)
382
-
383
-
384
  def _response_items_to_messages(
385
  items: str | list[ResponseInputItem],
386
  ) -> tuple[list[Message], str | list[ResponseInputItem]]:
@@ -509,77 +432,64 @@ def _instructions_to_messages(
509
  return instruction_messages
510
 
511
 
512
- def _remove_tool_call_blocks(text: str) -> str:
513
- """Strip tool call code blocks from text."""
514
- if not text:
515
- return text
516
- cleaned = TOOL_BLOCK_RE.sub("", text)
517
- return _strip_system_hints(cleaned)
 
518
 
 
 
519
 
520
- def _extract_tool_calls(text: str) -> tuple[str, list[ToolCall]]:
521
- """Extract tool call definitions and return cleaned text."""
522
- if not text:
523
- return text, []
524
 
525
- tool_calls: list[ToolCall] = []
526
 
527
- def _replace(match: re.Match[str]) -> str:
528
- block_content = match.group(1)
529
- if not block_content:
530
- return ""
531
 
532
- for call_match in TOOL_CALL_RE.finditer(block_content):
533
- name = (call_match.group(1) or "").strip()
534
- raw_args = (call_match.group(2) or "").strip()
535
- if not name:
536
- logger.warning(
537
- f"Encountered tool_call block without a function name: {block_content}"
538
- )
539
- continue
540
 
541
- arguments = raw_args
542
- try:
543
- parsed_args = json.loads(raw_args)
544
- arguments = json.dumps(parsed_args, ensure_ascii=False)
545
- except json.JSONDecodeError:
546
- logger.warning(
547
- f"Failed to parse tool call arguments for '{name}'. Passing raw string."
548
- )
 
 
 
 
 
 
 
 
 
 
549
 
550
- tool_calls.append(
551
- ToolCall(
552
- id=f"call_{uuid.uuid4().hex}",
553
- type="function",
554
- function=FunctionCall(name=name, arguments=arguments),
555
  )
556
  )
557
 
558
- return ""
559
-
560
- cleaned = TOOL_BLOCK_RE.sub(_replace, text)
561
- cleaned = _strip_system_hints(cleaned)
562
- return cleaned, tool_calls
563
 
564
 
565
  @router.get("/v1/models", response_model=ModelListResponse)
566
  async def list_models(api_key: str = Depends(verify_api_key)):
567
- now = int(datetime.now(tz=timezone.utc).timestamp())
568
-
569
- models = []
570
- for model in Model:
571
- m_name = model.model_name
572
- if not m_name or m_name == "unspecified":
573
- continue
574
-
575
- models.append(
576
- ModelData(
577
- id=m_name,
578
- created=now,
579
- owned_by="gemini-web",
580
- )
581
- )
582
-
583
  return ModelListResponse(data=models)
584
 
585
 
@@ -592,7 +502,11 @@ async def create_chat_completion(
592
  ):
593
  pool = GeminiClientPool()
594
  db = LMDBConversationStore()
595
- model = Model.from_name(request.model)
 
 
 
 
596
 
597
  if len(request.messages) == 0:
598
  raise HTTPException(
@@ -698,12 +612,12 @@ async def create_chat_completion(
698
  detail="Gemini output parsing failed unexpectedly.",
699
  ) from exc
700
 
701
- visible_output, tool_calls = _extract_tool_calls(raw_output_with_think)
702
- storage_output = _remove_tool_call_blocks(raw_output_clean).strip()
703
  tool_calls_payload = [call.model_dump(mode="json") for call in tool_calls]
704
 
705
  if structured_requirement:
706
- cleaned_visible = _strip_code_fence(visible_output or "")
707
  if not cleaned_visible:
708
  raise HTTPException(
709
  status_code=status.HTTP_502_BAD_GATEWAY,
@@ -849,7 +763,7 @@ async def create_response(
849
  db = LMDBConversationStore()
850
 
851
  try:
852
- model = Model.from_name(request_data.model)
853
  except ValueError as exc:
854
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
855
 
@@ -938,12 +852,12 @@ async def create_response(
938
  detail="Gemini output parsing failed unexpectedly.",
939
  ) from exc
940
 
941
- visible_text, detected_tool_calls = _extract_tool_calls(text_with_think)
942
- storage_output = _remove_tool_call_blocks(text_without_think).strip()
943
  assistant_text = LMDBConversationStore.remove_think_tags(visible_text.strip())
944
 
945
  if structured_requirement:
946
- cleaned_visible = _strip_code_fence(assistant_text or "")
947
  if not cleaned_visible:
948
  raise HTTPException(
949
  status_code=status.HTTP_502_BAD_GATEWAY,
@@ -1010,7 +924,7 @@ async def create_response(
1010
 
1011
  image_call_items.append(
1012
  ResponseImageGenerationCall(
1013
- id=f"img_{uuid.uuid4().hex}",
1014
  status="completed",
1015
  result=image_base64,
1016
  output_format=img_format,
@@ -1045,7 +959,7 @@ async def create_response(
1045
  response_id = f"resp_{uuid.uuid4().hex}"
1046
  message_id = f"msg_{uuid.uuid4().hex}"
1047
 
1048
- input_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
1049
  tool_arg_text = "".join(call.function.arguments or "" for call in detected_tool_calls)
1050
  completion_basis = assistant_text or ""
1051
  if tool_arg_text:
@@ -1108,25 +1022,6 @@ async def create_response(
1108
  return response_payload
1109
 
1110
 
1111
- def _text_from_message(message: Message) -> str:
1112
- """Return text content from a message for token estimation."""
1113
- base_text = ""
1114
- if isinstance(message.content, str):
1115
- base_text = message.content
1116
- elif isinstance(message.content, list):
1117
- base_text = "\n".join(
1118
- item.text or "" for item in message.content if getattr(item, "type", "") == "text"
1119
- )
1120
- elif message.content is None:
1121
- base_text = ""
1122
-
1123
- if message.tool_calls:
1124
- tool_arg_text = "".join(call.function.arguments or "" for call in message.tool_calls)
1125
- base_text = f"{base_text}\n{tool_arg_text}" if base_text else tool_arg_text
1126
-
1127
- return base_text
1128
-
1129
-
1130
  async def _find_reusable_session(
1131
  db: LMDBConversationStore,
1132
  pool: GeminiClientPool,
@@ -1224,47 +1119,6 @@ async def _send_with_split(session: ChatSession, text: str, files: list[Path | s
1224
  raise
1225
 
1226
 
1227
- def _iter_stream_segments(model_output: str, chunk_size: int = 64):
1228
- """Yield stream segments while keeping <think> markers and words intact."""
1229
- if not model_output:
1230
- return
1231
-
1232
- token_pattern = re.compile(r"\s+|\S+\s*")
1233
- pending = ""
1234
-
1235
- def _flush_pending() -> Iterator[str]:
1236
- nonlocal pending
1237
- if pending:
1238
- yield pending
1239
- pending = ""
1240
-
1241
- # Split on <think> boundaries so the markers are never fragmented.
1242
- parts = re.split(r"(</?think>)", model_output)
1243
- for part in parts:
1244
- if not part:
1245
- continue
1246
- if part in {"<think>", "</think>"}:
1247
- yield from _flush_pending()
1248
- yield part
1249
- continue
1250
-
1251
- for match in token_pattern.finditer(part):
1252
- token = match.group(0)
1253
-
1254
- if len(token) > chunk_size:
1255
- yield from _flush_pending()
1256
- for idx in range(0, len(token), chunk_size):
1257
- yield token[idx : idx + chunk_size]
1258
- continue
1259
-
1260
- if pending and len(pending) + len(token) > chunk_size:
1261
- yield from _flush_pending()
1262
-
1263
- pending += token
1264
-
1265
- yield from _flush_pending()
1266
-
1267
-
1268
  def _create_streaming_response(
1269
  model_output: str,
1270
  tool_calls: list[dict],
@@ -1276,7 +1130,7 @@ def _create_streaming_response(
1276
  """Create streaming response with `usage` calculation included in the final chunk."""
1277
 
1278
  # Calculate token usage
1279
- prompt_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
1280
  tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
1281
  completion_tokens = estimate_tokens(model_output + tool_args)
1282
  total_tokens = prompt_tokens + completion_tokens
@@ -1294,7 +1148,7 @@ def _create_streaming_response(
1294
  yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
1295
 
1296
  # Stream output text in chunks for efficiency
1297
- for chunk in _iter_stream_segments(model_output):
1298
  data = {
1299
  "id": completion_id,
1300
  "object": "chat.completion.chunk",
@@ -1408,7 +1262,7 @@ def _create_responses_streaming_response(
1408
  content_text += c.text
1409
 
1410
  if content_text:
1411
- for chunk in _iter_stream_segments(content_text):
1412
  delta_event = {
1413
  **base_event,
1414
  "type": "response.output_text.delta",
@@ -1457,7 +1311,7 @@ def _create_standard_response(
1457
  ) -> dict:
1458
  """Create standard response"""
1459
  # Calculate token usage
1460
- prompt_tokens = sum(estimate_tokens(_text_from_message(msg)) for msg in messages)
1461
  tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
1462
  completion_tokens = estimate_tokens(model_output + tool_args)
1463
  total_tokens = prompt_tokens + completion_tokens
@@ -1490,74 +1344,16 @@ def _create_standard_response(
1490
  return result
1491
 
1492
 
1493
- def _extract_image_dimensions(data: bytes) -> tuple[int | None, int | None]:
1494
- """Return image dimensions (width, height) if PNG or JPEG headers are present."""
1495
- # PNG: dimensions stored in bytes 16..24 of the IHDR chunk
1496
- if len(data) >= 24 and data.startswith(b"\x89PNG\r\n\x1a\n"):
1497
- try:
1498
- width, height = struct.unpack(">II", data[16:24])
1499
- return int(width), int(height)
1500
- except struct.error:
1501
- return None, None
1502
-
1503
- # JPEG: dimensions stored in SOF segment; iterate through markers to locate it
1504
- if len(data) >= 4 and data[0:2] == b"\xff\xd8":
1505
- idx = 2
1506
- length = len(data)
1507
- sof_markers = {
1508
- 0xC0,
1509
- 0xC1,
1510
- 0xC2,
1511
- 0xC3,
1512
- 0xC5,
1513
- 0xC6,
1514
- 0xC7,
1515
- 0xC9,
1516
- 0xCA,
1517
- 0xCB,
1518
- 0xCD,
1519
- 0xCE,
1520
- 0xCF,
1521
- }
1522
- while idx < length:
1523
- # Find marker alignment (markers are prefixed with 0xFF bytes)
1524
- if data[idx] != 0xFF:
1525
- idx += 1
1526
- continue
1527
- while idx < length and data[idx] == 0xFF:
1528
- idx += 1
1529
- if idx >= length:
1530
- break
1531
- marker = data[idx]
1532
- idx += 1
1533
-
1534
- if marker in (0xD8, 0xD9, 0x01) or 0xD0 <= marker <= 0xD7:
1535
- continue
1536
-
1537
- if idx + 1 >= length:
1538
- break
1539
- segment_length = (data[idx] << 8) + data[idx + 1]
1540
- idx += 2
1541
- if segment_length < 2:
1542
- break
1543
-
1544
- if marker in sof_markers:
1545
- if idx + 4 < length:
1546
- # Skip precision byte at idx, then read height/width (big-endian)
1547
- height = (data[idx + 1] << 8) + data[idx + 2]
1548
- width = (data[idx + 3] << 8) + data[idx + 4]
1549
- return int(width), int(height)
1550
- break
1551
-
1552
- idx += segment_length - 2
1553
-
1554
- return None, None
1555
-
1556
-
1557
  async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | None, int | None, str]:
1558
  """Persist an image provided by gemini_webapi and return base64 plus dimensions and filename."""
1559
  if isinstance(image, GeneratedImage):
1560
- saved_path = await image.save(path=str(temp_dir), full_size=True)
 
 
 
 
 
 
1561
  else:
1562
  saved_path = await image.save(path=str(temp_dir))
1563
 
@@ -1571,6 +1367,6 @@ async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | Non
1571
  original_path.rename(new_path)
1572
 
1573
  data = new_path.read_bytes()
1574
- width, height = _extract_image_dimensions(data)
1575
  filename = random_name
1576
  return base64.b64encode(data).decode("ascii"), width, height, filename
 
1
  import base64
2
  import json
3
  import re
 
4
  import uuid
5
  from dataclasses import dataclass
6
  from datetime import datetime, timezone
7
  from pathlib import Path
8
+ from typing import Any
9
 
10
  import orjson
11
  from fastapi import APIRouter, Depends, HTTPException, Request, status
 
20
  ChatCompletionRequest,
21
  ContentItem,
22
  ConversationInStore,
 
23
  Message,
24
  ModelData,
25
  ModelListResponse,
 
35
  ResponseToolChoice,
36
  ResponseUsage,
37
  Tool,
 
38
  ToolChoiceFunction,
39
  )
40
  from ..services import GeminiClientPool, GeminiClientWrapper, LMDBConversationStore
 
41
  from ..utils import g_config
42
+ from ..utils.helper import (
43
+ CODE_BLOCK_HINT,
44
+ CODE_HINT_STRIPPED,
45
+ XML_HINT_STRIPPED,
46
+ XML_WRAP_HINT,
47
+ estimate_tokens,
48
+ extract_image_dimensions,
49
+ extract_tool_calls,
50
+ iter_stream_segments,
51
+ remove_tool_call_blocks,
52
+ strip_code_fence,
53
+ text_from_message,
54
+ )
55
  from .middleware import get_image_store_dir, get_image_token, get_temp_dir, verify_api_key
56
 
57
  # Maximum characters Gemini Web can accept in a single request (configurable)
58
  MAX_CHARS_PER_REQUEST = int(g_config.gemini.max_chars_per_request * 0.9)
59
  CONTINUATION_HINT = "\n(More messages to come, please reply with just 'ok.')"
 
 
 
 
 
 
 
 
60
 
61
  router = APIRouter()
62
 
 
118
  )
119
 
120
 
 
 
 
 
 
 
 
 
121
  def _build_tool_prompt(
122
  tools: list[Tool],
123
  tool_choice: str | ToolChoiceFunction | None,
 
304
  return prepared
305
 
306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def _response_items_to_messages(
308
  items: str | list[ResponseInputItem],
309
  ) -> tuple[list[Message], str | list[ResponseInputItem]]:
 
432
  return instruction_messages
433
 
434
 
435
+ def _get_model_by_name(name: str) -> Model:
436
+ """
437
+ Retrieve a Model instance by name, considering custom models from config
438
+ and the update strategy (append or overwrite).
439
+ """
440
+ strategy = g_config.gemini.model_strategy
441
+ custom_models = {m.model_name: m for m in g_config.gemini.models if m.model_name}
442
 
443
+ if name in custom_models:
444
+ return Model.from_dict(custom_models[name].model_dump())
445
 
446
+ if strategy == "overwrite":
447
+ raise ValueError(f"Model '{name}' not found in custom models (strategy='overwrite').")
 
 
448
 
449
+ return Model.from_name(name)
450
 
 
 
 
 
451
 
452
+ def _get_available_models() -> list[ModelData]:
453
+ """
454
+ Return a list of available models based on configuration strategy.
455
+ """
456
+ now = int(datetime.now(tz=timezone.utc).timestamp())
457
+ strategy = g_config.gemini.model_strategy
458
+ models_data = []
 
459
 
460
+ custom_models = [m for m in g_config.gemini.models if m.model_name]
461
+ for m in custom_models:
462
+ models_data.append(
463
+ ModelData(
464
+ id=m.model_name,
465
+ created=now,
466
+ owned_by="custom",
467
+ )
468
+ )
469
+
470
+ if strategy == "append":
471
+ custom_ids = {m.model_name for m in custom_models}
472
+ for model in Model:
473
+ m_name = model.model_name
474
+ if not m_name or m_name == "unspecified":
475
+ continue
476
+ if m_name in custom_ids:
477
+ continue
478
 
479
+ models_data.append(
480
+ ModelData(
481
+ id=m_name,
482
+ created=now,
483
+ owned_by="gemini-web",
484
  )
485
  )
486
 
487
+ return models_data
 
 
 
 
488
 
489
 
490
  @router.get("/v1/models", response_model=ModelListResponse)
491
  async def list_models(api_key: str = Depends(verify_api_key)):
492
+ models = _get_available_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  return ModelListResponse(data=models)
494
 
495
 
 
502
  ):
503
  pool = GeminiClientPool()
504
  db = LMDBConversationStore()
505
+
506
+ try:
507
+ model = _get_model_by_name(request.model)
508
+ except ValueError as exc:
509
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
510
 
511
  if len(request.messages) == 0:
512
  raise HTTPException(
 
612
  detail="Gemini output parsing failed unexpectedly.",
613
  ) from exc
614
 
615
+ visible_output, tool_calls = extract_tool_calls(raw_output_with_think)
616
+ storage_output = remove_tool_call_blocks(raw_output_clean).strip()
617
  tool_calls_payload = [call.model_dump(mode="json") for call in tool_calls]
618
 
619
  if structured_requirement:
620
+ cleaned_visible = strip_code_fence(visible_output or "")
621
  if not cleaned_visible:
622
  raise HTTPException(
623
  status_code=status.HTTP_502_BAD_GATEWAY,
 
763
  db = LMDBConversationStore()
764
 
765
  try:
766
+ model = _get_model_by_name(request_data.model)
767
  except ValueError as exc:
768
  raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
769
 
 
852
  detail="Gemini output parsing failed unexpectedly.",
853
  ) from exc
854
 
855
+ visible_text, detected_tool_calls = extract_tool_calls(text_with_think)
856
+ storage_output = remove_tool_call_blocks(text_without_think).strip()
857
  assistant_text = LMDBConversationStore.remove_think_tags(visible_text.strip())
858
 
859
  if structured_requirement:
860
+ cleaned_visible = strip_code_fence(assistant_text or "")
861
  if not cleaned_visible:
862
  raise HTTPException(
863
  status_code=status.HTTP_502_BAD_GATEWAY,
 
924
 
925
  image_call_items.append(
926
  ResponseImageGenerationCall(
927
+ id=filename.rsplit(".", 1)[0],
928
  status="completed",
929
  result=image_base64,
930
  output_format=img_format,
 
959
  response_id = f"resp_{uuid.uuid4().hex}"
960
  message_id = f"msg_{uuid.uuid4().hex}"
961
 
962
+ input_tokens = sum(estimate_tokens(text_from_message(msg)) for msg in messages)
963
  tool_arg_text = "".join(call.function.arguments or "" for call in detected_tool_calls)
964
  completion_basis = assistant_text or ""
965
  if tool_arg_text:
 
1022
  return response_payload
1023
 
1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1025
  async def _find_reusable_session(
1026
  db: LMDBConversationStore,
1027
  pool: GeminiClientPool,
 
1119
  raise
1120
 
1121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1122
  def _create_streaming_response(
1123
  model_output: str,
1124
  tool_calls: list[dict],
 
1130
  """Create streaming response with `usage` calculation included in the final chunk."""
1131
 
1132
  # Calculate token usage
1133
+ prompt_tokens = sum(estimate_tokens(text_from_message(msg)) for msg in messages)
1134
  tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
1135
  completion_tokens = estimate_tokens(model_output + tool_args)
1136
  total_tokens = prompt_tokens + completion_tokens
 
1148
  yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
1149
 
1150
  # Stream output text in chunks for efficiency
1151
+ for chunk in iter_stream_segments(model_output):
1152
  data = {
1153
  "id": completion_id,
1154
  "object": "chat.completion.chunk",
 
1262
  content_text += c.text
1263
 
1264
  if content_text:
1265
+ for chunk in iter_stream_segments(content_text):
1266
  delta_event = {
1267
  **base_event,
1268
  "type": "response.output_text.delta",
 
1311
  ) -> dict:
1312
  """Create standard response"""
1313
  # Calculate token usage
1314
+ prompt_tokens = sum(estimate_tokens(text_from_message(msg)) for msg in messages)
1315
  tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
1316
  completion_tokens = estimate_tokens(model_output + tool_args)
1317
  total_tokens = prompt_tokens + completion_tokens
 
1344
  return result
1345
 
1346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1347
  async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | None, int | None, str]:
1348
  """Persist an image provided by gemini_webapi and return base64 plus dimensions and filename."""
1349
  if isinstance(image, GeneratedImage):
1350
+ try:
1351
+ saved_path = await image.save(path=str(temp_dir), full_size=True)
1352
+ except Exception as e:
1353
+ logger.warning(
1354
+ f"Failed to download full-size GeneratedImage, retrying with default size: {e}"
1355
+ )
1356
+ saved_path = await image.save(path=str(temp_dir), full_size=False)
1357
  else:
1358
  saved_path = await image.save(path=str(temp_dir))
1359
 
 
1367
  original_path.rename(new_path)
1368
 
1369
  data = new_path.read_bytes()
1370
+ width, height = extract_image_dimensions(data)
1371
  filename = random_name
1372
  return base64.b64encode(data).decode("ascii"), width, height, filename
app/services/client.py CHANGED
@@ -9,18 +9,12 @@ from loguru import logger
9
 
10
  from ..models import Message
11
  from ..utils import g_config
12
- from ..utils.helper import add_tag, save_file_to_tempfile, save_url_to_tempfile
13
-
14
- XML_WRAP_HINT = (
15
- "\nYou MUST wrap every tool call response inside a single fenced block exactly like:\n"
16
- '```xml\n<tool_call name="tool_name">{"arg": "value"}</tool_call>\n```\n'
17
- "Do not surround the fence with any other text or whitespace; otherwise the call will be ignored.\n"
18
- )
19
- CODE_BLOCK_HINT = (
20
- "\nWhenever you include code, markup, or shell snippets, wrap each snippet in a Markdown fenced "
21
- "block and supply the correct language label (for example, ```python ... ``` or ```html ... ```).\n"
22
- "Fence ONLY the actual code/markup; keep all narrative or explanatory text outside the fences.\n"
23
  )
 
24
  HTML_ESCAPE_RE = re.compile(r"&(?:lt|gt|amp|quot|apos|#[0-9]+|#x[0-9a-fA-F]+);")
25
  MARKDOWN_ESCAPE_RE = re.compile(r"\\(?=[-\\`*_{}\[\]()#+.!<>])")
26
  CODE_FENCE_RE = re.compile(r"(```.*?```|`[^`\n]+?`)", re.DOTALL)
 
9
 
10
  from ..models import Message
11
  from ..utils import g_config
12
+ from ..utils.helper import (
13
+ add_tag,
14
+ save_file_to_tempfile,
15
+ save_url_to_tempfile,
 
 
 
 
 
 
 
16
  )
17
+
18
  HTML_ESCAPE_RE = re.compile(r"&(?:lt|gt|amp|quot|apos|#[0-9]+|#x[0-9a-fA-F]+);")
19
  MARKDOWN_ESCAPE_RE = re.compile(r"\\(?=[-\\`*_{}\[\]()#+.!<>])")
20
  CODE_FENCE_RE = re.compile(r"(```.*?```|`[^`\n]+?`)", re.DOTALL)
app/utils/config.py CHANGED
@@ -1,6 +1,8 @@
 
 
1
  import os
2
  import sys
3
- from typing import Literal, Optional
4
 
5
  from loguru import logger
6
  from pydantic import BaseModel, Field, ValidationError, field_validator
@@ -50,12 +52,37 @@ class GeminiClientSettings(BaseModel):
50
  return stripped or None
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class GeminiConfig(BaseModel):
54
  """Gemini API configuration"""
55
 
56
  clients: list[GeminiClientSettings] = Field(
57
  ..., description="List of Gemini client credential pairs"
58
  )
 
 
 
 
 
59
  timeout: int = Field(default=120, ge=1, description="Init timeout")
60
  auto_refresh: bool = Field(True, description="Enable auto-refresh for Gemini cookies")
61
  refresh_interval: int = Field(
@@ -68,6 +95,36 @@ class GeminiConfig(BaseModel):
68
  description="Maximum characters Gemini Web can accept per request",
69
  )
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  class CORSConfig(BaseModel):
73
  """CORS configuration"""
@@ -207,10 +264,74 @@ def _merge_clients_with_env(
207
  new_client = GeminiClientSettings(**overrides)
208
  result_clients.append(new_client)
209
  else:
210
- raise IndexError(f"Client index {idx} in env is out of range.")
 
 
 
211
  return result_clients if result_clients else base_clients
212
 
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  def initialize_config() -> Config:
215
  """
216
  Initialize the configuration.
@@ -221,6 +342,8 @@ def initialize_config() -> Config:
221
  try:
222
  # First, extract and remove Gemini clients related environment variables
223
  env_clients_overrides = extract_gemini_clients_env()
 
 
224
 
225
  # Then, initialize Config with pydantic_settings
226
  config = Config() # type: ignore
@@ -228,7 +351,10 @@ def initialize_config() -> Config:
228
  # Synthesize clients
229
  config.gemini.clients = _merge_clients_with_env(
230
  config.gemini.clients, env_clients_overrides
231
- ) # type: ignore
 
 
 
232
 
233
  return config
234
  except ValidationError as e:
 
1
+ import ast
2
+ import json
3
  import os
4
  import sys
5
+ from typing import Any, Literal, Optional
6
 
7
  from loguru import logger
8
  from pydantic import BaseModel, Field, ValidationError, field_validator
 
52
  return stripped or None
53
 
54
 
55
+ class GeminiModelConfig(BaseModel):
56
+ """Configuration for a custom Gemini model."""
57
+
58
+ model_name: Optional[str] = Field(default=None, description="Name of the model")
59
+ model_header: Optional[dict[str, Optional[str]]] = Field(
60
+ default=None, description="Header for the model"
61
+ )
62
+
63
+ @field_validator("model_header", mode="before")
64
+ @classmethod
65
+ def _parse_json_string(cls, v: Any) -> Any:
66
+ if isinstance(v, str) and v.strip().startswith("{"):
67
+ try:
68
+ return json.loads(v)
69
+ except json.JSONDecodeError:
70
+ # Return the original value to let Pydantic handle the error or type mismatch
71
+ return v
72
+ return v
73
+
74
+
75
  class GeminiConfig(BaseModel):
76
  """Gemini API configuration"""
77
 
78
  clients: list[GeminiClientSettings] = Field(
79
  ..., description="List of Gemini client credential pairs"
80
  )
81
+ models: list[GeminiModelConfig] = Field(default=[], description="List of custom Gemini models")
82
+ model_strategy: Literal["append", "overwrite"] = Field(
83
+ default="append",
84
+ description="Strategy for loading models: 'append' merges custom with default, 'overwrite' uses only custom",
85
+ )
86
  timeout: int = Field(default=120, ge=1, description="Init timeout")
87
  auto_refresh: bool = Field(True, description="Enable auto-refresh for Gemini cookies")
88
  refresh_interval: int = Field(
 
95
  description="Maximum characters Gemini Web can accept per request",
96
  )
97
 
98
+ @field_validator("models", mode="before")
99
+ @classmethod
100
+ def _parse_models_json(cls, v: Any) -> Any:
101
+ if isinstance(v, str) and v.strip().startswith("["):
102
+ try:
103
+ return json.loads(v)
104
+ except json.JSONDecodeError as e:
105
+ logger.warning(f"Failed to parse models JSON string: {e}")
106
+ return v
107
+ return v
108
+
109
+ @field_validator("models")
110
+ @classmethod
111
+ def _filter_valid_models(cls, v: list[GeminiModelConfig]) -> list[GeminiModelConfig]:
112
+ """Filter out models that don't have all required fields set."""
113
+ valid_models = []
114
+ for model in v:
115
+ if model.model_name and model.model_header:
116
+ valid_models.append(model)
117
+ else:
118
+ missing = []
119
+ if not model.model_name:
120
+ missing.append("model_name")
121
+ if not model.model_header:
122
+ missing.append("model_header")
123
+ logger.warning(
124
+ f"Discarding custom model due to missing {', '.join(missing)}: {model}"
125
+ )
126
+ return valid_models
127
+
128
 
129
  class CORSConfig(BaseModel):
130
  """CORS configuration"""
 
264
  new_client = GeminiClientSettings(**overrides)
265
  result_clients.append(new_client)
266
  else:
267
+ raise IndexError(
268
+ f"Client index {idx} in env is out of range (current count: {len(result_clients)}). "
269
+ "Client indices must be contiguous starting from 0."
270
+ )
271
  return result_clients if result_clients else base_clients
272
 
273
 
274
+ def extract_gemini_models_env() -> dict[int, dict[str, Any]]:
275
+ """Extract and remove all Gemini models related environment variables, supporting nested fields."""
276
+ root_key = "CONFIG_GEMINI__MODELS"
277
+ env_overrides: dict[int, dict[str, Any]] = {}
278
+
279
+ if root_key in os.environ:
280
+ val = os.environ[root_key]
281
+ models_list = None
282
+ parsed_successfully = False
283
+
284
+ try:
285
+ models_list = json.loads(val)
286
+ parsed_successfully = True
287
+ except json.JSONDecodeError:
288
+ try:
289
+ models_list = ast.literal_eval(val)
290
+ parsed_successfully = True
291
+ except (ValueError, SyntaxError) as e:
292
+ logger.warning(f"Failed to parse {root_key} as JSON or Python literal: {e}")
293
+
294
+ if parsed_successfully and isinstance(models_list, list):
295
+ for idx, model_data in enumerate(models_list):
296
+ if isinstance(model_data, dict):
297
+ env_overrides[idx] = model_data
298
+
299
+ # Remove the environment variable to avoid Pydantic parsing errors
300
+ del os.environ[root_key]
301
+
302
+ return env_overrides
303
+
304
+
305
+ def _merge_models_with_env(
306
+ base_models: list[GeminiModelConfig] | None,
307
+ env_overrides: dict[int, dict[str, Any]],
308
+ ):
309
+ """Override base_models with env_overrides using standard update (replace whole fields)."""
310
+ if not env_overrides:
311
+ return base_models or []
312
+ result_models: list[GeminiModelConfig] = []
313
+ if base_models:
314
+ result_models = [model.model_copy() for model in base_models]
315
+
316
+ for idx in sorted(env_overrides):
317
+ overrides = env_overrides[idx]
318
+ if idx < len(result_models):
319
+ # Update existing model: overwrite fields found in env
320
+ model_dict = result_models[idx].model_dump()
321
+ model_dict.update(overrides)
322
+ result_models[idx] = GeminiModelConfig(**model_dict)
323
+ elif idx == len(result_models):
324
+ # Append new models
325
+ new_model = GeminiModelConfig(**overrides)
326
+ result_models.append(new_model)
327
+ else:
328
+ raise IndexError(
329
+ f"Model index {idx} in env is out of range (current count: {len(result_models)}). "
330
+ "Model indices must be contiguous starting from 0."
331
+ )
332
+ return result_models
333
+
334
+
335
  def initialize_config() -> Config:
336
  """
337
  Initialize the configuration.
 
342
  try:
343
  # First, extract and remove Gemini clients related environment variables
344
  env_clients_overrides = extract_gemini_clients_env()
345
+ # Extract and remove Gemini models related environment variables
346
+ env_models_overrides = extract_gemini_models_env()
347
 
348
  # Then, initialize Config with pydantic_settings
349
  config = Config() # type: ignore
 
351
  # Synthesize clients
352
  config.gemini.clients = _merge_clients_with_env(
353
  config.gemini.clients, env_clients_overrides
354
+ )
355
+
356
+ # Synthesize models
357
+ config.gemini.models = _merge_models_with_env(config.gemini.models, env_models_overrides)
358
 
359
  return config
360
  except ValidationError as e:
app/utils/helper.py CHANGED
@@ -1,12 +1,38 @@
1
  import base64
 
2
  import mimetypes
 
 
3
  import tempfile
 
4
  from pathlib import Path
 
 
5
 
6
  import httpx
7
  from loguru import logger
8
 
 
 
9
  VALID_TAG_ROLES = {"user", "assistant", "system", "tool"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def add_tag(role: str, content: str, unclose: bool = False) -> str:
@@ -18,8 +44,10 @@ def add_tag(role: str, content: str, unclose: bool = False) -> str:
18
  return f"<|im_start|>{role}\n{content}" + ("\n<|im_end|>" if not unclose else "")
19
 
20
 
21
- def estimate_tokens(text: str) -> int:
22
  """Estimate the number of tokens heuristically based on character count"""
 
 
23
  return int(len(text) / 3)
24
 
25
 
@@ -36,7 +64,7 @@ async def save_file_to_tempfile(
36
  return path
37
 
38
 
39
- async def save_url_to_tempfile(url: str, tempdir: Path | None = None):
40
  data: bytes | None = None
41
  suffix: str | None = None
42
  if url.startswith("data:image/"):
@@ -47,20 +75,315 @@ async def save_url_to_tempfile(url: str, tempdir: Path | None = None):
47
  base64_data = url.split(",")[1]
48
  data = base64.b64decode(base64_data)
49
 
50
- # Guess extension from mime type, default to the subtype if not found
51
  suffix = mimetypes.guess_extension(mime_type)
52
  if not suffix:
53
  suffix = f".{mime_type.split('/')[1]}"
54
  else:
55
- # http files
56
- async with httpx.AsyncClient() as client:
57
  resp = await client.get(url)
58
  resp.raise_for_status()
59
  data = resp.content
60
- suffix = Path(url).suffix or ".bin"
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=tempdir) as tmp:
63
  tmp.write(data)
64
  path = Path(tmp.name)
65
 
66
  return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import base64
2
+ import json
3
  import mimetypes
4
+ import re
5
+ import struct
6
  import tempfile
7
+ import uuid
8
  from pathlib import Path
9
+ from typing import Iterator
10
+ from urllib.parse import urlparse
11
 
12
  import httpx
13
  from loguru import logger
14
 
15
+ from ..models import FunctionCall, Message, ToolCall
16
+
17
  VALID_TAG_ROLES = {"user", "assistant", "system", "tool"}
18
+ XML_WRAP_HINT = (
19
+ "\nYou MUST wrap every tool call response inside a single fenced block exactly like:\n"
20
+ '```xml\n<tool_call name="tool_name">{"arg": "value"}</tool_call>\n```\n'
21
+ "Do not surround the fence with any other text or whitespace; otherwise the call will be ignored.\n"
22
+ )
23
+ CODE_BLOCK_HINT = (
24
+ "\nWhenever you include code, markup, or shell snippets, wrap each snippet in a Markdown fenced "
25
+ "block and supply the correct language label (for example, ```python ... ``` or ```html ... ```).\n"
26
+ "Fence ONLY the actual code/markup; keep all narrative or explanatory text outside the fences.\n"
27
+ )
28
+ TOOL_BLOCK_RE = re.compile(r"```xml\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
29
+ TOOL_CALL_RE = re.compile(
30
+ r"<tool_call\s+name=\"([^\"]+)\"\s*>(.*?)</tool_call>", re.DOTALL | re.IGNORECASE
31
+ )
32
+ JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*(.*?)\s*```$", re.DOTALL | re.IGNORECASE)
33
+ CONTROL_TOKEN_RE = re.compile(r"<\|im_(?:start|end)\|>")
34
+ XML_HINT_STRIPPED = XML_WRAP_HINT.strip()
35
+ CODE_HINT_STRIPPED = CODE_BLOCK_HINT.strip()
36
 
37
 
38
  def add_tag(role: str, content: str, unclose: bool = False) -> str:
 
44
  return f"<|im_start|>{role}\n{content}" + ("\n<|im_end|>" if not unclose else "")
45
 
46
 
47
+ def estimate_tokens(text: str | None) -> int:
48
  """Estimate the number of tokens heuristically based on character count"""
49
+ if not text:
50
+ return 0
51
  return int(len(text) / 3)
52
 
53
 
 
64
  return path
65
 
66
 
67
+ async def save_url_to_tempfile(url: str, tempdir: Path | None = None) -> Path:
68
  data: bytes | None = None
69
  suffix: str | None = None
70
  if url.startswith("data:image/"):
 
75
  base64_data = url.split(",")[1]
76
  data = base64.b64decode(base64_data)
77
 
 
78
  suffix = mimetypes.guess_extension(mime_type)
79
  if not suffix:
80
  suffix = f".{mime_type.split('/')[1]}"
81
  else:
82
+ async with httpx.AsyncClient(follow_redirects=True) as client:
 
83
  resp = await client.get(url)
84
  resp.raise_for_status()
85
  data = resp.content
86
+ content_type = resp.headers.get("content-type")
87
+
88
+ if content_type:
89
+ mime_type = content_type.split(";")[0].strip()
90
+ suffix = mimetypes.guess_extension(mime_type)
91
+
92
+ if not suffix:
93
+ path_url = urlparse(url).path
94
+ suffix = Path(path_url).suffix
95
+
96
+ if not suffix:
97
+ suffix = ".bin"
98
 
99
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=tempdir) as tmp:
100
  tmp.write(data)
101
  path = Path(tmp.name)
102
 
103
  return path
104
+
105
+
106
+ def strip_code_fence(text: str) -> str:
107
+ """Remove surrounding ```json fences if present."""
108
+ match = JSON_FENCE_RE.match(text.strip())
109
+ if match:
110
+ return match.group(1).strip()
111
+ return text.strip()
112
+
113
+
114
+ def strip_tagged_blocks(text: str) -> str:
115
+ """Remove <|im_start|>role ... <|im_end|> sections, dropping tool blocks entirely.
116
+ - tool blocks are removed entirely (if missing end marker, drop to EOF).
117
+ - other roles: remove markers and role, keep inner content (if missing end marker, keep to EOF).
118
+ """
119
+ if not text:
120
+ return text
121
+
122
+ result: list[str] = []
123
+ idx = 0
124
+ length = len(text)
125
+ start_marker = "<|im_start|>"
126
+ end_marker = "<|im_end|>"
127
+
128
+ while idx < length:
129
+ start = text.find(start_marker, idx)
130
+ if start == -1:
131
+ result.append(text[idx:])
132
+ break
133
+
134
+ # append any content before this block
135
+ result.append(text[idx:start])
136
+
137
+ role_start = start + len(start_marker)
138
+ newline = text.find("\n", role_start)
139
+ if newline == -1:
140
+ # malformed block; keep the remainder as-is (safe behavior)
141
+ result.append(text[start:])
142
+ break
143
+
144
+ role = text[role_start:newline].strip().lower()
145
+
146
+ end = text.find(end_marker, newline + 1)
147
+ if end == -1:
148
+ # missing end marker
149
+ if role == "tool":
150
+ # drop from the start marker to EOF (skip the remainder)
151
+ break
152
+ else:
153
+ # keep inner content from after the role newline to EOF
154
+ result.append(text[newline + 1 :])
155
+ break
156
+
157
+ block_end = end + len(end_marker)
158
+
159
+ if role == "tool":
160
+ # drop the whole block
161
+ idx = block_end
162
+ continue
163
+
164
+ # keep the content without role markers
165
+ content = text[newline + 1 : end]
166
+ result.append(content)
167
+ idx = block_end
168
+
169
+ return "".join(result)
170
+
171
+
172
+ def strip_system_hints(text: str) -> str:
173
+ """Remove system-level hint text from a given string."""
174
+ if not text:
175
+ return text
176
+ cleaned = strip_tagged_blocks(text)
177
+ cleaned = cleaned.replace(XML_WRAP_HINT, "").replace(XML_HINT_STRIPPED, "")
178
+ cleaned = cleaned.replace(CODE_BLOCK_HINT, "").replace(CODE_HINT_STRIPPED, "")
179
+ cleaned = CONTROL_TOKEN_RE.sub("", cleaned)
180
+ return cleaned.strip()
181
+
182
+
183
+ def remove_tool_call_blocks(text: str) -> str:
184
+ """Strip tool call code blocks from text."""
185
+ if not text:
186
+ return text
187
+
188
+ # 1. Remove fenced blocks ONLY if they contain tool calls
189
+ def _replace_block(match: re.Match[str]) -> str:
190
+ block_content = match.group(1)
191
+ if not block_content:
192
+ return match.group(0)
193
+
194
+ # Check if the block contains any tool call tag
195
+ if TOOL_CALL_RE.search(block_content):
196
+ return ""
197
+
198
+ # Preserve the block if no tool call found
199
+ return match.group(0)
200
+
201
+ cleaned = TOOL_BLOCK_RE.sub(_replace_block, text)
202
+
203
+ # 2. Remove orphaned tool calls
204
+ cleaned = TOOL_CALL_RE.sub("", cleaned)
205
+
206
+ return strip_system_hints(cleaned)
207
+
208
+
209
+ def extract_tool_calls(text: str) -> tuple[str, list[ToolCall]]:
210
+ """Extract tool call definitions and return cleaned text."""
211
+ if not text:
212
+ return text, []
213
+
214
+ tool_calls: list[ToolCall] = []
215
+
216
+ def _create_tool_call(name: str, raw_args: str) -> None:
217
+ """Helper to parse args and append to the tool_calls list."""
218
+ if not name:
219
+ logger.warning("Encountered tool_call without a function name.")
220
+ return
221
+
222
+ arguments = raw_args
223
+ try:
224
+ parsed_args = json.loads(raw_args)
225
+ arguments = json.dumps(parsed_args, ensure_ascii=False)
226
+ except json.JSONDecodeError:
227
+ logger.warning(f"Failed to parse tool call arguments for '{name}'. Passing raw string.")
228
+
229
+ tool_calls.append(
230
+ ToolCall(
231
+ id=f"call_{uuid.uuid4().hex}",
232
+ type="function",
233
+ function=FunctionCall(name=name, arguments=arguments),
234
+ )
235
+ )
236
+
237
+ def _replace_block(match: re.Match[str]) -> str:
238
+ block_content = match.group(1)
239
+ if not block_content:
240
+ return match.group(0)
241
+
242
+ found_in_block = False
243
+ for call_match in TOOL_CALL_RE.finditer(block_content):
244
+ found_in_block = True
245
+ name = (call_match.group(1) or "").strip()
246
+ raw_args = (call_match.group(2) or "").strip()
247
+ _create_tool_call(name, raw_args)
248
+
249
+ if found_in_block:
250
+ return ""
251
+ else:
252
+ return match.group(0)
253
+
254
+ cleaned = TOOL_BLOCK_RE.sub(_replace_block, text)
255
+
256
+ def _replace_orphan(match: re.Match[str]) -> str:
257
+ name = (match.group(1) or "").strip()
258
+ raw_args = (match.group(2) or "").strip()
259
+ _create_tool_call(name, raw_args)
260
+ return ""
261
+
262
+ cleaned = TOOL_CALL_RE.sub(_replace_orphan, cleaned)
263
+
264
+ cleaned = strip_system_hints(cleaned)
265
+ return cleaned, tool_calls
266
+
267
+
268
+ def iter_stream_segments(model_output: str, chunk_size: int = 64) -> Iterator[str]:
269
+ """Yield stream segments while keeping <think> markers and words intact."""
270
+ if not model_output:
271
+ return
272
+
273
+ token_pattern = re.compile(r"\s+|\S+\s*")
274
+ pending = ""
275
+
276
+ def _flush_pending() -> Iterator[str]:
277
+ nonlocal pending
278
+ if pending:
279
+ yield pending
280
+ pending = ""
281
+
282
+ # Split on <think> boundaries so the markers are never fragmented.
283
+ parts = re.split(r"(</?think>)", model_output)
284
+ for part in parts:
285
+ if not part:
286
+ continue
287
+ if part in {"<think>", "</think>"}:
288
+ yield from _flush_pending()
289
+ yield part
290
+ continue
291
+
292
+ for match in token_pattern.finditer(part):
293
+ token = match.group(0)
294
+
295
+ if len(token) > chunk_size:
296
+ yield from _flush_pending()
297
+ for idx in range(0, len(token), chunk_size):
298
+ yield token[idx : idx + chunk_size]
299
+ continue
300
+
301
+ if pending and len(pending) + len(token) > chunk_size:
302
+ yield from _flush_pending()
303
+
304
+ pending += token
305
+
306
+ yield from _flush_pending()
307
+
308
+
309
+ def text_from_message(message: Message) -> str:
310
+ """Return text content from a message for token estimation."""
311
+ base_text = ""
312
+ if isinstance(message.content, str):
313
+ base_text = message.content
314
+ elif isinstance(message.content, list):
315
+ base_text = "\n".join(
316
+ item.text or "" for item in message.content if getattr(item, "type", "") == "text"
317
+ )
318
+ elif message.content is None:
319
+ base_text = ""
320
+
321
+ if message.tool_calls:
322
+ tool_arg_text = "".join(call.function.arguments or "" for call in message.tool_calls)
323
+ base_text = f"{base_text}\n{tool_arg_text}" if base_text else tool_arg_text
324
+
325
+ return base_text
326
+
327
+
328
+ def extract_image_dimensions(data: bytes) -> tuple[int | None, int | None]:
329
+ """Return image dimensions (width, height) if PNG or JPEG headers are present."""
330
+ # PNG: dimensions stored in bytes 16..24 of the IHDR chunk
331
+ if len(data) >= 24 and data.startswith(b"\x89PNG\r\n\x1a\n"):
332
+ try:
333
+ width, height = struct.unpack(">II", data[16:24])
334
+ return int(width), int(height)
335
+ except struct.error:
336
+ return None, None
337
+
338
+ # JPEG: dimensions stored in SOF segment; iterate through markers to locate it
339
+ if len(data) >= 4 and data[0:2] == b"\xff\xd8":
340
+ idx = 2
341
+ length = len(data)
342
+ sof_markers = {
343
+ 0xC0,
344
+ 0xC1,
345
+ 0xC2,
346
+ 0xC3,
347
+ 0xC5,
348
+ 0xC6,
349
+ 0xC7,
350
+ 0xC9,
351
+ 0xCA,
352
+ 0xCB,
353
+ 0xCD,
354
+ 0xCE,
355
+ 0xCF,
356
+ }
357
+ while idx < length:
358
+ # Find marker alignment (markers are prefixed with 0xFF bytes)
359
+ if data[idx] != 0xFF:
360
+ idx += 1
361
+ continue
362
+ while idx < length and data[idx] == 0xFF:
363
+ idx += 1
364
+ if idx >= length:
365
+ break
366
+ marker = data[idx]
367
+ idx += 1
368
+
369
+ if marker in (0xD8, 0xD9, 0x01) or 0xD0 <= marker <= 0xD7:
370
+ continue
371
+
372
+ if idx + 1 >= length:
373
+ break
374
+ segment_length = (data[idx] << 8) + data[idx + 1]
375
+ idx += 2
376
+ if segment_length < 2:
377
+ break
378
+
379
+ if marker in sof_markers:
380
+ if idx + 4 < length:
381
+ # Skip precision byte at idx, then read height/width (big-endian)
382
+ height = (data[idx + 1] << 8) + data[idx + 2]
383
+ width = (data[idx + 3] << 8) + data[idx + 4]
384
+ return int(width), int(height)
385
+ break
386
+
387
+ idx += segment_length - 2
388
+
389
+ return None, None
config/config.yaml CHANGED
@@ -27,6 +27,8 @@ gemini:
27
  refresh_interval: 540 # Refresh interval in seconds
28
  verbose: false # Enable verbose logging for Gemini requests
29
  max_chars_per_request: 1000000 # Maximum characters Gemini Web accepts per request. Non-pro users might have a lower limit
 
 
30
 
31
  storage:
32
  path: "data/lmdb" # Database storage path
 
27
  refresh_interval: 540 # Refresh interval in seconds
28
  verbose: false # Enable verbose logging for Gemini requests
29
  max_chars_per_request: 1000000 # Maximum characters Gemini Web accepts per request. Non-pro users might have a lower limit
30
+ model_strategy: "append" # Strategy: 'append' (default + custom) or 'overwrite' (custom only)
31
+ models: []
32
 
33
  storage:
34
  path: "data/lmdb" # Database storage path