Spaces:
Paused
feat: Add support for custom Gemini models and model loading strategies (#86)
Browse files* feat: Add support for custom Gemini models and model loading strategies
- Introduced `model_strategy` configuration for "append" (default + custom models) or "overwrite" (custom models only).
- Enhanced `/v1/models` endpoint to return models based on the configured strategy.
- Improved model loading with environment variable overrides and validation.
- Refactored model handling logic for improved modularity and error handling.
* feat: Improve Gemini model environment variable parsing and nested field support
- Enhanced `extract_gemini_models_env` to handle nested fields within environment variables.
- Updated type hints for more flexibility in model overrides.
- Improved `_merge_models_with_env` to better support field-level updates and appending new models.
* refactor: Consolidate utility functions and clean up unused code
- Moved utility functions like `strip_code_fence`, `extract_tool_calls`, and `iter_stream_segments` to a centralized helper module.
- Removed unused and redundant private methods from `chat.py`, including `_strip_code_fence`, `_strip_tagged_blocks`, and `_strip_system_hints`.
- Updated imports and references across modules for consistency.
- Simplified tool call and streaming logic by replacing inline implementations with shared helper functions.
* fix: Handle None input in `estimate_tokens` and return 0 for empty text
* refactor: Simplify model configuration and add JSON parsing validators
- Replaced unused model placeholder in `config.yaml` with an empty list.
- Added JSON parsing validators for `model_header` and `models` to enhance flexibility and error handling.
- Improved validation to filter out incomplete model configurations.
* refactor: Simplify Gemini model environment variable parsing with JSON support
- Replaced prefix-based parsing with a root key approach.
- Added JSON parsing to handle list-based model configurations.
- Improved handling of errors and cleanup of environment variables.
* fix: Enhance Gemini model environment variable parsing with fallback to Python literals
- Added `ast.literal_eval` as a fallback for parsing environment variables when JSON decoding fails.
- Improved error handling and logging for invalid configurations.
- Ensured proper cleanup of environment variables post-parsing.
* fix: Improve regex patterns in helper module
- Adjusted `TOOL_CALL_RE` regex pattern for better accuracy.
* docs: Update README files to include custom model configuration and environment variable setup
* fix: Remove unused headers from HTTP client in helper module
* fix: Update README and README.zh to clarify model configuration via environment variables; enhance error logging in config validation
* Update README and README.zh to clarify model configuration via JSON string or list structure for enhanced flexibility in automated environments
- README.md +25 -1
- README.zh.md +27 -3
- app/server/chat.py +84 -288
- app/services/client.py +5 -11
- app/utils/config.py +129 -3
- app/utils/helper.py +329 -6
- config/config.yaml +2 -0
|
@@ -118,7 +118,7 @@ services:
|
|
| 118 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
|
| 119 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
|
| 120 |
- GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
|
| 121 |
-
restart: on-failure:3
|
| 122 |
```
|
| 123 |
|
| 124 |
Then run:
|
|
@@ -187,6 +187,30 @@ To use Gemini-FastAPI, you need to extract your Gemini session cookies:
|
|
| 187 |
|
| 188 |
Each client entry can be configured with a different proxy to work around rate limits. Omit the `proxy` field or set it to `null` or an empty string to keep a direct connection.
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
## Acknowledgments
|
| 191 |
|
| 192 |
- [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - The underlying Gemini web API client
|
|
|
|
| 118 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
|
| 119 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
|
| 120 |
- GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
|
| 121 |
+
restart: on-failure:3 # Avoid retrying too many times
|
| 122 |
```
|
| 123 |
|
| 124 |
Then run:
|
|
|
|
| 187 |
|
| 188 |
Each client entry can be configured with a different proxy to work around rate limits. Omit the `proxy` field or set it to `null` or an empty string to keep a direct connection.
|
| 189 |
|
| 190 |
+
### Custom Models
|
| 191 |
+
|
| 192 |
+
You can define custom models in `config/config.yaml` or via environment variables.
|
| 193 |
+
|
| 194 |
+
#### YAML Configuration
|
| 195 |
+
|
| 196 |
+
```yaml
|
| 197 |
+
gemini:
|
| 198 |
+
model_strategy: "append" # "append" (default + custom) or "overwrite" (custom only)
|
| 199 |
+
models:
|
| 200 |
+
- model_name: "gemini-3.0-pro"
|
| 201 |
+
model_header:
|
| 202 |
+
x-goog-ext-525001261-jspb: '[1,null,null,null,"9d8ca3786ebdfbea",null,null,0,[4],null,null,1]'
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
#### Environment Variables
|
| 206 |
+
|
| 207 |
+
You can supply models as a JSON string or list structure via `CONFIG_GEMINI__MODELS`. This provides a flexible way to override settings via the shell or in automated environments (e.g. Docker) without modifying the configuration file.
|
| 208 |
+
|
| 209 |
+
```bash
|
| 210 |
+
export CONFIG_GEMINI__MODEL_STRATEGY="overwrite"
|
| 211 |
+
export CONFIG_GEMINI__MODELS='[{"model_name": "gemini-3.0-pro", "model_header": {"x-goog-ext-525001261-jspb": "[1,null,null,null,\"9d8ca3786ebdfbea\",null,null,0,[4],null,null,1]"}}]'
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
## Acknowledgments
|
| 215 |
|
| 216 |
- [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - The underlying Gemini web API client
|
|
@@ -4,7 +4,6 @@
|
|
| 4 |
[](https://fastapi.tiangolo.com/)
|
| 5 |
[](LICENSE)
|
| 6 |
|
| 7 |
-
|
| 8 |
[ [English](README.md) | 中文 ]
|
| 9 |
|
| 10 |
将 Gemini 网页端模型封装为兼容 OpenAI API 的 API Server。基于 [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) 实现。
|
|
@@ -50,6 +49,7 @@ pip install -e .
|
|
| 50 |
### 配置
|
| 51 |
|
| 52 |
编辑 `config/config.yaml` 并提供至少一组凭证:
|
|
|
|
| 53 |
```yaml
|
| 54 |
gemini:
|
| 55 |
clients:
|
|
@@ -118,7 +118,7 @@ services:
|
|
| 118 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
|
| 119 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
|
| 120 |
- GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
|
| 121 |
-
restart: on-failure:3
|
| 122 |
```
|
| 123 |
|
| 124 |
然后运行:
|
|
@@ -186,6 +186,30 @@ export CONFIG_STORAGE__MAX_SIZE=268435456 # 256 MB
|
|
| 186 |
|
| 187 |
每个客户端条目可以配置不同的代理,从而规避速率限制。省略 `proxy` 字段或将其设置为 `null` 或空字符串以保持直连。
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
## 鸣谢
|
| 190 |
|
| 191 |
- [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - 底层 Gemini Web API 客户端
|
|
@@ -193,4 +217,4 @@ export CONFIG_STORAGE__MAX_SIZE=268435456 # 256 MB
|
|
| 193 |
|
| 194 |
## 免责声明
|
| 195 |
|
| 196 |
-
本项目与 Google 或 OpenAI 无关,仅供学习和研究使用。本项目使用了逆向工程 API,可能不符合 Google 服务条款。使用风险自负。
|
|
|
|
| 4 |
[](https://fastapi.tiangolo.com/)
|
| 5 |
[](LICENSE)
|
| 6 |
|
|
|
|
| 7 |
[ [English](README.md) | 中文 ]
|
| 8 |
|
| 9 |
将 Gemini 网页端模型封装为兼容 OpenAI API 的 API Server。基于 [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) 实现。
|
|
|
|
| 49 |
### 配置
|
| 50 |
|
| 51 |
编辑 `config/config.yaml` 并提供至少一组凭证:
|
| 52 |
+
|
| 53 |
```yaml
|
| 54 |
gemini:
|
| 55 |
clients:
|
|
|
|
| 118 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSID=${SECURE_1PSID}
|
| 119 |
- CONFIG_GEMINI__CLIENTS__0__SECURE_1PSIDTS=${SECURE_1PSIDTS}
|
| 120 |
- GEMINI_COOKIE_PATH=/app/cache # must match the cache volume mount above
|
| 121 |
+
restart: on-failure:3 # Avoid retrying too many times
|
| 122 |
```
|
| 123 |
|
| 124 |
然后运行:
|
|
|
|
| 186 |
|
| 187 |
每个客户端条目可以配置不同的代理,从而规避速率限制。省略 `proxy` 字段或将其设置为 `null` 或空字符串以保持直连。
|
| 188 |
|
| 189 |
+
### 自定义模型
|
| 190 |
+
|
| 191 |
+
你可以在 `config/config.yaml` 中或通过环境变量定义自定义模型。
|
| 192 |
+
|
| 193 |
+
#### YAML 配置
|
| 194 |
+
|
| 195 |
+
```yaml
|
| 196 |
+
gemini:
|
| 197 |
+
model_strategy: "append" # "append" (默认 + 自定义) 或 "overwrite" (仅限自定义)
|
| 198 |
+
models:
|
| 199 |
+
- model_name: "gemini-3.0-pro"
|
| 200 |
+
model_header:
|
| 201 |
+
x-goog-ext-525001261-jspb: '[1,null,null,null,"9d8ca3786ebdfbea",null,null,0,[4],null,null,1]'
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
#### 环境变量
|
| 205 |
+
|
| 206 |
+
你可以通过 `CONFIG_GEMINI__MODELS` 以 JSON 字符串或列表结构的形式提供模型。这为通过 shell 或在自动化环境(例如 Docker)中覆盖设置提供了一种灵活的方式,而无需修改配置文件。
|
| 207 |
+
|
| 208 |
+
```bash
|
| 209 |
+
export CONFIG_GEMINI__MODEL_STRATEGY="overwrite"
|
| 210 |
+
export CONFIG_GEMINI__MODELS='[{"model_name": "gemini-3.0-pro", "model_header": {"x-goog-ext-525001261-jspb": "[1,null,null,null,\"9d8ca3786ebdfbea\",null,null,0,[4],null,null,1]"}}]'
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
## 鸣谢
|
| 214 |
|
| 215 |
- [HanaokaYuzu/Gemini-API](https://github.com/HanaokaYuzu/Gemini-API) - 底层 Gemini Web API 客户端
|
|
|
|
| 217 |
|
| 218 |
## 免责声明
|
| 219 |
|
| 220 |
+
本项目与 Google 或 OpenAI 无关,仅供学习和研究使用。本项目使用了逆向工程 API,可能不符合 Google 服务条款。使用风险自负。
|
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import base64
|
| 2 |
import json
|
| 3 |
import re
|
| 4 |
-
import struct
|
| 5 |
import uuid
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from datetime import datetime, timezone
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Any
|
| 10 |
|
| 11 |
import orjson
|
| 12 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
@@ -21,7 +20,6 @@ from ..models import (
|
|
| 21 |
ChatCompletionRequest,
|
| 22 |
ContentItem,
|
| 23 |
ConversationInStore,
|
| 24 |
-
FunctionCall,
|
| 25 |
Message,
|
| 26 |
ModelData,
|
| 27 |
ModelListResponse,
|
|
@@ -37,26 +35,28 @@ from ..models import (
|
|
| 37 |
ResponseToolChoice,
|
| 38 |
ResponseUsage,
|
| 39 |
Tool,
|
| 40 |
-
ToolCall,
|
| 41 |
ToolChoiceFunction,
|
| 42 |
)
|
| 43 |
from ..services import GeminiClientPool, GeminiClientWrapper, LMDBConversationStore
|
| 44 |
-
from ..services.client import CODE_BLOCK_HINT, XML_WRAP_HINT
|
| 45 |
from ..utils import g_config
|
| 46 |
-
from ..utils.helper import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
from .middleware import get_image_store_dir, get_image_token, get_temp_dir, verify_api_key
|
| 48 |
|
| 49 |
# Maximum characters Gemini Web can accept in a single request (configurable)
|
| 50 |
MAX_CHARS_PER_REQUEST = int(g_config.gemini.max_chars_per_request * 0.9)
|
| 51 |
CONTINUATION_HINT = "\n(More messages to come, please reply with just 'ok.')"
|
| 52 |
-
TOOL_BLOCK_RE = re.compile(r"```xml\s*(.*?)```", re.DOTALL | re.IGNORECASE)
|
| 53 |
-
TOOL_CALL_RE = re.compile(
|
| 54 |
-
r"<tool_call\s+name=\"([^\"]+)\">(.*?)</tool_call>", re.DOTALL | re.IGNORECASE
|
| 55 |
-
)
|
| 56 |
-
JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*(.*?)\s*```$", re.DOTALL | re.IGNORECASE)
|
| 57 |
-
CONTROL_TOKEN_RE = re.compile(r"<\|im_(?:start|end)\|>")
|
| 58 |
-
XML_HINT_STRIPPED = XML_WRAP_HINT.strip()
|
| 59 |
-
CODE_HINT_STRIPPED = CODE_BLOCK_HINT.strip()
|
| 60 |
|
| 61 |
router = APIRouter()
|
| 62 |
|
|
@@ -118,14 +118,6 @@ def _build_structured_requirement(
|
|
| 118 |
)
|
| 119 |
|
| 120 |
|
| 121 |
-
def _strip_code_fence(text: str) -> str:
|
| 122 |
-
"""Remove surrounding ```json fences if present."""
|
| 123 |
-
match = JSON_FENCE_RE.match(text.strip())
|
| 124 |
-
if match:
|
| 125 |
-
return match.group(1).strip()
|
| 126 |
-
return text.strip()
|
| 127 |
-
|
| 128 |
-
|
| 129 |
def _build_tool_prompt(
|
| 130 |
tools: list[Tool],
|
| 131 |
tool_choice: str | ToolChoiceFunction | None,
|
|
@@ -312,75 +304,6 @@ def _prepare_messages_for_model(
|
|
| 312 |
return prepared
|
| 313 |
|
| 314 |
|
| 315 |
-
def _strip_system_hints(text: str) -> str:
|
| 316 |
-
"""Remove system-level hint text from a given string."""
|
| 317 |
-
if not text:
|
| 318 |
-
return text
|
| 319 |
-
cleaned = _strip_tagged_blocks(text)
|
| 320 |
-
cleaned = cleaned.replace(XML_WRAP_HINT, "").replace(XML_HINT_STRIPPED, "")
|
| 321 |
-
cleaned = cleaned.replace(CODE_BLOCK_HINT, "").replace(CODE_HINT_STRIPPED, "")
|
| 322 |
-
cleaned = CONTROL_TOKEN_RE.sub("", cleaned)
|
| 323 |
-
return cleaned.strip()
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
def _strip_tagged_blocks(text: str) -> str:
|
| 327 |
-
"""Remove <|im_start|>role ... <|im_end|> sections, dropping tool blocks entirely.
|
| 328 |
-
- tool blocks are removed entirely (if missing end marker, drop to EOF).
|
| 329 |
-
- other roles: remove markers and role, keep inner content (if missing end marker, keep to EOF).
|
| 330 |
-
"""
|
| 331 |
-
if not text:
|
| 332 |
-
return text
|
| 333 |
-
|
| 334 |
-
result: list[str] = []
|
| 335 |
-
idx = 0
|
| 336 |
-
length = len(text)
|
| 337 |
-
start_marker = "<|im_start|>"
|
| 338 |
-
end_marker = "<|im_end|>"
|
| 339 |
-
|
| 340 |
-
while idx < length:
|
| 341 |
-
start = text.find(start_marker, idx)
|
| 342 |
-
if start == -1:
|
| 343 |
-
result.append(text[idx:])
|
| 344 |
-
break
|
| 345 |
-
|
| 346 |
-
# append any content before this block
|
| 347 |
-
result.append(text[idx:start])
|
| 348 |
-
|
| 349 |
-
role_start = start + len(start_marker)
|
| 350 |
-
newline = text.find("\n", role_start)
|
| 351 |
-
if newline == -1:
|
| 352 |
-
# malformed block; keep remainder as-is (safe behavior)
|
| 353 |
-
result.append(text[start:])
|
| 354 |
-
break
|
| 355 |
-
|
| 356 |
-
role = text[role_start:newline].strip().lower()
|
| 357 |
-
|
| 358 |
-
end = text.find(end_marker, newline + 1)
|
| 359 |
-
if end == -1:
|
| 360 |
-
# missing end marker
|
| 361 |
-
if role == "tool":
|
| 362 |
-
# drop from start marker to EOF (skip remainder)
|
| 363 |
-
break
|
| 364 |
-
else:
|
| 365 |
-
# keep inner content from after the role newline to EOF
|
| 366 |
-
result.append(text[newline + 1 :])
|
| 367 |
-
break
|
| 368 |
-
|
| 369 |
-
block_end = end + len(end_marker)
|
| 370 |
-
|
| 371 |
-
if role == "tool":
|
| 372 |
-
# drop whole block
|
| 373 |
-
idx = block_end
|
| 374 |
-
continue
|
| 375 |
-
|
| 376 |
-
# keep the content without role markers
|
| 377 |
-
content = text[newline + 1 : end]
|
| 378 |
-
result.append(content)
|
| 379 |
-
idx = block_end
|
| 380 |
-
|
| 381 |
-
return "".join(result)
|
| 382 |
-
|
| 383 |
-
|
| 384 |
def _response_items_to_messages(
|
| 385 |
items: str | list[ResponseInputItem],
|
| 386 |
) -> tuple[list[Message], str | list[ResponseInputItem]]:
|
|
@@ -509,77 +432,64 @@ def _instructions_to_messages(
|
|
| 509 |
return instruction_messages
|
| 510 |
|
| 511 |
|
| 512 |
-
def
|
| 513 |
-
"""
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
|
|
|
| 518 |
|
|
|
|
|
|
|
| 519 |
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
if not text:
|
| 523 |
-
return text, []
|
| 524 |
|
| 525 |
-
|
| 526 |
|
| 527 |
-
def _replace(match: re.Match[str]) -> str:
|
| 528 |
-
block_content = match.group(1)
|
| 529 |
-
if not block_content:
|
| 530 |
-
return ""
|
| 531 |
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
continue
|
| 540 |
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
id=
|
| 553 |
-
|
| 554 |
-
|
| 555 |
)
|
| 556 |
)
|
| 557 |
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
cleaned = TOOL_BLOCK_RE.sub(_replace, text)
|
| 561 |
-
cleaned = _strip_system_hints(cleaned)
|
| 562 |
-
return cleaned, tool_calls
|
| 563 |
|
| 564 |
|
| 565 |
@router.get("/v1/models", response_model=ModelListResponse)
|
| 566 |
async def list_models(api_key: str = Depends(verify_api_key)):
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
models = []
|
| 570 |
-
for model in Model:
|
| 571 |
-
m_name = model.model_name
|
| 572 |
-
if not m_name or m_name == "unspecified":
|
| 573 |
-
continue
|
| 574 |
-
|
| 575 |
-
models.append(
|
| 576 |
-
ModelData(
|
| 577 |
-
id=m_name,
|
| 578 |
-
created=now,
|
| 579 |
-
owned_by="gemini-web",
|
| 580 |
-
)
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
return ModelListResponse(data=models)
|
| 584 |
|
| 585 |
|
|
@@ -592,7 +502,11 @@ async def create_chat_completion(
|
|
| 592 |
):
|
| 593 |
pool = GeminiClientPool()
|
| 594 |
db = LMDBConversationStore()
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
if len(request.messages) == 0:
|
| 598 |
raise HTTPException(
|
|
@@ -698,12 +612,12 @@ async def create_chat_completion(
|
|
| 698 |
detail="Gemini output parsing failed unexpectedly.",
|
| 699 |
) from exc
|
| 700 |
|
| 701 |
-
visible_output, tool_calls =
|
| 702 |
-
storage_output =
|
| 703 |
tool_calls_payload = [call.model_dump(mode="json") for call in tool_calls]
|
| 704 |
|
| 705 |
if structured_requirement:
|
| 706 |
-
cleaned_visible =
|
| 707 |
if not cleaned_visible:
|
| 708 |
raise HTTPException(
|
| 709 |
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
@@ -849,7 +763,7 @@ async def create_response(
|
|
| 849 |
db = LMDBConversationStore()
|
| 850 |
|
| 851 |
try:
|
| 852 |
-
model =
|
| 853 |
except ValueError as exc:
|
| 854 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
| 855 |
|
|
@@ -938,12 +852,12 @@ async def create_response(
|
|
| 938 |
detail="Gemini output parsing failed unexpectedly.",
|
| 939 |
) from exc
|
| 940 |
|
| 941 |
-
visible_text, detected_tool_calls =
|
| 942 |
-
storage_output =
|
| 943 |
assistant_text = LMDBConversationStore.remove_think_tags(visible_text.strip())
|
| 944 |
|
| 945 |
if structured_requirement:
|
| 946 |
-
cleaned_visible =
|
| 947 |
if not cleaned_visible:
|
| 948 |
raise HTTPException(
|
| 949 |
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
@@ -1010,7 +924,7 @@ async def create_response(
|
|
| 1010 |
|
| 1011 |
image_call_items.append(
|
| 1012 |
ResponseImageGenerationCall(
|
| 1013 |
-
id=
|
| 1014 |
status="completed",
|
| 1015 |
result=image_base64,
|
| 1016 |
output_format=img_format,
|
|
@@ -1045,7 +959,7 @@ async def create_response(
|
|
| 1045 |
response_id = f"resp_{uuid.uuid4().hex}"
|
| 1046 |
message_id = f"msg_{uuid.uuid4().hex}"
|
| 1047 |
|
| 1048 |
-
input_tokens = sum(estimate_tokens(
|
| 1049 |
tool_arg_text = "".join(call.function.arguments or "" for call in detected_tool_calls)
|
| 1050 |
completion_basis = assistant_text or ""
|
| 1051 |
if tool_arg_text:
|
|
@@ -1108,25 +1022,6 @@ async def create_response(
|
|
| 1108 |
return response_payload
|
| 1109 |
|
| 1110 |
|
| 1111 |
-
def _text_from_message(message: Message) -> str:
|
| 1112 |
-
"""Return text content from a message for token estimation."""
|
| 1113 |
-
base_text = ""
|
| 1114 |
-
if isinstance(message.content, str):
|
| 1115 |
-
base_text = message.content
|
| 1116 |
-
elif isinstance(message.content, list):
|
| 1117 |
-
base_text = "\n".join(
|
| 1118 |
-
item.text or "" for item in message.content if getattr(item, "type", "") == "text"
|
| 1119 |
-
)
|
| 1120 |
-
elif message.content is None:
|
| 1121 |
-
base_text = ""
|
| 1122 |
-
|
| 1123 |
-
if message.tool_calls:
|
| 1124 |
-
tool_arg_text = "".join(call.function.arguments or "" for call in message.tool_calls)
|
| 1125 |
-
base_text = f"{base_text}\n{tool_arg_text}" if base_text else tool_arg_text
|
| 1126 |
-
|
| 1127 |
-
return base_text
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
async def _find_reusable_session(
|
| 1131 |
db: LMDBConversationStore,
|
| 1132 |
pool: GeminiClientPool,
|
|
@@ -1224,47 +1119,6 @@ async def _send_with_split(session: ChatSession, text: str, files: list[Path | s
|
|
| 1224 |
raise
|
| 1225 |
|
| 1226 |
|
| 1227 |
-
def _iter_stream_segments(model_output: str, chunk_size: int = 64):
|
| 1228 |
-
"""Yield stream segments while keeping <think> markers and words intact."""
|
| 1229 |
-
if not model_output:
|
| 1230 |
-
return
|
| 1231 |
-
|
| 1232 |
-
token_pattern = re.compile(r"\s+|\S+\s*")
|
| 1233 |
-
pending = ""
|
| 1234 |
-
|
| 1235 |
-
def _flush_pending() -> Iterator[str]:
|
| 1236 |
-
nonlocal pending
|
| 1237 |
-
if pending:
|
| 1238 |
-
yield pending
|
| 1239 |
-
pending = ""
|
| 1240 |
-
|
| 1241 |
-
# Split on <think> boundaries so the markers are never fragmented.
|
| 1242 |
-
parts = re.split(r"(</?think>)", model_output)
|
| 1243 |
-
for part in parts:
|
| 1244 |
-
if not part:
|
| 1245 |
-
continue
|
| 1246 |
-
if part in {"<think>", "</think>"}:
|
| 1247 |
-
yield from _flush_pending()
|
| 1248 |
-
yield part
|
| 1249 |
-
continue
|
| 1250 |
-
|
| 1251 |
-
for match in token_pattern.finditer(part):
|
| 1252 |
-
token = match.group(0)
|
| 1253 |
-
|
| 1254 |
-
if len(token) > chunk_size:
|
| 1255 |
-
yield from _flush_pending()
|
| 1256 |
-
for idx in range(0, len(token), chunk_size):
|
| 1257 |
-
yield token[idx : idx + chunk_size]
|
| 1258 |
-
continue
|
| 1259 |
-
|
| 1260 |
-
if pending and len(pending) + len(token) > chunk_size:
|
| 1261 |
-
yield from _flush_pending()
|
| 1262 |
-
|
| 1263 |
-
pending += token
|
| 1264 |
-
|
| 1265 |
-
yield from _flush_pending()
|
| 1266 |
-
|
| 1267 |
-
|
| 1268 |
def _create_streaming_response(
|
| 1269 |
model_output: str,
|
| 1270 |
tool_calls: list[dict],
|
|
@@ -1276,7 +1130,7 @@ def _create_streaming_response(
|
|
| 1276 |
"""Create streaming response with `usage` calculation included in the final chunk."""
|
| 1277 |
|
| 1278 |
# Calculate token usage
|
| 1279 |
-
prompt_tokens = sum(estimate_tokens(
|
| 1280 |
tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
|
| 1281 |
completion_tokens = estimate_tokens(model_output + tool_args)
|
| 1282 |
total_tokens = prompt_tokens + completion_tokens
|
|
@@ -1294,7 +1148,7 @@ def _create_streaming_response(
|
|
| 1294 |
yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
|
| 1295 |
|
| 1296 |
# Stream output text in chunks for efficiency
|
| 1297 |
-
for chunk in
|
| 1298 |
data = {
|
| 1299 |
"id": completion_id,
|
| 1300 |
"object": "chat.completion.chunk",
|
|
@@ -1408,7 +1262,7 @@ def _create_responses_streaming_response(
|
|
| 1408 |
content_text += c.text
|
| 1409 |
|
| 1410 |
if content_text:
|
| 1411 |
-
for chunk in
|
| 1412 |
delta_event = {
|
| 1413 |
**base_event,
|
| 1414 |
"type": "response.output_text.delta",
|
|
@@ -1457,7 +1311,7 @@ def _create_standard_response(
|
|
| 1457 |
) -> dict:
|
| 1458 |
"""Create standard response"""
|
| 1459 |
# Calculate token usage
|
| 1460 |
-
prompt_tokens = sum(estimate_tokens(
|
| 1461 |
tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
|
| 1462 |
completion_tokens = estimate_tokens(model_output + tool_args)
|
| 1463 |
total_tokens = prompt_tokens + completion_tokens
|
|
@@ -1490,74 +1344,16 @@ def _create_standard_response(
|
|
| 1490 |
return result
|
| 1491 |
|
| 1492 |
|
| 1493 |
-
def _extract_image_dimensions(data: bytes) -> tuple[int | None, int | None]:
|
| 1494 |
-
"""Return image dimensions (width, height) if PNG or JPEG headers are present."""
|
| 1495 |
-
# PNG: dimensions stored in bytes 16..24 of the IHDR chunk
|
| 1496 |
-
if len(data) >= 24 and data.startswith(b"\x89PNG\r\n\x1a\n"):
|
| 1497 |
-
try:
|
| 1498 |
-
width, height = struct.unpack(">II", data[16:24])
|
| 1499 |
-
return int(width), int(height)
|
| 1500 |
-
except struct.error:
|
| 1501 |
-
return None, None
|
| 1502 |
-
|
| 1503 |
-
# JPEG: dimensions stored in SOF segment; iterate through markers to locate it
|
| 1504 |
-
if len(data) >= 4 and data[0:2] == b"\xff\xd8":
|
| 1505 |
-
idx = 2
|
| 1506 |
-
length = len(data)
|
| 1507 |
-
sof_markers = {
|
| 1508 |
-
0xC0,
|
| 1509 |
-
0xC1,
|
| 1510 |
-
0xC2,
|
| 1511 |
-
0xC3,
|
| 1512 |
-
0xC5,
|
| 1513 |
-
0xC6,
|
| 1514 |
-
0xC7,
|
| 1515 |
-
0xC9,
|
| 1516 |
-
0xCA,
|
| 1517 |
-
0xCB,
|
| 1518 |
-
0xCD,
|
| 1519 |
-
0xCE,
|
| 1520 |
-
0xCF,
|
| 1521 |
-
}
|
| 1522 |
-
while idx < length:
|
| 1523 |
-
# Find marker alignment (markers are prefixed with 0xFF bytes)
|
| 1524 |
-
if data[idx] != 0xFF:
|
| 1525 |
-
idx += 1
|
| 1526 |
-
continue
|
| 1527 |
-
while idx < length and data[idx] == 0xFF:
|
| 1528 |
-
idx += 1
|
| 1529 |
-
if idx >= length:
|
| 1530 |
-
break
|
| 1531 |
-
marker = data[idx]
|
| 1532 |
-
idx += 1
|
| 1533 |
-
|
| 1534 |
-
if marker in (0xD8, 0xD9, 0x01) or 0xD0 <= marker <= 0xD7:
|
| 1535 |
-
continue
|
| 1536 |
-
|
| 1537 |
-
if idx + 1 >= length:
|
| 1538 |
-
break
|
| 1539 |
-
segment_length = (data[idx] << 8) + data[idx + 1]
|
| 1540 |
-
idx += 2
|
| 1541 |
-
if segment_length < 2:
|
| 1542 |
-
break
|
| 1543 |
-
|
| 1544 |
-
if marker in sof_markers:
|
| 1545 |
-
if idx + 4 < length:
|
| 1546 |
-
# Skip precision byte at idx, then read height/width (big-endian)
|
| 1547 |
-
height = (data[idx + 1] << 8) + data[idx + 2]
|
| 1548 |
-
width = (data[idx + 3] << 8) + data[idx + 4]
|
| 1549 |
-
return int(width), int(height)
|
| 1550 |
-
break
|
| 1551 |
-
|
| 1552 |
-
idx += segment_length - 2
|
| 1553 |
-
|
| 1554 |
-
return None, None
|
| 1555 |
-
|
| 1556 |
-
|
| 1557 |
async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | None, int | None, str]:
|
| 1558 |
"""Persist an image provided by gemini_webapi and return base64 plus dimensions and filename."""
|
| 1559 |
if isinstance(image, GeneratedImage):
|
| 1560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1561 |
else:
|
| 1562 |
saved_path = await image.save(path=str(temp_dir))
|
| 1563 |
|
|
@@ -1571,6 +1367,6 @@ async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | Non
|
|
| 1571 |
original_path.rename(new_path)
|
| 1572 |
|
| 1573 |
data = new_path.read_bytes()
|
| 1574 |
-
width, height =
|
| 1575 |
filename = random_name
|
| 1576 |
return base64.b64encode(data).decode("ascii"), width, height, filename
|
|
|
|
| 1 |
import base64
|
| 2 |
import json
|
| 3 |
import re
|
|
|
|
| 4 |
import uuid
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from datetime import datetime, timezone
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
|
| 10 |
import orjson
|
| 11 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
|
|
| 20 |
ChatCompletionRequest,
|
| 21 |
ContentItem,
|
| 22 |
ConversationInStore,
|
|
|
|
| 23 |
Message,
|
| 24 |
ModelData,
|
| 25 |
ModelListResponse,
|
|
|
|
| 35 |
ResponseToolChoice,
|
| 36 |
ResponseUsage,
|
| 37 |
Tool,
|
|
|
|
| 38 |
ToolChoiceFunction,
|
| 39 |
)
|
| 40 |
from ..services import GeminiClientPool, GeminiClientWrapper, LMDBConversationStore
|
|
|
|
| 41 |
from ..utils import g_config
|
| 42 |
+
from ..utils.helper import (
|
| 43 |
+
CODE_BLOCK_HINT,
|
| 44 |
+
CODE_HINT_STRIPPED,
|
| 45 |
+
XML_HINT_STRIPPED,
|
| 46 |
+
XML_WRAP_HINT,
|
| 47 |
+
estimate_tokens,
|
| 48 |
+
extract_image_dimensions,
|
| 49 |
+
extract_tool_calls,
|
| 50 |
+
iter_stream_segments,
|
| 51 |
+
remove_tool_call_blocks,
|
| 52 |
+
strip_code_fence,
|
| 53 |
+
text_from_message,
|
| 54 |
+
)
|
| 55 |
from .middleware import get_image_store_dir, get_image_token, get_temp_dir, verify_api_key
|
| 56 |
|
| 57 |
# Maximum characters Gemini Web can accept in a single request (configurable)
|
| 58 |
MAX_CHARS_PER_REQUEST = int(g_config.gemini.max_chars_per_request * 0.9)
|
| 59 |
CONTINUATION_HINT = "\n(More messages to come, please reply with just 'ok.')"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
router = APIRouter()
|
| 62 |
|
|
|
|
| 118 |
)
|
| 119 |
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
def _build_tool_prompt(
|
| 122 |
tools: list[Tool],
|
| 123 |
tool_choice: str | ToolChoiceFunction | None,
|
|
|
|
| 304 |
return prepared
|
| 305 |
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def _response_items_to_messages(
|
| 308 |
items: str | list[ResponseInputItem],
|
| 309 |
) -> tuple[list[Message], str | list[ResponseInputItem]]:
|
|
|
|
| 432 |
return instruction_messages
|
| 433 |
|
| 434 |
|
| 435 |
+
def _get_model_by_name(name: str) -> Model:
|
| 436 |
+
"""
|
| 437 |
+
Retrieve a Model instance by name, considering custom models from config
|
| 438 |
+
and the update strategy (append or overwrite).
|
| 439 |
+
"""
|
| 440 |
+
strategy = g_config.gemini.model_strategy
|
| 441 |
+
custom_models = {m.model_name: m for m in g_config.gemini.models if m.model_name}
|
| 442 |
|
| 443 |
+
if name in custom_models:
|
| 444 |
+
return Model.from_dict(custom_models[name].model_dump())
|
| 445 |
|
| 446 |
+
if strategy == "overwrite":
|
| 447 |
+
raise ValueError(f"Model '{name}' not found in custom models (strategy='overwrite').")
|
|
|
|
|
|
|
| 448 |
|
| 449 |
+
return Model.from_name(name)
|
| 450 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
|
| 452 |
+
def _get_available_models() -> list[ModelData]:
|
| 453 |
+
"""
|
| 454 |
+
Return a list of available models based on configuration strategy.
|
| 455 |
+
"""
|
| 456 |
+
now = int(datetime.now(tz=timezone.utc).timestamp())
|
| 457 |
+
strategy = g_config.gemini.model_strategy
|
| 458 |
+
models_data = []
|
|
|
|
| 459 |
|
| 460 |
+
custom_models = [m for m in g_config.gemini.models if m.model_name]
|
| 461 |
+
for m in custom_models:
|
| 462 |
+
models_data.append(
|
| 463 |
+
ModelData(
|
| 464 |
+
id=m.model_name,
|
| 465 |
+
created=now,
|
| 466 |
+
owned_by="custom",
|
| 467 |
+
)
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if strategy == "append":
|
| 471 |
+
custom_ids = {m.model_name for m in custom_models}
|
| 472 |
+
for model in Model:
|
| 473 |
+
m_name = model.model_name
|
| 474 |
+
if not m_name or m_name == "unspecified":
|
| 475 |
+
continue
|
| 476 |
+
if m_name in custom_ids:
|
| 477 |
+
continue
|
| 478 |
|
| 479 |
+
models_data.append(
|
| 480 |
+
ModelData(
|
| 481 |
+
id=m_name,
|
| 482 |
+
created=now,
|
| 483 |
+
owned_by="gemini-web",
|
| 484 |
)
|
| 485 |
)
|
| 486 |
|
| 487 |
+
return models_data
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
|
| 489 |
|
| 490 |
@router.get("/v1/models", response_model=ModelListResponse)
|
| 491 |
async def list_models(api_key: str = Depends(verify_api_key)):
|
| 492 |
+
models = _get_available_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
return ModelListResponse(data=models)
|
| 494 |
|
| 495 |
|
|
|
|
| 502 |
):
|
| 503 |
pool = GeminiClientPool()
|
| 504 |
db = LMDBConversationStore()
|
| 505 |
+
|
| 506 |
+
try:
|
| 507 |
+
model = _get_model_by_name(request.model)
|
| 508 |
+
except ValueError as exc:
|
| 509 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
| 510 |
|
| 511 |
if len(request.messages) == 0:
|
| 512 |
raise HTTPException(
|
|
|
|
| 612 |
detail="Gemini output parsing failed unexpectedly.",
|
| 613 |
) from exc
|
| 614 |
|
| 615 |
+
visible_output, tool_calls = extract_tool_calls(raw_output_with_think)
|
| 616 |
+
storage_output = remove_tool_call_blocks(raw_output_clean).strip()
|
| 617 |
tool_calls_payload = [call.model_dump(mode="json") for call in tool_calls]
|
| 618 |
|
| 619 |
if structured_requirement:
|
| 620 |
+
cleaned_visible = strip_code_fence(visible_output or "")
|
| 621 |
if not cleaned_visible:
|
| 622 |
raise HTTPException(
|
| 623 |
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
|
|
| 763 |
db = LMDBConversationStore()
|
| 764 |
|
| 765 |
try:
|
| 766 |
+
model = _get_model_by_name(request_data.model)
|
| 767 |
except ValueError as exc:
|
| 768 |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
| 769 |
|
|
|
|
| 852 |
detail="Gemini output parsing failed unexpectedly.",
|
| 853 |
) from exc
|
| 854 |
|
| 855 |
+
visible_text, detected_tool_calls = extract_tool_calls(text_with_think)
|
| 856 |
+
storage_output = remove_tool_call_blocks(text_without_think).strip()
|
| 857 |
assistant_text = LMDBConversationStore.remove_think_tags(visible_text.strip())
|
| 858 |
|
| 859 |
if structured_requirement:
|
| 860 |
+
cleaned_visible = strip_code_fence(assistant_text or "")
|
| 861 |
if not cleaned_visible:
|
| 862 |
raise HTTPException(
|
| 863 |
status_code=status.HTTP_502_BAD_GATEWAY,
|
|
|
|
| 924 |
|
| 925 |
image_call_items.append(
|
| 926 |
ResponseImageGenerationCall(
|
| 927 |
+
id=filename.rsplit(".", 1)[0],
|
| 928 |
status="completed",
|
| 929 |
result=image_base64,
|
| 930 |
output_format=img_format,
|
|
|
|
| 959 |
response_id = f"resp_{uuid.uuid4().hex}"
|
| 960 |
message_id = f"msg_{uuid.uuid4().hex}"
|
| 961 |
|
| 962 |
+
input_tokens = sum(estimate_tokens(text_from_message(msg)) for msg in messages)
|
| 963 |
tool_arg_text = "".join(call.function.arguments or "" for call in detected_tool_calls)
|
| 964 |
completion_basis = assistant_text or ""
|
| 965 |
if tool_arg_text:
|
|
|
|
| 1022 |
return response_payload
|
| 1023 |
|
| 1024 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1025 |
async def _find_reusable_session(
|
| 1026 |
db: LMDBConversationStore,
|
| 1027 |
pool: GeminiClientPool,
|
|
|
|
| 1119 |
raise
|
| 1120 |
|
| 1121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1122 |
def _create_streaming_response(
|
| 1123 |
model_output: str,
|
| 1124 |
tool_calls: list[dict],
|
|
|
|
| 1130 |
"""Create streaming response with `usage` calculation included in the final chunk."""
|
| 1131 |
|
| 1132 |
# Calculate token usage
|
| 1133 |
+
prompt_tokens = sum(estimate_tokens(text_from_message(msg)) for msg in messages)
|
| 1134 |
tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
|
| 1135 |
completion_tokens = estimate_tokens(model_output + tool_args)
|
| 1136 |
total_tokens = prompt_tokens + completion_tokens
|
|
|
|
| 1148 |
yield f"data: {orjson.dumps(data).decode('utf-8')}\n\n"
|
| 1149 |
|
| 1150 |
# Stream output text in chunks for efficiency
|
| 1151 |
+
for chunk in iter_stream_segments(model_output):
|
| 1152 |
data = {
|
| 1153 |
"id": completion_id,
|
| 1154 |
"object": "chat.completion.chunk",
|
|
|
|
| 1262 |
content_text += c.text
|
| 1263 |
|
| 1264 |
if content_text:
|
| 1265 |
+
for chunk in iter_stream_segments(content_text):
|
| 1266 |
delta_event = {
|
| 1267 |
**base_event,
|
| 1268 |
"type": "response.output_text.delta",
|
|
|
|
| 1311 |
) -> dict:
|
| 1312 |
"""Create standard response"""
|
| 1313 |
# Calculate token usage
|
| 1314 |
+
prompt_tokens = sum(estimate_tokens(text_from_message(msg)) for msg in messages)
|
| 1315 |
tool_args = "".join(call.get("function", {}).get("arguments", "") for call in tool_calls or [])
|
| 1316 |
completion_tokens = estimate_tokens(model_output + tool_args)
|
| 1317 |
total_tokens = prompt_tokens + completion_tokens
|
|
|
|
| 1344 |
return result
|
| 1345 |
|
| 1346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1347 |
async def _image_to_base64(image: Image, temp_dir: Path) -> tuple[str, int | None, int | None, str]:
|
| 1348 |
"""Persist an image provided by gemini_webapi and return base64 plus dimensions and filename."""
|
| 1349 |
if isinstance(image, GeneratedImage):
|
| 1350 |
+
try:
|
| 1351 |
+
saved_path = await image.save(path=str(temp_dir), full_size=True)
|
| 1352 |
+
except Exception as e:
|
| 1353 |
+
logger.warning(
|
| 1354 |
+
f"Failed to download full-size GeneratedImage, retrying with default size: {e}"
|
| 1355 |
+
)
|
| 1356 |
+
saved_path = await image.save(path=str(temp_dir), full_size=False)
|
| 1357 |
else:
|
| 1358 |
saved_path = await image.save(path=str(temp_dir))
|
| 1359 |
|
|
|
|
| 1367 |
original_path.rename(new_path)
|
| 1368 |
|
| 1369 |
data = new_path.read_bytes()
|
| 1370 |
+
width, height = extract_image_dimensions(data)
|
| 1371 |
filename = random_name
|
| 1372 |
return base64.b64encode(data).decode("ascii"), width, height, filename
|
|
@@ -9,18 +9,12 @@ from loguru import logger
|
|
| 9 |
|
| 10 |
from ..models import Message
|
| 11 |
from ..utils import g_config
|
| 12 |
-
from ..utils.helper import
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
'```xml\n<tool_call name="tool_name">{"arg": "value"}</tool_call>\n```\n'
|
| 17 |
-
"Do not surround the fence with any other text or whitespace; otherwise the call will be ignored.\n"
|
| 18 |
-
)
|
| 19 |
-
CODE_BLOCK_HINT = (
|
| 20 |
-
"\nWhenever you include code, markup, or shell snippets, wrap each snippet in a Markdown fenced "
|
| 21 |
-
"block and supply the correct language label (for example, ```python ... ``` or ```html ... ```).\n"
|
| 22 |
-
"Fence ONLY the actual code/markup; keep all narrative or explanatory text outside the fences.\n"
|
| 23 |
)
|
|
|
|
| 24 |
HTML_ESCAPE_RE = re.compile(r"&(?:lt|gt|amp|quot|apos|#[0-9]+|#x[0-9a-fA-F]+);")
|
| 25 |
MARKDOWN_ESCAPE_RE = re.compile(r"\\(?=[-\\`*_{}\[\]()#+.!<>])")
|
| 26 |
CODE_FENCE_RE = re.compile(r"(```.*?```|`[^`\n]+?`)", re.DOTALL)
|
|
|
|
| 9 |
|
| 10 |
from ..models import Message
|
| 11 |
from ..utils import g_config
|
| 12 |
+
from ..utils.helper import (
|
| 13 |
+
add_tag,
|
| 14 |
+
save_file_to_tempfile,
|
| 15 |
+
save_url_to_tempfile,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
)
|
| 17 |
+
|
| 18 |
HTML_ESCAPE_RE = re.compile(r"&(?:lt|gt|amp|quot|apos|#[0-9]+|#x[0-9a-fA-F]+);")
|
| 19 |
MARKDOWN_ESCAPE_RE = re.compile(r"\\(?=[-\\`*_{}\[\]()#+.!<>])")
|
| 20 |
CODE_FENCE_RE = re.compile(r"(```.*?```|`[^`\n]+?`)", re.DOTALL)
|
|
@@ -1,6 +1,8 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
from typing import Literal, Optional
|
| 4 |
|
| 5 |
from loguru import logger
|
| 6 |
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|
@@ -50,12 +52,37 @@ class GeminiClientSettings(BaseModel):
|
|
| 50 |
return stripped or None
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
class GeminiConfig(BaseModel):
|
| 54 |
"""Gemini API configuration"""
|
| 55 |
|
| 56 |
clients: list[GeminiClientSettings] = Field(
|
| 57 |
..., description="List of Gemini client credential pairs"
|
| 58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
timeout: int = Field(default=120, ge=1, description="Init timeout")
|
| 60 |
auto_refresh: bool = Field(True, description="Enable auto-refresh for Gemini cookies")
|
| 61 |
refresh_interval: int = Field(
|
|
@@ -68,6 +95,36 @@ class GeminiConfig(BaseModel):
|
|
| 68 |
description="Maximum characters Gemini Web can accept per request",
|
| 69 |
)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
class CORSConfig(BaseModel):
|
| 73 |
"""CORS configuration"""
|
|
@@ -207,10 +264,74 @@ def _merge_clients_with_env(
|
|
| 207 |
new_client = GeminiClientSettings(**overrides)
|
| 208 |
result_clients.append(new_client)
|
| 209 |
else:
|
| 210 |
-
raise IndexError(
|
|
|
|
|
|
|
|
|
|
| 211 |
return result_clients if result_clients else base_clients
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
def initialize_config() -> Config:
|
| 215 |
"""
|
| 216 |
Initialize the configuration.
|
|
@@ -221,6 +342,8 @@ def initialize_config() -> Config:
|
|
| 221 |
try:
|
| 222 |
# First, extract and remove Gemini clients related environment variables
|
| 223 |
env_clients_overrides = extract_gemini_clients_env()
|
|
|
|
|
|
|
| 224 |
|
| 225 |
# Then, initialize Config with pydantic_settings
|
| 226 |
config = Config() # type: ignore
|
|
@@ -228,7 +351,10 @@ def initialize_config() -> Config:
|
|
| 228 |
# Synthesize clients
|
| 229 |
config.gemini.clients = _merge_clients_with_env(
|
| 230 |
config.gemini.clients, env_clients_overrides
|
| 231 |
-
)
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
return config
|
| 234 |
except ValidationError as e:
|
|
|
|
| 1 |
+
import ast
|
| 2 |
+
import json
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
+
from typing import Any, Literal, Optional
|
| 6 |
|
| 7 |
from loguru import logger
|
| 8 |
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|
|
|
| 52 |
return stripped or None
|
| 53 |
|
| 54 |
|
| 55 |
+
class GeminiModelConfig(BaseModel):
|
| 56 |
+
"""Configuration for a custom Gemini model."""
|
| 57 |
+
|
| 58 |
+
model_name: Optional[str] = Field(default=None, description="Name of the model")
|
| 59 |
+
model_header: Optional[dict[str, Optional[str]]] = Field(
|
| 60 |
+
default=None, description="Header for the model"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@field_validator("model_header", mode="before")
|
| 64 |
+
@classmethod
|
| 65 |
+
def _parse_json_string(cls, v: Any) -> Any:
|
| 66 |
+
if isinstance(v, str) and v.strip().startswith("{"):
|
| 67 |
+
try:
|
| 68 |
+
return json.loads(v)
|
| 69 |
+
except json.JSONDecodeError:
|
| 70 |
+
# Return the original value to let Pydantic handle the error or type mismatch
|
| 71 |
+
return v
|
| 72 |
+
return v
|
| 73 |
+
|
| 74 |
+
|
| 75 |
class GeminiConfig(BaseModel):
|
| 76 |
"""Gemini API configuration"""
|
| 77 |
|
| 78 |
clients: list[GeminiClientSettings] = Field(
|
| 79 |
..., description="List of Gemini client credential pairs"
|
| 80 |
)
|
| 81 |
+
models: list[GeminiModelConfig] = Field(default=[], description="List of custom Gemini models")
|
| 82 |
+
model_strategy: Literal["append", "overwrite"] = Field(
|
| 83 |
+
default="append",
|
| 84 |
+
description="Strategy for loading models: 'append' merges custom with default, 'overwrite' uses only custom",
|
| 85 |
+
)
|
| 86 |
timeout: int = Field(default=120, ge=1, description="Init timeout")
|
| 87 |
auto_refresh: bool = Field(True, description="Enable auto-refresh for Gemini cookies")
|
| 88 |
refresh_interval: int = Field(
|
|
|
|
| 95 |
description="Maximum characters Gemini Web can accept per request",
|
| 96 |
)
|
| 97 |
|
| 98 |
+
@field_validator("models", mode="before")
|
| 99 |
+
@classmethod
|
| 100 |
+
def _parse_models_json(cls, v: Any) -> Any:
|
| 101 |
+
if isinstance(v, str) and v.strip().startswith("["):
|
| 102 |
+
try:
|
| 103 |
+
return json.loads(v)
|
| 104 |
+
except json.JSONDecodeError as e:
|
| 105 |
+
logger.warning(f"Failed to parse models JSON string: {e}")
|
| 106 |
+
return v
|
| 107 |
+
return v
|
| 108 |
+
|
| 109 |
+
@field_validator("models")
|
| 110 |
+
@classmethod
|
| 111 |
+
def _filter_valid_models(cls, v: list[GeminiModelConfig]) -> list[GeminiModelConfig]:
|
| 112 |
+
"""Filter out models that don't have all required fields set."""
|
| 113 |
+
valid_models = []
|
| 114 |
+
for model in v:
|
| 115 |
+
if model.model_name and model.model_header:
|
| 116 |
+
valid_models.append(model)
|
| 117 |
+
else:
|
| 118 |
+
missing = []
|
| 119 |
+
if not model.model_name:
|
| 120 |
+
missing.append("model_name")
|
| 121 |
+
if not model.model_header:
|
| 122 |
+
missing.append("model_header")
|
| 123 |
+
logger.warning(
|
| 124 |
+
f"Discarding custom model due to missing {', '.join(missing)}: {model}"
|
| 125 |
+
)
|
| 126 |
+
return valid_models
|
| 127 |
+
|
| 128 |
|
| 129 |
class CORSConfig(BaseModel):
|
| 130 |
"""CORS configuration"""
|
|
|
|
| 264 |
new_client = GeminiClientSettings(**overrides)
|
| 265 |
result_clients.append(new_client)
|
| 266 |
else:
|
| 267 |
+
raise IndexError(
|
| 268 |
+
f"Client index {idx} in env is out of range (current count: {len(result_clients)}). "
|
| 269 |
+
"Client indices must be contiguous starting from 0."
|
| 270 |
+
)
|
| 271 |
return result_clients if result_clients else base_clients
|
| 272 |
|
| 273 |
|
| 274 |
+
def extract_gemini_models_env() -> dict[int, dict[str, Any]]:
|
| 275 |
+
"""Extract and remove all Gemini models related environment variables, supporting nested fields."""
|
| 276 |
+
root_key = "CONFIG_GEMINI__MODELS"
|
| 277 |
+
env_overrides: dict[int, dict[str, Any]] = {}
|
| 278 |
+
|
| 279 |
+
if root_key in os.environ:
|
| 280 |
+
val = os.environ[root_key]
|
| 281 |
+
models_list = None
|
| 282 |
+
parsed_successfully = False
|
| 283 |
+
|
| 284 |
+
try:
|
| 285 |
+
models_list = json.loads(val)
|
| 286 |
+
parsed_successfully = True
|
| 287 |
+
except json.JSONDecodeError:
|
| 288 |
+
try:
|
| 289 |
+
models_list = ast.literal_eval(val)
|
| 290 |
+
parsed_successfully = True
|
| 291 |
+
except (ValueError, SyntaxError) as e:
|
| 292 |
+
logger.warning(f"Failed to parse {root_key} as JSON or Python literal: {e}")
|
| 293 |
+
|
| 294 |
+
if parsed_successfully and isinstance(models_list, list):
|
| 295 |
+
for idx, model_data in enumerate(models_list):
|
| 296 |
+
if isinstance(model_data, dict):
|
| 297 |
+
env_overrides[idx] = model_data
|
| 298 |
+
|
| 299 |
+
# Remove the environment variable to avoid Pydantic parsing errors
|
| 300 |
+
del os.environ[root_key]
|
| 301 |
+
|
| 302 |
+
return env_overrides
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _merge_models_with_env(
|
| 306 |
+
base_models: list[GeminiModelConfig] | None,
|
| 307 |
+
env_overrides: dict[int, dict[str, Any]],
|
| 308 |
+
):
|
| 309 |
+
"""Override base_models with env_overrides using standard update (replace whole fields)."""
|
| 310 |
+
if not env_overrides:
|
| 311 |
+
return base_models or []
|
| 312 |
+
result_models: list[GeminiModelConfig] = []
|
| 313 |
+
if base_models:
|
| 314 |
+
result_models = [model.model_copy() for model in base_models]
|
| 315 |
+
|
| 316 |
+
for idx in sorted(env_overrides):
|
| 317 |
+
overrides = env_overrides[idx]
|
| 318 |
+
if idx < len(result_models):
|
| 319 |
+
# Update existing model: overwrite fields found in env
|
| 320 |
+
model_dict = result_models[idx].model_dump()
|
| 321 |
+
model_dict.update(overrides)
|
| 322 |
+
result_models[idx] = GeminiModelConfig(**model_dict)
|
| 323 |
+
elif idx == len(result_models):
|
| 324 |
+
# Append new models
|
| 325 |
+
new_model = GeminiModelConfig(**overrides)
|
| 326 |
+
result_models.append(new_model)
|
| 327 |
+
else:
|
| 328 |
+
raise IndexError(
|
| 329 |
+
f"Model index {idx} in env is out of range (current count: {len(result_models)}). "
|
| 330 |
+
"Model indices must be contiguous starting from 0."
|
| 331 |
+
)
|
| 332 |
+
return result_models
|
| 333 |
+
|
| 334 |
+
|
| 335 |
def initialize_config() -> Config:
|
| 336 |
"""
|
| 337 |
Initialize the configuration.
|
|
|
|
| 342 |
try:
|
| 343 |
# First, extract and remove Gemini clients related environment variables
|
| 344 |
env_clients_overrides = extract_gemini_clients_env()
|
| 345 |
+
# Extract and remove Gemini models related environment variables
|
| 346 |
+
env_models_overrides = extract_gemini_models_env()
|
| 347 |
|
| 348 |
# Then, initialize Config with pydantic_settings
|
| 349 |
config = Config() # type: ignore
|
|
|
|
| 351 |
# Synthesize clients
|
| 352 |
config.gemini.clients = _merge_clients_with_env(
|
| 353 |
config.gemini.clients, env_clients_overrides
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Synthesize models
|
| 357 |
+
config.gemini.models = _merge_models_with_env(config.gemini.models, env_models_overrides)
|
| 358 |
|
| 359 |
return config
|
| 360 |
except ValidationError as e:
|
|
@@ -1,12 +1,38 @@
|
|
| 1 |
import base64
|
|
|
|
| 2 |
import mimetypes
|
|
|
|
|
|
|
| 3 |
import tempfile
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
|
|
|
| 5 |
|
| 6 |
import httpx
|
| 7 |
from loguru import logger
|
| 8 |
|
|
|
|
|
|
|
| 9 |
VALID_TAG_ROLES = {"user", "assistant", "system", "tool"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def add_tag(role: str, content: str, unclose: bool = False) -> str:
|
|
@@ -18,8 +44,10 @@ def add_tag(role: str, content: str, unclose: bool = False) -> str:
|
|
| 18 |
return f"<|im_start|>{role}\n{content}" + ("\n<|im_end|>" if not unclose else "")
|
| 19 |
|
| 20 |
|
| 21 |
-
def estimate_tokens(text: str) -> int:
|
| 22 |
"""Estimate the number of tokens heuristically based on character count"""
|
|
|
|
|
|
|
| 23 |
return int(len(text) / 3)
|
| 24 |
|
| 25 |
|
|
@@ -36,7 +64,7 @@ async def save_file_to_tempfile(
|
|
| 36 |
return path
|
| 37 |
|
| 38 |
|
| 39 |
-
async def save_url_to_tempfile(url: str, tempdir: Path | None = None):
|
| 40 |
data: bytes | None = None
|
| 41 |
suffix: str | None = None
|
| 42 |
if url.startswith("data:image/"):
|
|
@@ -47,20 +75,315 @@ async def save_url_to_tempfile(url: str, tempdir: Path | None = None):
|
|
| 47 |
base64_data = url.split(",")[1]
|
| 48 |
data = base64.b64decode(base64_data)
|
| 49 |
|
| 50 |
-
# Guess extension from mime type, default to the subtype if not found
|
| 51 |
suffix = mimetypes.guess_extension(mime_type)
|
| 52 |
if not suffix:
|
| 53 |
suffix = f".{mime_type.split('/')[1]}"
|
| 54 |
else:
|
| 55 |
-
|
| 56 |
-
async with httpx.AsyncClient() as client:
|
| 57 |
resp = await client.get(url)
|
| 58 |
resp.raise_for_status()
|
| 59 |
data = resp.content
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=tempdir) as tmp:
|
| 63 |
tmp.write(data)
|
| 64 |
path = Path(tmp.name)
|
| 65 |
|
| 66 |
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import base64
|
| 2 |
+
import json
|
| 3 |
import mimetypes
|
| 4 |
+
import re
|
| 5 |
+
import struct
|
| 6 |
import tempfile
|
| 7 |
+
import uuid
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Iterator
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
|
| 12 |
import httpx
|
| 13 |
from loguru import logger
|
| 14 |
|
| 15 |
+
from ..models import FunctionCall, Message, ToolCall
|
| 16 |
+
|
| 17 |
VALID_TAG_ROLES = {"user", "assistant", "system", "tool"}
|
| 18 |
+
XML_WRAP_HINT = (
|
| 19 |
+
"\nYou MUST wrap every tool call response inside a single fenced block exactly like:\n"
|
| 20 |
+
'```xml\n<tool_call name="tool_name">{"arg": "value"}</tool_call>\n```\n'
|
| 21 |
+
"Do not surround the fence with any other text or whitespace; otherwise the call will be ignored.\n"
|
| 22 |
+
)
|
| 23 |
+
CODE_BLOCK_HINT = (
|
| 24 |
+
"\nWhenever you include code, markup, or shell snippets, wrap each snippet in a Markdown fenced "
|
| 25 |
+
"block and supply the correct language label (for example, ```python ... ``` or ```html ... ```).\n"
|
| 26 |
+
"Fence ONLY the actual code/markup; keep all narrative or explanatory text outside the fences.\n"
|
| 27 |
+
)
|
| 28 |
+
TOOL_BLOCK_RE = re.compile(r"```xml\s*(.*?)\s*```", re.DOTALL | re.IGNORECASE)
|
| 29 |
+
TOOL_CALL_RE = re.compile(
|
| 30 |
+
r"<tool_call\s+name=\"([^\"]+)\"\s*>(.*?)</tool_call>", re.DOTALL | re.IGNORECASE
|
| 31 |
+
)
|
| 32 |
+
JSON_FENCE_RE = re.compile(r"^```(?:json)?\s*(.*?)\s*```$", re.DOTALL | re.IGNORECASE)
|
| 33 |
+
CONTROL_TOKEN_RE = re.compile(r"<\|im_(?:start|end)\|>")
|
| 34 |
+
XML_HINT_STRIPPED = XML_WRAP_HINT.strip()
|
| 35 |
+
CODE_HINT_STRIPPED = CODE_BLOCK_HINT.strip()
|
| 36 |
|
| 37 |
|
| 38 |
def add_tag(role: str, content: str, unclose: bool = False) -> str:
|
|
|
|
| 44 |
return f"<|im_start|>{role}\n{content}" + ("\n<|im_end|>" if not unclose else "")
|
| 45 |
|
| 46 |
|
| 47 |
+
def estimate_tokens(text: str | None) -> int:
|
| 48 |
"""Estimate the number of tokens heuristically based on character count"""
|
| 49 |
+
if not text:
|
| 50 |
+
return 0
|
| 51 |
return int(len(text) / 3)
|
| 52 |
|
| 53 |
|
|
|
|
| 64 |
return path
|
| 65 |
|
| 66 |
|
| 67 |
+
async def save_url_to_tempfile(url: str, tempdir: Path | None = None) -> Path:
|
| 68 |
data: bytes | None = None
|
| 69 |
suffix: str | None = None
|
| 70 |
if url.startswith("data:image/"):
|
|
|
|
| 75 |
base64_data = url.split(",")[1]
|
| 76 |
data = base64.b64decode(base64_data)
|
| 77 |
|
|
|
|
| 78 |
suffix = mimetypes.guess_extension(mime_type)
|
| 79 |
if not suffix:
|
| 80 |
suffix = f".{mime_type.split('/')[1]}"
|
| 81 |
else:
|
| 82 |
+
async with httpx.AsyncClient(follow_redirects=True) as client:
|
|
|
|
| 83 |
resp = await client.get(url)
|
| 84 |
resp.raise_for_status()
|
| 85 |
data = resp.content
|
| 86 |
+
content_type = resp.headers.get("content-type")
|
| 87 |
+
|
| 88 |
+
if content_type:
|
| 89 |
+
mime_type = content_type.split(";")[0].strip()
|
| 90 |
+
suffix = mimetypes.guess_extension(mime_type)
|
| 91 |
+
|
| 92 |
+
if not suffix:
|
| 93 |
+
path_url = urlparse(url).path
|
| 94 |
+
suffix = Path(path_url).suffix
|
| 95 |
+
|
| 96 |
+
if not suffix:
|
| 97 |
+
suffix = ".bin"
|
| 98 |
|
| 99 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix, dir=tempdir) as tmp:
|
| 100 |
tmp.write(data)
|
| 101 |
path = Path(tmp.name)
|
| 102 |
|
| 103 |
return path
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def strip_code_fence(text: str) -> str:
|
| 107 |
+
"""Remove surrounding ```json fences if present."""
|
| 108 |
+
match = JSON_FENCE_RE.match(text.strip())
|
| 109 |
+
if match:
|
| 110 |
+
return match.group(1).strip()
|
| 111 |
+
return text.strip()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def strip_tagged_blocks(text: str) -> str:
|
| 115 |
+
"""Remove <|im_start|>role ... <|im_end|> sections, dropping tool blocks entirely.
|
| 116 |
+
- tool blocks are removed entirely (if missing end marker, drop to EOF).
|
| 117 |
+
- other roles: remove markers and role, keep inner content (if missing end marker, keep to EOF).
|
| 118 |
+
"""
|
| 119 |
+
if not text:
|
| 120 |
+
return text
|
| 121 |
+
|
| 122 |
+
result: list[str] = []
|
| 123 |
+
idx = 0
|
| 124 |
+
length = len(text)
|
| 125 |
+
start_marker = "<|im_start|>"
|
| 126 |
+
end_marker = "<|im_end|>"
|
| 127 |
+
|
| 128 |
+
while idx < length:
|
| 129 |
+
start = text.find(start_marker, idx)
|
| 130 |
+
if start == -1:
|
| 131 |
+
result.append(text[idx:])
|
| 132 |
+
break
|
| 133 |
+
|
| 134 |
+
# append any content before this block
|
| 135 |
+
result.append(text[idx:start])
|
| 136 |
+
|
| 137 |
+
role_start = start + len(start_marker)
|
| 138 |
+
newline = text.find("\n", role_start)
|
| 139 |
+
if newline == -1:
|
| 140 |
+
# malformed block; keep the remainder as-is (safe behavior)
|
| 141 |
+
result.append(text[start:])
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
role = text[role_start:newline].strip().lower()
|
| 145 |
+
|
| 146 |
+
end = text.find(end_marker, newline + 1)
|
| 147 |
+
if end == -1:
|
| 148 |
+
# missing end marker
|
| 149 |
+
if role == "tool":
|
| 150 |
+
# drop from the start marker to EOF (skip the remainder)
|
| 151 |
+
break
|
| 152 |
+
else:
|
| 153 |
+
# keep inner content from after the role newline to EOF
|
| 154 |
+
result.append(text[newline + 1 :])
|
| 155 |
+
break
|
| 156 |
+
|
| 157 |
+
block_end = end + len(end_marker)
|
| 158 |
+
|
| 159 |
+
if role == "tool":
|
| 160 |
+
# drop the whole block
|
| 161 |
+
idx = block_end
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
# keep the content without role markers
|
| 165 |
+
content = text[newline + 1 : end]
|
| 166 |
+
result.append(content)
|
| 167 |
+
idx = block_end
|
| 168 |
+
|
| 169 |
+
return "".join(result)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def strip_system_hints(text: str) -> str:
|
| 173 |
+
"""Remove system-level hint text from a given string."""
|
| 174 |
+
if not text:
|
| 175 |
+
return text
|
| 176 |
+
cleaned = strip_tagged_blocks(text)
|
| 177 |
+
cleaned = cleaned.replace(XML_WRAP_HINT, "").replace(XML_HINT_STRIPPED, "")
|
| 178 |
+
cleaned = cleaned.replace(CODE_BLOCK_HINT, "").replace(CODE_HINT_STRIPPED, "")
|
| 179 |
+
cleaned = CONTROL_TOKEN_RE.sub("", cleaned)
|
| 180 |
+
return cleaned.strip()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def remove_tool_call_blocks(text: str) -> str:
|
| 184 |
+
"""Strip tool call code blocks from text."""
|
| 185 |
+
if not text:
|
| 186 |
+
return text
|
| 187 |
+
|
| 188 |
+
# 1. Remove fenced blocks ONLY if they contain tool calls
|
| 189 |
+
def _replace_block(match: re.Match[str]) -> str:
|
| 190 |
+
block_content = match.group(1)
|
| 191 |
+
if not block_content:
|
| 192 |
+
return match.group(0)
|
| 193 |
+
|
| 194 |
+
# Check if the block contains any tool call tag
|
| 195 |
+
if TOOL_CALL_RE.search(block_content):
|
| 196 |
+
return ""
|
| 197 |
+
|
| 198 |
+
# Preserve the block if no tool call found
|
| 199 |
+
return match.group(0)
|
| 200 |
+
|
| 201 |
+
cleaned = TOOL_BLOCK_RE.sub(_replace_block, text)
|
| 202 |
+
|
| 203 |
+
# 2. Remove orphaned tool calls
|
| 204 |
+
cleaned = TOOL_CALL_RE.sub("", cleaned)
|
| 205 |
+
|
| 206 |
+
return strip_system_hints(cleaned)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def extract_tool_calls(text: str) -> tuple[str, list[ToolCall]]:
|
| 210 |
+
"""Extract tool call definitions and return cleaned text."""
|
| 211 |
+
if not text:
|
| 212 |
+
return text, []
|
| 213 |
+
|
| 214 |
+
tool_calls: list[ToolCall] = []
|
| 215 |
+
|
| 216 |
+
def _create_tool_call(name: str, raw_args: str) -> None:
|
| 217 |
+
"""Helper to parse args and append to the tool_calls list."""
|
| 218 |
+
if not name:
|
| 219 |
+
logger.warning("Encountered tool_call without a function name.")
|
| 220 |
+
return
|
| 221 |
+
|
| 222 |
+
arguments = raw_args
|
| 223 |
+
try:
|
| 224 |
+
parsed_args = json.loads(raw_args)
|
| 225 |
+
arguments = json.dumps(parsed_args, ensure_ascii=False)
|
| 226 |
+
except json.JSONDecodeError:
|
| 227 |
+
logger.warning(f"Failed to parse tool call arguments for '{name}'. Passing raw string.")
|
| 228 |
+
|
| 229 |
+
tool_calls.append(
|
| 230 |
+
ToolCall(
|
| 231 |
+
id=f"call_{uuid.uuid4().hex}",
|
| 232 |
+
type="function",
|
| 233 |
+
function=FunctionCall(name=name, arguments=arguments),
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def _replace_block(match: re.Match[str]) -> str:
|
| 238 |
+
block_content = match.group(1)
|
| 239 |
+
if not block_content:
|
| 240 |
+
return match.group(0)
|
| 241 |
+
|
| 242 |
+
found_in_block = False
|
| 243 |
+
for call_match in TOOL_CALL_RE.finditer(block_content):
|
| 244 |
+
found_in_block = True
|
| 245 |
+
name = (call_match.group(1) or "").strip()
|
| 246 |
+
raw_args = (call_match.group(2) or "").strip()
|
| 247 |
+
_create_tool_call(name, raw_args)
|
| 248 |
+
|
| 249 |
+
if found_in_block:
|
| 250 |
+
return ""
|
| 251 |
+
else:
|
| 252 |
+
return match.group(0)
|
| 253 |
+
|
| 254 |
+
cleaned = TOOL_BLOCK_RE.sub(_replace_block, text)
|
| 255 |
+
|
| 256 |
+
def _replace_orphan(match: re.Match[str]) -> str:
|
| 257 |
+
name = (match.group(1) or "").strip()
|
| 258 |
+
raw_args = (match.group(2) or "").strip()
|
| 259 |
+
_create_tool_call(name, raw_args)
|
| 260 |
+
return ""
|
| 261 |
+
|
| 262 |
+
cleaned = TOOL_CALL_RE.sub(_replace_orphan, cleaned)
|
| 263 |
+
|
| 264 |
+
cleaned = strip_system_hints(cleaned)
|
| 265 |
+
return cleaned, tool_calls
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def iter_stream_segments(model_output: str, chunk_size: int = 64) -> Iterator[str]:
|
| 269 |
+
"""Yield stream segments while keeping <think> markers and words intact."""
|
| 270 |
+
if not model_output:
|
| 271 |
+
return
|
| 272 |
+
|
| 273 |
+
token_pattern = re.compile(r"\s+|\S+\s*")
|
| 274 |
+
pending = ""
|
| 275 |
+
|
| 276 |
+
def _flush_pending() -> Iterator[str]:
|
| 277 |
+
nonlocal pending
|
| 278 |
+
if pending:
|
| 279 |
+
yield pending
|
| 280 |
+
pending = ""
|
| 281 |
+
|
| 282 |
+
# Split on <think> boundaries so the markers are never fragmented.
|
| 283 |
+
parts = re.split(r"(</?think>)", model_output)
|
| 284 |
+
for part in parts:
|
| 285 |
+
if not part:
|
| 286 |
+
continue
|
| 287 |
+
if part in {"<think>", "</think>"}:
|
| 288 |
+
yield from _flush_pending()
|
| 289 |
+
yield part
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
for match in token_pattern.finditer(part):
|
| 293 |
+
token = match.group(0)
|
| 294 |
+
|
| 295 |
+
if len(token) > chunk_size:
|
| 296 |
+
yield from _flush_pending()
|
| 297 |
+
for idx in range(0, len(token), chunk_size):
|
| 298 |
+
yield token[idx : idx + chunk_size]
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
if pending and len(pending) + len(token) > chunk_size:
|
| 302 |
+
yield from _flush_pending()
|
| 303 |
+
|
| 304 |
+
pending += token
|
| 305 |
+
|
| 306 |
+
yield from _flush_pending()
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def text_from_message(message: Message) -> str:
|
| 310 |
+
"""Return text content from a message for token estimation."""
|
| 311 |
+
base_text = ""
|
| 312 |
+
if isinstance(message.content, str):
|
| 313 |
+
base_text = message.content
|
| 314 |
+
elif isinstance(message.content, list):
|
| 315 |
+
base_text = "\n".join(
|
| 316 |
+
item.text or "" for item in message.content if getattr(item, "type", "") == "text"
|
| 317 |
+
)
|
| 318 |
+
elif message.content is None:
|
| 319 |
+
base_text = ""
|
| 320 |
+
|
| 321 |
+
if message.tool_calls:
|
| 322 |
+
tool_arg_text = "".join(call.function.arguments or "" for call in message.tool_calls)
|
| 323 |
+
base_text = f"{base_text}\n{tool_arg_text}" if base_text else tool_arg_text
|
| 324 |
+
|
| 325 |
+
return base_text
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def extract_image_dimensions(data: bytes) -> tuple[int | None, int | None]:
|
| 329 |
+
"""Return image dimensions (width, height) if PNG or JPEG headers are present."""
|
| 330 |
+
# PNG: dimensions stored in bytes 16..24 of the IHDR chunk
|
| 331 |
+
if len(data) >= 24 and data.startswith(b"\x89PNG\r\n\x1a\n"):
|
| 332 |
+
try:
|
| 333 |
+
width, height = struct.unpack(">II", data[16:24])
|
| 334 |
+
return int(width), int(height)
|
| 335 |
+
except struct.error:
|
| 336 |
+
return None, None
|
| 337 |
+
|
| 338 |
+
# JPEG: dimensions stored in SOF segment; iterate through markers to locate it
|
| 339 |
+
if len(data) >= 4 and data[0:2] == b"\xff\xd8":
|
| 340 |
+
idx = 2
|
| 341 |
+
length = len(data)
|
| 342 |
+
sof_markers = {
|
| 343 |
+
0xC0,
|
| 344 |
+
0xC1,
|
| 345 |
+
0xC2,
|
| 346 |
+
0xC3,
|
| 347 |
+
0xC5,
|
| 348 |
+
0xC6,
|
| 349 |
+
0xC7,
|
| 350 |
+
0xC9,
|
| 351 |
+
0xCA,
|
| 352 |
+
0xCB,
|
| 353 |
+
0xCD,
|
| 354 |
+
0xCE,
|
| 355 |
+
0xCF,
|
| 356 |
+
}
|
| 357 |
+
while idx < length:
|
| 358 |
+
# Find marker alignment (markers are prefixed with 0xFF bytes)
|
| 359 |
+
if data[idx] != 0xFF:
|
| 360 |
+
idx += 1
|
| 361 |
+
continue
|
| 362 |
+
while idx < length and data[idx] == 0xFF:
|
| 363 |
+
idx += 1
|
| 364 |
+
if idx >= length:
|
| 365 |
+
break
|
| 366 |
+
marker = data[idx]
|
| 367 |
+
idx += 1
|
| 368 |
+
|
| 369 |
+
if marker in (0xD8, 0xD9, 0x01) or 0xD0 <= marker <= 0xD7:
|
| 370 |
+
continue
|
| 371 |
+
|
| 372 |
+
if idx + 1 >= length:
|
| 373 |
+
break
|
| 374 |
+
segment_length = (data[idx] << 8) + data[idx + 1]
|
| 375 |
+
idx += 2
|
| 376 |
+
if segment_length < 2:
|
| 377 |
+
break
|
| 378 |
+
|
| 379 |
+
if marker in sof_markers:
|
| 380 |
+
if idx + 4 < length:
|
| 381 |
+
# Skip precision byte at idx, then read height/width (big-endian)
|
| 382 |
+
height = (data[idx + 1] << 8) + data[idx + 2]
|
| 383 |
+
width = (data[idx + 3] << 8) + data[idx + 4]
|
| 384 |
+
return int(width), int(height)
|
| 385 |
+
break
|
| 386 |
+
|
| 387 |
+
idx += segment_length - 2
|
| 388 |
+
|
| 389 |
+
return None, None
|
|
@@ -27,6 +27,8 @@ gemini:
|
|
| 27 |
refresh_interval: 540 # Refresh interval in seconds
|
| 28 |
verbose: false # Enable verbose logging for Gemini requests
|
| 29 |
max_chars_per_request: 1000000 # Maximum characters Gemini Web accepts per request. Non-pro users might have a lower limit
|
|
|
|
|
|
|
| 30 |
|
| 31 |
storage:
|
| 32 |
path: "data/lmdb" # Database storage path
|
|
|
|
| 27 |
refresh_interval: 540 # Refresh interval in seconds
|
| 28 |
verbose: false # Enable verbose logging for Gemini requests
|
| 29 |
max_chars_per_request: 1000000 # Maximum characters Gemini Web accepts per request. Non-pro users might have a lower limit
|
| 30 |
+
model_strategy: "append" # Strategy: 'append' (default + custom) or 'overwrite' (custom only)
|
| 31 |
+
models: []
|
| 32 |
|
| 33 |
storage:
|
| 34 |
path: "data/lmdb" # Database storage path
|