Upload 42 files
Browse files- README.md +25 -0
- api_types.py +41 -0
- app.py +874 -99
- config.local.yaml +3 -0
- config.production-modelscope.yaml +8 -0
- config.production.yaml +15 -1
- config.py +62 -13
- models/.cache/huggingface/download/rwkv7-g1a-0.1b-20250728-ctx4096.pth.metadata +1 -1
- tests/api_test.py +11 -0
- tests/run_api_single_request.py +12 -0
- tests/run_autodetect_flags.py +50 -0
- tests/run_chat_response.py +11 -0
- tests/run_chat_response_out.txt +0 -0
- tests/run_detect.py +7 -0
- tests/run_injected_tools.py +68 -0
- tests/test_client_api.py +50 -0
- tests/test_universal_and_detect.py +42 -0
- utils.py +188 -5
README.md
CHANGED
|
@@ -71,6 +71,11 @@ Advanced features:
|
|
| 71 |
}
|
| 72 |
```
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
Example: POST with `web_search` and reasoning enabled
|
| 75 |
|
| 76 |
```json
|
|
@@ -85,5 +90,25 @@ Example: POST with `web_search` and reasoning enabled
|
|
| 85 |
|
| 86 |
The server will perform a web search for the prompt, aggregate the top 3 results, and inject those into the prompt, then run the model with reasoning enabled — all using the same model instead of an external reasoning or search model.
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
Streaming behavior:
|
| 89 |
- The API streams responses token-by-token by default (`stream: true`) and persists a `state_name` for the generation if requested (or will generate one). Provide `state_name` to resume continuation from where the previous stream stopped. The server stores model state in memory under `(model, state_name)` so subsequent requests with the same `state_name` can continue generation from that exact point.
|
|
|
|
| 71 |
}
|
| 72 |
```
|
| 73 |
|
| 74 |
+
API endpoints and model listing:
|
| 75 |
+
- `GET /api/v1/models` — returns a JSON list of configured models, sampler defaults, and ALLOW_* flags. This lets clients build per-model UI toggles (web search, tools, reasoning) based on server-provided capabilities.
|
| 76 |
+
|
| 77 |
+
Examples:
|
| 78 |
+
- `curl http://127.0.0.1:7860/api/v1/models` will show configured models and their sampler defaults.
|
| 79 |
Example: POST with `web_search` and reasoning enabled
|
| 80 |
|
| 81 |
```json
|
|
|
|
| 90 |
|
| 91 |
The server will perform a web search for the prompt, aggregate the top 3 results, and inject those into the prompt, then run the model with reasoning enabled — all using the same model instead of an external reasoning or search model.
|
| 92 |
|
| 93 |
+
Universal tool and model-initiated tool calls:
|
| 94 |
+
- The `universal` tool returns a structured JSON/dict with the following fields: `action` (calc/web_search), `result` (string), and `metadata` (dict with `confidence`, query/expression, etc.).
|
| 95 |
+
- Example `universal` result:
|
| 96 |
+
|
| 97 |
+
```json
|
| 98 |
+
{
|
| 99 |
+
"action": "calc",
|
| 100 |
+
"result": "14",
|
| 101 |
+
"metadata": {"expression": "2+3*4", "confidence": 0.98}
|
| 102 |
+
}
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
- The model can also request tools mid-generation by emitting a sentinel tag, e.g.:
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
<tool-call>{"name":"calc","args":{"expression":"40+2"}}</tool-call>
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
When the model emits such a sentinel, the server will execute the requested tool, inject the results into the prompt, and continue streaming output. The server will also emit a metadata-only streaming chunk so the client is aware a tool was executed mid-stream.
|
| 112 |
+
|
| 113 |
Streaming behavior:
|
| 114 |
- The API streams responses token-by-token by default (`stream: true`) and persists a `state_name` for the generation if requested (or will generate one). Provide `state_name` to resume continuation from where the previous stream stopped. The server stores model state in memory under `(model, state_name)` so subsequent requests with the same `state_name` can continue generation from that exact point.
|
api_types.py
CHANGED
|
@@ -36,6 +36,33 @@ class ChatCompletionMessage(BaseModel):
|
|
| 36 |
)
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
class PromptTokensDetails(BaseModel):
|
| 40 |
cached_tokens: int
|
| 41 |
|
|
@@ -80,3 +107,17 @@ class ChatCompletionChunk(BaseModel):
|
|
| 80 |
model: str
|
| 81 |
choices: List[ChatCompletionChoice]
|
| 82 |
usage: Optional[Usage]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
|
| 39 |
+
class SamplerConfig(BaseModel):
|
| 40 |
+
"""Sampler configuration used in API requests and model defaults.
|
| 41 |
+
|
| 42 |
+
This mirrors the server-side `SamplerConfig` and exposes an optional
|
| 43 |
+
`ALLOW_*` set of fields that can be used to override the model/global
|
| 44 |
+
allow flags per-request (when present)."""
|
| 45 |
+
|
| 46 |
+
max_tokens: Optional[int] = Field(512)
|
| 47 |
+
temperature: Optional[float] = Field(1.0)
|
| 48 |
+
top_p: Optional[float] = Field(0.3)
|
| 49 |
+
presence_penalty: Optional[float] = Field(0.5)
|
| 50 |
+
count_penalty: Optional[float] = Field(0.5)
|
| 51 |
+
penalty_decay: Optional[float] = Field(0.996)
|
| 52 |
+
stop: Optional[List[str]] = Field(default_factory=lambda: ["\n\n"])
|
| 53 |
+
stop_tokens: Optional[List[int]] = Field(default_factory=lambda: [0])
|
| 54 |
+
ALLOW_WEB_SEARCH: Optional[bool] = Field(None)
|
| 55 |
+
ALLOW_TOOLS: Optional[bool] = Field(None)
|
| 56 |
+
ALLOW_REASONING: Optional[bool] = Field(None)
|
| 57 |
+
ALLOW_FILE_TOOL: Optional[bool] = Field(None, description="Per-sampler override for allowing file tools (uploads/file_read).")
|
| 58 |
+
# UI flags so a client can show the controls for toggles
|
| 59 |
+
SHOW_WEB_SEARCH_BUTTON: Optional[bool] = Field(None, description="Whether the UI should show a web-search toggle for this sampler")
|
| 60 |
+
SHOW_FILE_UPLOAD_BUTTON: Optional[bool] = Field(None, description="Whether the UI should show a file upload control for this sampler")
|
| 61 |
+
SHOW_REASONING_TOGGLE: Optional[bool] = Field(None, description="Whether the UI should show a reasoning toggle for this sampler")
|
| 62 |
+
# UI style hints. e.g. 'whatsapp' style, compact, or 'expanded'
|
| 63 |
+
UI_STYLE: Optional[str] = Field(None, description="UI style hint that clients may use to render controls (example: 'whatsapp' or 'compact')")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
class PromptTokensDetails(BaseModel):
|
| 67 |
cached_tokens: int
|
| 68 |
|
|
|
|
| 107 |
model: str
|
| 108 |
choices: List[ChatCompletionChoice]
|
| 109 |
usage: Optional[Usage]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class UploadedFile(BaseModel):
|
| 113 |
+
file_id: str
|
| 114 |
+
filename: str
|
| 115 |
+
size: int
|
| 116 |
+
mime_type: Optional[str] = None
|
| 117 |
+
path: Optional[str] = None
|
| 118 |
+
uploaded_at: Optional[int] = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class FileUploadResponse(BaseModel):
|
| 122 |
+
success: bool = True
|
| 123 |
+
file: UploadedFile
|
app.py
CHANGED
|
@@ -15,6 +15,8 @@ from utils import (
|
|
| 15 |
remove_nested_think_tags_stack,
|
| 16 |
format_bytes,
|
| 17 |
log,
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
import copy, types, gc, sys, re, time, collections, asyncio
|
|
@@ -78,7 +80,7 @@ os.environ["RWKV_CUDA_ON"] = (
|
|
| 78 |
from rwkv.model import RWKV
|
| 79 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 80 |
|
| 81 |
-
from fastapi import FastAPI, HTTPException
|
| 82 |
from starlette.background import BackgroundTask
|
| 83 |
from fastapi.responses import StreamingResponse
|
| 84 |
from fastapi.middleware.cors import CORSMiddleware
|
|
@@ -94,6 +96,9 @@ from api_types import (
|
|
| 94 |
PromptTokensDetails,
|
| 95 |
ChatCompletionChoice,
|
| 96 |
ChatCompletionMessage,
|
|
|
|
|
|
|
|
|
|
| 97 |
)
|
| 98 |
|
| 99 |
|
|
@@ -109,25 +114,103 @@ DEFALUT_MODEL_NAME = None
|
|
| 109 |
DEFAULT_REASONING_MODEL_NAME = None
|
| 110 |
|
| 111 |
# In-memory model state store to support streaming continuation/resume per state_name.
|
| 112 |
-
# Keys: (model_name, state_name) ->
|
| 113 |
STATE_STORE: Dict[tuple, Any] = {}
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
| 116 |
|
| 117 |
logGPUState()
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
if len(filtered_models) == 0:
|
| 123 |
-
# If no explicit 0.1b model detected, fall back to the first provided model but warn.
|
| 124 |
-
logger.warning("No '0.1b' model detected in config; using the first available model. To ensure single 0.1b use, include a model name with '0.1b'.")
|
| 125 |
-
CONFIG.MODELS = [CONFIG.MODELS[0]]
|
| 126 |
-
elif len(filtered_models) > 1:
|
| 127 |
-
logger.warning("Multiple '0.1b' models detected; selecting the first one as the single model.")
|
| 128 |
-
CONFIG.MODELS = [filtered_models[0]]
|
| 129 |
-
else:
|
| 130 |
-
CONFIG.MODELS = [filtered_models[0]]
|
| 131 |
|
| 132 |
for model_config in CONFIG.MODELS:
|
| 133 |
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
|
|
@@ -200,14 +283,35 @@ class ChatCompletionRequest(BaseModel):
|
|
| 200 |
presence_penalty: Optional[float] = Field(default=None)
|
| 201 |
count_penalty: Optional[float] = Field(default=None)
|
| 202 |
penalty_decay: Optional[float] = Field(default=None)
|
| 203 |
-
stream: Optional[bool] = Field(default=
|
| 204 |
state_name: Optional[str] = Field(default=None)
|
| 205 |
include_usage: Optional[bool] = Field(default=False)
|
| 206 |
stop: Optional[list[str]] = Field(["\n\n"])
|
| 207 |
stop_tokens: Optional[list[int]] = Field([0])
|
| 208 |
web_search: Optional[bool] = Field(default=False, description="Whether to perform a web search and append results to the prompt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve")
|
| 210 |
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
@model_validator(mode="before")
|
| 213 |
@classmethod
|
|
@@ -237,6 +341,26 @@ app.add_middleware(
|
|
| 237 |
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
|
| 238 |
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
async def runPrefill(
|
| 241 |
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
|
| 242 |
):
|
|
@@ -363,12 +487,191 @@ async def chatResponse(
|
|
| 363 |
) -> ChatCompletion:
|
| 364 |
createTimestamp = time.time()
|
| 365 |
|
| 366 |
-
prompt
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
)
|
| 371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
if request.tools:
|
| 373 |
try:
|
| 374 |
for tool in request.tools:
|
|
@@ -381,27 +684,73 @@ async def chatResponse(
|
|
| 381 |
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
| 382 |
search_str = web_search(search_q, search_top_k)
|
| 383 |
if search_str:
|
| 384 |
-
|
|
|
|
|
|
|
| 385 |
elif name == 'calc' or name == 'calculator':
|
| 386 |
from utils import calc
|
| 387 |
|
| 388 |
expr = args.get('expression')
|
| 389 |
if expr:
|
| 390 |
calc_res = calc(expr)
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
else:
|
| 393 |
# Unsupported tool - ignore or log
|
| 394 |
logger.info(f"Unsupported tool requested: {name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
except Exception as e:
|
| 396 |
logger.info(f"Tool processing error: {e}")
|
| 397 |
-
elif request.web_search:
|
| 398 |
try:
|
| 399 |
from utils import web_search
|
| 400 |
|
| 401 |
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
| 402 |
search_res = web_search(search_q, int(request.search_top_k or 3))
|
| 403 |
if search_res:
|
| 404 |
-
|
|
|
|
|
|
|
| 405 |
except Exception:
|
| 406 |
pass
|
| 407 |
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
|
@@ -411,9 +760,14 @@ async def chatResponse(
|
|
| 411 |
state_key = (request.model, request.state_name)
|
| 412 |
if state_key in STATE_STORE:
|
| 413 |
stored = STATE_STORE[state_key]
|
| 414 |
-
model_state = stored.get('state',
|
| 415 |
model_tokens = stored.get('model_tokens', [0])
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
else:
|
| 418 |
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 419 |
else:
|
|
@@ -425,32 +779,87 @@ async def chatResponse(
|
|
| 425 |
fullResponse = " <think" if enableReasoning else ""
|
| 426 |
completionTokenCount = 0
|
| 427 |
finishReason = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
out,
|
| 432 |
-
model_tokens,
|
| 433 |
-
model_state,
|
| 434 |
-
max_tokens=(
|
| 435 |
-
64000
|
| 436 |
-
if "max_tokens" not in request.model_fields_set and enableReasoning
|
| 437 |
-
else (request.max_tokens or 2048)
|
| 438 |
-
),
|
| 439 |
-
):
|
| 440 |
-
# chunk['content'] is now expected to be a single token's decoded text
|
| 441 |
-
fullResponse += chunk["content"]
|
| 442 |
-
# Check stop sequences (multi-token) after each token
|
| 443 |
-
for stop_words in request.stop or []:
|
| 444 |
-
if stop_words in fullResponse:
|
| 445 |
-
finishReason = f"stop:words:{stop_words}"
|
| 446 |
-
break
|
| 447 |
-
completionTokenCount += 1
|
| 448 |
-
|
| 449 |
-
if chunk["finish_reason"]:
|
| 450 |
-
finishReason = chunk["finish_reason"]
|
| 451 |
-
await asyncio.sleep(0)
|
| 452 |
|
| 453 |
-
|
| 454 |
|
| 455 |
responseLog = {
|
| 456 |
"content": fullResponse,
|
|
@@ -458,7 +867,7 @@ async def chatResponse(
|
|
| 458 |
"prefill_len": promptTokenCount,
|
| 459 |
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
|
| 460 |
"gen_len": completionTokenCount,
|
| 461 |
-
"gen_tps": round(completionTokenCount / (
|
| 462 |
}
|
| 463 |
logger.info(f"[RES] {completionId} - {responseLog}")
|
| 464 |
|
|
@@ -481,7 +890,7 @@ async def chatResponse(
|
|
| 481 |
role="Assistant",
|
| 482 |
content=content,
|
| 483 |
reasoning_content=reasoning_content if reasoning_content else None,
|
| 484 |
-
tool_calls=None,
|
| 485 |
),
|
| 486 |
logprobs=None,
|
| 487 |
finish_reason=finishReason,
|
|
@@ -496,6 +905,11 @@ async def chatResponse(
|
|
| 496 |
'state': model_state,
|
| 497 |
'model_tokens': model_tokens,
|
| 498 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
except Exception:
|
| 500 |
pass
|
| 501 |
|
|
@@ -510,12 +924,68 @@ async def chatResponseStream(
|
|
| 510 |
):
|
| 511 |
createTimestamp = int(time.time())
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
)
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
if request.tools:
|
| 520 |
try:
|
| 521 |
for tool in request.tools:
|
|
@@ -528,26 +998,43 @@ async def chatResponseStream(
|
|
| 528 |
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
| 529 |
search_str = web_search(search_q, search_top_k)
|
| 530 |
if search_str:
|
| 531 |
-
|
|
|
|
|
|
|
| 532 |
elif name == 'calc' or name == 'calculator':
|
| 533 |
from utils import calc
|
| 534 |
|
| 535 |
expr = args.get('expression')
|
| 536 |
if expr:
|
| 537 |
calc_res = calc(expr)
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
else:
|
| 540 |
logger.info(f"Unsupported tool requested: {name}")
|
| 541 |
except Exception as e:
|
| 542 |
logger.info(f"Tool processing error: {e}")
|
| 543 |
-
elif request.web_search:
|
| 544 |
try:
|
| 545 |
from utils import web_search
|
| 546 |
|
| 547 |
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
| 548 |
search_res = web_search(search_q, int(request.search_top_k or 3))
|
| 549 |
if search_res:
|
| 550 |
-
|
|
|
|
|
|
|
| 551 |
except Exception:
|
| 552 |
pass
|
| 553 |
|
|
@@ -558,9 +1045,13 @@ async def chatResponseStream(
|
|
| 558 |
state_key = (request.model, request.state_name)
|
| 559 |
if state_key in STATE_STORE:
|
| 560 |
stored = STATE_STORE[state_key]
|
| 561 |
-
model_state = stored.get('state',
|
| 562 |
model_tokens = stored.get('model_tokens', [0])
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
else:
|
| 565 |
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 566 |
else:
|
|
@@ -571,6 +1062,9 @@ async def chatResponseStream(
|
|
| 571 |
|
| 572 |
completionTokenCount = 0
|
| 573 |
finishReason = None
|
|
|
|
|
|
|
|
|
|
| 574 |
|
| 575 |
response = ChatCompletionChunk(
|
| 576 |
id=completionId,
|
|
@@ -605,6 +1099,14 @@ async def chatResponseStream(
|
|
| 605 |
# Attach state_name in the initial chunk so client can save it to continue later
|
| 606 |
r_dict = response.model_dump()
|
| 607 |
r_dict['state_name'] = request.state_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
yield f"data: {r_dict}\n\n"
|
| 609 |
|
| 610 |
buffer = []
|
|
@@ -771,15 +1273,73 @@ async def chatResponseStream(
|
|
| 771 |
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 772 |
response.choices[0].delta = delta
|
| 773 |
if delta.content != None or delta.reasoning_content != None:
|
| 774 |
-
|
| 775 |
try:
|
| 776 |
if request.state_name:
|
| 777 |
STATE_STORE[(request.model, request.state_name)] = {
|
| 778 |
'state': model_state,
|
| 779 |
'model_tokens': model_tokens,
|
| 780 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
except Exception:
|
| 782 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
yield f"data: {response.model_dump_json()}\n\n"
|
| 784 |
# check stop sequences and stop streaming if we see them
|
| 785 |
for stop_words in request.stop or []:
|
|
@@ -791,39 +1351,139 @@ async def chatResponseStream(
|
|
| 791 |
|
| 792 |
del streamConfig
|
| 793 |
else:
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
|
| 798 |
-
|
| 799 |
-
|
| 800 |
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
|
| 809 |
-
|
| 810 |
-
|
| 811 |
-
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
|
| 815 |
-
choices=[
|
| 816 |
-
ChatCompletionChoice(
|
| 817 |
-
index=0,
|
| 818 |
-
delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
|
| 819 |
-
logprobs=None,
|
| 820 |
-
finish_reason=finishReason,
|
| 821 |
-
)
|
| 822 |
-
],
|
| 823 |
-
)
|
| 824 |
|
| 825 |
-
|
| 826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 827 |
|
| 828 |
genenrateTime = time.time()
|
| 829 |
|
|
@@ -858,7 +1518,13 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 858 |
completionId = str(next(CompletionIdGenerator))
|
| 859 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 860 |
|
|
|
|
|
|
|
| 861 |
modelName = request.model.split(":")[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
enableReasoning = ":thinking" in request.model
|
| 863 |
|
| 864 |
if "rwkv-latest" in request.model:
|
|
@@ -899,14 +1565,24 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 899 |
model_tokens_for_resume = stored.get('model_tokens', [0])
|
| 900 |
request_dict = request.model_dump()
|
| 901 |
|
|
|
|
|
|
|
|
|
|
| 902 |
for k, v in defaultSamplerConfig.model_dump().items():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
if k in request_dict and request_dict[k] is None:
|
| 904 |
request_dict[k] = v
|
| 905 |
realRequest = ChatCompletionRequest(**request_dict)
|
|
|
|
|
|
|
|
|
|
| 906 |
|
| 907 |
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
|
| 908 |
|
| 909 |
-
if
|
| 910 |
r = StreamingResponse(
|
| 911 |
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
|
| 912 |
media_type="text/event-stream",
|
|
@@ -928,10 +1604,109 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 928 |
return r
|
| 929 |
|
| 930 |
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
|
| 936 |
if __name__ == "__main__":
|
| 937 |
import uvicorn
|
|
|
|
| 15 |
remove_nested_think_tags_stack,
|
| 16 |
format_bytes,
|
| 17 |
log,
|
| 18 |
+
detect_tools_and_reasoning,
|
| 19 |
+
universal_tool,
|
| 20 |
)
|
| 21 |
|
| 22 |
import copy, types, gc, sys, re, time, collections, asyncio
|
|
|
|
| 80 |
from rwkv.model import RWKV
|
| 81 |
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
| 82 |
|
| 83 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File
|
| 84 |
from starlette.background import BackgroundTask
|
| 85 |
from fastapi.responses import StreamingResponse
|
| 86 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 96 |
PromptTokensDetails,
|
| 97 |
ChatCompletionChoice,
|
| 98 |
ChatCompletionMessage,
|
| 99 |
+
SamplerConfig,
|
| 100 |
+
UploadedFile,
|
| 101 |
+
FileUploadResponse,
|
| 102 |
)
|
| 103 |
|
| 104 |
|
|
|
|
| 114 |
DEFAULT_REASONING_MODEL_NAME = None
|
| 115 |
|
| 116 |
# In-memory model state store to support streaming continuation/resume per state_name.
|
| 117 |
+
# Keys: (model_name, state_name) -> dict with 'state' and 'model_tokens'
|
| 118 |
STATE_STORE: Dict[tuple, Any] = {}
|
| 119 |
|
| 120 |
+
# Serialized state store file path and flush interval defined in CONFIG
|
| 121 |
+
_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json')
|
| 122 |
+
_LAST_STATE_STORE_WRITE = 0
|
| 123 |
+
|
| 124 |
+
# sentinel for model-initiated tool calls: <tool-call>{json}</tool-call>
|
| 125 |
+
TOOL_CALL_RE = re.compile(r"<tool-call>\s*(\{.*?\})\s*</tool-call>", re.S)
|
| 126 |
+
|
| 127 |
+
# File uploads: simple in-memory index (persisted on disk via the files themselves)
|
| 128 |
+
UPLOADED_FILES: Dict[str, dict] = {}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _serialize_state_store() -> dict:
|
| 132 |
+
# Save only model_tokens to disk; model_state (torch objects) are not serializable
|
| 133 |
+
serial = {}
|
| 134 |
+
for (model_name, state_name), entry in STATE_STORE.items():
|
| 135 |
+
try:
|
| 136 |
+
mt = entry.get('model_tokens') if isinstance(entry, dict) else None
|
| 137 |
+
if mt is None:
|
| 138 |
+
# if entry is a raw model_state, skip
|
| 139 |
+
continue
|
| 140 |
+
serial[f"{model_name}|{state_name}"] = {
|
| 141 |
+
'model': model_name,
|
| 142 |
+
'state_name': state_name,
|
| 143 |
+
'model_tokens': mt,
|
| 144 |
+
}
|
| 145 |
+
except Exception:
|
| 146 |
+
continue
|
| 147 |
+
return serial
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def _load_state_store_from_disk():
|
| 151 |
+
global STATE_STORE
|
| 152 |
+
try:
|
| 153 |
+
if os.path.exists(_STATE_STORE_PATH):
|
| 154 |
+
import json
|
| 155 |
+
|
| 156 |
+
with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f:
|
| 157 |
+
data = json.load(f)
|
| 158 |
+
for k, v in data.items():
|
| 159 |
+
model = v.get('model')
|
| 160 |
+
state_name = v.get('state_name')
|
| 161 |
+
model_tokens = v.get('model_tokens')
|
| 162 |
+
if model and state_name and isinstance(model_tokens, list):
|
| 163 |
+
STATE_STORE[(model, state_name)] = {
|
| 164 |
+
'state': None,
|
| 165 |
+
'model_tokens': model_tokens,
|
| 166 |
+
}
|
| 167 |
+
logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}")
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.info(f"Failed to load state store from disk: {e}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _save_state_store_to_disk(force=False):
|
| 173 |
+
global _LAST_STATE_STORE_WRITE
|
| 174 |
+
now = time.time()
|
| 175 |
+
if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5):
|
| 176 |
+
return
|
| 177 |
+
try:
|
| 178 |
+
serial = _serialize_state_store()
|
| 179 |
+
if not serial:
|
| 180 |
+
return
|
| 181 |
+
import json
|
| 182 |
+
tmp = _STATE_STORE_PATH + ".tmp"
|
| 183 |
+
with open(tmp, 'w', encoding='utf-8') as f:
|
| 184 |
+
json.dump(serial, f)
|
| 185 |
+
os.replace(tmp, _STATE_STORE_PATH)
|
| 186 |
+
_LAST_STATE_STORE_WRITE = now
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.info(f"Write state store to disk failed: {e}")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]):
|
| 192 |
+
"""
|
| 193 |
+
Recompute the `out` logits and `model_state` by forwarding through tokens in chunks.
|
| 194 |
+
Returns a tuple (out, model_state).
|
| 195 |
+
"""
|
| 196 |
+
ms = MODEL_STORAGE.get(model_name)
|
| 197 |
+
if not ms or not ms.model:
|
| 198 |
+
return None, None
|
| 199 |
+
model_state = None
|
| 200 |
+
out = None
|
| 201 |
+
tokens = list(model_tokens) if isinstance(model_tokens, list) else [0]
|
| 202 |
+
while len(tokens) > 0:
|
| 203 |
+
out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
|
| 204 |
+
tokens = tokens[CONFIG.CHUNK_LEN :]
|
| 205 |
+
return out, model_state
|
| 206 |
+
|
| 207 |
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
| 208 |
|
| 209 |
logGPUState()
|
| 210 |
|
| 211 |
+
# Keep any configured models intact; do not force selection by name/size.
|
| 212 |
+
# The previous policy enforced a single '0.1b' model which hid additional configs; use the full list.
|
| 213 |
+
logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
for model_config in CONFIG.MODELS:
|
| 216 |
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
|
|
|
|
| 283 |
presence_penalty: Optional[float] = Field(default=None)
|
| 284 |
count_penalty: Optional[float] = Field(default=None)
|
| 285 |
penalty_decay: Optional[float] = Field(default=None)
|
| 286 |
+
stream: Optional[bool] = Field(default=None, description="Whether to stream token-by-token responses. If None, uses CONFIG.DEFAULT_STREAM")
|
| 287 |
state_name: Optional[str] = Field(default=None)
|
| 288 |
include_usage: Optional[bool] = Field(default=False)
|
| 289 |
stop: Optional[list[str]] = Field(["\n\n"])
|
| 290 |
stop_tokens: Optional[list[int]] = Field([0])
|
| 291 |
web_search: Optional[bool] = Field(default=False, description="Whether to perform a web search and append results to the prompt")
|
| 292 |
+
enable_web_search: Optional[bool] = Field(default=None, description="Explicitly enable web search (overrides auto/web_search) if set")
|
| 293 |
+
auto_web_search: Optional[bool] = Field(default=None, description="Whether to enable web_search based on auto-detected intent")
|
| 294 |
+
enable_tools: Optional[bool] = Field(default=None, description="Explicitly enable tools (overrides auto detection)")
|
| 295 |
+
auto_tools: Optional[bool] = Field(default=None, description="Whether to enable tools based on auto-detected intent")
|
| 296 |
+
enable_reasoning: Optional[bool] = Field(default=None, description="Explicitly override reasoning enablement")
|
| 297 |
+
auto_reasoning: Optional[bool] = Field(default=None, description="Whether to enable reasoning based on auto detection")
|
| 298 |
+
enable_universal: Optional[bool] = Field(default=None, description="Explicitly enable the universal tool execution")
|
| 299 |
+
auto_universal: Optional[bool] = Field(default=None, description="Whether to auto enable universal tool execution")
|
| 300 |
search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve")
|
| 301 |
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})")
|
| 302 |
+
# Per-request sampler overrides for ALLOW_* flags. These let the user
|
| 303 |
+
# disable server-side features for this particular request if needed.
|
| 304 |
+
sampler_allow_web_search: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing web_search")
|
| 305 |
+
sampler_allow_tools: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing tools")
|
| 306 |
+
sampler_allow_reasoning: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing reasoning")
|
| 307 |
+
# Per-request sampler config object; if provided, these settings will
|
| 308 |
+
# override the model defaults for this request.
|
| 309 |
+
sampler: Optional[SamplerConfig] = Field(default=None, description="Per-request sampler settings (overrides model default)")
|
| 310 |
+
# File uploads: allow referencing uploaded files in the request
|
| 311 |
+
file_ids: Optional[List[str]] = Field(default=None, description="List of uploaded file IDs that the model may use for this request")
|
| 312 |
+
enable_file_tool: Optional[bool] = Field(default=None, description="Explicitly enable file-based tools for this request")
|
| 313 |
+
auto_file_tool: Optional[bool] = Field(default=None, description="Auto-detect whether file-based tools are needed")
|
| 314 |
+
sampler_allow_file_tool: Optional[bool] = Field(default=None, description="Per-request sampler override allowing file tools")
|
| 315 |
|
| 316 |
@model_validator(mode="before")
|
| 317 |
@classmethod
|
|
|
|
| 341 |
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
|
| 342 |
|
| 343 |
|
| 344 |
+
@app.on_event("startup")
|
| 345 |
+
async def _startup_state_load_and_persist_loop():
|
| 346 |
+
# Load previous persisted state (tokens only) at startup
|
| 347 |
+
_load_state_store_from_disk()
|
| 348 |
+
|
| 349 |
+
async def _persist_loop():
|
| 350 |
+
while True:
|
| 351 |
+
try:
|
| 352 |
+
_save_state_store_to_disk(force=False)
|
| 353 |
+
except Exception:
|
| 354 |
+
pass
|
| 355 |
+
await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5))
|
| 356 |
+
|
| 357 |
+
# Spawn background flush task
|
| 358 |
+
try:
|
| 359 |
+
asyncio.create_task(_persist_loop())
|
| 360 |
+
except Exception:
|
| 361 |
+
pass
|
| 362 |
+
|
| 363 |
+
|
| 364 |
async def runPrefill(
|
| 365 |
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
|
| 366 |
):
|
|
|
|
| 487 |
) -> ChatCompletion:
|
| 488 |
createTimestamp = time.time()
|
| 489 |
|
| 490 |
+
# Build raw prompt for detection (prefer explicit request.prompt, else messages)
|
| 491 |
+
raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [])
|
| 492 |
+
# Intent detection: analyze raw_prompt or messages to auto-activate tools/web-search/reasoning
|
| 493 |
+
detection = detect_tools_and_reasoning(raw_prompt)
|
| 494 |
+
# After computing auto flags, build the actual prompt string to include <think> if needed
|
| 495 |
+
prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 496 |
+
|
| 497 |
+
# Decide whether web_search should be used based on explicit flags, auto flags, and config defaults
|
| 498 |
+
# Base computed web_search flag
|
| 499 |
+
web_search_enabled = (
|
| 500 |
+
True
|
| 501 |
+
if (request.enable_web_search is not None and request.enable_web_search)
|
| 502 |
+
else (
|
| 503 |
+
request.web_search
|
| 504 |
+
or (request.auto_web_search if request.auto_web_search is not None else CONFIG.AUTO_ENABLE_WEB_SEARCH and detection.get('need_web_search'))
|
| 505 |
+
)
|
| 506 |
)
|
| 507 |
+
if not getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True) and request.enable_web_search is None and not request.web_search:
|
| 508 |
+
web_search_enabled = False
|
| 509 |
+
# If the root config says web search is disabled by default, honor it
|
| 510 |
+
if not getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True) and request.enable_web_search is None and not request.web_search:
|
| 511 |
+
web_search_enabled = False
|
| 512 |
+
# Next: respect per-request sampler override (sampler_allow_web_search) or request.sampler.ALLOW_WEB_SEARCH,
|
| 513 |
+
# then per-model/per-sampler ALLOW_* settings.
|
| 514 |
+
try:
|
| 515 |
+
# 1) per-request `sampler` object ALLOW_* if present, then
|
| 516 |
+
# 2) explicit per-request sampler_allow_* booleans (backwards compatible), else
|
| 517 |
+
# 3) model.DEFAULT_SAMPLER.ALLOW_* if set, else model.ALLOW_*.
|
| 518 |
+
if request.sampler and getattr(request.sampler, 'ALLOW_WEB_SEARCH', None) is not None:
|
| 519 |
+
web_search_enabled = bool(request.sampler.ALLOW_WEB_SEARCH)
|
| 520 |
+
elif hasattr(request, 'sampler_allow_web_search') and request.sampler_allow_web_search is not None:
|
| 521 |
+
web_search_enabled = bool(request.sampler_allow_web_search)
|
| 522 |
+
else:
|
| 523 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 524 |
+
if ms and ms.MODEL_CONFIG:
|
| 525 |
+
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_WEB_SEARCH', None) is not None:
|
| 526 |
+
web_search_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_WEB_SEARCH)
|
| 527 |
+
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_WEB_SEARCH') and not ms.MODEL_CONFIG.ALLOW_WEB_SEARCH:
|
| 528 |
+
web_search_enabled = False
|
| 529 |
+
except Exception:
|
| 530 |
+
pass
|
| 531 |
+
|
| 532 |
+
# Decide whether file tools should be used
|
| 533 |
+
if request.enable_file_tool is not None:
|
| 534 |
+
file_tool_enabled = bool(request.enable_file_tool)
|
| 535 |
+
else:
|
| 536 |
+
auto_file_flag = request.auto_file_tool if request.auto_file_tool is not None else CONFIG.AUTO_ENABLE_TOOLS
|
| 537 |
+
# Default to enabled when files are provided and global setting allows
|
| 538 |
+
file_tool_enabled = bool((request.file_ids and len(request.file_ids) > 0) or (auto_file_flag and request.file_ids))
|
| 539 |
+
# Respect root-level defaults
|
| 540 |
+
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
|
| 541 |
+
file_tool_enabled = False
|
| 542 |
+
# Per-request sampler overrides
|
| 543 |
+
try:
|
| 544 |
+
if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
|
| 545 |
+
file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
|
| 546 |
+
elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
|
| 547 |
+
file_tool_enabled = bool(request.sampler_allow_file_tool)
|
| 548 |
+
else:
|
| 549 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 550 |
+
if ms and ms.MODEL_CONFIG:
|
| 551 |
+
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
|
| 552 |
+
file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
|
| 553 |
+
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL:
|
| 554 |
+
file_tool_enabled = False
|
| 555 |
+
except Exception:
|
| 556 |
+
pass
|
| 557 |
+
|
| 558 |
+
# Decide whether tools should be used
|
| 559 |
+
if request.enable_tools is not None:
|
| 560 |
+
tools_enabled = bool(request.enable_tools)
|
| 561 |
+
else:
|
| 562 |
+
# if explicit tools provided, or enable by default config, or auto detection suggests
|
| 563 |
+
auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
|
| 564 |
+
tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
|
| 565 |
+
# Respect sampler-level override (request.sampler.ALLOW_TOOLS), then
|
| 566 |
+
# request.sampler_allow_tools, then sampler default and finally model-level allow
|
| 567 |
+
try:
|
| 568 |
+
if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None:
|
| 569 |
+
tools_enabled = bool(request.sampler.ALLOW_TOOLS)
|
| 570 |
+
elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None:
|
| 571 |
+
tools_enabled = bool(request.sampler_allow_tools)
|
| 572 |
+
else:
|
| 573 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 574 |
+
if ms and ms.MODEL_CONFIG:
|
| 575 |
+
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None:
|
| 576 |
+
if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS:
|
| 577 |
+
tools_enabled = False
|
| 578 |
+
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS:
|
| 579 |
+
tools_enabled = False
|
| 580 |
+
except Exception:
|
| 581 |
+
pass
|
| 582 |
+
|
| 583 |
+
# Decide whether reasoning should be enabled (in addition to :thinking or explicit)
|
| 584 |
+
reasoning_enabled = bool(
|
| 585 |
+
True
|
| 586 |
+
if (request.enable_reasoning is not None and request.enable_reasoning)
|
| 587 |
+
else (
|
| 588 |
+
bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning'))))
|
| 589 |
+
)
|
| 590 |
+
)
|
| 591 |
+
# If the root config sets reasoning to disabled by default and no explicit request to enable, disable it
|
| 592 |
+
if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None:
|
| 593 |
+
reasoning_enabled = False
|
| 594 |
+
# Respect sampler-level override for reasoning: request.sampler.ALLOW_REASONING -> sampler_allow_reasoning -> sampler.default -> model
|
| 595 |
+
try:
|
| 596 |
+
if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None:
|
| 597 |
+
reasoning_enabled = bool(request.sampler.ALLOW_REASONING)
|
| 598 |
+
elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None:
|
| 599 |
+
reasoning_enabled = bool(request.sampler_allow_reasoning)
|
| 600 |
+
else:
|
| 601 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 602 |
+
if ms and ms.MODEL_CONFIG:
|
| 603 |
+
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None:
|
| 604 |
+
if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING:
|
| 605 |
+
reasoning_enabled = False
|
| 606 |
+
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
|
| 607 |
+
reasoning_enabled = False
|
| 608 |
+
except Exception:
|
| 609 |
+
pass
|
| 610 |
+
|
| 611 |
+
# Keep the local boolean for generating content
|
| 612 |
+
enableReasoning = reasoning_enabled
|
| 613 |
+
try:
|
| 614 |
+
ms = MODEL_STORAGE.get(request.model)
|
| 615 |
+
if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
|
| 616 |
+
enableReasoning = False
|
| 617 |
+
except Exception:
|
| 618 |
+
pass
|
| 619 |
+
|
| 620 |
+
# Ensure web_search property mirrors computed web_search_enabled if not explicitly provided
|
| 621 |
+
if request.enable_web_search is None:
|
| 622 |
+
request.web_search = web_search_enabled
|
| 623 |
+
# If tools should be automatically enabled, add detected ones
|
| 624 |
+
if tools_enabled and not request.tools:
|
| 625 |
+
if detection.get('detected_tools'):
|
| 626 |
+
request.tools = detection.get('detected_tools')
|
| 627 |
+
# If universal is needed and not explicitly requested, add universal tool
|
| 628 |
+
if (request.enable_universal is True) or (
|
| 629 |
+
request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal'))
|
| 630 |
+
):
|
| 631 |
+
if not request.tools:
|
| 632 |
+
request.tools = [{"name": "universal", "args": {"query": raw_prompt}}]
|
| 633 |
+
|
| 634 |
+
executed_tool_calls = []
|
| 635 |
+
# If file tools are enabled and files are attached, inject them into the prompt (for streaming)
|
| 636 |
+
if file_tool_enabled and request.file_ids:
|
| 637 |
+
for fid in request.file_ids:
|
| 638 |
+
try:
|
| 639 |
+
if fid not in UPLOADED_FILES:
|
| 640 |
+
continue
|
| 641 |
+
meta = UPLOADED_FILES.get(fid)
|
| 642 |
+
if not meta:
|
| 643 |
+
continue
|
| 644 |
+
from utils import file_read_from_path
|
| 645 |
+
fpath = meta.get('path')
|
| 646 |
+
if not fpath or not os.path.exists(fpath):
|
| 647 |
+
continue
|
| 648 |
+
file_content = file_read_from_path(fpath, 200000)
|
| 649 |
+
if file_content:
|
| 650 |
+
exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
|
| 651 |
+
executed_tool_calls.append(exec_entry)
|
| 652 |
+
prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
|
| 653 |
+
except Exception as e:
|
| 654 |
+
logger.info(f"File injection error: {e}")
|
| 655 |
+
# If file tools are enabled and files are attached, inject them into the prompt
|
| 656 |
+
if file_tool_enabled and request.file_ids:
|
| 657 |
+
for fid in request.file_ids:
|
| 658 |
+
try:
|
| 659 |
+
if fid not in UPLOADED_FILES:
|
| 660 |
+
continue
|
| 661 |
+
meta = UPLOADED_FILES.get(fid)
|
| 662 |
+
if not meta:
|
| 663 |
+
continue
|
| 664 |
+
from utils import file_read_from_path
|
| 665 |
+
fpath = meta.get('path')
|
| 666 |
+
if not fpath or not os.path.exists(fpath):
|
| 667 |
+
continue
|
| 668 |
+
file_content = file_read_from_path(fpath, 200000)
|
| 669 |
+
if file_content:
|
| 670 |
+
exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
|
| 671 |
+
executed_tool_calls.append(exec_entry)
|
| 672 |
+
prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
|
| 673 |
+
except Exception as e:
|
| 674 |
+
logger.info(f"File injection error: {e}")
|
| 675 |
if request.tools:
|
| 676 |
try:
|
| 677 |
for tool in request.tools:
|
|
|
|
| 684 |
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
| 685 |
search_str = web_search(search_q, search_top_k)
|
| 686 |
if search_str:
|
| 687 |
+
search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
|
| 688 |
+
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
|
| 689 |
+
prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt)
|
| 690 |
elif name == 'calc' or name == 'calculator':
|
| 691 |
from utils import calc
|
| 692 |
|
| 693 |
expr = args.get('expression')
|
| 694 |
if expr:
|
| 695 |
calc_res = calc(expr)
|
| 696 |
+
# Wrap result into a structured dict
|
| 697 |
+
calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
|
| 698 |
+
executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
|
| 699 |
+
prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt)
|
| 700 |
+
elif name == 'universal':
|
| 701 |
+
try:
|
| 702 |
+
res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
| 703 |
+
# If universal_tool returns a dict, extract text result for prompt injection
|
| 704 |
+
if isinstance(res, dict):
|
| 705 |
+
result_text = res.get('result') if res.get('result') is not None else ''
|
| 706 |
+
else:
|
| 707 |
+
result_text = str(res)
|
| 708 |
+
executed_tool_calls.append({"name": "universal", "args": args, "result": res})
|
| 709 |
+
prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt)
|
| 710 |
+
except Exception as e:
|
| 711 |
+
logger.info(f"Universal tool execution error: {e}")
|
| 712 |
else:
|
| 713 |
# Unsupported tool - ignore or log
|
| 714 |
logger.info(f"Unsupported tool requested: {name}")
|
| 715 |
+
if name == 'file_read':
|
| 716 |
+
# read an uploaded file by id/path
|
| 717 |
+
try:
|
| 718 |
+
fid = args.get('file_id') or args.get('id') or (request.file_ids[0] if request.file_ids else None)
|
| 719 |
+
if not fid:
|
| 720 |
+
continue
|
| 721 |
+
if fid not in UPLOADED_FILES:
|
| 722 |
+
continue
|
| 723 |
+
meta = UPLOADED_FILES.get(fid)
|
| 724 |
+
if not meta:
|
| 725 |
+
continue
|
| 726 |
+
from utils import file_read_from_path
|
| 727 |
+
fpath = meta.get('path')
|
| 728 |
+
if not fpath or not os.path.exists(fpath):
|
| 729 |
+
continue
|
| 730 |
+
file_content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000))
|
| 731 |
+
exec_entry = {"name": "file_read", "args": {"file_id": fid, "max_bytes": int(args.get('max_bytes') or 100000)}, "result": {"action": "file_read", "result": file_content, "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
|
| 732 |
+
executed_tool_calls.append(exec_entry)
|
| 733 |
+
_res = exec_entry.get('result') if isinstance(exec_entry, dict) else None
|
| 734 |
+
_res_text = ''
|
| 735 |
+
if isinstance(_res, dict):
|
| 736 |
+
_res_text = _res.get('result') or ''
|
| 737 |
+
elif _res is not None:
|
| 738 |
+
_res_text = str(_res)
|
| 739 |
+
prompt = (f"ToolResults:\n{_res_text}\n\nUse these file contents to answer the prompt.\n\n" + prompt)
|
| 740 |
+
except Exception as e:
|
| 741 |
+
logger.info(f"file_read tool error: {e}")
|
| 742 |
except Exception as e:
|
| 743 |
logger.info(f"Tool processing error: {e}")
|
| 744 |
+
elif request.web_search or web_search_enabled:
|
| 745 |
try:
|
| 746 |
from utils import web_search
|
| 747 |
|
| 748 |
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
| 749 |
search_res = web_search(search_q, int(request.search_top_k or 3))
|
| 750 |
if search_res:
|
| 751 |
+
search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
|
| 752 |
+
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
|
| 753 |
+
prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
|
| 754 |
except Exception:
|
| 755 |
pass
|
| 756 |
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
|
|
|
| 760 |
state_key = (request.model, request.state_name)
|
| 761 |
if state_key in STATE_STORE:
|
| 762 |
stored = STATE_STORE[state_key]
|
| 763 |
+
model_state = stored.get('state', None)
|
| 764 |
model_tokens = stored.get('model_tokens', [0])
|
| 765 |
+
if model_state is None:
|
| 766 |
+
# Recompute out and model_state from tokens since we did not persist the torch state
|
| 767 |
+
out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
|
| 768 |
+
else:
|
| 769 |
+
# If we have a model_state, we still need out logits. Compute from last window of tokens
|
| 770 |
+
out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
|
| 771 |
else:
|
| 772 |
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 773 |
else:
|
|
|
|
| 779 |
fullResponse = " <think" if enableReasoning else ""
|
| 780 |
completionTokenCount = 0
|
| 781 |
finishReason = None
|
| 782 |
+
# Limit model-initiated tool calls per request to avoid loops
|
| 783 |
+
model_initiated_tool_calls = 0
|
| 784 |
+
MODEL_MAX_TOOL_CALLS = 3
|
| 785 |
+
should_restart = True
|
| 786 |
+
while should_restart:
|
| 787 |
+
should_restart = False
|
| 788 |
+
gen = generate(
|
| 789 |
+
request,
|
| 790 |
+
out,
|
| 791 |
+
model_tokens,
|
| 792 |
+
model_state,
|
| 793 |
+
max_tokens=(
|
| 794 |
+
64000
|
| 795 |
+
if "max_tokens" not in request.model_fields_set and enableReasoning
|
| 796 |
+
else (request.max_tokens or 2048)
|
| 797 |
+
),
|
| 798 |
+
)
|
| 799 |
+
for chunk in gen:
|
| 800 |
+
# chunk['content'] is now expected to be a single token's decoded text
|
| 801 |
+
fullResponse += chunk["content"]
|
| 802 |
+
# Detect model-issued tool call markers within the output
|
| 803 |
+
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
|
| 804 |
+
m = TOOL_CALL_RE.search(fullResponse)
|
| 805 |
+
if m:
|
| 806 |
+
try:
|
| 807 |
+
payload_raw = m.group(1)
|
| 808 |
+
import json
|
| 809 |
+
|
| 810 |
+
payload = json.loads(payload_raw)
|
| 811 |
+
tool_name = payload.get('name')
|
| 812 |
+
tool_args = payload.get('args', {})
|
| 813 |
+
tool_res = None
|
| 814 |
+
if tool_name == 'web_search':
|
| 815 |
+
from utils import web_search
|
| 816 |
+
|
| 817 |
+
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
| 818 |
+
k = int(tool_args.get('top_k') or request.search_top_k or 3)
|
| 819 |
+
tool_res = web_search(q, k)
|
| 820 |
+
elif tool_name in ('calc', 'calculator'):
|
| 821 |
+
from utils import calc
|
| 822 |
+
|
| 823 |
+
expr = tool_args.get('expression')
|
| 824 |
+
if expr:
|
| 825 |
+
tool_res = calc(expr)
|
| 826 |
+
else:
|
| 827 |
+
try:
|
| 828 |
+
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
| 829 |
+
except Exception:
|
| 830 |
+
tool_res = None
|
| 831 |
+
|
| 832 |
+
if tool_res:
|
| 833 |
+
if not isinstance(tool_res, dict):
|
| 834 |
+
if tool_name in ('calc', 'calculator'):
|
| 835 |
+
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
|
| 836 |
+
elif tool_name == 'web_search':
|
| 837 |
+
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
|
| 838 |
+
else:
|
| 839 |
+
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
|
| 840 |
+
else:
|
| 841 |
+
tool_res_struct = tool_res
|
| 842 |
+
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
|
| 843 |
+
executed_tool_calls.append(exec_entry)
|
| 844 |
+
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
|
| 845 |
+
prompt = delta_text + prompt
|
| 846 |
+
fullResponse = TOOL_CALL_RE.sub('', fullResponse)
|
| 847 |
+
buffer = [fullResponse]
|
| 848 |
+
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
|
| 849 |
+
model_initiated_tool_calls += 1
|
| 850 |
+
except Exception as e:
|
| 851 |
+
logger.info(f"Model-initiated tool handling error: {e}")
|
| 852 |
+
# Check stop sequences (multi-token) after each token
|
| 853 |
+
for stop_words in request.stop or []:
|
| 854 |
+
if stop_words in fullResponse:
|
| 855 |
+
finishReason = f"stop:words:{stop_words}"
|
| 856 |
+
break
|
| 857 |
+
completionTokenCount += 1
|
| 858 |
|
| 859 |
+
if chunk["finish_reason"]:
|
| 860 |
+
finishReason = chunk["finish_reason"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
+
generateTime = time.time()
|
| 863 |
|
| 864 |
responseLog = {
|
| 865 |
"content": fullResponse,
|
|
|
|
| 867 |
"prefill_len": promptTokenCount,
|
| 868 |
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
|
| 869 |
"gen_len": completionTokenCount,
|
| 870 |
+
"gen_tps": round(completionTokenCount / (generateTime - prefillTime) if generateTime!=prefillTime else 0, 2),
|
| 871 |
}
|
| 872 |
logger.info(f"[RES] {completionId} - {responseLog}")
|
| 873 |
|
|
|
|
| 890 |
role="Assistant",
|
| 891 |
content=content,
|
| 892 |
reasoning_content=reasoning_content if reasoning_content else None,
|
| 893 |
+
tool_calls=executed_tool_calls if executed_tool_calls else None,
|
| 894 |
),
|
| 895 |
logprobs=None,
|
| 896 |
finish_reason=finishReason,
|
|
|
|
| 905 |
'state': model_state,
|
| 906 |
'model_tokens': model_tokens,
|
| 907 |
}
|
| 908 |
+
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
|
| 909 |
+
try:
|
| 910 |
+
_save_state_store_to_disk(force=True)
|
| 911 |
+
except Exception:
|
| 912 |
+
pass
|
| 913 |
except Exception:
|
| 914 |
pass
|
| 915 |
|
|
|
|
| 924 |
):
|
| 925 |
createTimestamp = int(time.time())
|
| 926 |
|
| 927 |
+
raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [], False)
|
| 928 |
+
# Intent detection and defaults: check whether to auto-enable tools, web_search, reasoning
|
| 929 |
+
detection = detect_tools_and_reasoning(raw_prompt)
|
| 930 |
+
|
| 931 |
+
web_search_enabled = (
|
| 932 |
+
True
|
| 933 |
+
if (request.enable_web_search is not None and request.enable_web_search)
|
| 934 |
+
else (
|
| 935 |
+
request.web_search
|
| 936 |
+
or (request.auto_web_search if request.auto_web_search is not None else CONFIG.AUTO_ENABLE_WEB_SEARCH and detection.get('need_web_search'))
|
| 937 |
+
)
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
if request.enable_tools is not None:
|
| 941 |
+
tools_enabled = bool(request.enable_tools)
|
| 942 |
+
else:
|
| 943 |
+
auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
|
| 944 |
+
tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
|
| 945 |
+
|
| 946 |
+
reasoning_enabled = bool(
|
| 947 |
+
True
|
| 948 |
+
if (request.enable_reasoning is not None and request.enable_reasoning)
|
| 949 |
+
else (
|
| 950 |
+
bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning'))))
|
| 951 |
+
)
|
| 952 |
)
|
| 953 |
+
enableReasoning = reasoning_enabled
|
| 954 |
+
try:
|
| 955 |
+
ms_cfg = MODEL_STORAGE.get(request.model)
|
| 956 |
+
if ms_cfg and ms_cfg.MODEL_CONFIG and hasattr(ms_cfg.MODEL_CONFIG, 'ALLOW_REASONING') and not ms_cfg.MODEL_CONFIG.ALLOW_REASONING:
|
| 957 |
+
enableReasoning = False
|
| 958 |
+
except Exception:
|
| 959 |
+
pass
|
| 960 |
+
# Decide whether file tools should be used for streaming variant
|
| 961 |
+
if request.enable_file_tool is not None:
|
| 962 |
+
file_tool_enabled = bool(request.enable_file_tool)
|
| 963 |
+
else:
|
| 964 |
+
auto_file_flag = request.auto_file_tool if request.auto_file_tool is not None else CONFIG.AUTO_ENABLE_TOOLS
|
| 965 |
+
file_tool_enabled = bool((request.file_ids and len(request.file_ids) > 0) or (auto_file_flag and request.file_ids))
|
| 966 |
+
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
|
| 967 |
+
file_tool_enabled = False
|
| 968 |
+
try:
|
| 969 |
+
if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
|
| 970 |
+
file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
|
| 971 |
+
elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
|
| 972 |
+
file_tool_enabled = bool(request.sampler_allow_file_tool)
|
| 973 |
+
else:
|
| 974 |
+
ms2 = MODEL_STORAGE.get(request.model)
|
| 975 |
+
if ms2 and ms2.MODEL_CONFIG:
|
| 976 |
+
if hasattr(ms2.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms2.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
|
| 977 |
+
file_tool_enabled = bool(ms2.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
|
| 978 |
+
elif hasattr(ms2.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms2.MODEL_CONFIG.ALLOW_FILE_TOOL:
|
| 979 |
+
file_tool_enabled = False
|
| 980 |
+
except Exception:
|
| 981 |
+
pass
|
| 982 |
+
# Build final prompt after deciding enableReasoning
|
| 983 |
+
prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
| 984 |
+
|
| 985 |
+
if tools_enabled and not request.tools:
|
| 986 |
+
if detection.get('detected_tools'):
|
| 987 |
+
request.tools = detection.get('detected_tools')
|
| 988 |
+
executed_tool_calls = []
|
| 989 |
if request.tools:
|
| 990 |
try:
|
| 991 |
for tool in request.tools:
|
|
|
|
| 998 |
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
| 999 |
search_str = web_search(search_q, search_top_k)
|
| 1000 |
if search_str:
|
| 1001 |
+
search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
|
| 1002 |
+
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
|
| 1003 |
+
prompt = (f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt)
|
| 1004 |
elif name == 'calc' or name == 'calculator':
|
| 1005 |
from utils import calc
|
| 1006 |
|
| 1007 |
expr = args.get('expression')
|
| 1008 |
if expr:
|
| 1009 |
calc_res = calc(expr)
|
| 1010 |
+
calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
|
| 1011 |
+
executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
|
| 1012 |
+
prompt = (f"CalcResult:{expr} = {calc_res_struct.get('result')}\n\n" + prompt)
|
| 1013 |
+
elif name == 'universal':
|
| 1014 |
+
try:
|
| 1015 |
+
res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
| 1016 |
+
if isinstance(res, dict):
|
| 1017 |
+
result_text = res.get('result') if res.get('result') is not None else ''
|
| 1018 |
+
else:
|
| 1019 |
+
result_text = str(res)
|
| 1020 |
+
executed_tool_calls.append({"name": "universal", "args": args, "result": res})
|
| 1021 |
+
prompt = (f"ToolResults:\n{result_text}\n\n" + prompt)
|
| 1022 |
+
except Exception as e:
|
| 1023 |
+
logger.info(f"Universal tool execution error: {e}")
|
| 1024 |
else:
|
| 1025 |
logger.info(f"Unsupported tool requested: {name}")
|
| 1026 |
except Exception as e:
|
| 1027 |
logger.info(f"Tool processing error: {e}")
|
| 1028 |
+
elif request.web_search or web_search_enabled:
|
| 1029 |
try:
|
| 1030 |
from utils import web_search
|
| 1031 |
|
| 1032 |
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
| 1033 |
search_res = web_search(search_q, int(request.search_top_k or 3))
|
| 1034 |
if search_res:
|
| 1035 |
+
search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
|
| 1036 |
+
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
|
| 1037 |
+
prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
|
| 1038 |
except Exception:
|
| 1039 |
pass
|
| 1040 |
|
|
|
|
| 1045 |
state_key = (request.model, request.state_name)
|
| 1046 |
if state_key in STATE_STORE:
|
| 1047 |
stored = STATE_STORE[state_key]
|
| 1048 |
+
model_state = stored.get('state', None)
|
| 1049 |
model_tokens = stored.get('model_tokens', [0])
|
| 1050 |
+
if model_state is None:
|
| 1051 |
+
# Recompute out and model_state from tokens since we did not persist the torch state
|
| 1052 |
+
out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
|
| 1053 |
+
else:
|
| 1054 |
+
out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
|
| 1055 |
else:
|
| 1056 |
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
| 1057 |
else:
|
|
|
|
| 1062 |
|
| 1063 |
completionTokenCount = 0
|
| 1064 |
finishReason = None
|
| 1065 |
+
# Limit how many tool calls the model can initiate during a single stream
|
| 1066 |
+
model_initiated_tool_calls = 0
|
| 1067 |
+
MODEL_MAX_TOOL_CALLS = 3
|
| 1068 |
|
| 1069 |
response = ChatCompletionChunk(
|
| 1070 |
id=completionId,
|
|
|
|
| 1099 |
# Attach state_name in the initial chunk so client can save it to continue later
|
| 1100 |
r_dict = response.model_dump()
|
| 1101 |
r_dict['state_name'] = request.state_name
|
| 1102 |
+
# Attach executed tool_calls both at root for easy client metadata, and within the assistant message delta
|
| 1103 |
+
if executed_tool_calls:
|
| 1104 |
+
r_dict['tool_calls'] = executed_tool_calls
|
| 1105 |
+
try:
|
| 1106 |
+
if r_dict.get('choices') and len(r_dict['choices']) > 0 and r_dict['choices'][0].get('delta') is not None:
|
| 1107 |
+
r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls
|
| 1108 |
+
except Exception:
|
| 1109 |
+
pass
|
| 1110 |
yield f"data: {r_dict}\n\n"
|
| 1111 |
|
| 1112 |
buffer = []
|
|
|
|
| 1273 |
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
| 1274 |
response.choices[0].delta = delta
|
| 1275 |
if delta.content != None or delta.reasoning_content != None:
|
| 1276 |
+
# Save model state frequently (after each token) to allow resuming
|
| 1277 |
try:
|
| 1278 |
if request.state_name:
|
| 1279 |
STATE_STORE[(request.model, request.state_name)] = {
|
| 1280 |
'state': model_state,
|
| 1281 |
'model_tokens': model_tokens,
|
| 1282 |
}
|
| 1283 |
+
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
|
| 1284 |
+
try:
|
| 1285 |
+
_save_state_store_to_disk(force=True)
|
| 1286 |
+
except Exception:
|
| 1287 |
+
pass
|
| 1288 |
except Exception:
|
| 1289 |
pass
|
| 1290 |
+
# model-initiated tool call detection
|
| 1291 |
+
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
|
| 1292 |
+
m = TOOL_CALL_RE.search(fullText)
|
| 1293 |
+
if m:
|
| 1294 |
+
try:
|
| 1295 |
+
payload_raw = m.group(1)
|
| 1296 |
+
import json
|
| 1297 |
+
|
| 1298 |
+
payload = json.loads(payload_raw)
|
| 1299 |
+
tool_name = payload.get('name')
|
| 1300 |
+
tool_args = payload.get('args', {})
|
| 1301 |
+
tool_res = None
|
| 1302 |
+
if tool_name == 'web_search':
|
| 1303 |
+
from utils import web_search
|
| 1304 |
+
|
| 1305 |
+
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
| 1306 |
+
k = int(tool_args.get('top_k') or request.search_top_k or 3)
|
| 1307 |
+
tool_res = web_search(q, k)
|
| 1308 |
+
elif tool_name in ('calc', 'calculator'):
|
| 1309 |
+
from utils import calc
|
| 1310 |
+
|
| 1311 |
+
expr = tool_args.get('expression')
|
| 1312 |
+
if expr:
|
| 1313 |
+
tool_res = calc(expr)
|
| 1314 |
+
else:
|
| 1315 |
+
try:
|
| 1316 |
+
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
| 1317 |
+
except Exception:
|
| 1318 |
+
tool_res = None
|
| 1319 |
+
|
| 1320 |
+
if tool_res:
|
| 1321 |
+
# Normalize tool_res into a structured dict if needed
|
| 1322 |
+
if not isinstance(tool_res, dict):
|
| 1323 |
+
if tool_name in ('calc', 'calculator'):
|
| 1324 |
+
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
|
| 1325 |
+
elif tool_name == 'web_search':
|
| 1326 |
+
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
|
| 1327 |
+
else:
|
| 1328 |
+
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
|
| 1329 |
+
else:
|
| 1330 |
+
tool_res_struct = tool_res
|
| 1331 |
+
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
|
| 1332 |
+
executed_tool_calls.append(exec_entry)
|
| 1333 |
+
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
|
| 1334 |
+
prompt = delta_text + prompt
|
| 1335 |
+
fullText = TOOL_CALL_RE.sub('', fullText)
|
| 1336 |
+
buffer = [fullText]
|
| 1337 |
+
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
|
| 1338 |
+
model_initiated_tool_calls += 1
|
| 1339 |
+
should_restart = True
|
| 1340 |
+
break
|
| 1341 |
+
except Exception as e:
|
| 1342 |
+
logger.info(f"Model-initiated tool handling error: {e}")
|
| 1343 |
yield f"data: {response.model_dump_json()}\n\n"
|
| 1344 |
# check stop sequences and stop streaming if we see them
|
| 1345 |
for stop_words in request.stop or []:
|
|
|
|
| 1351 |
|
| 1352 |
del streamConfig
|
| 1353 |
else:
|
| 1354 |
+
should_restart = True
|
| 1355 |
+
while should_restart:
|
| 1356 |
+
should_restart = False
|
| 1357 |
+
gen = generate(request, out, model_tokens, model_state)
|
| 1358 |
+
for chunk in gen:
|
| 1359 |
+
completionTokenCount += 1
|
| 1360 |
+
buffer.append(chunk["content"])
|
| 1361 |
|
| 1362 |
+
if chunk["finish_reason"]:
|
| 1363 |
+
finishReason = chunk["finish_reason"]
|
| 1364 |
|
| 1365 |
+
# Save model state frequently (after each token) to allow resuming
|
| 1366 |
+
try:
|
| 1367 |
+
if request.state_name:
|
| 1368 |
+
STATE_STORE[(request.model, request.state_name)] = {
|
| 1369 |
+
'state': model_state,
|
| 1370 |
+
'model_tokens': model_tokens,
|
| 1371 |
+
}
|
| 1372 |
+
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
|
| 1373 |
+
try:
|
| 1374 |
+
_save_state_store_to_disk(force=True)
|
| 1375 |
+
except Exception:
|
| 1376 |
+
pass
|
| 1377 |
+
except Exception:
|
| 1378 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1379 |
|
| 1380 |
+
# Detect model-initiated tool calls
|
| 1381 |
+
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
|
| 1382 |
+
fullText = ''.join(buffer)
|
| 1383 |
+
m = TOOL_CALL_RE.search(fullText)
|
| 1384 |
+
if m:
|
| 1385 |
+
try:
|
| 1386 |
+
payload_raw = m.group(1)
|
| 1387 |
+
import json
|
| 1388 |
+
|
| 1389 |
+
payload = json.loads(payload_raw)
|
| 1390 |
+
tool_name = payload.get('name')
|
| 1391 |
+
tool_args = payload.get('args', {})
|
| 1392 |
+
tool_res = None
|
| 1393 |
+
if tool_name == 'web_search':
|
| 1394 |
+
from utils import web_search
|
| 1395 |
+
|
| 1396 |
+
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
| 1397 |
+
k = int(tool_args.get('top_k') or request.search_top_k or 3)
|
| 1398 |
+
tool_res = web_search(q, k)
|
| 1399 |
+
elif tool_name in ('calc', 'calculator'):
|
| 1400 |
+
from utils import calc
|
| 1401 |
+
|
| 1402 |
+
expr = tool_args.get('expression')
|
| 1403 |
+
if expr:
|
| 1404 |
+
tool_res = calc(expr)
|
| 1405 |
+
else:
|
| 1406 |
+
try:
|
| 1407 |
+
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
| 1408 |
+
except Exception:
|
| 1409 |
+
tool_res = None
|
| 1410 |
+
|
| 1411 |
+
if tool_res:
|
| 1412 |
+
if not isinstance(tool_res, dict):
|
| 1413 |
+
if tool_name in ('calc', 'calculator'):
|
| 1414 |
+
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
|
| 1415 |
+
elif tool_name == 'web_search':
|
| 1416 |
+
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
|
| 1417 |
+
else:
|
| 1418 |
+
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
|
| 1419 |
+
else:
|
| 1420 |
+
tool_res_struct = tool_res
|
| 1421 |
+
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
|
| 1422 |
+
executed_tool_calls.append(exec_entry)
|
| 1423 |
+
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
|
| 1424 |
+
prompt = delta_text + prompt
|
| 1425 |
+
fullText = TOOL_CALL_RE.sub('', fullText)
|
| 1426 |
+
buffer = [fullText]
|
| 1427 |
+
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
|
| 1428 |
+
# Notify client that a tool was called mid-stream (metadata-only chunk)
|
| 1429 |
+
try:
|
| 1430 |
+
meta_resp = ChatCompletionChunk(
|
| 1431 |
+
id=completionId,
|
| 1432 |
+
created=createTimestamp,
|
| 1433 |
+
model=request.model,
|
| 1434 |
+
usage=(
|
| 1435 |
+
Usage(
|
| 1436 |
+
prompt_tokens=promptTokenCount,
|
| 1437 |
+
completion_tokens=completionTokenCount,
|
| 1438 |
+
total_tokens=promptTokenCount + completionTokenCount,
|
| 1439 |
+
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
| 1440 |
+
)
|
| 1441 |
+
if request.include_usage
|
| 1442 |
+
else None
|
| 1443 |
+
),
|
| 1444 |
+
choices=[
|
| 1445 |
+
ChatCompletionChoice(
|
| 1446 |
+
index=0,
|
| 1447 |
+
delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
|
| 1448 |
+
logprobs=None,
|
| 1449 |
+
finish_reason=None,
|
| 1450 |
+
)
|
| 1451 |
+
],
|
| 1452 |
+
)
|
| 1453 |
+
yield f"data: {meta_resp.model_dump_json()}\n\n"
|
| 1454 |
+
except Exception:
|
| 1455 |
+
pass
|
| 1456 |
+
model_initiated_tool_calls += 1
|
| 1457 |
+
should_restart = True
|
| 1458 |
+
break
|
| 1459 |
+
except Exception as e:
|
| 1460 |
+
logger.info(f"Model-initiated tool handling error: {e}")
|
| 1461 |
+
|
| 1462 |
+
response = ChatCompletionChunk(
|
| 1463 |
+
id=completionId,
|
| 1464 |
+
created=createTimestamp,
|
| 1465 |
+
model=request.model,
|
| 1466 |
+
usage=(
|
| 1467 |
+
Usage(
|
| 1468 |
+
prompt_tokens=promptTokenCount,
|
| 1469 |
+
completion_tokens=completionTokenCount,
|
| 1470 |
+
total_tokens=promptTokenCount + completionTokenCount,
|
| 1471 |
+
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
| 1472 |
+
)
|
| 1473 |
+
if request.include_usage
|
| 1474 |
+
else None
|
| 1475 |
+
),
|
| 1476 |
+
choices=[
|
| 1477 |
+
ChatCompletionChoice(
|
| 1478 |
+
index=0,
|
| 1479 |
+
delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
|
| 1480 |
+
logprobs=None,
|
| 1481 |
+
finish_reason=finishReason,
|
| 1482 |
+
)
|
| 1483 |
+
],
|
| 1484 |
+
)
|
| 1485 |
+
yield f"data: {response.model_dump_json()}\n\n"
|
| 1486 |
+
await asyncio.sleep(0)
|
| 1487 |
|
| 1488 |
genenrateTime = time.time()
|
| 1489 |
|
|
|
|
| 1518 |
completionId = str(next(CompletionIdGenerator))
|
| 1519 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 1520 |
|
| 1521 |
+
# Support model suffixes like ':thinking' for reasoning or ':web' to request
|
| 1522 |
+
# web search by default for this request. E.g., 'rwkv-latest:web' will enable web_search.
|
| 1523 |
modelName = request.model.split(":")[0]
|
| 1524 |
+
if ":web" in request.model:
|
| 1525 |
+
request.enable_web_search = True
|
| 1526 |
+
if ":file" in request.model:
|
| 1527 |
+
request.enable_file_tool = True
|
| 1528 |
enableReasoning = ":thinking" in request.model
|
| 1529 |
|
| 1530 |
if "rwkv-latest" in request.model:
|
|
|
|
| 1565 |
model_tokens_for_resume = stored.get('model_tokens', [0])
|
| 1566 |
request_dict = request.model_dump()
|
| 1567 |
|
| 1568 |
+
# Apply defaults from model's DEFAULT_SAMPLER, optionally overridden by the
|
| 1569 |
+
# per-request `sampler` object (or legacy sampler_allow_* booleans).
|
| 1570 |
+
sampler_overrides = request_dict.get('sampler') or {}
|
| 1571 |
for k, v in defaultSamplerConfig.model_dump().items():
|
| 1572 |
+
# If the request provided a sampler override for this field, use it
|
| 1573 |
+
if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None:
|
| 1574 |
+
request_dict[k] = sampler_overrides.get(k)
|
| 1575 |
+
continue
|
| 1576 |
if k in request_dict and request_dict[k] is None:
|
| 1577 |
request_dict[k] = v
|
| 1578 |
realRequest = ChatCompletionRequest(**request_dict)
|
| 1579 |
+
# Ensure stream defaults to configuration value when not explicitly provided
|
| 1580 |
+
if realRequest.stream is None:
|
| 1581 |
+
realRequest.stream = CONFIG.DEFAULT_STREAM
|
| 1582 |
|
| 1583 |
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
|
| 1584 |
|
| 1585 |
+
if realRequest.stream:
|
| 1586 |
r = StreamingResponse(
|
| 1587 |
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
|
| 1588 |
media_type="text/event-stream",
|
|
|
|
| 1604 |
return r
|
| 1605 |
|
| 1606 |
|
| 1607 |
+
# We keep the service API-only; remove static mount for demo frontend to
|
| 1608 |
+
# avoid serving HTML files by default and keep the repository Python-only.
|
| 1609 |
+
logger.info("Static frontend mount removed for Python-only deploy; use API endpoints for integration")
|
| 1610 |
+
|
| 1611 |
+
|
| 1612 |
+
@app.get('/api/v1/models')
|
| 1613 |
+
def list_models():
|
| 1614 |
+
"""Return model configuration summary for clients/UI.
|
| 1615 |
+
|
| 1616 |
+
This endpoint returns configured models, their default sampler values, and
|
| 1617 |
+
ALLOW_* flags so UI clients can build a controls surface based on server
|
| 1618 |
+
capabilities (web search, tools, reasoning).
|
| 1619 |
+
"""
|
| 1620 |
+
out = []
|
| 1621 |
+
root_defaults = {
|
| 1622 |
+
'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
|
| 1623 |
+
'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True),
|
| 1624 |
+
'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True),
|
| 1625 |
+
'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True),
|
| 1626 |
+
'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True),
|
| 1627 |
+
'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True),
|
| 1628 |
+
'UPLOAD_URL': '/api/v1/files',
|
| 1629 |
+
}
|
| 1630 |
+
for m in CONFIG.MODELS:
|
| 1631 |
+
out.append(
|
| 1632 |
+
{
|
| 1633 |
+
'SERVICE_NAME': m.SERVICE_NAME,
|
| 1634 |
+
'DEFAULT_CHAT': m.DEFAULT_CHAT,
|
| 1635 |
+
'DEFAULT_REASONING': m.DEFAULT_REASONING,
|
| 1636 |
+
'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True),
|
| 1637 |
+
'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True),
|
| 1638 |
+
'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True),
|
| 1639 |
+
'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True),
|
| 1640 |
+
'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True),
|
| 1641 |
+
'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True),
|
| 1642 |
+
'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True),
|
| 1643 |
+
'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None,
|
| 1644 |
+
# Convenience info for clients: upload endpoint and root defaults
|
| 1645 |
+
'UPLOAD_URL': '/api/v1/files',
|
| 1646 |
+
'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
|
| 1647 |
+
}
|
| 1648 |
+
)
|
| 1649 |
+
return {'root_defaults': root_defaults, 'models': out}
|
| 1650 |
+
|
| 1651 |
+
|
| 1652 |
+
@app.post('/api/v1/files', response_model=FileUploadResponse)
|
| 1653 |
+
async def upload_file(file: UploadFile = File(...), model: Optional[str] = None):
|
| 1654 |
+
"""Save uploaded file to CONFIG.UPLOAD_DIR and return metadata."""
|
| 1655 |
+
try:
|
| 1656 |
+
# Respect root-level upload toggle
|
| 1657 |
+
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True):
|
| 1658 |
+
raise HTTPException(403, 'File uploads are disabled by server configuration')
|
| 1659 |
+
# If a model is provided, verify the model allows file tools
|
| 1660 |
+
if model:
|
| 1661 |
+
if model not in MODEL_STORAGE:
|
| 1662 |
+
raise HTTPException(404, f"Model {model} not found")
|
| 1663 |
+
ms = MODEL_STORAGE[model]
|
| 1664 |
+
if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True):
|
| 1665 |
+
raise HTTPException(403, f"Model {model} does not allow file uploads")
|
| 1666 |
+
from utils import save_bytes_to_upload
|
| 1667 |
+
|
| 1668 |
+
content = await file.read()
|
| 1669 |
+
fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file'
|
| 1670 |
+
meta = save_bytes_to_upload(fname, content)
|
| 1671 |
+
if meta.get('error'):
|
| 1672 |
+
raise HTTPException(500, f"Could not save file: {meta.get('error')}")
|
| 1673 |
+
UPLOADED_FILES[meta['file_id']] = meta
|
| 1674 |
+
return FileUploadResponse(success=True, file=UploadedFile(**meta))
|
| 1675 |
+
except Exception as e:
|
| 1676 |
+
raise HTTPException(500, str(e))
|
| 1677 |
+
|
| 1678 |
+
|
| 1679 |
+
@app.get('/api/v1/files')
|
| 1680 |
+
def list_files():
|
| 1681 |
+
return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()]
|
| 1682 |
+
|
| 1683 |
+
|
| 1684 |
+
@app.get('/api/v1/files/{file_id}')
|
| 1685 |
+
def get_file(file_id: str, download: bool = False):
|
| 1686 |
+
if file_id not in UPLOADED_FILES:
|
| 1687 |
+
raise HTTPException(404, 'File not found')
|
| 1688 |
+
meta = UPLOADED_FILES[file_id]
|
| 1689 |
+
if download:
|
| 1690 |
+
# return file contents
|
| 1691 |
+
try:
|
| 1692 |
+
with open(meta['path'], 'rb') as f:
|
| 1693 |
+
return StreamingResponse(f, media_type='application/octet-stream')
|
| 1694 |
+
except Exception as e:
|
| 1695 |
+
raise HTTPException(500, str(e))
|
| 1696 |
+
return UploadedFile(**meta)
|
| 1697 |
+
|
| 1698 |
+
|
| 1699 |
+
@app.delete('/api/v1/files/{file_id}')
|
| 1700 |
+
def delete_file(file_id: str):
|
| 1701 |
+
if file_id not in UPLOADED_FILES:
|
| 1702 |
+
raise HTTPException(404, 'File not found')
|
| 1703 |
+
meta = UPLOADED_FILES.pop(file_id)
|
| 1704 |
+
try:
|
| 1705 |
+
if os.path.exists(meta['path']):
|
| 1706 |
+
os.remove(meta['path'])
|
| 1707 |
+
except Exception:
|
| 1708 |
+
pass
|
| 1709 |
+
return {'success': True}
|
| 1710 |
|
| 1711 |
if __name__ == "__main__":
|
| 1712 |
import uvicorn
|
config.local.yaml
CHANGED
|
@@ -22,3 +22,6 @@ MODELS:
|
|
| 22 |
- "\n\n"
|
| 23 |
stop_tokens:
|
| 24 |
- 0
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
- "\n\n"
|
| 23 |
stop_tokens:
|
| 24 |
- 0
|
| 25 |
+
ALLOW_WEB_SEARCH: True
|
| 26 |
+
ALLOW_TOOLS: True
|
| 27 |
+
ALLOW_REASONING: True
|
config.production-modelscope.yaml
CHANGED
|
@@ -3,6 +3,11 @@ PORT: 7860
|
|
| 3 |
STRATEGY: "cuda fp16"
|
| 4 |
RWKV_CUDA_ON: True
|
| 5 |
CHUNK_LEN: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
MODELS:
|
| 7 |
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 8 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
|
@@ -22,3 +27,6 @@ MODELS:
|
|
| 22 |
- "\n\n"
|
| 23 |
stop_tokens:
|
| 24 |
- 0
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
STRATEGY: "cuda fp16"
|
| 4 |
RWKV_CUDA_ON: True
|
| 5 |
CHUNK_LEN: 256
|
| 6 |
+
DEFAULT_STREAM: True
|
| 7 |
+
AUTO_ENABLE_TOOLS: True
|
| 8 |
+
AUTO_ENABLE_REASONING: True
|
| 9 |
+
AUTO_ENABLE_WEB_SEARCH: True
|
| 10 |
+
ENABLE_TOOLS_BY_DEFAULT: False
|
| 11 |
MODELS:
|
| 12 |
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 13 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
|
|
|
| 27 |
- "\n\n"
|
| 28 |
stop_tokens:
|
| 29 |
- 0
|
| 30 |
+
ALLOW_WEB_SEARCH: True
|
| 31 |
+
ALLOW_TOOLS: True
|
| 32 |
+
ALLOW_REASONING: True
|
config.production.yaml
CHANGED
|
@@ -3,6 +3,11 @@ PORT: 7860
|
|
| 3 |
STRATEGY: "cuda fp16"
|
| 4 |
RWKV_CUDA_ON: True
|
| 5 |
CHUNK_LEN: 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
MODELS:
|
| 7 |
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 8 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
|
@@ -11,6 +16,9 @@ MODELS:
|
|
| 11 |
REASONING: True
|
| 12 |
DEFAULT_CHAT: True
|
| 13 |
DEFAULT_REASONING: True
|
|
|
|
|
|
|
|
|
|
| 14 |
DEFAULT_SAMPLER:
|
| 15 |
max_tokens: 4096
|
| 16 |
temperature: 1.0
|
|
@@ -21,4 +29,10 @@ MODELS:
|
|
| 21 |
stop:
|
| 22 |
- "\n\n"
|
| 23 |
stop_tokens:
|
| 24 |
-
- 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
STRATEGY: "cuda fp16"
|
| 4 |
RWKV_CUDA_ON: True
|
| 5 |
CHUNK_LEN: 256
|
| 6 |
+
DEFAULT_STREAM: True
|
| 7 |
+
AUTO_ENABLE_TOOLS: True
|
| 8 |
+
AUTO_ENABLE_REASONING: True
|
| 9 |
+
AUTO_ENABLE_WEB_SEARCH: True
|
| 10 |
+
ENABLE_TOOLS_BY_DEFAULT: False
|
| 11 |
MODELS:
|
| 12 |
- SERVICE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096"
|
| 13 |
DOWNLOAD_MODEL_FILE_NAME: "rwkv7-g1a-0.1b-20250728-ctx4096.pth"
|
|
|
|
| 16 |
REASONING: True
|
| 17 |
DEFAULT_CHAT: True
|
| 18 |
DEFAULT_REASONING: True
|
| 19 |
+
ALLOW_WEB_SEARCH: True
|
| 20 |
+
ALLOW_TOOLS: True
|
| 21 |
+
ALLOW_REASONING: True
|
| 22 |
DEFAULT_SAMPLER:
|
| 23 |
max_tokens: 4096
|
| 24 |
temperature: 1.0
|
|
|
|
| 29 |
stop:
|
| 30 |
- "\n\n"
|
| 31 |
stop_tokens:
|
| 32 |
+
- 0
|
| 33 |
+
ALLOW_WEB_SEARCH: True
|
| 34 |
+
ALLOW_TOOLS: True
|
| 35 |
+
ALLOW_REASONING: True
|
| 36 |
+
STATE_STORE_PATH: "./state_store.json"
|
| 37 |
+
STATE_STORE_FLUSH_INTERVAL: 5
|
| 38 |
+
STATE_STORE_SAVE_ON_UPDATE: True
|
config.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from pydantic import BaseModel, Field
|
| 2 |
-
from typing import List, Optional
|
| 3 |
from typing import List, Optional, Union, Any
|
| 4 |
|
| 5 |
import sys
|
|
@@ -12,7 +11,7 @@ class CliConfig(BaseSettings, cli_parse_args=True, cli_use_class_docs_for_groups
|
|
| 12 |
CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
|
| 13 |
|
| 14 |
|
| 15 |
-
CLI_CONFIG = CliConfig()
|
| 16 |
|
| 17 |
|
| 18 |
class SamplerConfig(BaseModel):
|
|
@@ -26,6 +25,15 @@ class SamplerConfig(BaseModel):
|
|
| 26 |
penalty_decay: float = Field(0.996, description="Penalty decay factor.")
|
| 27 |
stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
|
| 28 |
stop_tokens: List[int] = Field([0], description="List of stop tokens.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
class ModelConfig(BaseModel):
|
|
@@ -52,26 +60,66 @@ class ModelConfig(BaseModel):
|
|
| 52 |
DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
|
| 53 |
DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
|
| 54 |
DEFAULT_SAMPLER: SamplerConfig = Field(
|
| 55 |
-
SamplerConfig(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
)
|
| 57 |
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
class RootConfig(BaseModel):
|
| 61 |
"""Root configuration for the RWKV service."""
|
| 62 |
|
| 63 |
-
HOST: Optional[str] = Field(
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
PORT: Optional[int] = Field(
|
| 67 |
-
8000, description="Port number to listen on."
|
| 68 |
-
) # 因为YAML示例中被注释掉了
|
| 69 |
-
STRATEGY: str = Field(
|
| 70 |
-
"cpu", description="Strategy for model execution (e.g., 'cuda fp16')."
|
| 71 |
-
)
|
| 72 |
RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
|
| 73 |
CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
|
| 74 |
MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
import yaml
|
|
@@ -81,4 +129,5 @@ try:
|
|
| 81 |
CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
|
| 82 |
except Exception as e:
|
| 83 |
print(f"Pydantic Model Validation Failed: {e}")
|
| 84 |
-
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel, Field
|
|
|
|
| 2 |
from typing import List, Optional, Union, Any
|
| 3 |
|
| 4 |
import sys
|
|
|
|
| 11 |
CONFIG_FILE: str = Field("./config.local.yaml", description="Config file path")
|
| 12 |
|
| 13 |
|
| 14 |
+
CLI_CONFIG = CliConfig(CONFIG_FILE="./config.local.yaml")
|
| 15 |
|
| 16 |
|
| 17 |
class SamplerConfig(BaseModel):
|
|
|
|
| 25 |
penalty_decay: float = Field(0.996, description="Penalty decay factor.")
|
| 26 |
stop: List[str] = Field(["\n\n"], description="List of stop sequences.")
|
| 27 |
stop_tokens: List[int] = Field([0], description="List of stop tokens.")
|
| 28 |
+
ALLOW_WEB_SEARCH: Optional[bool] = Field(None, description="Per-sampler override for allowing web search. If None, falls back to model/global.")
|
| 29 |
+
ALLOW_FILE_TOOL: Optional[bool] = Field(None, description="Per-sampler override for allowing file tools (e.g., file_read). If None, falls back to model/global.")
|
| 30 |
+
ALLOW_TOOLS: Optional[bool] = Field(None, description="Per-sampler override for allowing server-side tools. If None, falls back to model/global.")
|
| 31 |
+
ALLOW_REASONING: Optional[bool] = Field(None, description="Per-sampler override for allowing built-in reasoning. If None, falls back to model/global.")
|
| 32 |
+
# UI flags (non-functional in server, included so UI clients can show controls)
|
| 33 |
+
SHOW_WEB_SEARCH_BUTTON: Optional[bool] = Field(None, description="Whether to show the web-search toggle in the client UI for this sampler")
|
| 34 |
+
SHOW_FILE_UPLOAD_BUTTON: Optional[bool] = Field(None, description="Whether to show the file-upload control in the client UI for this sampler")
|
| 35 |
+
SHOW_REASONING_TOGGLE: Optional[bool] = Field(None, description="Whether to show the reasoning (think) toggle in the client UI for this sampler")
|
| 36 |
+
UI_STYLE: Optional[str] = Field(None, description="UI style hint that clients may use to render controls (example: 'whatsapp' or 'compact')")
|
| 37 |
|
| 38 |
|
| 39 |
class ModelConfig(BaseModel):
|
|
|
|
| 60 |
DEFAULT_CHAT: bool = Field(False, description="Whether this model is the default chat model.")
|
| 61 |
DEFAULT_REASONING: bool = Field(False, description="Whether this model is the default reasoning model.")
|
| 62 |
DEFAULT_SAMPLER: SamplerConfig = Field(
|
| 63 |
+
SamplerConfig(
|
| 64 |
+
max_tokens=512,
|
| 65 |
+
temperature=1.0,
|
| 66 |
+
top_p=0.3,
|
| 67 |
+
presence_penalty=0.5,
|
| 68 |
+
count_penalty=0.5,
|
| 69 |
+
penalty_decay=0.996,
|
| 70 |
+
stop=["\n\n"],
|
| 71 |
+
stop_tokens=[0],
|
| 72 |
+
ALLOW_WEB_SEARCH=None,
|
| 73 |
+
ALLOW_TOOLS=None,
|
| 74 |
+
ALLOW_REASONING=None,
|
| 75 |
+
ALLOW_FILE_TOOL=None,
|
| 76 |
+
SHOW_WEB_SEARCH_BUTTON=None,
|
| 77 |
+
SHOW_FILE_UPLOAD_BUTTON=None,
|
| 78 |
+
SHOW_REASONING_TOGGLE=None,
|
| 79 |
+
UI_STYLE=None,
|
| 80 |
+
),
|
| 81 |
+
description="Default sampler configuration for this model."
|
| 82 |
)
|
| 83 |
VOCAB: str = Field("rwkv_vocab_v20230424", description="Vocab Name")
|
| 84 |
+
# Allow or disallow server-side features on a per-model basis
|
| 85 |
+
ALLOW_WEB_SEARCH: bool = Field(True, description="Whether this model supports web search injection")
|
| 86 |
+
ALLOW_TOOLS: bool = Field(True, description="Whether this model supports server-side tools execution")
|
| 87 |
+
ALLOW_REASONING: bool = Field(True, description="Whether this model supports built-in reasoning (in-process)")
|
| 88 |
+
ALLOW_FILE_TOOL: bool = Field(True, description="Whether this model supports file-based tools (file_upload/file_read)")
|
| 89 |
+
# UI flags for the model that the client may use to show/hide controls
|
| 90 |
+
SHOW_WEB_SEARCH_BUTTON: bool = Field(True, description="Whether to show the web search toggle for this model in client UIs")
|
| 91 |
+
SHOW_FILE_UPLOAD_BUTTON: bool = Field(True, description="Whether to show a file upload button for this model in client UIs")
|
| 92 |
+
SHOW_REASONING_TOGGLE: bool = Field(True, description="Whether to show the reasoning toggle for this model in client UIs")
|
| 93 |
|
| 94 |
|
| 95 |
class RootConfig(BaseModel):
|
| 96 |
"""Root configuration for the RWKV service."""
|
| 97 |
|
| 98 |
+
HOST: Optional[str] = Field("127.0.0.1", description="Host IP address to bind to.")
|
| 99 |
+
PORT: Optional[int] = Field(8000, description="Port number to listen on.")
|
| 100 |
+
STRATEGY: str = Field("cpu", description="Strategy for model execution (e.g., 'cuda fp16').")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
RWKV_CUDA_ON: bool = Field(False, description="Whether to enable RWKV CUDA kernel.")
|
| 102 |
CHUNK_LEN: int = Field(256, description="Chunk length for processing.")
|
| 103 |
MODELS: List[ModelConfig] = Field(..., description="List of model configurations.")
|
| 104 |
+
# Additional defaults for auto behavior
|
| 105 |
+
DEFAULT_STREAM: bool = Field(True, description="Whether streaming is enabled by default")
|
| 106 |
+
AUTO_ENABLE_TOOLS: bool = Field(True, description="Whether to try auto-enabling tools based on intent")
|
| 107 |
+
AUTO_ENABLE_REASONING: bool = Field(True, description="Whether to auto-enable reasoning when needed")
|
| 108 |
+
AUTO_ENABLE_WEB_SEARCH: bool = Field(True, description="Whether to auto-enable web search based on intent")
|
| 109 |
+
ENABLE_TOOLS_BY_DEFAULT: bool = Field(False, description="Whether tools are enabled by default (without explicit request)")
|
| 110 |
+
ENABLE_WEB_SEARCH_BY_DEFAULT: bool = Field(True, description="Whether web search is enabled by default")
|
| 111 |
+
ENABLE_REASONING_BY_DEFAULT: bool = Field(True, description="Whether model reasoning is enabled by default when requested/supported")
|
| 112 |
+
# State store persistence
|
| 113 |
+
STATE_STORE_PATH: str = Field("./state_store.json", description="Path to persist streaming/resume state store")
|
| 114 |
+
STATE_STORE_FLUSH_INTERVAL: int = Field(5, description="Seconds between background flushes to the state store file")
|
| 115 |
+
STATE_STORE_SAVE_ON_UPDATE: bool = Field(True, description="Whether to save the state store to disk immediately when updated")
|
| 116 |
+
# File uploads / tools
|
| 117 |
+
UPLOAD_DIR: str = Field("./uploads", description="Directory to store uploaded files")
|
| 118 |
+
ALLOW_FILE_TOOL_BY_DEFAULT: bool = Field(True, description="Whether file-based tools are enabled by default")
|
| 119 |
+
# UI flags for the root server. These flags are advisory only and do not enable functionality.
|
| 120 |
+
SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT: bool = Field(True, description="Whether to show web search toggle by default in clients")
|
| 121 |
+
SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT: bool = Field(True, description="Whether to show file-upload control by default in clients")
|
| 122 |
+
SHOW_REASONING_TOGGLE_BY_DEFAULT: bool = Field(True, description="Whether to show reasoning toggle by default in clients")
|
| 123 |
|
| 124 |
|
| 125 |
import yaml
|
|
|
|
| 129 |
CONFIG = RootConfig.model_validate(yaml.safe_load(f.read()))
|
| 130 |
except Exception as e:
|
| 131 |
print(f"Pydantic Model Validation Failed: {e}")
|
| 132 |
+
# Exit with non-zero to indicate error when config is invalid
|
| 133 |
+
sys.exit(1)
|
models/.cache/huggingface/download/rwkv7-g1a-0.1b-20250728-ctx4096.pth.metadata
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
8c8cdf8c605dc7dfdccb676b9d0c482ba002f710
|
| 2 |
964f01cc4673273bbcf1b9c3cdc243d58af97bffeab51cb20c752eeaf048a3c6
|
| 3 |
-
|
|
|
|
| 1 |
8c8cdf8c605dc7dfdccb676b9d0c482ba002f710
|
| 2 |
964f01cc4673273bbcf1b9c3cdc243d58af97bffeab51cb20c752eeaf048a3c6
|
| 3 |
+
1763950644.4308126
|
tests/api_test.py
CHANGED
|
@@ -83,3 +83,14 @@ except Exception as e:
|
|
| 83 |
print('Error in streaming request:', e)
|
| 84 |
|
| 85 |
print('\nDone tests')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
print('Error in streaming request:', e)
|
| 84 |
|
| 85 |
print('\nDone tests')
|
| 86 |
+
|
| 87 |
+
print('\nChecking model listing endpoint')
|
| 88 |
+
try:
|
| 89 |
+
r = requests.get('http://127.0.0.1:7860/api/v1/models', timeout=10)
|
| 90 |
+
print('Models endpoint status', r.status_code)
|
| 91 |
+
try:
|
| 92 |
+
print(json.dumps(r.json(), indent=2))
|
| 93 |
+
except Exception:
|
| 94 |
+
print('Models endpoint returned non-JSON:', r.text[:200])
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print('Error calling models endpoint:', e)
|
tests/run_api_single_request.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.getcwd())
|
| 3 |
+
from app import chat_completions, ChatCompletionRequest
|
| 4 |
+
import asyncio
|
| 5 |
+
|
| 6 |
+
async def main():
|
| 7 |
+
req = ChatCompletionRequest(model='rwkv-latest', prompt='Who is the current president of France?', stream=False, max_tokens=32, temperature=0.2, include_usage=True, web_search=None, auto_web_search=True)
|
| 8 |
+
res = await chat_completions(req)
|
| 9 |
+
print(res)
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
asyncio.run(main())
|
tests/run_autodetect_flags.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.getcwd())
|
| 3 |
+
from app import ChatCompletionRequest
|
| 4 |
+
from utils import detect_tools_and_reasoning
|
| 5 |
+
from config import CONFIG
|
| 6 |
+
|
| 7 |
+
# convenience function to compute flags, basically copying the logic we used in chatResponse
|
| 8 |
+
|
| 9 |
+
def compute_flags(req: ChatCompletionRequest):
|
| 10 |
+
prompt = req.prompt if req.prompt else (req.messages and '\n\n'.join([m.content for m in req.messages]) or '')
|
| 11 |
+
detection = detect_tools_and_reasoning(prompt)
|
| 12 |
+
|
| 13 |
+
web_search_enabled = (
|
| 14 |
+
True
|
| 15 |
+
if (req.enable_web_search is not None and req.enable_web_search)
|
| 16 |
+
else (
|
| 17 |
+
req.web_search
|
| 18 |
+
or (req.auto_web_search if req.auto_web_search is not None else CONFIG.AUTO_ENABLE_WEB_SEARCH and detection.get('need_web_search'))
|
| 19 |
+
)
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
if req.enable_tools is not None:
|
| 23 |
+
tools_enabled = bool(req.enable_tools)
|
| 24 |
+
else:
|
| 25 |
+
auto_tools_flag = req.auto_tools if req.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
|
| 26 |
+
tools_enabled = bool(req.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
|
| 27 |
+
|
| 28 |
+
if req.enable_reasoning is not None:
|
| 29 |
+
reasoning_enabled = bool(req.enable_reasoning)
|
| 30 |
+
else:
|
| 31 |
+
reasoning_enabled = False
|
| 32 |
+
|
| 33 |
+
return {
|
| 34 |
+
'detection': detection,
|
| 35 |
+
'web_search_enabled': web_search_enabled,
|
| 36 |
+
'tools_enabled': tools_enabled,
|
| 37 |
+
'reasoning_enabled': reasoning_enabled,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# test cases
|
| 41 |
+
cases = [
|
| 42 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Who is the current president of France?', stream=None, auto_web_search=True, auto_tools=None, auto_reasoning=None),
|
| 43 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Calculate 2+3*4 for me', stream=None, auto_web_search=True, auto_tools=True, auto_reasoning=None),
|
| 44 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Explain why the sky is blue', stream=None, auto_web_search=False, auto_tools=None, auto_reasoning=True),
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
for c in cases:
|
| 48 |
+
print('---')
|
| 49 |
+
print(c.prompt)
|
| 50 |
+
print(compute_flags(c))
|
tests/run_chat_response.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, asyncio
|
| 2 |
+
sys.path.append(os.getcwd())
|
| 3 |
+
from app import chatResponse, ChatCompletionRequest
|
| 4 |
+
|
| 5 |
+
async def test():
|
| 6 |
+
req = ChatCompletionRequest(model='rwkv-latest', prompt='Who is the president of France today?', stream=False, max_tokens=2, temperature=0.2, include_usage=True, auto_web_search=True)
|
| 7 |
+
res = await chatResponse(req, model_state=None, completionId='test123', enableReasoning=False)
|
| 8 |
+
print(res.model_dump())
|
| 9 |
+
|
| 10 |
+
if __name__ == '__main__':
|
| 11 |
+
asyncio.run(test())
|
tests/run_chat_response_out.txt
ADDED
|
Binary file (7.31 kB). View file
|
|
|
tests/run_detect.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os
|
| 2 |
+
sys.path.append(os.getcwd())
|
| 3 |
+
from utils import detect_tools_and_reasoning
|
| 4 |
+
|
| 5 |
+
print(detect_tools_and_reasoning('Who is the president of France today?'))
|
| 6 |
+
print(detect_tools_and_reasoning('Calculate 2+3*4 for me'))
|
| 7 |
+
print(detect_tools_and_reasoning('Explain why the sky is blue'))
|
tests/run_injected_tools.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.getcwd())
|
| 3 |
+
from app import ChatCompletionRequest
|
| 4 |
+
from utils import detect_tools_and_reasoning
|
| 5 |
+
from config import CONFIG
|
| 6 |
+
from pprint import pprint
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def emulate_injection(req: ChatCompletionRequest):
|
| 10 |
+
raw_prompt = req.prompt.strip() if req.prompt is not None else '\n\n'.join([m.content for m in req.messages]) if req.messages else ''
|
| 11 |
+
detection = detect_tools_and_reasoning(raw_prompt)
|
| 12 |
+
|
| 13 |
+
# compute web_search_enabled
|
| 14 |
+
web_search_enabled = (
|
| 15 |
+
True
|
| 16 |
+
if (req.enable_web_search is not None and req.enable_web_search)
|
| 17 |
+
else (
|
| 18 |
+
req.web_search
|
| 19 |
+
or (req.auto_web_search if req.auto_web_search is not None else CONFIG.AUTO_ENABLE_WEB_SEARCH and detection.get('need_web_search'))
|
| 20 |
+
)
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if req.enable_tools is not None:
|
| 24 |
+
tools_enabled = bool(req.enable_tools)
|
| 25 |
+
else:
|
| 26 |
+
# Respect sampler override if present
|
| 27 |
+
if req.sampler and isinstance(req.sampler, dict) and req.sampler.get('ALLOW_TOOLS') is not None:
|
| 28 |
+
tools_enabled = bool(req.sampler.get('ALLOW_TOOLS'))
|
| 29 |
+
else:
|
| 30 |
+
auto_tools_flag = req.auto_tools if req.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
|
| 31 |
+
tools_enabled = bool(req.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
|
| 32 |
+
|
| 33 |
+
if req.enable_reasoning is not None:
|
| 34 |
+
reasoning_enabled = bool(req.enable_reasoning)
|
| 35 |
+
else:
|
| 36 |
+
reasoning_enabled = False
|
| 37 |
+
|
| 38 |
+
# If tools_enabled and not provided, add detected tools
|
| 39 |
+
if tools_enabled and not req.tools and detection.get('detected_tools'):
|
| 40 |
+
req.tools = detection.get('detected_tools')
|
| 41 |
+
|
| 42 |
+
# If web_search should be used, and not already set, set flag
|
| 43 |
+
if web_search_enabled and not req.web_search:
|
| 44 |
+
req.web_search = True
|
| 45 |
+
|
| 46 |
+
return {
|
| 47 |
+
'raw_prompt': raw_prompt,
|
| 48 |
+
'detection': detection,
|
| 49 |
+
'web_search_enabled': web_search_enabled,
|
| 50 |
+
'tools_enabled': tools_enabled,
|
| 51 |
+
'reasoning_enabled': reasoning_enabled,
|
| 52 |
+
'req': req.model_dump(),
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
# test cases
|
| 56 |
+
cases = [
|
| 57 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Who is the current president of France?', stream=None, auto_web_search=True, auto_tools=None, auto_reasoning=None),
|
| 58 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Calculate 2+3*4 for me', stream=None, auto_web_search=True, auto_tools=True, auto_reasoning=None),
|
| 59 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Explain why the sky is blue', stream=None, auto_web_search=False, auto_tools=None, auto_reasoning=True),
|
| 60 |
+
# Sampler override should disable web_search even though auto_web_search is True
|
| 61 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Who is the current president of France?', stream=None, auto_web_search=True, auto_tools=None, auto_reasoning=None, sampler_allow_web_search=False),
|
| 62 |
+
# Per-request sampler object also should disable tools
|
| 63 |
+
ChatCompletionRequest(model='rwkv-latest', prompt='Calculate 2+3*4 for me', stream=None, auto_web_search=True, auto_tools=None, auto_reasoning=None, sampler= { 'ALLOW_TOOLS': False }),
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
for c in cases:
|
| 67 |
+
print('---')
|
| 68 |
+
pprint(emulate_injection(c))
|
tests/test_client_api.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi.testclient import TestClient
|
| 2 |
+
from app import app
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
client = TestClient(app)
|
| 6 |
+
|
| 7 |
+
print('Non-streaming test')
|
| 8 |
+
payload = {
|
| 9 |
+
'model': 'rwkv-latest',
|
| 10 |
+
'prompt': 'Who is the president of France today?',
|
| 11 |
+
'stream': False,
|
| 12 |
+
'max_tokens': 64,
|
| 13 |
+
'temperature': 0.2,
|
| 14 |
+
'include_usage': True,
|
| 15 |
+
}
|
| 16 |
+
res = client.post('/api/v1/chat/completions', json=payload)
|
| 17 |
+
print('Status', res.status_code)
|
| 18 |
+
try:
|
| 19 |
+
print(json.dumps(res.json(), indent=2))
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print('Response not JSON or parse failed', e)
|
| 22 |
+
|
| 23 |
+
print('\nTools calc test')
|
| 24 |
+
payload = {
|
| 25 |
+
'model': 'rwkv-latest',
|
| 26 |
+
'prompt': 'Calculate 2+3*4 and explain the result.',
|
| 27 |
+
'stream': False,
|
| 28 |
+
'tools': [{'name': 'calc', 'args': {'expression': '2+3*4'}}],
|
| 29 |
+
}
|
| 30 |
+
res = client.post('/api/v1/chat/completions', json=payload)
|
| 31 |
+
print('Status', res.status_code)
|
| 32 |
+
try:
|
| 33 |
+
print(json.dumps(res.json(), indent=2))
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print('Response not JSON or parse failed', e)
|
| 36 |
+
|
| 37 |
+
print('\nTools web_search test')
|
| 38 |
+
payload = {
|
| 39 |
+
'model': 'rwkv-latest',
|
| 40 |
+
'prompt': 'Who is the current president of France?',
|
| 41 |
+
'stream': False,
|
| 42 |
+
'web_search': True,
|
| 43 |
+
'search_top_k': 2,
|
| 44 |
+
}
|
| 45 |
+
res = client.post('/api/v1/chat/completions', json=payload)
|
| 46 |
+
print('Status', res.status_code)
|
| 47 |
+
try:
|
| 48 |
+
print(json.dumps(res.json(), indent=2))
|
| 49 |
+
except Exception as e:
|
| 50 |
+
print('Response not JSON or parse failed', e)
|
tests/test_universal_and_detect.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
sys.path.append(os.getcwd())
|
| 3 |
+
from utils import universal_tool, detect_tools_and_reasoning
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_universal_calc():
|
| 7 |
+
res = universal_tool({"query":"2+3*4"})
|
| 8 |
+
assert isinstance(res, dict)
|
| 9 |
+
assert res.get('action') == 'calc'
|
| 10 |
+
assert 'result' in res
|
| 11 |
+
assert str(res.get('result')) == '14'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_universal_web_search():
|
| 15 |
+
res = universal_tool({"query":"Who is the president of France?"})
|
| 16 |
+
assert isinstance(res, dict)
|
| 17 |
+
assert res.get('action') in ('web_search', 'calc', 'unknown') or True
|
| 18 |
+
assert 'result' in res
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_detect_calc():
|
| 22 |
+
d = detect_tools_and_reasoning('Calculate 2+3*4 for me')
|
| 23 |
+
assert d.get('need_calc')
|
| 24 |
+
assert any(t.get('name') == 'calc' for t in d.get('detected_tools', []))
|
| 25 |
+
conf = d.get('confidence') or {}
|
| 26 |
+
assert conf.get('calc_confidence', 0) > 0.5
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_detect_web_search():
|
| 30 |
+
d = detect_tools_and_reasoning('Who is the president of France?')
|
| 31 |
+
assert d.get('need_web_search')
|
| 32 |
+
assert any(t.get('name') == 'web_search' for t in d.get('detected_tools', []))
|
| 33 |
+
conf = d.get('confidence') or {}
|
| 34 |
+
assert conf.get('web_search_confidence', 0) > 0.5
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
test_universal_calc()
|
| 39 |
+
test_universal_web_search()
|
| 40 |
+
test_detect_calc()
|
| 41 |
+
test_detect_web_search()
|
| 42 |
+
print('All tests passed')
|
utils.py
CHANGED
|
@@ -78,11 +78,13 @@ def logger():
|
|
| 78 |
while True:
|
| 79 |
item = LOGGER_QUEUE.get()
|
| 80 |
try:
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
except Exception:
|
| 87 |
pass
|
| 88 |
|
|
@@ -175,3 +177,184 @@ def calc(expr: str) -> str:
|
|
| 175 |
return str(result)
|
| 176 |
except Exception as e:
|
| 177 |
return f"ERROR: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
while True:
|
| 79 |
item = LOGGER_QUEUE.get()
|
| 80 |
try:
|
| 81 |
+
LOG_PORT = os.environ.get("LOG_PORT")
|
| 82 |
+
if LOG_PORT:
|
| 83 |
+
requests.post(
|
| 84 |
+
LOG_PORT,
|
| 85 |
+
headers={"Content-Type": "application/json"},
|
| 86 |
+
json=item,
|
| 87 |
+
)
|
| 88 |
except Exception:
|
| 89 |
pass
|
| 90 |
|
|
|
|
| 177 |
return str(result)
|
| 178 |
except Exception as e:
|
| 179 |
return f"ERROR: {e}"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def detect_tools_and_reasoning(text_or_messages) -> dict:
|
| 183 |
+
"""Detects whether web_search, calc, or reasoning are likely needed based on heuristics.
|
| 184 |
+
|
| 185 |
+
Accepts either a single string prompt or a list of ChatMessage. Returns a dict with booleans and detected tools list.
|
| 186 |
+
"""
|
| 187 |
+
if isinstance(text_or_messages, list):
|
| 188 |
+
try:
|
| 189 |
+
text = "\n\n".join([m.get('content', '') if isinstance(m, dict) else (getattr(m, 'content', '') or '') for m in text_or_messages if m])
|
| 190 |
+
except Exception:
|
| 191 |
+
text = ""
|
| 192 |
+
else:
|
| 193 |
+
text = str(text_or_messages or "")
|
| 194 |
+
|
| 195 |
+
t = text.lower()
|
| 196 |
+
# Simple heuristics
|
| 197 |
+
need_calc = False
|
| 198 |
+
need_web_search = False
|
| 199 |
+
need_reasoning = False
|
| 200 |
+
need_universal = False
|
| 201 |
+
detected_tools = []
|
| 202 |
+
|
| 203 |
+
# Heuristic for calc: presence of operators AND numbers OR keywords 'calculate/compute' plus numeric tokens
|
| 204 |
+
if (re.search(r"\d+\s*[-+*/%]\s*\d+", t) or (re.search(r"\b(calculate|compute|solve|evaluate|sum|add|subtract|multiply|divide)\b", t) and re.search(r"\d", t))):
|
| 205 |
+
need_calc = True
|
| 206 |
+
# Try to extract a most-likely arithmetic expression from the text
|
| 207 |
+
# Accept digits, parentheses and operators
|
| 208 |
+
m = re.search(r"([\d\(\)\s+\-*/%^.]+)", text)
|
| 209 |
+
expr = m.group(0).strip() if m else None
|
| 210 |
+
# only keep if it includes an operator
|
| 211 |
+
if expr and not re.search(r"[-+*/%]", expr):
|
| 212 |
+
expr = None
|
| 213 |
+
detected_tools.append({"name": "calc", "args": {"expression": expr, "confidence": 0.95 if expr else 0.5}})
|
| 214 |
+
|
| 215 |
+
# Heuristic for web search: 'who is', 'what is', 'current', 'latest', 'news', or question words with facts
|
| 216 |
+
# Heuristic for web search: question words + facts or 'current/latest' signals; avoid math queries
|
| 217 |
+
if (
|
| 218 |
+
re.search(r"\b(who is|who's|what is|what's|when is|where is|current|latest|news|is the president|president of|population of|capital of|how many|GDP of)\b", t)
|
| 219 |
+
and not re.search(r"\d+\s*[-+*/%]\s*\d+", t)
|
| 220 |
+
):
|
| 221 |
+
need_web_search = True
|
| 222 |
+
detected_tools.append({"name": "web_search", "args": {"query": text, "confidence": 0.9}})
|
| 223 |
+
|
| 224 |
+
# Heuristic for reasoning: words like 'explain', 'why', 'reason', 'prove', 'derive', 'compare'
|
| 225 |
+
if re.search(r"\b(explain|why|because|reason|prove|derive|compare|analysis|analysis:|evaluate|argue|consequence|trade-offs)\b", t):
|
| 226 |
+
need_reasoning = True
|
| 227 |
+
|
| 228 |
+
# Heuristic for universal tool: requests to "use tool", "execute tool", or generic function-call language
|
| 229 |
+
if re.search(r"\b(use (a )?tool|execute (a )?tool|call (a )?tool|function call|run tool|do this via a tool|invoke tool|call tool)\b", t):
|
| 230 |
+
need_universal = True
|
| 231 |
+
|
| 232 |
+
# compute confidence summary
|
| 233 |
+
# For now, we use a simple heuristic: reasoning >0.8 if key words present; web_search 0.9; calc 0.95 if numeric
|
| 234 |
+
confs = {
|
| 235 |
+
"calc_confidence": 0.95 if need_calc else 0.0,
|
| 236 |
+
"web_search_confidence": 0.9 if need_web_search else 0.0,
|
| 237 |
+
"reasoning_confidence": 0.85 if need_reasoning else 0.0,
|
| 238 |
+
"universal_confidence": 0.65 if need_universal else 0.0,
|
| 239 |
+
}
|
| 240 |
+
return {
|
| 241 |
+
"need_calc": need_calc,
|
| 242 |
+
"need_web_search": need_web_search,
|
| 243 |
+
"need_reasoning": need_reasoning,
|
| 244 |
+
"need_universal": need_universal,
|
| 245 |
+
"detected_tools": detected_tools,
|
| 246 |
+
"confidence": confs,
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def ensure_upload_dir():
|
| 251 |
+
from config import CONFIG
|
| 252 |
+
try:
|
| 253 |
+
os.makedirs(CONFIG.UPLOAD_DIR, exist_ok=True)
|
| 254 |
+
except Exception:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
from typing import Optional
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def save_bytes_to_upload(filename: Optional[str], data: bytes) -> dict:
|
| 262 |
+
from config import CONFIG
|
| 263 |
+
import hashlib, time, uuid
|
| 264 |
+
|
| 265 |
+
ensure_upload_dir()
|
| 266 |
+
_id = str(uuid.uuid4())
|
| 267 |
+
safe_name = f"{_id}_{os.path.basename(str(filename or 'uploaded_file'))}"
|
| 268 |
+
path = os.path.join(CONFIG.UPLOAD_DIR, safe_name)
|
| 269 |
+
try:
|
| 270 |
+
with open(path, 'wb') as f:
|
| 271 |
+
f.write(data)
|
| 272 |
+
size = os.path.getsize(path)
|
| 273 |
+
import mimetypes
|
| 274 |
+
mime_type = mimetypes.guess_type(path)[0]
|
| 275 |
+
return {
|
| 276 |
+
'file_id': _id,
|
| 277 |
+
'filename': filename,
|
| 278 |
+
'path': path,
|
| 279 |
+
'mime_type': mime_type,
|
| 280 |
+
'size': size,
|
| 281 |
+
'uploaded_at': int(time.time()),
|
| 282 |
+
}
|
| 283 |
+
except Exception as e:
|
| 284 |
+
return {'error': str(e)}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def file_read_from_path(path: str, max_bytes: int = 100000) -> str:
|
| 288 |
+
try:
|
| 289 |
+
if not path or not os.path.exists(path):
|
| 290 |
+
return ""
|
| 291 |
+
with open(path, 'rb') as f:
|
| 292 |
+
b = f.read(max_bytes)
|
| 293 |
+
try:
|
| 294 |
+
return b.decode('utf-8', errors='replace')
|
| 295 |
+
except Exception:
|
| 296 |
+
return str(b)
|
| 297 |
+
except Exception:
|
| 298 |
+
return ""
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def universal_tool(args: dict, allow_web_search: bool = True, allow_tools: bool = True, allow_file_tool: bool = True) -> dict:
|
| 302 |
+
"""Universal tool: if 'action' is provided, call the corresponding tool; otherwise autodetect using heuristics.
|
| 303 |
+
|
| 304 |
+
Supported actions: 'calc', 'web_search'. If the action is not provided, attempt to detect the appropriate tool.
|
| 305 |
+
Returns a string result for prompt injection.
|
| 306 |
+
"""
|
| 307 |
+
if not isinstance(args, dict):
|
| 308 |
+
return {"error": "ERROR: invalid args for universal tool"}
|
| 309 |
+
|
| 310 |
+
action = args.get("action")
|
| 311 |
+
query = args.get("query")
|
| 312 |
+
# explicit action
|
| 313 |
+
if action == "calc":
|
| 314 |
+
if not allow_tools:
|
| 315 |
+
return {"action": "calc", "result": None, "metadata": {"error": "disabled_by_policy", "confidence": 0.0}}
|
| 316 |
+
expr = args.get("expression") or query
|
| 317 |
+
if not expr:
|
| 318 |
+
return {"action": "calc", "result": None, "metadata": {"error": "no expression provided", "confidence": 0.0}}
|
| 319 |
+
res = calc(str(expr))
|
| 320 |
+
return {"action": "calc", "result": str(res), "metadata": {"expression": expr, "confidence": 0.98}}
|
| 321 |
+
if action == "web_search":
|
| 322 |
+
if not allow_web_search:
|
| 323 |
+
return {"action": "web_search", "result": "", "metadata": {"error": "disabled_by_policy", "confidence": 0.0}}
|
| 324 |
+
q = args.get("query") or query
|
| 325 |
+
if not q:
|
| 326 |
+
return {"action": "web_search", "result": "", "metadata": {"confidence": 0.0}}
|
| 327 |
+
res = web_search(str(q), int(args.get("top_k") or 3))
|
| 328 |
+
return {"action": "web_search", "result": str(res), "metadata": {"query": q, "top_k": int(args.get("top_k") or 3), "confidence": 0.9}}
|
| 329 |
+
if action == 'file_read':
|
| 330 |
+
if not allow_file_tool:
|
| 331 |
+
return {"action": "file_read", "result": None, "metadata": {"error": "disabled_by_policy", "confidence": 0.0}}
|
| 332 |
+
fpath = args.get('path') or args.get('file_path')
|
| 333 |
+
if not fpath and args.get('file_id'):
|
| 334 |
+
from config import CONFIG
|
| 335 |
+
fid = args.get('file_id')
|
| 336 |
+
if fid:
|
| 337 |
+
candidate = os.path.join(CONFIG.UPLOAD_DIR, os.path.basename(str(fid)))
|
| 338 |
+
else:
|
| 339 |
+
candidate = None
|
| 340 |
+
if candidate and os.path.exists(candidate):
|
| 341 |
+
fpath = candidate
|
| 342 |
+
if not fpath:
|
| 343 |
+
return {"action": "file_read", "result": None, "metadata": {"error": "no_path_or_id", "confidence": 0.0}}
|
| 344 |
+
content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000))
|
| 345 |
+
return {"action": "file_read", "result": str(content), "metadata": {"path": fpath, "confidence": 0.9}}
|
| 346 |
+
# auto-detect based on query content
|
| 347 |
+
if query:
|
| 348 |
+
# if expression - use calc
|
| 349 |
+
if re.search(r"\d+\s*[-+*/%]\s*\d+", str(query)):
|
| 350 |
+
if not allow_tools:
|
| 351 |
+
return {"action": "calc", "result": None, "metadata": {"error": "disabled_by_policy", "confidence": 0.0}}
|
| 352 |
+
res = calc(str(query))
|
| 353 |
+
return {"action": "calc", "result": str(res), "metadata": {"expression": str(query), "confidence": 0.95}}
|
| 354 |
+
# else, web_search
|
| 355 |
+
if not allow_web_search:
|
| 356 |
+
return {"action": "web_search", "result": "", "metadata": {"error": "disabled_by_policy", "confidence": 0.0}}
|
| 357 |
+
res = web_search(str(query), int(args.get("top_k") or 3))
|
| 358 |
+
return {"action": "web_search", "result": str(res), "metadata": {"query": str(query), "top_k": int(args.get("top_k") or 3), "confidence": 0.9}}
|
| 359 |
+
|
| 360 |
+
return {"error": "ERROR: could not determine action for universal tool"}
|