diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,1716 +1,1599 @@
-import os
-
-if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
- from modelscope import patch_hub
-
- patch_hub()
-
-os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
-
-
-from config import CONFIG, ModelConfig
-from utils import (
- cleanMessages,
- parse_think_response,
- remove_nested_think_tags_stack,
- format_bytes,
- log,
- detect_tools_and_reasoning,
- universal_tool,
-)
-
-import copy, types, gc, sys, re, time, collections, asyncio
-from huggingface_hub import hf_hub_download
-from loguru import logger
-from rich import print
-
-from snowflake import SnowflakeGenerator
-
-CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
-
-from typing import List, Optional, Union, Any, Dict
-import uuid
-from pydantic import BaseModel, Field, model_validator
-from pydantic_settings import BaseSettings
-
-
-import numpy as np
-import torch
-
-
-if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
- logger.info(f"CUDA not found, fall back to cpu")
- CONFIG.STRATEGY = "cpu fp16"
-# Normalize STRATEGY to include precision if missing (e.g., 'cpu' -> 'cpu fp16')
-_s = CONFIG.STRATEGY.lower()
-if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s):
- logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`")
- CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16"
-
-
-try:
- from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
-except Exception:
- nvmlInit = None
- nvmlDeviceGetHandleByIndex = None
- nvmlDeviceGetMemoryInfo = None
-
-if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None:
- nvmlInit()
- gpu_h = nvmlDeviceGetHandleByIndex(0)
-
-
-def logGPUState():
- if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None:
- gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
- logger.info(
- f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
- )
-
-
-torch.backends.cudnn.benchmark = True
-torch.backends.cudnn.allow_tf32 = True
-torch.backends.cuda.matmul.allow_tf32 = True
-os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models
-os.environ["RWKV_JIT_ON"] = "1"
-os.environ["RWKV_CUDA_ON"] = (
- "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
-)
-
-from rwkv.model import RWKV
-from rwkv.utils import PIPELINE, PIPELINE_ARGS
-
-from fastapi import FastAPI, HTTPException, UploadFile, File
-from starlette.background import BackgroundTask
-from fastapi.responses import StreamingResponse
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.staticfiles import StaticFiles
-from fastapi.middleware.gzip import GZipMiddleware
-
-
-from api_types import (
- ChatMessage,
- ChatCompletion,
- ChatCompletionChunk,
- Usage,
- PromptTokensDetails,
- ChatCompletionChoice,
- ChatCompletionMessage,
- SamplerConfig,
- UploadedFile,
- FileUploadResponse,
-)
-
-
-class ModelStorage:
- MODEL_CONFIG: Optional[ModelConfig] = None
- model: Optional[RWKV] = None
- pipeline: Optional[PIPELINE] = None
-
-
-MODEL_STORAGE: Dict[str, ModelStorage] = {}
-
-DEFALUT_MODEL_NAME = None
-DEFAULT_REASONING_MODEL_NAME = None
-
-# In-memory model state store to support streaming continuation/resume per state_name.
-# Keys: (model_name, state_name) -> dict with 'state' and 'model_tokens'
-STATE_STORE: Dict[tuple, Any] = {}
-
-# Serialized state store file path and flush interval defined in CONFIG
-_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json')
-_LAST_STATE_STORE_WRITE = 0
-
-# sentinel for model-initiated tool calls: {json}
-TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S)
-
-# File uploads: simple in-memory index (persisted on disk via the files themselves)
-UPLOADED_FILES: Dict[str, dict] = {}
-
-
-def _serialize_state_store() -> dict:
- # Save only model_tokens to disk; model_state (torch objects) are not serializable
- serial = {}
- for (model_name, state_name), entry in STATE_STORE.items():
- try:
- mt = entry.get('model_tokens') if isinstance(entry, dict) else None
- if mt is None:
- # if entry is a raw model_state, skip
- continue
- serial[f"{model_name}|{state_name}"] = {
- 'model': model_name,
- 'state_name': state_name,
- 'model_tokens': mt,
- }
- except Exception:
- continue
- return serial
-
-
-def _load_state_store_from_disk():
- global STATE_STORE
- try:
- if os.path.exists(_STATE_STORE_PATH):
- import json
-
- with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f:
- data = json.load(f)
- for k, v in data.items():
- model = v.get('model')
- state_name = v.get('state_name')
- model_tokens = v.get('model_tokens')
- if model and state_name and isinstance(model_tokens, list):
- STATE_STORE[(model, state_name)] = {
- 'state': None,
- 'model_tokens': model_tokens,
- }
- logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}")
- except Exception as e:
- logger.info(f"Failed to load state store from disk: {e}")
-
-
-def _save_state_store_to_disk(force=False):
- global _LAST_STATE_STORE_WRITE
- now = time.time()
- if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5):
- return
- try:
- serial = _serialize_state_store()
- if not serial:
- return
- import json
- tmp = _STATE_STORE_PATH + ".tmp"
- with open(tmp, 'w', encoding='utf-8') as f:
- json.dump(serial, f)
- os.replace(tmp, _STATE_STORE_PATH)
- _LAST_STATE_STORE_WRITE = now
- except Exception as e:
- logger.info(f"Write state store to disk failed: {e}")
-
-
-def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]):
- """
- Recompute the `out` logits and `model_state` by forwarding through tokens in chunks.
- Returns a tuple (out, model_state).
- """
- ms = MODEL_STORAGE.get(model_name)
- if not ms or not ms.model:
- return None, None
- model_state = None
- out = None
- tokens = list(model_tokens) if isinstance(model_tokens, list) else [0]
- while len(tokens) > 0:
- out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
- tokens = tokens[CONFIG.CHUNK_LEN :]
- return out, model_state
-
-logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
-
-logGPUState()
-
-# Keep any configured models intact; do not force selection by name/size.
-# The previous policy enforced a single '0.1b' model which hid additional configs; use the full list.
-logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config")
-
-for model_config in CONFIG.MODELS:
- logger.info(f"Load Model - {model_config.SERVICE_NAME}")
-
- if model_config.MODEL_FILE_PATH == None:
- model_config.MODEL_FILE_PATH = hf_hub_download(
- repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID),
- filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME),
- local_dir=str(model_config.DOWNLOAD_MODEL_DIR),
- )
- logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
-
- if model_config.DEFAULT_CHAT:
- if DEFALUT_MODEL_NAME != None:
- logger.info(
- f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
- )
- DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
-
- if model_config.DEFAULT_REASONING:
- if DEFAULT_REASONING_MODEL_NAME != None:
- logger.info(
- f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
- )
- DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
-
- logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
- print(model_config.DEFAULT_SAMPLER)
-
- MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
- MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
- MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
- model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
- strategy=CONFIG.STRATEGY,
- )
- MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
- MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
- )
- if "cuda" in CONFIG.STRATEGY:
- torch.cuda.empty_cache()
- gc.collect()
- logGPUState()
-
-
-logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
-logger.info(
- f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
-)
-
-if len(MODEL_STORAGE) == 1:
- single_name = list(MODEL_STORAGE.keys())[0]
- if DEFALUT_MODEL_NAME != single_name:
- DEFALUT_MODEL_NAME = single_name
- logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`")
- if DEFAULT_REASONING_MODEL_NAME != single_name:
- DEFAULT_REASONING_MODEL_NAME = single_name
- logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`")
-
-
-class ChatCompletionRequest(BaseModel):
- model: str = Field(
- default="rwkv-latest",
- description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`",
- )
- messages: Optional[List[ChatMessage]] = Field(default=None)
- prompt: Optional[str] = Field(default=None)
- max_tokens: Optional[int] = Field(default=None)
- temperature: Optional[float] = Field(default=None)
- top_p: Optional[float] = Field(default=None)
- presence_penalty: Optional[float] = Field(default=None)
- count_penalty: Optional[float] = Field(default=None)
- penalty_decay: Optional[float] = Field(default=None)
- stream: Optional[bool] = Field(default=True, description="Whether to stream token-by-token responses. If None, uses CONFIG.DEFAULT_STREAM")
- state_name: Optional[str] = Field(default=None)
- include_usage: Optional[bool] = Field(default=False)
- stop: Optional[list[str]] = Field(["\n\n"])
- stop_tokens: Optional[list[int]] = Field([0])
- web_search: Optional[bool] = Field(default=True, description="Whether to perform a web search and append results to the prompt")
- enable_web_search: Optional[bool] = Field(default=True, description="Explicitly enable web search (overrides auto/web_search) if set")
- auto_web_search: Optional[bool] = Field(default=True, description="Whether to enable web_search based on auto-detected intent")
- enable_tools: Optional[bool] = Field(default=None, description="Explicitly enable tools (overrides auto detection)")
- auto_tools: Optional[bool] = Field(default=True, description="Whether to enable tools based on auto-detected intent")
- enable_reasoning: Optional[bool] = Field(default=True, description="Explicitly override reasoning enablement")
- auto_reasoning: Optional[bool] = Field(default=True, description="Whether to enable reasoning based on auto detection")
- enable_universal: Optional[bool] = Field(default=None, description="Explicitly enable the universal tool execution")
- auto_universal: Optional[bool] = Field(default=True, description="Whether to auto enable universal tool execution")
- search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve")
- tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})")
- # Per-request sampler overrides for ALLOW_* flags. These let the user
- # disable server-side features for this particular request if needed.
- sampler_allow_web_search: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing web_search")
- sampler_allow_tools: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing tools")
- sampler_allow_reasoning: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing reasoning")
- # Per-request sampler config object; if provided, these settings will
- # override the model defaults for this request.
- sampler: Optional[SamplerConfig] = Field(default=None, description="Per-request sampler settings (overrides model default)")
- # File uploads: allow referencing uploaded files in the request
- file_ids: Optional[List[str]] = Field(default=None, description="List of uploaded file IDs that the model may use for this request")
- enable_file_tool: Optional[bool] = Field(default=None, description="Explicitly enable file-based tools for this request")
- auto_file_tool: Optional[bool] = Field(default=None, description="Auto-detect whether file-based tools are needed")
- sampler_allow_file_tool: Optional[bool] = Field(default=None, description="Per-request sampler override allowing file tools")
-
- @model_validator(mode="before")
- @classmethod
- def validate_mutual_exclusivity(cls, data: Any) -> Any:
- if not isinstance(data, dict):
- return data
-
- messages_provided = "messages" in data and data["messages"] != None
- prompt_provided = "prompt" in data and data["prompt"] != None
-
- if messages_provided and prompt_provided:
- raise ValueError("messages and prompt cannot coexist. Choose one.")
- if not messages_provided and not prompt_provided:
- raise ValueError("Either messages or prompt must be provided.")
- return data
-
-
-app = FastAPI(title="RWKV OpenAI-Compatible API")
-
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
-
-
-@app.on_event("startup")
-async def _startup_state_load_and_persist_loop():
- # Load previous persisted state (tokens only) at startup
- _load_state_store_from_disk()
-
- async def _persist_loop():
- while True:
- try:
- _save_state_store_to_disk(force=False)
- except Exception:
- pass
- await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5))
-
- # Spawn background flush task
- try:
- asyncio.create_task(_persist_loop())
- except Exception:
- pass
-
-
-async def runPrefill(
- request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
-):
- ctx = ctx.replace("\r\n", "\n")
- out = None
-
- ms = MODEL_STORAGE.get(request.model)
- if not ms or not ms.pipeline or not ms.model:
- raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
- tokens = ms.pipeline.encode(ctx)
- tokens = [int(x) for x in tokens]
- model_tokens += tokens
-
- while len(tokens) > 0:
- out, model_state = ms.model.forward(
- tokens[: CONFIG.CHUNK_LEN], model_state
- )
- tokens = tokens[CONFIG.CHUNK_LEN :]
- await asyncio.sleep(0)
-
- return out, model_tokens, model_state
-
-
-def generate(
- request: ChatCompletionRequest,
- out,
- model_tokens: List[int],
- model_state,
- max_tokens=2048,
-):
- ms = MODEL_STORAGE.get(request.model)
- if not ms or not ms.pipeline or not ms.model:
- raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
-
- temperature = request.temperature if request.temperature is not None else 0.2
- top_p = request.top_p if request.top_p is not None else 0.9
- alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0
- alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0
- penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5
-
- args = PIPELINE_ARGS(
- temperature=max(0.2, temperature),
- top_p=top_p,
- alpha_frequency=alpha_frequency,
- alpha_presence=alpha_presence,
- token_ban=[], # ban the generation of some tokens
- token_stop=[0],
- ) # stop generation whenever you see any token here
-
- occurrence = {}
- out_tokens: List[int] = []
- out_last = 0
-
- # Stream token-by-token; each chunk contains a single decoded token string.
-
- for i in range(max_tokens):
- for n in occurrence:
- out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
- # out[0] -= 1e10 # disable END_OF_TEXT
-
- token = ms.pipeline.sample_logits(
- out, temperature=args.temperature, top_p=args.top_p
- )
-
- if token == 0 and request.stop_tokens and token in request.stop_tokens:
- yield {
- "content": "",
- "tokens": out_tokens[out_last:],
- "finish_reason": "stop:token:0",
- "state": model_state,
- }
-
- del out
- gc.collect()
- return
-
- out, model_state = ms.model.forward([token], model_state)
- model_tokens.append(token)
- out_tokens.append(token)
-
- if request.stop_tokens and token in request.stop_tokens:
- yield {
- "content": "",
- "tokens": out_tokens[out_last:],
- "finish_reason": f"stop:token:{token}",
- "state": model_state,
- }
-
- del out
- gc.collect()
- return
-
- for xxx in list(occurrence.keys()):
- occurrence[xxx] *= penalty_decay
- occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
-
- # Decode token to text and yield it as a single-token chunk
- decoded = ms.pipeline.decode([token])
- # filter out replacement characters
- if "\ufffd" in decoded:
- continue
-
- yield {
- "content": decoded,
- "tokens": [token],
- "finish_reason": None,
- "state": model_state,
- }
- out_last = i + 1
-
- else:
- yield {
- "content": "",
- "tokens": [],
- "finish_reason": "length",
- }
-
-
-async def chatResponse(
- request: ChatCompletionRequest,
- model_state: Any,
- completionId: str,
- enableReasoning: bool,
-) -> ChatCompletion:
- createTimestamp = time.time()
-
- # Build raw prompt for detection (prefer explicit request.prompt, else messages)
- raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [])
- # Intent detection: analyze raw_prompt or messages to auto-activate tools/web-search/reasoning
- detection = detect_tools_and_reasoning(raw_prompt)
- # After computing auto flags, build the actual prompt string to include if needed
- prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' 0) or (auto_file_flag and request.file_ids))
- # Respect root-level defaults
- if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
- file_tool_enabled = False
- # Per-request sampler overrides
- try:
- if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
- file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
- elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
- file_tool_enabled = bool(request.sampler_allow_file_tool)
- else:
- ms = MODEL_STORAGE.get(request.model)
- if ms and ms.MODEL_CONFIG:
- if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
- file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
- elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL:
- file_tool_enabled = False
- except Exception:
- pass
-
- # Decide whether tools should be used
- if request.enable_tools is not None:
- tools_enabled = bool(request.enable_tools)
- else:
- # if explicit tools provided, or enable by default config, or auto detection suggests
- auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
- tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
- # Respect sampler-level override (request.sampler.ALLOW_TOOLS), then
- # request.sampler_allow_tools, then sampler default and finally model-level allow
- try:
- if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None:
- tools_enabled = bool(request.sampler.ALLOW_TOOLS)
- elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None:
- tools_enabled = bool(request.sampler_allow_tools)
- else:
- ms = MODEL_STORAGE.get(request.model)
- if ms and ms.MODEL_CONFIG:
- if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None:
- if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS:
- tools_enabled = False
- elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS:
- tools_enabled = False
- except Exception:
- pass
-
- # Decide whether reasoning should be enabled (in addition to :thinking or explicit)
- reasoning_enabled = bool(
- True
- if (request.enable_reasoning is not None and request.enable_reasoning)
- else (
- bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning'))))
- )
- )
- # If the root config sets reasoning to disabled by default and no explicit request to enable, disable it
- if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None:
- reasoning_enabled = False
- # Respect sampler-level override for reasoning: request.sampler.ALLOW_REASONING -> sampler_allow_reasoning -> sampler.default -> model
- try:
- if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None:
- reasoning_enabled = bool(request.sampler.ALLOW_REASONING)
- elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None:
- reasoning_enabled = bool(request.sampler_allow_reasoning)
- else:
- ms = MODEL_STORAGE.get(request.model)
- if ms and ms.MODEL_CONFIG:
- if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None:
- if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING:
- reasoning_enabled = False
- elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
- reasoning_enabled = False
- except Exception:
- pass
-
- # Keep the local boolean for generating content
- enableReasoning = reasoning_enabled
- try:
- ms = MODEL_STORAGE.get(request.model)
- if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
- enableReasoning = False
- except Exception:
- pass
-
- # Ensure web_search property mirrors computed web_search_enabled if not explicitly provided
- if request.enable_web_search is None:
- request.web_search = web_search_enabled
- # If tools should be automatically enabled, add detected ones
- if tools_enabled and not request.tools:
- if detection.get('detected_tools'):
- request.tools = detection.get('detected_tools')
- # If universal is needed and not explicitly requested, add universal tool
- if (request.enable_universal is True) or (
- request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal'))
- ):
- if not request.tools:
- request.tools = [{"name": "universal", "args": {"query": raw_prompt}}]
-
- executed_tool_calls = []
- # If file tools are enabled and files are attached, inject them into the prompt (for streaming)
- if file_tool_enabled and request.file_ids:
- for fid in request.file_ids:
- try:
- if fid not in UPLOADED_FILES:
- continue
- meta = UPLOADED_FILES.get(fid)
- if not meta:
- continue
- from utils import file_read_from_path
- fpath = meta.get('path')
- if not fpath or not os.path.exists(fpath):
- continue
- file_content = file_read_from_path(fpath, 200000)
- if file_content:
- exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
- executed_tool_calls.append(exec_entry)
- prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
- except Exception as e:
- logger.info(f"File injection error: {e}")
- # If file tools are enabled and files are attached, inject them into the prompt
- if file_tool_enabled and request.file_ids:
- for fid in request.file_ids:
- try:
- if fid not in UPLOADED_FILES:
- continue
- meta = UPLOADED_FILES.get(fid)
- if not meta:
- continue
- from utils import file_read_from_path
- fpath = meta.get('path')
- if not fpath or not os.path.exists(fpath):
- continue
- file_content = file_read_from_path(fpath, 200000)
- if file_content:
- exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
- executed_tool_calls.append(exec_entry)
- prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
- except Exception as e:
- logger.info(f"File injection error: {e}")
- if request.tools:
- try:
- for tool in request.tools:
- name = tool.get('name')
- args = tool.get('args', {})
- if name == 'web_search':
- from utils import web_search
-
- search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
- search_top_k = int(args.get('top_k') or request.search_top_k or 3)
- search_str = web_search(search_q, search_top_k)
- if search_str:
- search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
- executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
- prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt)
- elif name == 'calc' or name == 'calculator':
- from utils import calc
-
- expr = args.get('expression')
- if expr:
- calc_res = calc(expr)
- # Wrap result into a structured dict
- calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
- executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
- prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt)
- elif name == 'universal':
- try:
- res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
- # If universal_tool returns a dict, extract text result for prompt injection
- if isinstance(res, dict):
- result_text = res.get('result') if res.get('result') is not None else ''
- else:
- result_text = str(res)
- executed_tool_calls.append({"name": "universal", "args": args, "result": res})
- prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt)
- except Exception as e:
- logger.info(f"Universal tool execution error: {e}")
- else:
- # Unsupported tool - ignore or log
- logger.info(f"Unsupported tool requested: {name}")
- if name == 'file_read':
- # read an uploaded file by id/path
- try:
- fid = args.get('file_id') or args.get('id') or (request.file_ids[0] if request.file_ids else None)
- if not fid:
- continue
- if fid not in UPLOADED_FILES:
- continue
- meta = UPLOADED_FILES.get(fid)
- if not meta:
- continue
- from utils import file_read_from_path
- fpath = meta.get('path')
- if not fpath or not os.path.exists(fpath):
- continue
- file_content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000))
- exec_entry = {"name": "file_read", "args": {"file_id": fid, "max_bytes": int(args.get('max_bytes') or 100000)}, "result": {"action": "file_read", "result": file_content, "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
- executed_tool_calls.append(exec_entry)
- _res = exec_entry.get('result') if isinstance(exec_entry, dict) else None
- _res_text = ''
- if isinstance(_res, dict):
- _res_text = _res.get('result') or ''
- elif _res is not None:
- _res_text = str(_res)
- prompt = (f"ToolResults:\n{_res_text}\n\nUse these file contents to answer the prompt.\n\n" + prompt)
- except Exception as e:
- logger.info(f"file_read tool error: {e}")
- except Exception as e:
- logger.info(f"Tool processing error: {e}")
- elif request.web_search or web_search_enabled:
- try:
- from utils import web_search
-
- search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
- search_res = web_search(search_q, int(request.search_top_k or 3))
- if search_res:
- search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
- executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
- prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
- except Exception:
- pass
- logger.info(f"[REQ] {completionId} - prompt - {prompt}")
-
- # Resume or prefill tokens/state
- if request.state_name:
- state_key = (request.model, request.state_name)
- if state_key in STATE_STORE:
- stored = STATE_STORE[state_key]
- model_state = stored.get('state', None)
- model_tokens = stored.get('model_tokens', [0])
- if model_state is None:
- # Recompute out and model_state from tokens since we did not persist the torch state
- out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
- else:
- # If we have a model_state, we still need out logits. Compute from last window of tokens
- out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
- else:
- out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
- else:
- out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
-
- prefillTime = time.time()
- promptTokenCount = len(model_tokens)
-
- fullResponse = " 0) or (auto_file_flag and request.file_ids))
- if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
- file_tool_enabled = False
- try:
- if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
- file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
- elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
- file_tool_enabled = bool(request.sampler_allow_file_tool)
- else:
- ms2 = MODEL_STORAGE.get(request.model)
- if ms2 and ms2.MODEL_CONFIG:
- if hasattr(ms2.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms2.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
- file_tool_enabled = bool(ms2.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
- elif hasattr(ms2.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms2.MODEL_CONFIG.ALLOW_FILE_TOOL:
- file_tool_enabled = False
- except Exception:
- pass
- # Build final prompt after deciding enableReasoning
- prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' 0 and r_dict['choices'][0].get('delta') is not None:
- r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls
- except Exception:
- pass
- yield f"data: {r_dict}\n\n"
-
- buffer = []
-
- if enableReasoning:
- buffer.append(" tag
- "fullTextCursor": 0,
- "in_think": False,
- "cacheStr": "",
- }
-
- for chunk in generate(
- request,
- out,
- model_tokens,
- model_state,
- max_tokens=(
- 64000
- if "max_tokens" not in request.model_fields_set and enableReasoning
- else (request.max_tokens or 2048)
- ),
- ):
- completionTokenCount += 1
- # Each token stream is delivered as a decoded character/bytes (maybe 1 or more chars)
- chunkContent: str = chunk["content"]
- buffer.append(chunkContent)
-
- fullText = "".join(buffer)
-
- if chunk["finish_reason"]:
- finishReason = chunk["finish_reason"]
-
- response = ChatCompletionChunk(
- id=completionId,
- created=createTimestamp,
- model=request.model,
- usage=(
- Usage(
- prompt_tokens=promptTokenCount,
- completion_tokens=completionTokenCount,
- total_tokens=promptTokenCount + completionTokenCount,
- prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
- )
- if request.include_usage
- else None
- ),
- choices=[
- ChatCompletionChoice(
- index=0,
- delta=ChatCompletionMessage(
- role="Assistant",
- content=None,
- reasoning_content=None,
- tool_calls=None,
- ),
- logprobs=None,
- finish_reason=finishReason,
- )
- ],
- )
- if response.choices and response.choices[0].delta is None:
- response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
-
- markStart = fullText.find("<", streamConfig["fullTextCursor"])
- if not streamConfig["isChecking"] and markStart != -1:
- streamConfig["isChecking"] = True
-
- if streamConfig["in_think"]:
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.reasoning_content = fullText[streamConfig["fullTextCursor"] : markStart]
- else:
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.content = fullText[streamConfig["fullTextCursor"] : markStart]
-
- streamConfig["cacheStr"] = ""
- streamConfig["fullTextCursor"] = markStart
-
- if streamConfig["isChecking"]:
- streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
- else:
- if streamConfig["in_think"]:
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.reasoning_content = chunkContent
- else:
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.content = chunkContent
- streamConfig["fullTextCursor"] = len(fullText)
-
- markEnd = fullText.find(">", streamConfig["fullTextCursor"])
- if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
- streamConfig["isChecking"] = False
-
- if (
- not streamConfig["in_think"]
- and streamConfig["cacheStr"].find("") != -1
- ):
- streamConfig["in_think"] = True
-
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.reasoning_content = (
- delta.reasoning_content
- if delta.reasoning_content != None
- else "" + streamConfig["cacheStr"].replace("", "")
- )
-
- elif (
- streamConfig["in_think"]
- and streamConfig["cacheStr"].find("") != -1
- ):
- streamConfig["in_think"] = False
-
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.content = (
- delta.content
- if delta.content != None
- else "" + streamConfig["cacheStr"].replace("", "")
- )
- else:
- if streamConfig["in_think"]:
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.reasoning_content = (
- delta.reasoning_content
- if delta.reasoning_content != None
- else "" + streamConfig["cacheStr"]
- )
- else:
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- delta.content = (
- delta.content
- if delta.content != None
- else "" + streamConfig["cacheStr"]
- )
- streamConfig["fullTextCursor"] = len(fullText)
-
- delta = response.choices[0].delta
- if delta is None:
- delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
- response.choices[0].delta = delta
- if delta.content != None or delta.reasoning_content != None:
- # Save model state frequently (after each token) to allow resuming
- try:
- if request.state_name:
- STATE_STORE[(request.model, request.state_name)] = {
- 'state': model_state,
- 'model_tokens': model_tokens,
- }
- if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
- try:
- _save_state_store_to_disk(force=True)
- except Exception:
- pass
- except Exception:
- pass
- # model-initiated tool call detection
- if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
- m = TOOL_CALL_RE.search(fullText)
- if m:
- try:
- payload_raw = m.group(1)
- import json
-
- payload = json.loads(payload_raw)
- tool_name = payload.get('name')
- tool_args = payload.get('args', {})
- tool_res = None
- if tool_name == 'web_search':
- from utils import web_search
-
- q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
- k = int(tool_args.get('top_k') or request.search_top_k or 3)
- tool_res = web_search(q, k)
- elif tool_name in ('calc', 'calculator'):
- from utils import calc
-
- expr = tool_args.get('expression')
- if expr:
- tool_res = calc(expr)
- else:
- try:
- tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
- except Exception:
- tool_res = None
-
- if tool_res:
- # Normalize tool_res into a structured dict if needed
- if not isinstance(tool_res, dict):
- if tool_name in ('calc', 'calculator'):
- tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
- elif tool_name == 'web_search':
- tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
- else:
- tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
- else:
- tool_res_struct = tool_res
- exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
- executed_tool_calls.append(exec_entry)
- delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
- prompt = delta_text + prompt
- fullText = TOOL_CALL_RE.sub('', fullText)
- buffer = [fullText]
- out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
- model_initiated_tool_calls += 1
- should_restart = True
- break
- except Exception as e:
- logger.info(f"Model-initiated tool handling error: {e}")
- yield f"data: {response.model_dump_json()}\n\n"
- # check stop sequences and stop streaming if we see them
- for stop_words in request.stop or []:
- if stop_words in ''.join(buffer):
- finishReason = f"stop:words:{stop_words}"
- return
-
- await asyncio.sleep(0)
-
- del streamConfig
- else:
- should_restart = True
- while should_restart:
- should_restart = False
- gen = generate(request, out, model_tokens, model_state)
- for chunk in gen:
- completionTokenCount += 1
- buffer.append(chunk["content"])
-
- if chunk["finish_reason"]:
- finishReason = chunk["finish_reason"]
-
- # Save model state frequently (after each token) to allow resuming
- try:
- if request.state_name:
- STATE_STORE[(request.model, request.state_name)] = {
- 'state': model_state,
- 'model_tokens': model_tokens,
- }
- if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
- try:
- _save_state_store_to_disk(force=True)
- except Exception:
- pass
- except Exception:
- pass
-
- # Detect model-initiated tool calls
- if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
- fullText = ''.join(buffer)
- m = TOOL_CALL_RE.search(fullText)
- if m:
- try:
- payload_raw = m.group(1)
- import json
-
- payload = json.loads(payload_raw)
- tool_name = payload.get('name')
- tool_args = payload.get('args', {})
- tool_res = None
- if tool_name == 'web_search':
- from utils import web_search
-
- q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
- k = int(tool_args.get('top_k') or request.search_top_k or 3)
- tool_res = web_search(q, k)
- elif tool_name in ('calc', 'calculator'):
- from utils import calc
-
- expr = tool_args.get('expression')
- if expr:
- tool_res = calc(expr)
- else:
- try:
- tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
- except Exception:
- tool_res = None
-
- if tool_res:
- if not isinstance(tool_res, dict):
- if tool_name in ('calc', 'calculator'):
- tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
- elif tool_name == 'web_search':
- tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
- else:
- tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
- else:
- tool_res_struct = tool_res
- exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
- executed_tool_calls.append(exec_entry)
- delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
- prompt = delta_text + prompt
- fullText = TOOL_CALL_RE.sub('', fullText)
- buffer = [fullText]
- out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
- # Notify client that a tool was called mid-stream (metadata-only chunk)
- try:
- meta_resp = ChatCompletionChunk(
- id=completionId,
- created=createTimestamp,
- model=request.model,
- usage=(
- Usage(
- prompt_tokens=promptTokenCount,
- completion_tokens=completionTokenCount,
- total_tokens=promptTokenCount + completionTokenCount,
- prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
- )
- if request.include_usage
- else None
- ),
- choices=[
- ChatCompletionChoice(
- index=0,
- delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
- logprobs=None,
- finish_reason=None,
- )
- ],
- )
- yield f"data: {meta_resp.model_dump_json()}\n\n"
- except Exception:
- pass
- model_initiated_tool_calls += 1
- should_restart = True
- break
- except Exception as e:
- logger.info(f"Model-initiated tool handling error: {e}")
-
- response = ChatCompletionChunk(
- id=completionId,
- created=createTimestamp,
- model=request.model,
- usage=(
- Usage(
- prompt_tokens=promptTokenCount,
- completion_tokens=completionTokenCount,
- total_tokens=promptTokenCount + completionTokenCount,
- prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
- )
- if request.include_usage
- else None
- ),
- choices=[
- ChatCompletionChoice(
- index=0,
- delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
- logprobs=None,
- finish_reason=finishReason,
- )
- ],
- )
- yield f"data: {response.model_dump_json()}\n\n"
- await asyncio.sleep(0)
-
- genenrateTime = time.time()
-
- responseLog = {
- "content": "".join(buffer),
- "finish": finishReason,
- "prefill_len": promptTokenCount,
- "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
- "gen_len": completionTokenCount,
- "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
- }
- logger.info(f"[RES] {completionId} - {responseLog}")
- if request.messages is None:
- request.messages = []
- request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"]))
- log(
- {
- **request.model_dump(),
- **responseLog,
- "completionId": completionId,
- "machineLabel": os.environ.get("MACHINE_LABEL"),
- }
- )
-
- del buffer
-
- yield "data: [DONE]\n\n"
-
-
-@app.post("/api/v1/chat/completions")
-async def chat_completions(request: ChatCompletionRequest):
- completionId = str(next(CompletionIdGenerator))
- logger.info(f"[REQ] {completionId} - {request.model_dump()}")
-
- # Support model suffixes like ':thinking' for reasoning or ':web' to request
- # web search by default for this request. E.g., 'rwkv-latest:web' will enable web_search.
- modelName = request.model.split(":")[0]
- if ":web" in request.model:
- request.enable_web_search = True
- if ":file" in request.model:
- request.enable_file_tool = True
- enableReasoning = ":thinking" in request.model
-
- if "rwkv-latest" in request.model:
- # Map to the default chat model in all cases. Do not redirect to a separate
- # reasoning model when ':thinking' is used. The same model will be used
- # and reasoning handled in-process by setting enableReasoning=True.
- if DEFALUT_MODEL_NAME == None:
- raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
- ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME)
- if not ms_def or not ms_def.MODEL_CONFIG:
- raise HTTPException(500, "Default sampler config missing for default model")
- defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER
- request.model = DEFALUT_MODEL_NAME
-
- elif modelName in MODEL_STORAGE:
- ms_sel = MODEL_STORAGE.get(modelName)
- if not ms_sel or not ms_sel.MODEL_CONFIG:
- raise HTTPException(500, f"Default sampler config missing for model {modelName}")
- defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER
- request.model = modelName
- else:
- raise HTTPException(404, f"Can not find `{modelName}`")
-
- async def chatResponseStreamDisconnect():
- logGPUState()
-
- # Load or initialize model_state and tokens based on state_name
- model_state = None
- model_tokens_for_resume = [0]
- state_name = request.state_name
- if state_name is None:
- state_name = str(uuid.uuid4())
- request.state_name = state_name
- state_key = (request.model, state_name)
- if state_key in STATE_STORE:
- stored = STATE_STORE[state_key]
- model_state = stored.get('state', None)
- model_tokens_for_resume = stored.get('model_tokens', [0])
- request_dict = request.model_dump()
-
- # Apply defaults from model's DEFAULT_SAMPLER, optionally overridden by the
- # per-request `sampler` object (or legacy sampler_allow_* booleans).
- sampler_overrides = request_dict.get('sampler') or {}
- for k, v in defaultSamplerConfig.model_dump().items():
- # If the request provided a sampler override for this field, use it
- if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None:
- request_dict[k] = sampler_overrides.get(k)
- continue
- if k in request_dict and request_dict[k] is None:
- request_dict[k] = v
- realRequest = ChatCompletionRequest(**request_dict)
- # Ensure stream defaults to configuration value when not explicitly provided
- if realRequest.stream is None:
- realRequest.stream = CONFIG.DEFAULT_STREAM
-
- logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
-
- if realRequest.stream:
- r = StreamingResponse(
- chatResponseStream(realRequest, model_state, completionId, enableReasoning),
- media_type="text/event-stream",
- background=BackgroundTask(chatResponseStreamDisconnect),
- )
- else:
- r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
- # Attach state_name to non-streaming response as additional metadata
- try:
- import json
-
- if isinstance(r, ChatCompletion):
- d = r.model_dump()
- d['state_name'] = state_name
- return d
- except Exception:
- pass
-
- return r
-
-
-# We keep the service API-only; remove static mount for demo frontend to
-# avoid serving HTML files by default and keep the repository Python-only.
-logger.info("Static frontend mount removed for Python-only deploy; use API endpoints for integration")
-
-
-@app.get('/api/v1/models')
-def list_models():
- """Return model configuration summary for clients/UI.
-
- This endpoint returns configured models, their default sampler values, and
- ALLOW_* flags so UI clients can build a controls surface based on server
- capabilities (web search, tools, reasoning).
- """
- out = []
- root_defaults = {
- 'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
- 'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True),
- 'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True),
- 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True),
- 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True),
- 'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True),
- 'UPLOAD_URL': '/api/v1/files',
- }
- for m in CONFIG.MODELS:
- out.append(
- {
- 'SERVICE_NAME': m.SERVICE_NAME,
- 'DEFAULT_CHAT': m.DEFAULT_CHAT,
- 'DEFAULT_REASONING': m.DEFAULT_REASONING,
- 'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True),
- 'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True),
- 'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True),
- 'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True),
- 'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True),
- 'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True),
- 'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True),
- 'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None,
- # Convenience info for clients: upload endpoint and root defaults
- 'UPLOAD_URL': '/api/v1/files',
- 'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
- }
- )
- return {'root_defaults': root_defaults, 'models': out}
-
-
-@app.post('/api/v1/files', response_model=FileUploadResponse)
-async def upload_file(file: UploadFile = File(...), model: Optional[str] = None):
- """Save uploaded file to CONFIG.UPLOAD_DIR and return metadata."""
- try:
- # Respect root-level upload toggle
- if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True):
- raise HTTPException(403, 'File uploads are disabled by server configuration')
- # If a model is provided, verify the model allows file tools
- if model:
- if model not in MODEL_STORAGE:
- raise HTTPException(404, f"Model {model} not found")
- ms = MODEL_STORAGE[model]
- if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True):
- raise HTTPException(403, f"Model {model} does not allow file uploads")
- from utils import save_bytes_to_upload
-
- content = await file.read()
- fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file'
- meta = save_bytes_to_upload(fname, content)
- if meta.get('error'):
- raise HTTPException(500, f"Could not save file: {meta.get('error')}")
- UPLOADED_FILES[meta['file_id']] = meta
- return FileUploadResponse(success=True, file=UploadedFile(**meta))
- except Exception as e:
- raise HTTPException(500, str(e))
-
-
-@app.get('/api/v1/files')
-def list_files():
- return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()]
-
-
-@app.get('/api/v1/files/{file_id}')
-def get_file(file_id: str, download: bool = False):
- if file_id not in UPLOADED_FILES:
- raise HTTPException(404, 'File not found')
- meta = UPLOADED_FILES[file_id]
- if download:
- # return file contents
- try:
- with open(meta['path'], 'rb') as f:
- return StreamingResponse(f, media_type='application/octet-stream')
- except Exception as e:
- raise HTTPException(500, str(e))
- return UploadedFile(**meta)
-
-
-@app.delete('/api/v1/files/{file_id}')
-def delete_file(file_id: str):
- if file_id not in UPLOADED_FILES:
- raise HTTPException(404, 'File not found')
- meta = UPLOADED_FILES.pop(file_id)
- try:
- if os.path.exists(meta['path']):
- os.remove(meta['path'])
- except Exception:
- pass
- return {'success': True}
-
-if __name__ == "__main__":
- import uvicorn
-
- host = CONFIG.HOST or "127.0.0.1"
- port = CONFIG.PORT or 7860
- uvicorn.run(app, host=host, port=port)
+import os
+
+if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
+ from modelscope import patch_hub
+ patch_hub()
+
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
+
+from config import CONFIG, ModelConfig
+from utils import (
+ cleanMessages,
+ parse_think_response,
+ remove_nested_think_tags_stack,
+ format_bytes,
+ log,
+ detect_tools_and_reasoning,
+ universal_tool,
+)
+
+import copy, types, gc, sys, re, time, collections, asyncio
+from huggingface_hub import hf_hub_download
+from loguru import logger
+from rich import print
+
+from snowflake import SnowflakeGenerator
+
+CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
+
+from typing import List, Optional, Union, Any, Dict
+import uuid
+from pydantic import BaseModel, Field, model_validator
+from pydantic_settings import BaseSettings
+
+import numpy as np
+import torch
+
+if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
+ logger.info(f"CUDA not found, fall back to cpu")
+ CONFIG.STRATEGY = "cpu fp16"
+
+_s = CONFIG.STRATEGY.lower()
+if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s):
+ logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`")
+ CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16"
+
+try:
+ from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
+except Exception:
+ nvmlInit = None
+ nvmlDeviceGetHandleByIndex = None
+ nvmlDeviceGetMemoryInfo = None
+
+if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None:
+ nvmlInit()
+ gpu_h = nvmlDeviceGetHandleByIndex(0)
+
+def logGPUState():
+ if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None:
+ gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
+ logger.info(
+ f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
+ )
+
+torch.backends.cudnn.benchmark = True
+torch.backends.cudnn.allow_tf32 = True
+torch.backends.cuda.matmul.allow_tf32 = True
+os.environ["RWKV_V7_ON"] = "1"
+os.environ["RWKV_JIT_ON"] = "1"
+os.environ["RWKV_CUDA_ON"] = (
+ "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
+)
+
+from rwkv.model import RWKV
+from rwkv.utils import PIPELINE, PIPELINE_ARGS
+
+from fastapi import FastAPI, HTTPException, UploadFile, File
+from starlette.background import BackgroundTask
+from fastapi.responses import StreamingResponse
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.staticfiles import StaticFiles
+from fastapi.middleware.gzip import GZipMiddleware
+
+from api_types import (
+ ChatMessage,
+ ChatCompletion,
+ ChatCompletionChunk,
+ Usage,
+ PromptTokensDetails,
+ ChatCompletionChoice,
+ ChatCompletionMessage,
+ SamplerConfig,
+ UploadedFile,
+ FileUploadResponse,
+)
+
+class ModelStorage:
+ MODEL_CONFIG: Optional[ModelConfig] = None
+ model: Optional[RWKV] = None
+ pipeline: Optional[PIPELINE] = None
+
+MODEL_STORAGE: Dict[str, ModelStorage] = {}
+
+DEFALUT_MODEL_NAME = None
+DEFAULT_REASONING_MODEL_NAME = None
+
+STATE_STORE: Dict[tuple, Any] = {}
+
+_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json')
+_LAST_STATE_STORE_WRITE = 0
+
+TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S)
+
+UPLOADED_FILES: Dict[str, dict] = {}
+
+def _serialize_state_store() -> dict:
+ serial = {}
+ for (model_name, state_name), entry in STATE_STORE.items():
+ try:
+ mt = entry.get('model_tokens') if isinstance(entry, dict) else None
+ if mt is None:
+ continue
+ serial[f"{model_name}|{state_name}"] = {
+ 'model': model_name,
+ 'state_name': state_name,
+ 'model_tokens': mt,
+ }
+ except Exception:
+ continue
+ return serial
+
+def _load_state_store_from_disk():
+ global STATE_STORE
+ try:
+ if os.path.exists(_STATE_STORE_PATH):
+ import json
+
+ with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ for k, v in data.items():
+ model = v.get('model')
+ state_name = v.get('state_name')
+ model_tokens = v.get('model_tokens')
+ if model and state_name and isinstance(model_tokens, list):
+ STATE_STORE[(model, state_name)] = {
+ 'state': None,
+ 'model_tokens': model_tokens,
+ }
+ logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}")
+ except Exception as e:
+ logger.info(f"Failed to load state store from disk: {e}")
+
+def _save_state_store_to_disk(force=False):
+ global _LAST_STATE_STORE_WRITE
+ now = time.time()
+ if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5):
+ return
+ try:
+ serial = _serialize_state_store()
+ if not serial:
+ return
+ import json
+ tmp = _STATE_STORE_PATH + ".tmp"
+ with open(tmp, 'w', encoding='utf-8') as f:
+ json.dump(serial, f)
+ os.replace(tmp, _STATE_STORE_PATH)
+ _LAST_STATE_STORE_WRITE = now
+ except Exception as e:
+ logger.info(f"Write state store to disk failed: {e}")
+
+def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]):
+ ms = MODEL_STORAGE.get(model_name)
+ if not ms or not ms.model:
+ return None, None
+ model_state = None
+ out = None
+ tokens = list(model_tokens) if isinstance(model_tokens, list) else [0]
+ while len(tokens) > 0:
+ out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
+ tokens = tokens[CONFIG.CHUNK_LEN :]
+ return out, model_state
+
+logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
+
+logGPUState()
+
+logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config")
+
+for model_config in CONFIG.MODELS:
+ logger.info(f"Load Model - {model_config.SERVICE_NAME}")
+
+ if model_config.MODEL_FILE_PATH == None:
+ model_config.MODEL_FILE_PATH = hf_hub_download(
+ repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID),
+ filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME),
+ local_dir=str(model_config.DOWNLOAD_MODEL_DIR),
+ )
+ logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
+
+ if model_config.DEFAULT_CHAT:
+ if DEFALUT_MODEL_NAME != None:
+ logger.info(
+ f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
+ )
+ DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
+
+ if model_config.DEFAULT_REASONING:
+ if DEFAULT_REASONING_MODEL_NAME != None:
+ logger.info(
+ f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
+ )
+ DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
+
+ logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
+ print(model_config.DEFAULT_SAMPLER)
+
+ MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
+ MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
+ MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
+ model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
+ strategy=CONFIG.STRATEGY,
+ )
+ MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
+ MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
+ )
+ if "cuda" in CONFIG.STRATEGY:
+ torch.cuda.empty_cache()
+ gc.collect()
+ logGPUState()
+
+logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
+logger.info(
+ f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
+)
+
+if len(MODEL_STORAGE) == 1:
+ single_name = list(MODEL_STORAGE.keys())[0]
+ if DEFALUT_MODEL_NAME != single_name:
+ DEFALUT_MODEL_NAME = single_name
+ logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`")
+ if DEFAULT_REASONING_MODEL_NAME != single_name:
+ DEFAULT_REASONING_MODEL_NAME = single_name
+ logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`")
+
+class ChatCompletionRequest(BaseModel):
+ model: str = Field(default="rwkv-latest")
+ messages: Optional[List[ChatMessage]] = Field(default=None)
+ prompt: Optional[str] = Field(default=None)
+ max_tokens: Optional[int] = Field(default=None)
+ temperature: Optional[float] = Field(default=None)
+ top_p: Optional[float] = Field(default=None)
+ presence_penalty: Optional[float] = Field(default=None)
+ count_penalty: Optional[float] = Field(default=None)
+ penalty_decay: Optional[float] = Field(default=None)
+ stream: Optional[bool] = Field(default=True)
+ state_name: Optional[str] = Field(default=None)
+ include_usage: Optional[bool] = Field(default=False)
+ stop: Optional[list[str]] = Field(["\n\n"])
+ stop_tokens: Optional[list[int]] = Field([0])
+ web_search: Optional[bool] = Field(default=True)
+ enable_web_search: Optional[bool] = Field(default=True)
+ auto_web_search: Optional[bool] = Field(default=True)
+ enable_tools: Optional[bool] = Field(default=None)
+ auto_tools: Optional[bool] = Field(default=True)
+ enable_reasoning: Optional[bool] = Field(default=True)
+ auto_reasoning: Optional[bool] = Field(default=True)
+ enable_universal: Optional[bool] = Field(default=None)
+ auto_universal: Optional[bool] = Field(default=True)
+ search_top_k: Optional[int] = Field(default=3)
+ tools: Optional[List[Dict[str, Any]]] = Field(default=None)
+ sampler_allow_web_search: Optional[bool] = Field(default=None)
+ sampler_allow_tools: Optional[bool] = Field(default=None)
+ sampler_allow_reasoning: Optional[bool] = Field(default=None)
+ sampler: Optional[SamplerConfig] = Field(default=None)
+ file_ids: Optional[List[str]] = Field(default=None)
+ enable_file_tool: Optional[bool] = Field(default=None)
+ auto_file_tool: Optional[bool] = Field(default=None)
+ sampler_allow_file_tool: Optional[bool] = Field(default=None)
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_mutual_exclusivity(cls, data: Any) -> Any:
+ if not isinstance(data, dict):
+ return data
+
+ messages_provided = "messages" in data and data["messages"] != None
+ prompt_provided = "prompt" in data and data["prompt"] != None
+
+ if messages_provided and prompt_provided:
+ raise ValueError("messages and prompt cannot coexist. Choose one.")
+ if not messages_provided and not prompt_provided:
+ raise ValueError("Either messages or prompt must be provided.")
+ return data
+
+app = FastAPI(title="RWKV OpenAI-Compatible API")
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
+
+@app.on_event("startup")
+async def _startup_state_load_and_persist_loop():
+ _load_state_store_from_disk()
+
+ async def _persist_loop():
+ while True:
+ try:
+ _save_state_store_to_disk(force=False)
+ except Exception:
+ pass
+ await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5))
+
+ try:
+ asyncio.create_task(_persist_loop())
+ except Exception:
+ pass
+
+async def runPrefill(
+ request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
+):
+ ctx = ctx.replace("\r\n", "\n")
+ out = None
+
+ ms = MODEL_STORAGE.get(request.model)
+ if not ms or not ms.pipeline or not ms.model:
+ raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
+ tokens = ms.pipeline.encode(ctx)
+ tokens = [int(x) for x in tokens]
+ model_tokens += tokens
+
+ while len(tokens) > 0:
+ out, model_state = ms.model.forward(
+ tokens[: CONFIG.CHUNK_LEN], model_state
+ )
+ tokens = tokens[CONFIG.CHUNK_LEN :]
+ await asyncio.sleep(0)
+
+ return out, model_tokens, model_state
+
+def generate(
+ request: ChatCompletionRequest,
+ out,
+ model_tokens: List[int],
+ model_state,
+ max_tokens=2048,
+):
+ ms = MODEL_STORAGE.get(request.model)
+ if not ms or not ms.pipeline or not ms.model:
+ raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
+
+ temperature = request.temperature if request.temperature is not None else 0.2
+ top_p = request.top_p if request.top_p is not None else 0.9
+ alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0
+ alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0
+ penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5
+
+ args = PIPELINE_ARGS(
+ temperature=max(0.2, temperature),
+ top_p=top_p,
+ alpha_frequency=alpha_frequency,
+ alpha_presence=alpha_presence,
+ token_ban=[],
+ token_stop=[0],
+ )
+
+ occurrence = {}
+ out_tokens: List[int] = []
+ out_last = 0
+
+ for i in range(max_tokens):
+ for n in occurrence:
+ out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
+
+ token = ms.pipeline.sample_logits(
+ out, temperature=args.temperature, top_p=args.top_p
+ )
+
+ if token == 0 and request.stop_tokens and token in request.stop_tokens:
+ yield {
+ "content": "",
+ "tokens": out_tokens[out_last:],
+ "finish_reason": "stop:token:0",
+ "state": model_state,
+ }
+
+ del out
+ gc.collect()
+ return
+
+ out, model_state = ms.model.forward([token], model_state)
+ model_tokens.append(token)
+ out_tokens.append(token)
+
+ if request.stop_tokens and token in request.stop_tokens:
+ yield {
+ "content": "",
+ "tokens": out_tokens[out_last:],
+ "finish_reason": f"stop:token:{token}",
+ "state": model_state,
+ }
+
+ del out
+ gc.collect()
+ return
+
+ for xxx in list(occurrence.keys()):
+ occurrence[xxx] *= penalty_decay
+ occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
+
+ decoded = ms.pipeline.decode([token])
+ if "\ufffd" in decoded:
+ continue
+
+ yield {
+ "content": decoded,
+ "tokens": [token],
+ "finish_reason": None,
+ "state": model_state,
+ }
+ out_last = i + 1
+
+ else:
+ yield {
+ "content": "",
+ "tokens": [],
+ "finish_reason": "length",
+ }
+
+async def chatResponse(
+ request: ChatCompletionRequest,
+ model_state: Any,
+ completionId: str,
+ enableReasoning: bool,
+) -> ChatCompletion:
+ createTimestamp = time.time()
+
+ raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [])
+ detection = detect_tools_and_reasoning(raw_prompt)
+ prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' 0) or (auto_file_flag and request.file_ids))
+ if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
+ file_tool_enabled = False
+ try:
+ if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
+ file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
+ elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
+ file_tool_enabled = bool(request.sampler_allow_file_tool)
+ else:
+ ms = MODEL_STORAGE.get(request.model)
+ if ms and ms.MODEL_CONFIG:
+ if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
+ file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
+ elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL:
+ file_tool_enabled = False
+ except Exception:
+ pass
+
+ if request.enable_tools is not None:
+ tools_enabled = bool(request.enable_tools)
+ else:
+ auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
+ tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
+ try:
+ if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None:
+ tools_enabled = bool(request.sampler.ALLOW_TOOLS)
+ elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None:
+ tools_enabled = bool(request.sampler_allow_tools)
+ else:
+ ms = MODEL_STORAGE.get(request.model)
+ if ms and ms.MODEL_CONFIG:
+ if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None:
+ if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS:
+ tools_enabled = False
+ elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS:
+ tools_enabled = False
+ except Exception:
+ pass
+
+ reasoning_enabled = bool(
+ True
+ if (request.enable_reasoning is not None and request.enable_reasoning)
+ else (
+ bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning'))))
+ )
+ )
+ if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None:
+ reasoning_enabled = False
+ try:
+ if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None:
+ reasoning_enabled = bool(request.sampler.ALLOW_REASONING)
+ elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None:
+ reasoning_enabled = bool(request.sampler_allow_reasoning)
+ else:
+ ms = MODEL_STORAGE.get(request.model)
+ if ms and ms.MODEL_CONFIG:
+ if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None:
+ if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING:
+ reasoning_enabled = False
+ elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
+ reasoning_enabled = False
+ except Exception:
+ pass
+
+ enableReasoning = reasoning_enabled
+ try:
+ ms = MODEL_STORAGE.get(request.model)
+ if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
+ enableReasoning = False
+ except Exception:
+ pass
+
+ if request.enable_web_search is None:
+ request.web_search = web_search_enabled
+ if tools_enabled and not request.tools:
+ if detection.get('detected_tools'):
+ request.tools = detection.get('detected_tools')
+ if (request.enable_universal is True) or (
+ request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal'))
+ ):
+ if not request.tools:
+ request.tools = [{"name": "universal", "args": {"query": raw_prompt}}]
+
+ executed_tool_calls = []
+ if file_tool_enabled and request.file_ids:
+ for fid in request.file_ids:
+ try:
+ if fid not in UPLOADED_FILES:
+ continue
+ meta = UPLOADED_FILES.get(fid)
+ if not meta:
+ continue
+ from utils import file_read_from_path
+ fpath = meta.get('path')
+ if not fpath or not os.path.exists(fpath):
+ continue
+ file_content = file_read_from_path(fpath, 200000)
+ if file_content:
+ exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
+ executed_tool_calls.append(exec_entry)
+ prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
+ except Exception as e:
+ logger.info(f"File injection error: {e}")
+ if file_tool_enabled and request.file_ids:
+ for fid in request.file_ids:
+ try:
+ if fid not in UPLOADED_FILES:
+ continue
+ meta = UPLOADED_FILES.get(fid)
+ if not meta:
+ continue
+ from utils import file_read_from_path
+ fpath = meta.get('path')
+ if not fpath or not os.path.exists(fpath):
+ continue
+ file_content = file_read_from_path(fpath, 200000)
+ if file_content:
+ exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
+ executed_tool_calls.append(exec_entry)
+ prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
+ except Exception as e:
+ logger.info(f"File injection error: {e}")
+ if request.tools:
+ try:
+ for tool in request.tools:
+ name = tool.get('name')
+ args = tool.get('args', {})
+ if name == 'web_search':
+ from utils import web_search
+
+ search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
+ search_top_k = int(args.get('top_k') or request.search_top_k or 3)
+ search_str = web_search(search_q, search_top_k)
+ if search_str:
+ search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
+ executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
+ prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt)
+ elif name == 'calc' or name == 'calculator':
+ from utils import calc
+
+ expr = args.get('expression')
+ if expr:
+ calc_res = calc(expr)
+ calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
+ executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
+ prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt)
+ elif name == 'universal':
+ try:
+ res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
+ if isinstance(res, dict):
+ result_text = res.get('result') if res.get('result') is not None else ''
+ else:
+ result_text = str(res)
+ executed_tool_calls.append({"name": "universal", "args": args, "result": res})
+ prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt)
+ except Exception as e:
+ logger.info(f"Universal tool execution error: {e}")
+ else:
+ logger.info(f"Unsupported tool requested: {name}")
+ if name == 'file_read':
+ try:
+ fid = args.get('file_id') or args.get('id') or (request.file_ids[0] if request.file_ids else None)
+ if not fid:
+ continue
+ if fid not in UPLOADED_FILES:
+ continue
+ meta = UPLOADED_FILES.get(fid)
+ if not meta:
+ continue
+ from utils import file_read_from_path
+ fpath = meta.get('path')
+ if not fpath or not os.path.exists(fpath):
+ continue
+ file_content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000))
+ exec_entry = {"name": "file_read", "args": {"file_id": fid, "max_bytes": int(args.get('max_bytes') or 100000)}, "result": {"action": "file_read", "result": file_content, "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
+ executed_tool_calls.append(exec_entry)
+ _res = exec_entry.get('result') if isinstance(exec_entry, dict) else None
+ _res_text = ''
+ if isinstance(_res, dict):
+ _res_text = _res.get('result') or ''
+ elif _res is not None:
+ _res_text = str(_res)
+ prompt = (f"ToolResults:\n{_res_text}\n\nUse these file contents to answer the prompt.\n\n" + prompt)
+ except Exception as e:
+ logger.info(f"file_read tool error: {e}")
+ except Exception as e:
+ logger.info(f"Tool processing error: {e}")
+ elif request.web_search or web_search_enabled:
+ try:
+ from utils import web_search
+
+ search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
+ search_res = web_search(search_q, int(request.search_top_k or 3))
+ if search_res:
+ search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
+ executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
+ prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
+ except Exception:
+ pass
+ logger.info(f"[REQ] {completionId} - prompt - {prompt}")
+
+ if request.state_name:
+ state_key = (request.model, request.state_name)
+ if state_key in STATE_STORE:
+ stored = STATE_STORE[state_key]
+ model_state = stored.get('state', None)
+ model_tokens = stored.get('model_tokens', [0])
+ if model_state is None:
+ out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
+ else:
+ out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
+ else:
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
+ else:
+ out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
+
+ prefillTime = time.time()
+ promptTokenCount = len(model_tokens)
+
+ fullResponse = " 0) or (auto_file_flag and request.file_ids))
+ if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
+ file_tool_enabled = False
+ try:
+ if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
+ file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
+ elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
+ file_tool_enabled = bool(request.sampler_allow_file_tool)
+ else:
+ ms2 = MODEL_STORAGE.get(request.model)
+ if ms2 and ms2.MODEL_CONFIG:
+ if hasattr(ms2.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms2.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
+ file_tool_enabled = bool(ms2.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
+ elif hasattr(ms2.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms2.MODEL_CONFIG.ALLOW_FILE_TOOL:
+ file_tool_enabled = False
+ except Exception:
+ pass
+ prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' 0 and r_dict['choices'][0].get('delta') is not None:
+ r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls
+ except Exception:
+ pass
+ yield f"data: {r_dict}\n\n"
+
+ buffer = []
+
+ if enableReasoning:
+ buffer.append("", streamConfig["fullTextCursor"])
+ if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
+ streamConfig["isChecking"] = False
+
+ if (
+ not streamConfig["in_think"]
+ and streamConfig["cacheStr"].find("") != -1
+ ):
+ streamConfig["in_think"] = True
+
+ delta = response.choices[0].delta
+ if delta is None:
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
+ response.choices[0].delta = delta
+ delta.reasoning_content = (
+ delta.reasoning_content
+ if delta.reasoning_content != None
+ else "" + streamConfig["cacheStr"].replace("", "")
+ )
+
+ elif (
+ streamConfig["in_think"]
+ and streamConfig["cacheStr"].find("") != -1
+ ):
+ streamConfig["in_think"] = False
+
+ delta = response.choices[0].delta
+ if delta is None:
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
+ response.choices[0].delta = delta
+ delta.content = (
+ delta.content
+ if delta.content != None
+ else "" + streamConfig["cacheStr"].replace("", "")
+ )
+ else:
+ if streamConfig["in_think"]:
+ delta = response.choices[0].delta
+ if delta is None:
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
+ response.choices[0].delta = delta
+ delta.reasoning_content = (
+ delta.reasoning_content
+ if delta.reasoning_content != None
+ else "" + streamConfig["cacheStr"]
+ )
+ else:
+ delta = response.choices[0].delta
+ if delta is None:
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
+ response.choices[0].delta = delta
+ delta.content = (
+ delta.content
+ if delta.content != None
+ else "" + streamConfig["cacheStr"]
+ )
+ streamConfig["fullTextCursor"] = len(fullText)
+
+ delta = response.choices[0].delta
+ if delta is None:
+ delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
+ response.choices[0].delta = delta
+ if delta.content != None or delta.reasoning_content != None:
+ try:
+ if request.state_name:
+ STATE_STORE[(request.model, request.state_name)] = {
+ 'state': model_state,
+ 'model_tokens': model_tokens,
+ }
+ if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
+ try:
+ _save_state_store_to_disk(force=True)
+ except Exception:
+ pass
+ except Exception:
+ pass
+ if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
+ m = TOOL_CALL_RE.search(fullText)
+ if m:
+ try:
+ payload_raw = m.group(1)
+ import json
+
+ payload = json.loads(payload_raw)
+ tool_name = payload.get('name')
+ tool_args = payload.get('args', {})
+ tool_res = None
+ if tool_name == 'web_search':
+ from utils import web_search
+
+ q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
+ k = int(tool_args.get('top_k') or request.search_top_k or 3)
+ tool_res = web_search(q, k)
+ elif tool_name in ('calc', 'calculator'):
+ from utils import calc
+
+ expr = tool_args.get('expression')
+ if expr:
+ tool_res = calc(expr)
+ else:
+ try:
+ tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
+ except Exception:
+ tool_res = None
+
+ if tool_res:
+ if not isinstance(tool_res, dict):
+ if tool_name in ('calc', 'calculator'):
+ tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
+ elif tool_name == 'web_search':
+ tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
+ else:
+ tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
+ else:
+ tool_res_struct = tool_res
+ exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
+ executed_tool_calls.append(exec_entry)
+ delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
+ prompt = delta_text + prompt
+ fullText = TOOL_CALL_RE.sub('', fullText)
+ buffer = [fullText]
+ out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
+ try:
+ meta_resp = ChatCompletionChunk(
+ id=completionId,
+ created=createTimestamp,
+ model=request.model,
+ usage=(
+ Usage(
+ prompt_tokens=promptTokenCount,
+ completion_tokens=completionTokenCount,
+ total_tokens=promptTokenCount + completionTokenCount,
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
+ )
+ if request.include_usage
+ else None
+ ),
+ choices=[
+ ChatCompletionChoice(
+ index=0,
+ delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
+ logprobs=None,
+ finish_reason=None,
+ )
+ ],
+ )
+ yield f"data: {meta_resp.model_dump_json()}\n\n"
+ except Exception:
+ pass
+ model_initiated_tool_calls += 1
+ should_restart = True
+ break
+ except Exception as e:
+ logger.info(f"Model-initiated tool handling error: {e}")
+ yield f"data: {response.model_dump_json()}\n\n"
+ for stop_words in request.stop or []:
+ if stop_words in ''.join(buffer):
+ finishReason = f"stop:words:{stop_words}"
+ return
+
+ await asyncio.sleep(0)
+
+ del streamConfig
+ else:
+ should_restart = True
+ while should_restart:
+ should_restart = False
+ gen = generate(request, out, model_tokens, model_state)
+ for chunk in gen:
+ completionTokenCount += 1
+ buffer.append(chunk["content"])
+
+ if chunk["finish_reason"]:
+ finishReason = chunk["finish_reason"]
+
+ try:
+ if request.state_name:
+ STATE_STORE[(request.model, request.state_name)] = {
+ 'state': model_state,
+ 'model_tokens': model_tokens,
+ }
+ if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
+ try:
+ _save_state_store_to_disk(force=True)
+ except Exception:
+ pass
+ except Exception:
+ pass
+
+ if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
+ fullText = ''.join(buffer)
+ m = TOOL_CALL_RE.search(fullText)
+ if m:
+ try:
+ payload_raw = m.group(1)
+ import json
+
+ payload = json.loads(payload_raw)
+ tool_name = payload.get('name')
+ tool_args = payload.get('args', {})
+ tool_res = None
+ if tool_name == 'web_search':
+ from utils import web_search
+
+ q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
+ k = int(tool_args.get('top_k') or request.search_top_k or 3)
+ tool_res = web_search(q, k)
+ elif tool_name in ('calc', 'calculator'):
+ from utils import calc
+
+ expr = tool_args.get('expression')
+ if expr:
+ tool_res = calc(expr)
+ else:
+ try:
+ tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
+ except Exception:
+ tool_res = None
+
+ if tool_res:
+ if not isinstance(tool_res, dict):
+ if tool_name in ('calc', 'calculator'):
+ tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
+ elif tool_name == 'web_search':
+ tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
+ else:
+ tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
+ else:
+ tool_res_struct = tool_res
+ exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
+ executed_tool_calls.append(exec_entry)
+ delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
+ prompt = delta_text + prompt
+ fullText = TOOL_CALL_RE.sub('', fullText)
+ buffer = [fullText]
+ out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
+ try:
+ meta_resp = ChatCompletionChunk(
+ id=completionId,
+ created=createTimestamp,
+ model=request.model,
+ usage=(
+ Usage(
+ prompt_tokens=promptTokenCount,
+ completion_tokens=completionTokenCount,
+ total_tokens=promptTokenCount + completionTokenCount,
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
+ )
+ if request.include_usage
+ else None
+ ),
+ choices=[
+ ChatCompletionChoice(
+ index=0,
+ delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
+ logprobs=None,
+ finish_reason=None,
+ )
+ ],
+ )
+ yield f"data: {meta_resp.model_dump_json()}\n\n"
+ except Exception:
+ pass
+ model_initiated_tool_calls += 1
+ should_restart = True
+ break
+ except Exception as e:
+ logger.info(f"Model-initiated tool handling error: {e}")
+
+ response = ChatCompletionChunk(
+ id=completionId,
+ created=createTimestamp,
+ model=request.model,
+ usage=(
+ Usage(
+ prompt_tokens=promptTokenCount,
+ completion_tokens=completionTokenCount,
+ total_tokens=promptTokenCount + completionTokenCount,
+ prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
+ )
+ if request.include_usage
+ else None
+ ),
+ choices=[
+ ChatCompletionChoice(
+ index=0,
+ delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
+ logprobs=None,
+ finish_reason=finishReason,
+ )
+ ],
+ )
+ yield f"data: {response.model_dump_json()}\n\n"
+ await asyncio.sleep(0)
+
+ genenrateTime = time.time()
+
+ responseLog = {
+ "content": "".join(buffer),
+ "finish": finishReason,
+ "prefill_len": promptTokenCount,
+ "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
+ "gen_len": completionTokenCount,
+ "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
+ }
+ logger.info(f"[RES] {completionId} - {responseLog}")
+ if request.messages is None:
+ request.messages = []
+ request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"]))
+ log(
+ {
+ **request.model_dump(),
+ **responseLog,
+ "completionId": completionId,
+ "machineLabel": os.environ.get("MACHINE_LABEL"),
+ }
+ )
+
+ del buffer
+
+ yield "data: [DONE]\n\n"
+
+@app.post("/api/v1/chat/completions")
+async def chat_completions(request: ChatCompletionRequest):
+ completionId = str(next(CompletionIdGenerator))
+ logger.info(f"[REQ] {completionId} - {request.model_dump()}")
+
+ if "rwkv-latest" in request.model:
+ if DEFALUT_MODEL_NAME == None:
+ raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
+ ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME)
+ if not ms_def or not ms_def.MODEL_CONFIG:
+ raise HTTPException(500, "Default sampler config missing for default model")
+ defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER
+ request.model = DEFALUT_MODEL_NAME
+ elif request.model in MODEL_STORAGE:
+ ms_sel = MODEL_STORAGE.get(request.model)
+ if not ms_sel or not ms_sel.MODEL_CONFIG:
+ raise HTTPException(500, f"Default sampler config missing for model {request.model}")
+ defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER
+ else:
+ raise HTTPException(404, f"Can not find `{request.model}`")
+
+ enableReasoning = request.enable_reasoning if request.enable_reasoning is not None else False
+
+ async def chatResponseStreamDisconnect():
+ logGPUState()
+
+ model_state = None
+ model_tokens_for_resume = [0]
+ state_name = request.state_name
+ if state_name is None:
+ state_name = str(uuid.uuid4())
+ request.state_name = state_name
+ state_key = (request.model, state_name)
+ if state_key in STATE_STORE:
+ stored = STATE_STORE[state_key]
+ model_state = stored.get('state', None)
+ model_tokens_for_resume = stored.get('model_tokens', [0])
+ request_dict = request.model_dump()
+
+ sampler_overrides = request_dict.get('sampler') or {}
+ for k, v in defaultSamplerConfig.model_dump().items():
+ if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None:
+ request_dict[k] = sampler_overrides.get(k)
+ continue
+ if k in request_dict and request_dict[k] is None:
+ request_dict[k] = v
+ realRequest = ChatCompletionRequest(**request_dict)
+ if realRequest.stream is None:
+ realRequest.stream = CONFIG.DEFAULT_STREAM
+
+ logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
+
+ if realRequest.stream:
+ r = StreamingResponse(
+ chatResponseStream(realRequest, model_state, completionId, enableReasoning),
+ media_type="text/event-stream",
+ background=BackgroundTask(chatResponseStreamDisconnect),
+ )
+ else:
+ r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
+ try:
+ import json
+
+ if isinstance(r, ChatCompletion):
+ d = r.model_dump()
+ d['state_name'] = state_name
+ return d
+ except Exception:
+ pass
+
+ return r
+
+logger.info("Static frontend mount removed for Python-only deploy; use API endpoints for integration")
+
+@app.get('/api/v1/models')
+def list_models():
+ out = []
+ root_defaults = {
+ 'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
+ 'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True),
+ 'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True),
+ 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True),
+ 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True),
+ 'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True),
+ 'UPLOAD_URL': '/api/v1/files',
+ }
+ for m in CONFIG.MODELS:
+ out.append(
+ {
+ 'SERVICE_NAME': m.SERVICE_NAME,
+ 'DEFAULT_CHAT': m.DEFAULT_CHAT,
+ 'DEFAULT_REASONING': m.DEFAULT_REASONING,
+ 'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True),
+ 'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True),
+ 'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True),
+ 'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True),
+ 'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True),
+ 'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True),
+ 'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True),
+ 'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None,
+ 'UPLOAD_URL': '/api/v1/files',
+ 'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
+ }
+ )
+ return {'root_defaults': root_defaults, 'models': out}
+
+@app.post('/api/v1/files', response_model=FileUploadResponse)
+async def upload_file(file: UploadFile = File(...), model: Optional[str] = None):
+ try:
+ if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True):
+ raise HTTPException(403, 'File uploads are disabled by server configuration')
+ if model:
+ if model not in MODEL_STORAGE:
+ raise HTTPException(404, f"Model {model} not found")
+ ms = MODEL_STORAGE[model]
+ if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True):
+ raise HTTPException(403, f"Model {model} does not allow file uploads")
+ from utils import save_bytes_to_upload
+
+ content = await file.read()
+ fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file'
+ meta = save_bytes_to_upload(fname, content)
+ if meta.get('error'):
+ raise HTTPException(500, f"Could not save file: {meta.get('error')}")
+ UPLOADED_FILES[meta['file_id']] = meta
+ return FileUploadResponse(success=True, file=UploadedFile(**meta))
+ except Exception as e:
+ raise HTTPException(500, str(e))
+
+@app.get('/api/v1/files')
+def list_files():
+ return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()]
+
+@app.get('/api/v1/files/{file_id}')
+def get_file(file_id: str, download: bool = False):
+ if file_id not in UPLOADED_FILES:
+ raise HTTPException(404, 'File not found')
+ meta = UPLOADED_FILES[file_id]
+ if download:
+ try:
+ with open(meta['path'], 'rb') as f:
+ return StreamingResponse(f, media_type='application/octet-stream')
+ except Exception as e:
+ raise HTTPException(500, str(e))
+ return UploadedFile(**meta)
+
+@app.delete('/api/v1/files/{file_id}')
+def delete_file(file_id: str):
+ if file_id not in UPLOADED_FILES:
+ raise HTTPException(404, 'File not found')
+ meta = UPLOADED_FILES.pop(file_id)
+ try:
+ if os.path.exists(meta['path']):
+ os.remove(meta['path'])
+ except Exception:
+ pass
+ return {'success': True}
+
+if __name__ == "__main__":
+ import uvicorn
+
+ host = CONFIG.HOST or "127.0.0.1"
+ port = CONFIG.PORT or 7860
+ uvicorn.run(app, host=host, port=port)
\ No newline at end of file