import os
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
from modelscope import patch_hub
patch_hub()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
from config import CONFIG, ModelConfig
from utils import (
cleanMessages,
parse_think_response,
remove_nested_think_tags_stack,
format_bytes,
log,
detect_tools_and_reasoning,
universal_tool,
)
import copy, types, gc, sys, re, time, collections, asyncio
from huggingface_hub import hf_hub_download
from loguru import logger
from rich import print
from snowflake import SnowflakeGenerator
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
from typing import List, Optional, Union, Any, Dict
import uuid
from pydantic import BaseModel, Field, model_validator
from pydantic_settings import BaseSettings
import numpy as np
import torch
if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
logger.info(f"CUDA not found, fall back to cpu")
CONFIG.STRATEGY = "cpu fp16"
_s = CONFIG.STRATEGY.lower()
if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s):
logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`")
CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16"
try:
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
except Exception:
nvmlInit = None
nvmlDeviceGetHandleByIndex = None
nvmlDeviceGetMemoryInfo = None
if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None:
nvmlInit()
gpu_h = nvmlDeviceGetHandleByIndex(0)
def logGPUState():
if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None:
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
logger.info(
f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
os.environ["RWKV_V7_ON"] = "1"
os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_CUDA_ON"] = (
"1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from fastapi import FastAPI, HTTPException, UploadFile, File
from starlette.background import BackgroundTask
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.gzip import GZipMiddleware
from api_types import (
ChatMessage,
ChatCompletion,
ChatCompletionChunk,
Usage,
PromptTokensDetails,
ChatCompletionChoice,
ChatCompletionMessage,
SamplerConfig,
UploadedFile,
FileUploadResponse,
)
class ModelStorage:
MODEL_CONFIG: Optional[ModelConfig] = None
model: Optional[RWKV] = None
pipeline: Optional[PIPELINE] = None
MODEL_STORAGE: Dict[str, ModelStorage] = {}
DEFALUT_MODEL_NAME = None
DEFAULT_REASONING_MODEL_NAME = None
STATE_STORE: Dict[tuple, Any] = {}
_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json')
_LAST_STATE_STORE_WRITE = 0
TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S)
UPLOADED_FILES: Dict[str, dict] = {}
def _serialize_state_store() -> dict:
serial = {}
for (model_name, state_name), entry in STATE_STORE.items():
try:
mt = entry.get('model_tokens') if isinstance(entry, dict) else None
if mt is None:
continue
serial[f"{model_name}|{state_name}"] = {
'model': model_name,
'state_name': state_name,
'model_tokens': mt,
}
except Exception:
continue
return serial
def _load_state_store_from_disk():
global STATE_STORE
try:
if os.path.exists(_STATE_STORE_PATH):
import json
with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
for k, v in data.items():
model = v.get('model')
state_name = v.get('state_name')
model_tokens = v.get('model_tokens')
if model and state_name and isinstance(model_tokens, list):
STATE_STORE[(model, state_name)] = {
'state': None,
'model_tokens': model_tokens,
}
logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}")
except Exception as e:
logger.info(f"Failed to load state store from disk: {e}")
def _save_state_store_to_disk(force=False):
global _LAST_STATE_STORE_WRITE
now = time.time()
if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5):
return
try:
serial = _serialize_state_store()
if not serial:
return
import json
tmp = _STATE_STORE_PATH + ".tmp"
with open(tmp, 'w', encoding='utf-8') as f:
json.dump(serial, f)
os.replace(tmp, _STATE_STORE_PATH)
_LAST_STATE_STORE_WRITE = now
except Exception as e:
logger.info(f"Write state store to disk failed: {e}")
def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]):
ms = MODEL_STORAGE.get(model_name)
if not ms or not ms.model:
return None, None
model_state = None
out = None
tokens = list(model_tokens) if isinstance(model_tokens, list) else [0]
while len(tokens) > 0:
out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
tokens = tokens[CONFIG.CHUNK_LEN :]
return out, model_state
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
logGPUState()
logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config")
for model_config in CONFIG.MODELS:
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
if model_config.MODEL_FILE_PATH == None:
model_config.MODEL_FILE_PATH = hf_hub_download(
repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID),
filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME),
local_dir=str(model_config.DOWNLOAD_MODEL_DIR),
)
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
if model_config.DEFAULT_CHAT:
if DEFALUT_MODEL_NAME != None:
logger.info(
f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
)
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
if model_config.DEFAULT_REASONING:
if DEFAULT_REASONING_MODEL_NAME != None:
logger.info(
f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
)
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
print(model_config.DEFAULT_SAMPLER)
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
strategy=CONFIG.STRATEGY,
)
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
)
if "cuda" in CONFIG.STRATEGY:
torch.cuda.empty_cache()
gc.collect()
logGPUState()
logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
logger.info(
f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`"
)
if len(MODEL_STORAGE) == 1:
single_name = list(MODEL_STORAGE.keys())[0]
if DEFALUT_MODEL_NAME != single_name:
DEFALUT_MODEL_NAME = single_name
logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`")
if DEFAULT_REASONING_MODEL_NAME != single_name:
DEFAULT_REASONING_MODEL_NAME = single_name
logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`")
class ChatCompletionRequest(BaseModel):
model: str = Field(default="rwkv-latest")
messages: Optional[List[ChatMessage]] = Field(default=None)
prompt: Optional[str] = Field(default=None)
max_tokens: Optional[int] = Field(default=None)
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
presence_penalty: Optional[float] = Field(default=None)
count_penalty: Optional[float] = Field(default=None)
penalty_decay: Optional[float] = Field(default=None)
stream: Optional[bool] = Field(default=True)
state_name: Optional[str] = Field(default=None)
include_usage: Optional[bool] = Field(default=False)
stop: Optional[list[str]] = Field(["\n\n"])
stop_tokens: Optional[list[int]] = Field([0])
web_search: Optional[bool] = Field(default=True)
enable_web_search: Optional[bool] = Field(default=True)
auto_web_search: Optional[bool] = Field(default=True)
enable_tools: Optional[bool] = Field(default=None)
auto_tools: Optional[bool] = Field(default=True)
enable_reasoning: Optional[bool] = Field(default=True)
auto_reasoning: Optional[bool] = Field(default=True)
enable_universal: Optional[bool] = Field(default=None)
auto_universal: Optional[bool] = Field(default=True)
search_top_k: Optional[int] = Field(default=3)
tools: Optional[List[Dict[str, Any]]] = Field(default=None)
sampler_allow_web_search: Optional[bool] = Field(default=None)
sampler_allow_tools: Optional[bool] = Field(default=None)
sampler_allow_reasoning: Optional[bool] = Field(default=None)
sampler: Optional[SamplerConfig] = Field(default=None)
file_ids: Optional[List[str]] = Field(default=None)
enable_file_tool: Optional[bool] = Field(default=None)
auto_file_tool: Optional[bool] = Field(default=None)
sampler_allow_file_tool: Optional[bool] = Field(default=None)
@model_validator(mode="before")
@classmethod
def validate_mutual_exclusivity(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
messages_provided = "messages" in data and data["messages"] != None
prompt_provided = "prompt" in data and data["prompt"] != None
if messages_provided and prompt_provided:
raise ValueError("messages and prompt cannot coexist. Choose one.")
if not messages_provided and not prompt_provided:
raise ValueError("Either messages or prompt must be provided.")
return data
app = FastAPI(title="RWKV OpenAI-Compatible API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
@app.on_event("startup")
async def _startup_state_load_and_persist_loop():
_load_state_store_from_disk()
async def _persist_loop():
while True:
try:
_save_state_store_to_disk(force=False)
except Exception:
pass
await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5))
try:
asyncio.create_task(_persist_loop())
except Exception:
pass
async def runPrefill(
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
):
ctx = ctx.replace("\r\n", "\n")
out = None
ms = MODEL_STORAGE.get(request.model)
if not ms or not ms.pipeline or not ms.model:
raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
tokens = ms.pipeline.encode(ctx)
tokens = [int(x) for x in tokens]
model_tokens += tokens
while len(tokens) > 0:
out, model_state = ms.model.forward(
tokens[: CONFIG.CHUNK_LEN], model_state
)
tokens = tokens[CONFIG.CHUNK_LEN :]
await asyncio.sleep(0)
return out, model_tokens, model_state
def generate(
request: ChatCompletionRequest,
out,
model_tokens: List[int],
model_state,
max_tokens=2048,
):
ms = MODEL_STORAGE.get(request.model)
if not ms or not ms.pipeline or not ms.model:
raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
temperature = request.temperature if request.temperature is not None else 0.2
top_p = request.top_p if request.top_p is not None else 0.9
alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0
alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0
penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5
args = PIPELINE_ARGS(
temperature=max(0.2, temperature),
top_p=top_p,
alpha_frequency=alpha_frequency,
alpha_presence=alpha_presence,
token_ban=[],
token_stop=[0],
)
occurrence = {}
out_tokens: List[int] = []
out_last = 0
for i in range(max_tokens):
for n in occurrence:
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
token = ms.pipeline.sample_logits(
out, temperature=args.temperature, top_p=args.top_p
)
if token == 0 and request.stop_tokens and token in request.stop_tokens:
yield {
"content": "",
"tokens": out_tokens[out_last:],
"finish_reason": "stop:token:0",
"state": model_state,
}
del out
gc.collect()
return
out, model_state = ms.model.forward([token], model_state)
model_tokens.append(token)
out_tokens.append(token)
if request.stop_tokens and token in request.stop_tokens:
yield {
"content": "",
"tokens": out_tokens[out_last:],
"finish_reason": f"stop:token:{token}",
"state": model_state,
}
del out
gc.collect()
return
for xxx in list(occurrence.keys()):
occurrence[xxx] *= penalty_decay
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
decoded = ms.pipeline.decode([token])
if "\ufffd" in decoded:
continue
yield {
"content": decoded,
"tokens": [token],
"finish_reason": None,
"state": model_state,
}
out_last = i + 1
else:
yield {
"content": "",
"tokens": [],
"finish_reason": "length",
}
async def chatResponse(
request: ChatCompletionRequest,
model_state: Any,
completionId: str,
enableReasoning: bool,
) -> ChatCompletion:
createTimestamp = time.time()
raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [])
detection = detect_tools_and_reasoning(raw_prompt)
prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' 0) or (auto_file_flag and request.file_ids))
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
file_tool_enabled = False
try:
if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
file_tool_enabled = bool(request.sampler_allow_file_tool)
else:
ms = MODEL_STORAGE.get(request.model)
if ms and ms.MODEL_CONFIG:
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL:
file_tool_enabled = False
except Exception:
pass
if request.enable_tools is not None:
tools_enabled = bool(request.enable_tools)
else:
auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS
tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
try:
if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None:
tools_enabled = bool(request.sampler.ALLOW_TOOLS)
elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None:
tools_enabled = bool(request.sampler_allow_tools)
else:
ms = MODEL_STORAGE.get(request.model)
if ms and ms.MODEL_CONFIG:
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None:
if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS:
tools_enabled = False
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS:
tools_enabled = False
except Exception:
pass
reasoning_enabled = bool(
True
if (request.enable_reasoning is not None and request.enable_reasoning)
else (
bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning'))))
)
)
if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None:
reasoning_enabled = False
try:
if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None:
reasoning_enabled = bool(request.sampler.ALLOW_REASONING)
elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None:
reasoning_enabled = bool(request.sampler_allow_reasoning)
else:
ms = MODEL_STORAGE.get(request.model)
if ms and ms.MODEL_CONFIG:
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None:
if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING:
reasoning_enabled = False
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
reasoning_enabled = False
except Exception:
pass
enableReasoning = reasoning_enabled
try:
ms = MODEL_STORAGE.get(request.model)
if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
enableReasoning = False
except Exception:
pass
if request.enable_web_search is None:
request.web_search = web_search_enabled
if tools_enabled and not request.tools:
if detection.get('detected_tools'):
request.tools = detection.get('detected_tools')
if (request.enable_universal is True) or (
request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal'))
):
if not request.tools:
request.tools = [{"name": "universal", "args": {"query": raw_prompt}}]
executed_tool_calls = []
if file_tool_enabled and request.file_ids:
for fid in request.file_ids:
try:
if fid not in UPLOADED_FILES:
continue
meta = UPLOADED_FILES.get(fid)
if not meta:
continue
from utils import file_read_from_path
fpath = meta.get('path')
if not fpath or not os.path.exists(fpath):
continue
file_content = file_read_from_path(fpath, 200000)
if file_content:
exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
executed_tool_calls.append(exec_entry)
prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
except Exception as e:
logger.info(f"File injection error: {e}")
if file_tool_enabled and request.file_ids:
for fid in request.file_ids:
try:
if fid not in UPLOADED_FILES:
continue
meta = UPLOADED_FILES.get(fid)
if not meta:
continue
from utils import file_read_from_path
fpath = meta.get('path')
if not fpath or not os.path.exists(fpath):
continue
file_content = file_read_from_path(fpath, 200000)
if file_content:
exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
executed_tool_calls.append(exec_entry)
prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
except Exception as e:
logger.info(f"File injection error: {e}")
if request.tools:
try:
for tool in request.tools:
name = tool.get('name')
args = tool.get('args', {})
if name == 'web_search':
from utils import web_search
search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
search_str = web_search(search_q, search_top_k)
if search_str:
search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt)
elif name == 'calc' or name == 'calculator':
from utils import calc
expr = args.get('expression')
if expr:
calc_res = calc(expr)
calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt)
elif name == 'universal':
try:
res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
if isinstance(res, dict):
result_text = res.get('result') if res.get('result') is not None else ''
else:
result_text = str(res)
executed_tool_calls.append({"name": "universal", "args": args, "result": res})
prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt)
except Exception as e:
logger.info(f"Universal tool execution error: {e}")
else:
logger.info(f"Unsupported tool requested: {name}")
if name == 'file_read':
try:
fid = args.get('file_id') or args.get('id') or (request.file_ids[0] if request.file_ids else None)
if not fid:
continue
if fid not in UPLOADED_FILES:
continue
meta = UPLOADED_FILES.get(fid)
if not meta:
continue
from utils import file_read_from_path
fpath = meta.get('path')
if not fpath or not os.path.exists(fpath):
continue
file_content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000))
exec_entry = {"name": "file_read", "args": {"file_id": fid, "max_bytes": int(args.get('max_bytes') or 100000)}, "result": {"action": "file_read", "result": file_content, "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
executed_tool_calls.append(exec_entry)
_res = exec_entry.get('result') if isinstance(exec_entry, dict) else None
_res_text = ''
if isinstance(_res, dict):
_res_text = _res.get('result') or ''
elif _res is not None:
_res_text = str(_res)
prompt = (f"ToolResults:\n{_res_text}\n\nUse these file contents to answer the prompt.\n\n" + prompt)
except Exception as e:
logger.info(f"file_read tool error: {e}")
except Exception as e:
logger.info(f"Tool processing error: {e}")
elif request.web_search or web_search_enabled:
try:
from utils import web_search
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
search_res = web_search(search_q, int(request.search_top_k or 3))
if search_res:
search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
except Exception:
pass
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
if request.state_name:
state_key = (request.model, request.state_name)
if state_key in STATE_STORE:
stored = STATE_STORE[state_key]
model_state = stored.get('state', None)
model_tokens = stored.get('model_tokens', [0])
if model_state is None:
out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
else:
out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
else:
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
else:
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
prefillTime = time.time()
promptTokenCount = len(model_tokens)
fullResponse = " 0) or (auto_file_flag and request.file_ids))
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
file_tool_enabled = False
try:
if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
file_tool_enabled = bool(request.sampler_allow_file_tool)
else:
ms2 = MODEL_STORAGE.get(request.model)
if ms2 and ms2.MODEL_CONFIG:
if hasattr(ms2.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms2.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
file_tool_enabled = bool(ms2.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
elif hasattr(ms2.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms2.MODEL_CONFIG.ALLOW_FILE_TOOL:
file_tool_enabled = False
except Exception:
pass
prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' 0 and r_dict['choices'][0].get('delta') is not None:
r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls
except Exception:
pass
yield f"data: {r_dict}\n\n"
buffer = []
if enableReasoning:
buffer.append("", streamConfig["fullTextCursor"])
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
streamConfig["isChecking"] = False
if (
not streamConfig["in_think"]
and streamConfig["cacheStr"].find("") != -1
):
streamConfig["in_think"] = True
delta = response.choices[0].delta
if delta is None:
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
response.choices[0].delta = delta
delta.reasoning_content = (
delta.reasoning_content
if delta.reasoning_content != None
else "" + streamConfig["cacheStr"].replace("", "")
)
elif (
streamConfig["in_think"]
and streamConfig["cacheStr"].find("") != -1
):
streamConfig["in_think"] = False
delta = response.choices[0].delta
if delta is None:
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
response.choices[0].delta = delta
delta.content = (
delta.content
if delta.content != None
else "" + streamConfig["cacheStr"].replace("", "")
)
else:
if streamConfig["in_think"]:
delta = response.choices[0].delta
if delta is None:
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
response.choices[0].delta = delta
delta.reasoning_content = (
delta.reasoning_content
if delta.reasoning_content != None
else "" + streamConfig["cacheStr"]
)
else:
delta = response.choices[0].delta
if delta is None:
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
response.choices[0].delta = delta
delta.content = (
delta.content
if delta.content != None
else "" + streamConfig["cacheStr"]
)
streamConfig["fullTextCursor"] = len(fullText)
delta = response.choices[0].delta
if delta is None:
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
response.choices[0].delta = delta
if delta.content != None or delta.reasoning_content != None:
try:
if request.state_name:
STATE_STORE[(request.model, request.state_name)] = {
'state': model_state,
'model_tokens': model_tokens,
}
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
try:
_save_state_store_to_disk(force=True)
except Exception:
pass
except Exception:
pass
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
m = TOOL_CALL_RE.search(fullText)
if m:
try:
payload_raw = m.group(1)
import json
payload = json.loads(payload_raw)
tool_name = payload.get('name')
tool_args = payload.get('args', {})
tool_res = None
if tool_name == 'web_search':
from utils import web_search
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
k = int(tool_args.get('top_k') or request.search_top_k or 3)
tool_res = web_search(q, k)
elif tool_name in ('calc', 'calculator'):
from utils import calc
expr = tool_args.get('expression')
if expr:
tool_res = calc(expr)
else:
try:
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
except Exception:
tool_res = None
if tool_res:
if not isinstance(tool_res, dict):
if tool_name in ('calc', 'calculator'):
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
elif tool_name == 'web_search':
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
else:
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
else:
tool_res_struct = tool_res
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
executed_tool_calls.append(exec_entry)
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
prompt = delta_text + prompt
fullText = TOOL_CALL_RE.sub('', fullText)
buffer = [fullText]
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
try:
meta_resp = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
logprobs=None,
finish_reason=None,
)
],
)
yield f"data: {meta_resp.model_dump_json()}\n\n"
except Exception:
pass
model_initiated_tool_calls += 1
should_restart = True
break
except Exception as e:
logger.info(f"Model-initiated tool handling error: {e}")
yield f"data: {response.model_dump_json()}\n\n"
for stop_words in request.stop or []:
if stop_words in ''.join(buffer):
finishReason = f"stop:words:{stop_words}"
return
await asyncio.sleep(0)
del streamConfig
else:
should_restart = True
while should_restart:
should_restart = False
gen = generate(request, out, model_tokens, model_state)
for chunk in gen:
completionTokenCount += 1
buffer.append(chunk["content"])
if chunk["finish_reason"]:
finishReason = chunk["finish_reason"]
try:
if request.state_name:
STATE_STORE[(request.model, request.state_name)] = {
'state': model_state,
'model_tokens': model_tokens,
}
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
try:
_save_state_store_to_disk(force=True)
except Exception:
pass
except Exception:
pass
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
fullText = ''.join(buffer)
m = TOOL_CALL_RE.search(fullText)
if m:
try:
payload_raw = m.group(1)
import json
payload = json.loads(payload_raw)
tool_name = payload.get('name')
tool_args = payload.get('args', {})
tool_res = None
if tool_name == 'web_search':
from utils import web_search
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
k = int(tool_args.get('top_k') or request.search_top_k or 3)
tool_res = web_search(q, k)
elif tool_name in ('calc', 'calculator'):
from utils import calc
expr = tool_args.get('expression')
if expr:
tool_res = calc(expr)
else:
try:
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
except Exception:
tool_res = None
if tool_res:
if not isinstance(tool_res, dict):
if tool_name in ('calc', 'calculator'):
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
elif tool_name == 'web_search':
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
else:
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
else:
tool_res_struct = tool_res
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
executed_tool_calls.append(exec_entry)
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
prompt = delta_text + prompt
fullText = TOOL_CALL_RE.sub('', fullText)
buffer = [fullText]
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
try:
meta_resp = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
logprobs=None,
finish_reason=None,
)
],
)
yield f"data: {meta_resp.model_dump_json()}\n\n"
except Exception:
pass
model_initiated_tool_calls += 1
should_restart = True
break
except Exception as e:
logger.info(f"Model-initiated tool handling error: {e}")
response = ChatCompletionChunk(
id=completionId,
created=createTimestamp,
model=request.model,
usage=(
Usage(
prompt_tokens=promptTokenCount,
completion_tokens=completionTokenCount,
total_tokens=promptTokenCount + completionTokenCount,
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
)
if request.include_usage
else None
),
choices=[
ChatCompletionChoice(
index=0,
delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
logprobs=None,
finish_reason=finishReason,
)
],
)
yield f"data: {response.model_dump_json()}\n\n"
await asyncio.sleep(0)
genenrateTime = time.time()
responseLog = {
"content": "".join(buffer),
"finish": finishReason,
"prefill_len": promptTokenCount,
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
"gen_len": completionTokenCount,
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
}
logger.info(f"[RES] {completionId} - {responseLog}")
if request.messages is None:
request.messages = []
request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"]))
log(
{
**request.model_dump(),
**responseLog,
"completionId": completionId,
"machineLabel": os.environ.get("MACHINE_LABEL"),
}
)
del buffer
yield "data: [DONE]\n\n"
@app.post("/api/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
completionId = str(next(CompletionIdGenerator))
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
if "rwkv-latest" in request.model:
if DEFALUT_MODEL_NAME == None:
raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME)
if not ms_def or not ms_def.MODEL_CONFIG:
raise HTTPException(500, "Default sampler config missing for default model")
defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER
request.model = DEFALUT_MODEL_NAME
elif request.model in MODEL_STORAGE:
ms_sel = MODEL_STORAGE.get(request.model)
if not ms_sel or not ms_sel.MODEL_CONFIG:
raise HTTPException(500, f"Default sampler config missing for model {request.model}")
defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER
else:
raise HTTPException(404, f"Can not find `{request.model}`")
enableReasoning = request.enable_reasoning if request.enable_reasoning is not None else False
async def chatResponseStreamDisconnect():
logGPUState()
model_state = None
model_tokens_for_resume = [0]
state_name = request.state_name
if state_name is None:
state_name = str(uuid.uuid4())
request.state_name = state_name
state_key = (request.model, state_name)
if state_key in STATE_STORE:
stored = STATE_STORE[state_key]
model_state = stored.get('state', None)
model_tokens_for_resume = stored.get('model_tokens', [0])
request_dict = request.model_dump()
sampler_overrides = request_dict.get('sampler') or {}
for k, v in defaultSamplerConfig.model_dump().items():
if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None:
request_dict[k] = sampler_overrides.get(k)
continue
if k in request_dict and request_dict[k] is None:
request_dict[k] = v
realRequest = ChatCompletionRequest(**request_dict)
if realRequest.stream is None:
realRequest.stream = CONFIG.DEFAULT_STREAM
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
if realRequest.stream:
r = StreamingResponse(
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
media_type="text/event-stream",
background=BackgroundTask(chatResponseStreamDisconnect),
)
else:
r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
try:
import json
if isinstance(r, ChatCompletion):
d = r.model_dump()
d['state_name'] = state_name
return d
except Exception:
pass
return r
logger.info("Static frontend mount removed for Python-only deploy; use API endpoints for integration")
@app.get('/api/v1/models')
def list_models():
out = []
root_defaults = {
'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True),
'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True),
'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True),
'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True),
'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True),
'UPLOAD_URL': '/api/v1/files',
}
for m in CONFIG.MODELS:
out.append(
{
'SERVICE_NAME': m.SERVICE_NAME,
'DEFAULT_CHAT': m.DEFAULT_CHAT,
'DEFAULT_REASONING': m.DEFAULT_REASONING,
'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True),
'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True),
'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True),
'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True),
'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True),
'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True),
'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True),
'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None,
'UPLOAD_URL': '/api/v1/files',
'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
}
)
return {'root_defaults': root_defaults, 'models': out}
@app.post('/api/v1/files', response_model=FileUploadResponse)
async def upload_file(file: UploadFile = File(...), model: Optional[str] = None):
try:
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True):
raise HTTPException(403, 'File uploads are disabled by server configuration')
if model:
if model not in MODEL_STORAGE:
raise HTTPException(404, f"Model {model} not found")
ms = MODEL_STORAGE[model]
if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True):
raise HTTPException(403, f"Model {model} does not allow file uploads")
from utils import save_bytes_to_upload
content = await file.read()
fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file'
meta = save_bytes_to_upload(fname, content)
if meta.get('error'):
raise HTTPException(500, f"Could not save file: {meta.get('error')}")
UPLOADED_FILES[meta['file_id']] = meta
return FileUploadResponse(success=True, file=UploadedFile(**meta))
except Exception as e:
raise HTTPException(500, str(e))
@app.get('/api/v1/files')
def list_files():
return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()]
@app.get('/api/v1/files/{file_id}')
def get_file(file_id: str, download: bool = False):
if file_id not in UPLOADED_FILES:
raise HTTPException(404, 'File not found')
meta = UPLOADED_FILES[file_id]
if download:
try:
with open(meta['path'], 'rb') as f:
return StreamingResponse(f, media_type='application/octet-stream')
except Exception as e:
raise HTTPException(500, str(e))
return UploadedFile(**meta)
@app.delete('/api/v1/files/{file_id}')
def delete_file(file_id: str):
if file_id not in UPLOADED_FILES:
raise HTTPException(404, 'File not found')
meta = UPLOADED_FILES.pop(file_id)
try:
if os.path.exists(meta['path']):
os.remove(meta['path'])
except Exception:
pass
return {'success': True}
if __name__ == "__main__":
import uvicorn
host = CONFIG.HOST or "127.0.0.1"
port = CONFIG.PORT or 7860
uvicorn.run(app, host=host, port=port)