|
|
import os
|
|
|
|
|
|
if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio":
|
|
|
from modelscope import patch_hub
|
|
|
|
|
|
patch_hub()
|
|
|
|
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
|
|
|
|
|
|
|
|
|
from config import CONFIG, ModelConfig
|
|
|
from utils import (
|
|
|
cleanMessages,
|
|
|
parse_think_response,
|
|
|
remove_nested_think_tags_stack,
|
|
|
format_bytes,
|
|
|
log,
|
|
|
detect_tools_and_reasoning,
|
|
|
universal_tool,
|
|
|
)
|
|
|
|
|
|
import copy, types, gc, sys, re, time, collections, asyncio
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
from loguru import logger
|
|
|
from rich import print
|
|
|
|
|
|
from snowflake import SnowflakeGenerator
|
|
|
|
|
|
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
|
|
|
|
|
from typing import List, Optional, Union, Any, Dict
|
|
|
import uuid
|
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
from pydantic_settings import BaseSettings
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
|
|
|
if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available():
|
|
|
logger.info(f"CUDA not found, fall back to cpu")
|
|
|
CONFIG.STRATEGY = "cpu fp16"
|
|
|
|
|
|
_s = CONFIG.STRATEGY.lower()
|
|
|
if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s):
|
|
|
logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`")
|
|
|
CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16"
|
|
|
|
|
|
|
|
|
try:
|
|
|
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
|
|
|
except Exception:
|
|
|
nvmlInit = None
|
|
|
nvmlDeviceGetHandleByIndex = None
|
|
|
nvmlDeviceGetMemoryInfo = None
|
|
|
|
|
|
if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None:
|
|
|
nvmlInit()
|
|
|
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
|
|
|
|
|
|
|
|
def logGPUState():
|
|
|
if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None:
|
|
|
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
|
|
logger.info(
|
|
|
f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
os.environ["RWKV_V7_ON"] = "1"
|
|
|
os.environ["RWKV_JIT_ON"] = "1"
|
|
|
os.environ["RWKV_CUDA_ON"] = (
|
|
|
"1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0"
|
|
|
)
|
|
|
|
|
|
from rwkv.model import RWKV
|
|
|
from rwkv.utils import PIPELINE, PIPELINE_ARGS
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File
|
|
|
from starlette.background import BackgroundTask
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.middleware.gzip import GZipMiddleware
|
|
|
|
|
|
|
|
|
from api_types import (
|
|
|
ChatMessage,
|
|
|
ChatCompletion,
|
|
|
ChatCompletionChunk,
|
|
|
Usage,
|
|
|
PromptTokensDetails,
|
|
|
ChatCompletionChoice,
|
|
|
ChatCompletionMessage,
|
|
|
SamplerConfig,
|
|
|
UploadedFile,
|
|
|
FileUploadResponse,
|
|
|
)
|
|
|
|
|
|
|
|
|
class ModelStorage:
|
|
|
MODEL_CONFIG: Optional[ModelConfig] = None
|
|
|
model: Optional[RWKV] = None
|
|
|
pipeline: Optional[PIPELINE] = None
|
|
|
|
|
|
|
|
|
MODEL_STORAGE: Dict[str, ModelStorage] = {}
|
|
|
|
|
|
DEFALUT_MODEL_NAME = None
|
|
|
DEFAULT_REASONING_MODEL_NAME = None
|
|
|
|
|
|
|
|
|
|
|
|
STATE_STORE: Dict[tuple, Any] = {}
|
|
|
|
|
|
|
|
|
_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json')
|
|
|
_LAST_STATE_STORE_WRITE = 0
|
|
|
|
|
|
|
|
|
TOOL_CALL_RE = re.compile(r"<tool-call>\s*(\{.*?\})\s*</tool-call>", re.S)
|
|
|
|
|
|
|
|
|
UPLOADED_FILES: Dict[str, dict] = {}
|
|
|
|
|
|
|
|
|
def _serialize_state_store() -> dict:
|
|
|
|
|
|
serial = {}
|
|
|
for (model_name, state_name), entry in STATE_STORE.items():
|
|
|
try:
|
|
|
mt = entry.get('model_tokens') if isinstance(entry, dict) else None
|
|
|
if mt is None:
|
|
|
|
|
|
continue
|
|
|
serial[f"{model_name}|{state_name}"] = {
|
|
|
'model': model_name,
|
|
|
'state_name': state_name,
|
|
|
'model_tokens': mt,
|
|
|
}
|
|
|
except Exception:
|
|
|
continue
|
|
|
return serial
|
|
|
|
|
|
|
|
|
def _load_state_store_from_disk():
|
|
|
global STATE_STORE
|
|
|
try:
|
|
|
if os.path.exists(_STATE_STORE_PATH):
|
|
|
import json
|
|
|
|
|
|
with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f:
|
|
|
data = json.load(f)
|
|
|
for k, v in data.items():
|
|
|
model = v.get('model')
|
|
|
state_name = v.get('state_name')
|
|
|
model_tokens = v.get('model_tokens')
|
|
|
if model and state_name and isinstance(model_tokens, list):
|
|
|
STATE_STORE[(model, state_name)] = {
|
|
|
'state': None,
|
|
|
'model_tokens': model_tokens,
|
|
|
}
|
|
|
logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}")
|
|
|
except Exception as e:
|
|
|
logger.info(f"Failed to load state store from disk: {e}")
|
|
|
|
|
|
|
|
|
def _save_state_store_to_disk(force=False):
|
|
|
global _LAST_STATE_STORE_WRITE
|
|
|
now = time.time()
|
|
|
if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5):
|
|
|
return
|
|
|
try:
|
|
|
serial = _serialize_state_store()
|
|
|
if not serial:
|
|
|
return
|
|
|
import json
|
|
|
tmp = _STATE_STORE_PATH + ".tmp"
|
|
|
with open(tmp, 'w', encoding='utf-8') as f:
|
|
|
json.dump(serial, f)
|
|
|
os.replace(tmp, _STATE_STORE_PATH)
|
|
|
_LAST_STATE_STORE_WRITE = now
|
|
|
except Exception as e:
|
|
|
logger.info(f"Write state store to disk failed: {e}")
|
|
|
|
|
|
|
|
|
def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]):
|
|
|
"""
|
|
|
Recompute the `out` logits and `model_state` by forwarding through tokens in chunks.
|
|
|
Returns a tuple (out, model_state).
|
|
|
"""
|
|
|
ms = MODEL_STORAGE.get(model_name)
|
|
|
if not ms or not ms.model:
|
|
|
return None, None
|
|
|
model_state = None
|
|
|
out = None
|
|
|
tokens = list(model_tokens) if isinstance(model_tokens, list) else [0]
|
|
|
while len(tokens) > 0:
|
|
|
out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state)
|
|
|
tokens = tokens[CONFIG.CHUNK_LEN :]
|
|
|
return out, model_state
|
|
|
|
|
|
|
|
|
def resolve_request_flags(request, detection):
|
|
|
"""Resolve effective booleans for web_search, file_tool, tools, and reasoning
|
|
|
based on request flags (explicit), sampler overrides, model defaults and detection.
|
|
|
Returns dict with keys: web_search_enabled, file_tool_enabled, tools_enabled, reasoning_enabled.
|
|
|
"""
|
|
|
|
|
|
web_search_enabled = (
|
|
|
True
|
|
|
if (request.enable_web_search is not None and request.enable_web_search)
|
|
|
else (
|
|
|
request.web_search
|
|
|
or (request.auto_web_search if request.auto_web_search is not None else (getattr(CONFIG, 'AUTO_ENABLE_WEB_SEARCH', True) and detection.get('need_web_search')))
|
|
|
)
|
|
|
)
|
|
|
if not getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True) and request.enable_web_search is None and not (request.web_search or False):
|
|
|
web_search_enabled = False
|
|
|
try:
|
|
|
if request.sampler and getattr(request.sampler, 'ALLOW_WEB_SEARCH', None) is not None:
|
|
|
web_search_enabled = bool(request.sampler.ALLOW_WEB_SEARCH)
|
|
|
elif hasattr(request, 'sampler_allow_web_search') and request.sampler_allow_web_search is not None:
|
|
|
web_search_enabled = bool(request.sampler_allow_web_search)
|
|
|
else:
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if ms and ms.MODEL_CONFIG:
|
|
|
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_WEB_SEARCH', None) is not None:
|
|
|
web_search_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_WEB_SEARCH)
|
|
|
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_WEB_SEARCH') and not ms.MODEL_CONFIG.ALLOW_WEB_SEARCH:
|
|
|
web_search_enabled = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if request.enable_file_tool is not None:
|
|
|
file_tool_enabled = bool(request.enable_file_tool)
|
|
|
else:
|
|
|
auto_file_flag = request.auto_file_tool if request.auto_file_tool is not None else getattr(CONFIG, 'AUTO_ENABLE_TOOLS', True)
|
|
|
file_tool_enabled = bool((request.file_ids and len(request.file_ids) > 0) or (auto_file_flag and request.file_ids))
|
|
|
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None:
|
|
|
file_tool_enabled = False
|
|
|
try:
|
|
|
if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None:
|
|
|
file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL)
|
|
|
elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None:
|
|
|
file_tool_enabled = bool(request.sampler_allow_file_tool)
|
|
|
else:
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if ms and ms.MODEL_CONFIG:
|
|
|
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None:
|
|
|
file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL)
|
|
|
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL:
|
|
|
file_tool_enabled = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if request.enable_tools is not None:
|
|
|
tools_enabled = bool(request.enable_tools)
|
|
|
else:
|
|
|
auto_tools_flag = request.auto_tools if request.auto_tools is not None else getattr(CONFIG, 'AUTO_ENABLE_TOOLS', True)
|
|
|
tools_enabled = bool(request.tools) or getattr(CONFIG, 'ENABLE_TOOLS_BY_DEFAULT', False) or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search')))
|
|
|
try:
|
|
|
if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None:
|
|
|
tools_enabled = bool(request.sampler.ALLOW_TOOLS)
|
|
|
elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None:
|
|
|
tools_enabled = bool(request.sampler_allow_tools)
|
|
|
else:
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if ms and ms.MODEL_CONFIG:
|
|
|
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None:
|
|
|
if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS:
|
|
|
tools_enabled = False
|
|
|
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS:
|
|
|
tools_enabled = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
reasoning_enabled = bool(
|
|
|
True
|
|
|
if (request.enable_reasoning is not None and request.enable_reasoning)
|
|
|
else (
|
|
|
bool(False) or bool(request.auto_reasoning if request.auto_reasoning is not None else (getattr(CONFIG, 'AUTO_ENABLE_REASONING', True) and bool(detection.get('need_reasoning'))))
|
|
|
)
|
|
|
)
|
|
|
if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None:
|
|
|
reasoning_enabled = False
|
|
|
try:
|
|
|
if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None:
|
|
|
reasoning_enabled = bool(request.sampler.ALLOW_REASONING)
|
|
|
elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None:
|
|
|
reasoning_enabled = bool(request.sampler_allow_reasoning)
|
|
|
else:
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if ms and ms.MODEL_CONFIG:
|
|
|
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None:
|
|
|
if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING:
|
|
|
reasoning_enabled = False
|
|
|
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
|
|
|
reasoning_enabled = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
try:
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING:
|
|
|
reasoning_enabled = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
return {
|
|
|
'web_search_enabled': web_search_enabled,
|
|
|
'file_tool_enabled': file_tool_enabled,
|
|
|
'tools_enabled': tools_enabled,
|
|
|
'reasoning_enabled': reasoning_enabled,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
from model_tags import apply_model_tags_to_request_obj as apply_model_tags_to_request
|
|
|
except Exception:
|
|
|
def apply_model_tags_to_request(req: Any):
|
|
|
|
|
|
if not req or not getattr(req, 'model', None) or ':' not in req.model:
|
|
|
return
|
|
|
original = req.model
|
|
|
parts = [p.strip() for p in original.split(":") if p is not None and p != ""]
|
|
|
if len(parts) <= 1:
|
|
|
return
|
|
|
base = parts[0]
|
|
|
tags = parts[1:]
|
|
|
req.model = base
|
|
|
for tag in tags:
|
|
|
t = tag.lower()
|
|
|
if t in ("thinking", "think", "reasoning", "reason"):
|
|
|
req.enable_reasoning = True
|
|
|
req.auto_reasoning = False
|
|
|
elif t in ("web", "web_search", "search"):
|
|
|
req.enable_web_search = True
|
|
|
req.web_search = True
|
|
|
req.auto_web_search = False
|
|
|
elif t in ("no-web", "disable-web", "no-web-search"):
|
|
|
req.enable_web_search = False
|
|
|
req.web_search = False
|
|
|
elif t in ("tools", "enable-tools"):
|
|
|
req.enable_tools = True
|
|
|
req.auto_tools = False
|
|
|
elif t in ("no-tools", "disable-tools"):
|
|
|
req.enable_tools = False
|
|
|
elif t in ("file", "file_tool", "filetool"):
|
|
|
req.enable_file_tool = True
|
|
|
req.auto_file_tool = False
|
|
|
elif t in ("no-file", "disable-file"):
|
|
|
req.enable_file_tool = False
|
|
|
elif t in ("universal", "univ"):
|
|
|
req.enable_universal = True
|
|
|
req.auto_universal = False
|
|
|
elif t in ("stream",):
|
|
|
req.stream = True
|
|
|
|
|
|
|
|
|
logger.info(f"STRATEGY - {CONFIG.STRATEGY}")
|
|
|
|
|
|
logGPUState()
|
|
|
|
|
|
|
|
|
def load_models_once():
|
|
|
"""Load and initialize configured models into `MODEL_STORAGE`. This is executed once at server startup."""
|
|
|
global DEFALUT_MODEL_NAME, DEFAULT_REASONING_MODEL_NAME
|
|
|
logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config")
|
|
|
for model_config in CONFIG.MODELS:
|
|
|
logger.info(f"Load Model - {model_config.SERVICE_NAME}")
|
|
|
|
|
|
if model_config.MODEL_FILE_PATH == None:
|
|
|
model_config.MODEL_FILE_PATH = hf_hub_download(
|
|
|
repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID),
|
|
|
filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME),
|
|
|
local_dir=str(model_config.DOWNLOAD_MODEL_DIR),
|
|
|
)
|
|
|
logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}")
|
|
|
|
|
|
if model_config.DEFAULT_CHAT:
|
|
|
if DEFALUT_MODEL_NAME != None:
|
|
|
logger.info(
|
|
|
f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
|
|
|
)
|
|
|
DEFALUT_MODEL_NAME = model_config.SERVICE_NAME
|
|
|
|
|
|
if model_config.DEFAULT_REASONING:
|
|
|
if DEFAULT_REASONING_MODEL_NAME != None:
|
|
|
logger.info(
|
|
|
f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`"
|
|
|
)
|
|
|
DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME
|
|
|
|
|
|
logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`")
|
|
|
print(model_config.DEFAULT_SAMPLER)
|
|
|
|
|
|
MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage()
|
|
|
MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config
|
|
|
MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV(
|
|
|
model=model_config.MODEL_FILE_PATH.replace(".pth", ""),
|
|
|
strategy=CONFIG.STRATEGY,
|
|
|
)
|
|
|
MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE(
|
|
|
MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB
|
|
|
)
|
|
|
if "cuda" in CONFIG.STRATEGY:
|
|
|
torch.cuda.empty_cache()
|
|
|
gc.collect()
|
|
|
logGPUState()
|
|
|
|
|
|
logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`")
|
|
|
logger.info(f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`")
|
|
|
if len(MODEL_STORAGE) == 1:
|
|
|
single_name = list(MODEL_STORAGE.keys())[0]
|
|
|
if DEFALUT_MODEL_NAME != single_name:
|
|
|
DEFALUT_MODEL_NAME = single_name
|
|
|
logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`")
|
|
|
if DEFAULT_REASONING_MODEL_NAME != single_name:
|
|
|
DEFAULT_REASONING_MODEL_NAME = single_name
|
|
|
logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`")
|
|
|
|
|
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
|
model: str = Field(
|
|
|
default="rwkv-latest",
|
|
|
description="Specify the model name. Model tags/suffixes (e.g., ':thinking' or ':web') are not supported — set the corresponding request flags (enable_reasoning, web_search, enable_file_tool) instead.",
|
|
|
)
|
|
|
messages: Optional[List[ChatMessage]] = Field(default=None)
|
|
|
prompt: Optional[str] = Field(default=None)
|
|
|
max_tokens: Optional[int] = Field(default=None)
|
|
|
temperature: Optional[float] = Field(default=None)
|
|
|
top_p: Optional[float] = Field(default=None)
|
|
|
presence_penalty: Optional[float] = Field(default=None)
|
|
|
count_penalty: Optional[float] = Field(default=None)
|
|
|
penalty_decay: Optional[float] = Field(default=None)
|
|
|
stream: Optional[bool] = Field(default=True, description="Whether to stream token-by-token responses. If None, uses CONFIG.DEFAULT_STREAM")
|
|
|
state_name: Optional[str] = Field(default=None)
|
|
|
include_usage: Optional[bool] = Field(default=False)
|
|
|
stop: Optional[list[str]] = Field(["\n\n"])
|
|
|
stop_tokens: Optional[list[int]] = Field([0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
web_search: Optional[bool] = Field(default=None, description="Whether to perform a web search and append results to the prompt; if None, auto-detection is used")
|
|
|
enable_web_search: Optional[bool] = Field(default=None, description="Explicitly enable web search (overrides auto/web_search) if set; if None, auto-detection controls it")
|
|
|
auto_web_search: Optional[bool] = Field(default=True, description="Whether to enable web_search based on auto-detected intent")
|
|
|
enable_tools: Optional[bool] = Field(default=None, description="Explicitly enable tools (overrides auto detection)")
|
|
|
auto_tools: Optional[bool] = Field(default=True, description="Whether to enable tools based on auto-detected intent")
|
|
|
enable_reasoning: Optional[bool] = Field(default=None, description="Explicitly override reasoning enablement; if None, auto-detection controls it")
|
|
|
auto_reasoning: Optional[bool] = Field(default=True, description="Whether to enable reasoning based on auto detection")
|
|
|
enable_universal: Optional[bool] = Field(default=None, description="Explicitly enable the universal tool execution")
|
|
|
auto_universal: Optional[bool] = Field(default=True, description="Whether to auto enable universal tool execution")
|
|
|
search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve")
|
|
|
tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})")
|
|
|
|
|
|
|
|
|
sampler_allow_web_search: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing web_search")
|
|
|
sampler_allow_tools: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing tools")
|
|
|
sampler_allow_reasoning: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing reasoning")
|
|
|
|
|
|
|
|
|
sampler: Optional[SamplerConfig] = Field(default=None, description="Per-request sampler settings (overrides model default)")
|
|
|
|
|
|
file_ids: Optional[List[str]] = Field(default=None, description="List of uploaded file IDs that the model may use for this request")
|
|
|
enable_file_tool: Optional[bool] = Field(default=None, description="Explicitly enable file-based tools for this request")
|
|
|
auto_file_tool: Optional[bool] = Field(default=None, description="Auto-detect whether file-based tools are needed")
|
|
|
sampler_allow_file_tool: Optional[bool] = Field(default=None, description="Per-request sampler override allowing file tools")
|
|
|
|
|
|
@model_validator(mode="before")
|
|
|
@classmethod
|
|
|
def validate_mutual_exclusivity(cls, data: Any) -> Any:
|
|
|
if not isinstance(data, dict):
|
|
|
return data
|
|
|
|
|
|
messages_provided = "messages" in data and data["messages"] != None
|
|
|
prompt_provided = "prompt" in data and data["prompt"] != None
|
|
|
|
|
|
if messages_provided and prompt_provided:
|
|
|
raise ValueError("messages and prompt cannot coexist. Choose one.")
|
|
|
if not messages_provided and not prompt_provided:
|
|
|
raise ValueError("Either messages or prompt must be provided.")
|
|
|
return data
|
|
|
|
|
|
|
|
|
app = FastAPI(title="RWKV OpenAI-Compatible API")
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5)
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def _startup_state_load_and_persist_loop():
|
|
|
|
|
|
_load_state_store_from_disk()
|
|
|
|
|
|
try:
|
|
|
load_models_once()
|
|
|
except Exception as e:
|
|
|
logger.info(f"Model loading at startup failed: {e}")
|
|
|
|
|
|
async def _persist_loop():
|
|
|
while True:
|
|
|
try:
|
|
|
_save_state_store_to_disk(force=False)
|
|
|
except Exception:
|
|
|
pass
|
|
|
await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5))
|
|
|
|
|
|
|
|
|
try:
|
|
|
asyncio.create_task(_persist_loop())
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
async def runPrefill(
|
|
|
request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state
|
|
|
):
|
|
|
ctx = ctx.replace("\r\n", "\n")
|
|
|
out = None
|
|
|
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if not ms or not ms.pipeline or not ms.model:
|
|
|
raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
|
|
|
tokens = ms.pipeline.encode(ctx)
|
|
|
tokens = [int(x) for x in tokens]
|
|
|
model_tokens += tokens
|
|
|
|
|
|
while len(tokens) > 0:
|
|
|
out, model_state = ms.model.forward(
|
|
|
tokens[: CONFIG.CHUNK_LEN], model_state
|
|
|
)
|
|
|
tokens = tokens[CONFIG.CHUNK_LEN :]
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
return out, model_tokens, model_state
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
request: ChatCompletionRequest,
|
|
|
out,
|
|
|
model_tokens: List[int],
|
|
|
model_state,
|
|
|
max_tokens=2048,
|
|
|
):
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if not ms or not ms.pipeline or not ms.model:
|
|
|
raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing")
|
|
|
|
|
|
|
|
|
from config import CONFIG as _CFG
|
|
|
temperature = request.temperature if request.temperature is not None else 0.2
|
|
|
temperature = min(max(temperature, getattr(_CFG, 'MIN_TEMPERATURE', 0.0)), getattr(_CFG, 'MAX_TEMPERATURE', 2.0))
|
|
|
top_p = request.top_p if request.top_p is not None else 0.9
|
|
|
top_p = min(max(top_p, getattr(_CFG, 'MIN_TOP_P', 0.0)), getattr(_CFG, 'MAX_TOP_P', 1.0))
|
|
|
alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0
|
|
|
alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0
|
|
|
penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5
|
|
|
|
|
|
args = PIPELINE_ARGS(
|
|
|
temperature=max(0.2, temperature),
|
|
|
top_p=top_p,
|
|
|
alpha_frequency=alpha_frequency,
|
|
|
alpha_presence=alpha_presence,
|
|
|
token_ban=[],
|
|
|
token_stop=[0],
|
|
|
)
|
|
|
|
|
|
occurrence = {}
|
|
|
out_tokens: List[int] = []
|
|
|
out_last = 0
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(max_tokens):
|
|
|
for n in occurrence:
|
|
|
out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
|
|
|
|
|
|
|
|
|
token = ms.pipeline.sample_logits(
|
|
|
out, temperature=args.temperature, top_p=args.top_p
|
|
|
)
|
|
|
|
|
|
if token == 0 and request.stop_tokens and token in request.stop_tokens:
|
|
|
yield {
|
|
|
"content": "",
|
|
|
"tokens": out_tokens[out_last:],
|
|
|
"finish_reason": "stop:token:0",
|
|
|
"state": model_state,
|
|
|
}
|
|
|
|
|
|
del out
|
|
|
gc.collect()
|
|
|
return
|
|
|
|
|
|
out, model_state = ms.model.forward([token], model_state)
|
|
|
model_tokens.append(token)
|
|
|
out_tokens.append(token)
|
|
|
|
|
|
if request.stop_tokens and token in request.stop_tokens:
|
|
|
yield {
|
|
|
"content": "",
|
|
|
"tokens": out_tokens[out_last:],
|
|
|
"finish_reason": f"stop:token:{token}",
|
|
|
"state": model_state,
|
|
|
}
|
|
|
|
|
|
del out
|
|
|
gc.collect()
|
|
|
return
|
|
|
|
|
|
for xxx in list(occurrence.keys()):
|
|
|
occurrence[xxx] *= penalty_decay
|
|
|
occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
|
|
|
|
|
|
|
|
|
decoded = ms.pipeline.decode([token])
|
|
|
|
|
|
if "\ufffd" in decoded:
|
|
|
continue
|
|
|
|
|
|
yield {
|
|
|
"content": decoded,
|
|
|
"tokens": [token],
|
|
|
"finish_reason": None,
|
|
|
"state": model_state,
|
|
|
}
|
|
|
out_last = i + 1
|
|
|
|
|
|
else:
|
|
|
yield {
|
|
|
"content": "",
|
|
|
"tokens": [],
|
|
|
"finish_reason": "length",
|
|
|
}
|
|
|
|
|
|
|
|
|
async def chatResponse(
|
|
|
request: ChatCompletionRequest,
|
|
|
model_state: Any,
|
|
|
completionId: str,
|
|
|
enableReasoning: bool,
|
|
|
) -> ChatCompletion:
|
|
|
|
|
|
createTimestamp = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
def decide_file_tool_enabled(request):
|
|
|
return resolve_request_flags(request, detection)['file_tool_enabled']
|
|
|
|
|
|
def decide_tools_enabled(request, detection):
|
|
|
return resolve_request_flags(request, detection)['tools_enabled']
|
|
|
|
|
|
def decide_reasoning_enabled(request, detection, enableReasoning):
|
|
|
|
|
|
flags = resolve_request_flags(request, detection)
|
|
|
return flags['reasoning_enabled']
|
|
|
|
|
|
def execute_tools(request, detection, prompt, executed_tool_calls, web_search_enabled, tools_enabled, file_tool_enabled, raw_prompt):
|
|
|
"""Helper to execute tools and update prompt and executed_tool_calls."""
|
|
|
|
|
|
if file_tool_enabled and request.file_ids:
|
|
|
for fid in request.file_ids:
|
|
|
try:
|
|
|
if fid not in UPLOADED_FILES:
|
|
|
continue
|
|
|
meta = UPLOADED_FILES.get(fid)
|
|
|
if not meta:
|
|
|
continue
|
|
|
from utils import file_read_from_path
|
|
|
fpath = meta.get('path')
|
|
|
if not fpath or not os.path.exists(fpath):
|
|
|
continue
|
|
|
file_content = file_read_from_path(fpath, 200000)
|
|
|
if file_content:
|
|
|
exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}}
|
|
|
executed_tool_calls.append(exec_entry)
|
|
|
prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"File injection error: {e}")
|
|
|
|
|
|
if request.tools:
|
|
|
try:
|
|
|
for tool in request.tools:
|
|
|
name = tool.get('name')
|
|
|
args = tool.get('args', {})
|
|
|
if name == 'web_search':
|
|
|
from utils import web_search
|
|
|
search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
|
|
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
|
|
search_str = web_search(search_q, search_top_k)
|
|
|
if search_str:
|
|
|
search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
|
|
|
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
|
|
|
prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt)
|
|
|
elif name == 'calc' or name == 'calculator':
|
|
|
from utils import calc
|
|
|
expr = args.get('expression')
|
|
|
if expr:
|
|
|
calc_res = calc(expr)
|
|
|
calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
|
|
|
executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
|
|
|
prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt)
|
|
|
elif name == 'universal':
|
|
|
try:
|
|
|
res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
|
|
if isinstance(res, dict):
|
|
|
result_text = res.get('result') if res.get('result') is not None else ''
|
|
|
else:
|
|
|
result_text = str(res)
|
|
|
executed_tool_calls.append({"name": "universal", "args": args, "result": res})
|
|
|
prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"Universal tool execution error: {e}")
|
|
|
else:
|
|
|
logger.info(f"Unsupported tool requested: {name}")
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.info(f"Tool processing error: {e}")
|
|
|
elif request.web_search or web_search_enabled:
|
|
|
try:
|
|
|
from utils import web_search
|
|
|
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
|
|
search_res = web_search(search_q, int(request.search_top_k or 3))
|
|
|
if search_res:
|
|
|
search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
|
|
|
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
|
|
|
prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
|
|
|
except Exception:
|
|
|
pass
|
|
|
return prompt, executed_tool_calls
|
|
|
|
|
|
raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [])
|
|
|
detection = detect_tools_and_reasoning(raw_prompt)
|
|
|
prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
|
|
|
|
|
def decide_web_search_enabled(request, detection):
|
|
|
web_search_enabled = (
|
|
|
True
|
|
|
if (request.enable_web_search is not None and request.enable_web_search)
|
|
|
else (
|
|
|
request.web_search
|
|
|
or (request.auto_web_search if request.auto_web_search is not None else (getattr(CONFIG, 'AUTO_ENABLE_WEB_SEARCH', True) and detection.get('need_web_search')))
|
|
|
)
|
|
|
)
|
|
|
if not getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True) and request.enable_web_search is None and not (request.web_search or False):
|
|
|
web_search_enabled = False
|
|
|
try:
|
|
|
if request.sampler and getattr(request.sampler, 'ALLOW_WEB_SEARCH', None) is not None:
|
|
|
web_search_enabled = bool(request.sampler.ALLOW_WEB_SEARCH)
|
|
|
elif hasattr(request, 'sampler_allow_web_search') and request.sampler_allow_web_search is not None:
|
|
|
web_search_enabled = bool(request.sampler_allow_web_search)
|
|
|
else:
|
|
|
ms = MODEL_STORAGE.get(request.model)
|
|
|
if ms and ms.MODEL_CONFIG:
|
|
|
if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_WEB_SEARCH', None) is not None:
|
|
|
web_search_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_WEB_SEARCH)
|
|
|
elif hasattr(ms.MODEL_CONFIG, 'ALLOW_WEB_SEARCH') and not ms.MODEL_CONFIG.ALLOW_WEB_SEARCH:
|
|
|
web_search_enabled = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
return web_search_enabled
|
|
|
|
|
|
web_search_enabled = decide_web_search_enabled(request, detection)
|
|
|
file_tool_enabled = decide_file_tool_enabled(request)
|
|
|
tools_enabled = decide_tools_enabled(request, detection)
|
|
|
reasoning_enabled = decide_reasoning_enabled(request, detection, enableReasoning)
|
|
|
enableReasoning = reasoning_enabled
|
|
|
|
|
|
if request.enable_web_search is None:
|
|
|
request.web_search = web_search_enabled
|
|
|
if tools_enabled and not request.tools:
|
|
|
if detection.get('detected_tools'):
|
|
|
detected = detection.get('detected_tools') or []
|
|
|
filtered = []
|
|
|
for t in detected:
|
|
|
tname = t.get('name')
|
|
|
args = t.get('args', {})
|
|
|
ms_check = MODEL_STORAGE.get(request.model)
|
|
|
allowed = True
|
|
|
try:
|
|
|
if ms_check and ms_check.MODEL_CONFIG:
|
|
|
if tname in ('web_search', 'fetch_url') and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_FETCH_URL') and not ms_check.MODEL_CONFIG.ALLOW_FETCH_URL:
|
|
|
allowed = False
|
|
|
if tname == 'summarize' and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_SUMMARIZE') and not ms_check.MODEL_CONFIG.ALLOW_SUMMARIZE:
|
|
|
allowed = False
|
|
|
if tname in ('keywords',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_KEYWORDS') and not ms_check.MODEL_CONFIG.ALLOW_KEYWORDS:
|
|
|
allowed = False
|
|
|
if tname in ('sentiment',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_SENTIMENT') and not ms_check.MODEL_CONFIG.ALLOW_SENTIMENT:
|
|
|
allowed = False
|
|
|
if tname in ('translate',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_TRANSLATE') and not ms_check.MODEL_CONFIG.ALLOW_TRANSLATE:
|
|
|
allowed = False
|
|
|
if tname in ('spell_check',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_SPELL_CHECK') and not ms_check.MODEL_CONFIG.ALLOW_SPELL_CHECK:
|
|
|
allowed = False
|
|
|
if tname in ('format_code',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_FORMAT_CODE') and not ms_check.MODEL_CONFIG.ALLOW_FORMAT_CODE:
|
|
|
allowed = False
|
|
|
if tname in ('explain_code',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_EXPLAIN_CODE') and not ms_check.MODEL_CONFIG.ALLOW_EXPLAIN_CODE:
|
|
|
allowed = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
if allowed:
|
|
|
filtered.append(t)
|
|
|
request.tools = filtered if filtered else None
|
|
|
if (request.enable_universal is True) or (
|
|
|
request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal'))
|
|
|
):
|
|
|
if not request.tools:
|
|
|
request.tools = [{"name": "universal", "args": {"query": raw_prompt}}]
|
|
|
|
|
|
|
|
|
executed_tool_calls = []
|
|
|
prompt, executed_tool_calls = execute_tools(request, detection, prompt, executed_tool_calls, web_search_enabled, tools_enabled, file_tool_enabled, raw_prompt)
|
|
|
logger.info(f"[REQ] {completionId} - prompt - {prompt}")
|
|
|
|
|
|
|
|
|
if request.state_name:
|
|
|
state_key = (request.model, request.state_name)
|
|
|
if state_key in STATE_STORE:
|
|
|
stored = STATE_STORE[state_key]
|
|
|
model_state = stored.get('state', None)
|
|
|
model_tokens = stored.get('model_tokens', [0])
|
|
|
if model_state is None:
|
|
|
|
|
|
out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
|
|
|
else:
|
|
|
|
|
|
out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
|
|
|
else:
|
|
|
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
|
|
else:
|
|
|
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
|
|
|
|
|
prefillTime = time.time()
|
|
|
promptTokenCount = len(model_tokens)
|
|
|
|
|
|
fullResponse = " <think" if enableReasoning else ""
|
|
|
completionTokenCount = 0
|
|
|
finishReason = None
|
|
|
|
|
|
model_initiated_tool_calls = 0
|
|
|
MODEL_MAX_TOOL_CALLS = 3
|
|
|
should_restart = True
|
|
|
while should_restart:
|
|
|
should_restart = False
|
|
|
|
|
|
max_gen_tokens = (
|
|
|
getattr(CONFIG, 'MAX_GENERATION_TOKENS_LIMIT', 64000)
|
|
|
if "max_tokens" not in request.model_fields_set and enableReasoning
|
|
|
else (request.max_tokens or 2048)
|
|
|
)
|
|
|
max_tokens_limit = getattr(CONFIG, 'MAX_TOKENS_PER_REQUEST', None)
|
|
|
if max_tokens_limit:
|
|
|
max_gen_tokens = min(max_gen_tokens, max_tokens_limit)
|
|
|
gen = generate(request, out, model_tokens, model_state, max_tokens=max_gen_tokens)
|
|
|
for chunk in gen:
|
|
|
|
|
|
fullResponse += chunk["content"]
|
|
|
|
|
|
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
|
|
|
m = TOOL_CALL_RE.search(fullResponse)
|
|
|
if m:
|
|
|
try:
|
|
|
payload_raw = m.group(1)
|
|
|
import json
|
|
|
|
|
|
payload = json.loads(payload_raw)
|
|
|
tool_name = payload.get('name')
|
|
|
tool_args = payload.get('args', {})
|
|
|
tool_res = None
|
|
|
if tool_name == 'web_search':
|
|
|
from utils import web_search
|
|
|
|
|
|
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
|
|
k = int(tool_args.get('top_k') or request.search_top_k or 3)
|
|
|
tool_res = web_search(q, k)
|
|
|
elif tool_name in ('calc', 'calculator'):
|
|
|
from utils import calc
|
|
|
|
|
|
expr = tool_args.get('expression')
|
|
|
if expr:
|
|
|
tool_res = calc(expr)
|
|
|
else:
|
|
|
try:
|
|
|
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
|
|
except Exception:
|
|
|
tool_res = None
|
|
|
|
|
|
if tool_res:
|
|
|
if not isinstance(tool_res, dict):
|
|
|
if tool_name in ('calc', 'calculator'):
|
|
|
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
|
|
|
elif tool_name == 'web_search':
|
|
|
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
|
|
|
else:
|
|
|
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
|
|
|
else:
|
|
|
tool_res_struct = tool_res
|
|
|
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
|
|
|
executed_tool_calls.append(exec_entry)
|
|
|
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
|
|
|
prompt = delta_text + prompt
|
|
|
fullResponse = TOOL_CALL_RE.sub('', fullResponse)
|
|
|
buffer = [fullResponse]
|
|
|
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
|
|
|
model_initiated_tool_calls += 1
|
|
|
except Exception as e:
|
|
|
logger.info(f"Model-initiated tool handling error: {e}")
|
|
|
|
|
|
for stop_words in request.stop or []:
|
|
|
if stop_words in fullResponse:
|
|
|
finishReason = f"stop:words:{stop_words}"
|
|
|
break
|
|
|
completionTokenCount += 1
|
|
|
|
|
|
if chunk["finish_reason"]:
|
|
|
finishReason = chunk["finish_reason"]
|
|
|
|
|
|
generateTime = time.time()
|
|
|
|
|
|
responseLog = {
|
|
|
"content": fullResponse,
|
|
|
"finish": finishReason,
|
|
|
"prefill_len": promptTokenCount,
|
|
|
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
|
|
|
"gen_len": completionTokenCount,
|
|
|
"gen_tps": round(completionTokenCount / (generateTime - prefillTime) if generateTime!=prefillTime else 0, 2),
|
|
|
}
|
|
|
logger.info(f"[RES] {completionId} - {responseLog}")
|
|
|
|
|
|
reasoning_content, content = parse_think_response(fullResponse)
|
|
|
|
|
|
try:
|
|
|
from utils import bias_mitigation
|
|
|
|
|
|
mitigation = bias_mitigation(content)
|
|
|
if mitigation and isinstance(mitigation, dict):
|
|
|
if mitigation.get('suppressed'):
|
|
|
executed_tool_calls.append({"name": "safety_mitigation", "args": {}, "result": {"action": "safety", "result": mitigation.get('sanitized'), "metadata": {"reason": mitigation.get('reason')}}})
|
|
|
content = mitigation.get('sanitized')
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
response = ChatCompletion(
|
|
|
id=completionId,
|
|
|
created=int(createTimestamp),
|
|
|
model=request.model,
|
|
|
usage=Usage(
|
|
|
prompt_tokens=promptTokenCount,
|
|
|
completion_tokens=completionTokenCount,
|
|
|
total_tokens=promptTokenCount + completionTokenCount,
|
|
|
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
|
|
),
|
|
|
choices=[
|
|
|
ChatCompletionChoice(
|
|
|
index=0,
|
|
|
message=ChatCompletionMessage(
|
|
|
role="Assistant",
|
|
|
content=content,
|
|
|
reasoning_content=reasoning_content if reasoning_content else None,
|
|
|
tool_calls=executed_tool_calls if executed_tool_calls else None,
|
|
|
),
|
|
|
logprobs=None,
|
|
|
finish_reason=finishReason,
|
|
|
)
|
|
|
],
|
|
|
)
|
|
|
|
|
|
|
|
|
try:
|
|
|
if request.state_name:
|
|
|
STATE_STORE[(request.model, request.state_name)] = {
|
|
|
'state': model_state,
|
|
|
'model_tokens': model_tokens,
|
|
|
}
|
|
|
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
|
|
|
try:
|
|
|
_save_state_store_to_disk(force=True)
|
|
|
except Exception:
|
|
|
pass
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
async def chatResponseStream(
|
|
|
request: ChatCompletionRequest,
|
|
|
model_state: Any,
|
|
|
completionId: str,
|
|
|
enableReasoning: bool,
|
|
|
):
|
|
|
createTimestamp = int(time.time())
|
|
|
|
|
|
raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or [], False)
|
|
|
|
|
|
detection = detect_tools_and_reasoning(raw_prompt)
|
|
|
|
|
|
|
|
|
flags = resolve_request_flags(request, detection)
|
|
|
web_search_enabled = flags['web_search_enabled']
|
|
|
tools_enabled = flags['tools_enabled']
|
|
|
file_tool_enabled = flags['file_tool_enabled']
|
|
|
reasoning_enabled = flags['reasoning_enabled']
|
|
|
enableReasoning = reasoning_enabled
|
|
|
try:
|
|
|
ms_cfg = MODEL_STORAGE.get(request.model)
|
|
|
if ms_cfg and ms_cfg.MODEL_CONFIG and hasattr(ms_cfg.MODEL_CONFIG, 'ALLOW_REASONING') and not ms_cfg.MODEL_CONFIG.ALLOW_REASONING:
|
|
|
enableReasoning = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' <think' if enableReasoning else ''}"
|
|
|
|
|
|
if tools_enabled and not request.tools:
|
|
|
if detection.get('detected_tools'):
|
|
|
|
|
|
detected = detection.get('detected_tools') or []
|
|
|
filtered = []
|
|
|
for t in detected:
|
|
|
tname = t.get('name')
|
|
|
args = t.get('args', {})
|
|
|
ms_check = MODEL_STORAGE.get(request.model)
|
|
|
allowed = True
|
|
|
try:
|
|
|
if ms_check and ms_check.MODEL_CONFIG:
|
|
|
if tname in ('web_search', 'fetch_url') and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_FETCH_URL') and not ms_check.MODEL_CONFIG.ALLOW_FETCH_URL:
|
|
|
allowed = False
|
|
|
if tname == 'summarize' and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_SUMMARIZE') and not ms_check.MODEL_CONFIG.ALLOW_SUMMARIZE:
|
|
|
allowed = False
|
|
|
if tname in ('keywords',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_KEYWORDS') and not ms_check.MODEL_CONFIG.ALLOW_KEYWORDS:
|
|
|
allowed = False
|
|
|
if tname in ('sentiment',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_SENTIMENT') and not ms_check.MODEL_CONFIG.ALLOW_SENTIMENT:
|
|
|
allowed = False
|
|
|
if tname in ('translate',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_TRANSLATE') and not ms_check.MODEL_CONFIG.ALLOW_TRANSLATE:
|
|
|
allowed = False
|
|
|
if tname in ('spell_check',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_SPELL_CHECK') and not ms_check.MODEL_CONFIG.ALLOW_SPELL_CHECK:
|
|
|
allowed = False
|
|
|
if tname in ('format_code',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_FORMAT_CODE') and not ms_check.MODEL_CONFIG.ALLOW_FORMAT_CODE:
|
|
|
allowed = False
|
|
|
if tname in ('explain_code',) and hasattr(ms_check.MODEL_CONFIG, 'ALLOW_EXPLAIN_CODE') and not ms_check.MODEL_CONFIG.ALLOW_EXPLAIN_CODE:
|
|
|
allowed = False
|
|
|
except Exception:
|
|
|
pass
|
|
|
if allowed:
|
|
|
filtered.append(t)
|
|
|
request.tools = filtered if filtered else None
|
|
|
executed_tool_calls = []
|
|
|
if request.tools:
|
|
|
try:
|
|
|
for tool in request.tools:
|
|
|
name = tool.get('name')
|
|
|
args = tool.get('args', {})
|
|
|
if name == 'web_search':
|
|
|
from utils import web_search
|
|
|
|
|
|
search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
|
|
search_top_k = int(args.get('top_k') or request.search_top_k or 3)
|
|
|
search_str = web_search(search_q, search_top_k)
|
|
|
if search_str:
|
|
|
search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}}
|
|
|
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct})
|
|
|
prompt = (f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt)
|
|
|
elif name == 'calc' or name == 'calculator':
|
|
|
from utils import calc
|
|
|
|
|
|
expr = args.get('expression')
|
|
|
if expr:
|
|
|
calc_res = calc(expr)
|
|
|
calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}}
|
|
|
executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct})
|
|
|
prompt = (f"CalcResult:{expr} = {calc_res_struct.get('result')}\n\n" + prompt)
|
|
|
elif name == 'universal':
|
|
|
try:
|
|
|
res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
|
|
if isinstance(res, dict):
|
|
|
result_text = res.get('result') if res.get('result') is not None else ''
|
|
|
else:
|
|
|
result_text = str(res)
|
|
|
executed_tool_calls.append({"name": "universal", "args": args, "result": res})
|
|
|
prompt = (f"ToolResults:\n{result_text}\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"Universal tool execution error: {e}")
|
|
|
else:
|
|
|
logger.info(f"Unsupported tool requested: {name}")
|
|
|
if name == 'fetch_url' or name == 'get_url':
|
|
|
try:
|
|
|
from utils import fetch_url
|
|
|
|
|
|
url = args.get('url') or args.get('uri') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
|
|
if not url:
|
|
|
continue
|
|
|
page = fetch_url(url, int(args.get('max_chars') or 20000))
|
|
|
executed_tool_calls.append({"name": "fetch_url", "args": {"url": url}, "result": {"action": "fetch_url", "result": page, "metadata": {"url": url}}})
|
|
|
prompt = (f"ToolResults:\n{page}\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"fetch_url tool error: {e}")
|
|
|
if name == 'summarize' or name == 'summary':
|
|
|
try:
|
|
|
from utils import summarize_text, fetch_url
|
|
|
|
|
|
txt = args.get('text') or ''
|
|
|
if not txt and args.get('url'):
|
|
|
txt = fetch_url(args.get('url'))
|
|
|
if not txt and request.prompt:
|
|
|
txt = request.prompt
|
|
|
if not txt and request.messages:
|
|
|
txt = cleanMessages(request.messages or [])
|
|
|
if txt:
|
|
|
s = summarize_text(txt, int(args.get('max_sentences') or 3))
|
|
|
executed_tool_calls.append({"name": "summarize", "args": args, "result": {"action": "summarize", "result": s, "metadata": {"confidence": 0.85}}})
|
|
|
prompt = (f"ToolResults:\n{s}\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"summarize tool error: {e}")
|
|
|
if name == 'keywords' or name == 'keyword_extraction':
|
|
|
try:
|
|
|
from utils import extract_keywords, fetch_url
|
|
|
|
|
|
txt = args.get('text') or ''
|
|
|
if not txt and args.get('url'):
|
|
|
txt = fetch_url(args.get('url'))
|
|
|
if not txt and request.prompt:
|
|
|
txt = request.prompt
|
|
|
kws = extract_keywords(txt, int(args.get('top_k') or 5))
|
|
|
executed_tool_calls.append({"name": "keywords", "args": args, "result": {"action": "keywords", "result": kws, "metadata": {"top_k": int(args.get('top_k') or 5), "confidence": 0.78}}})
|
|
|
prompt = (f"ToolResults:\nKeywords:{','.join(kws)}\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"keywords tool error: {e}")
|
|
|
if name == 'sentiment' or name == 'tone':
|
|
|
try:
|
|
|
from utils import sentiment_analysis, fetch_url
|
|
|
|
|
|
txt = args.get('text') or ''
|
|
|
if not txt and args.get('url'):
|
|
|
txt = fetch_url(args.get('url'))
|
|
|
if not txt and request.prompt:
|
|
|
txt = request.prompt
|
|
|
res = sentiment_analysis(txt)
|
|
|
executed_tool_calls.append({"name": "sentiment", "args": args, "result": {"action": "sentiment", "result": res, "metadata": {"confidence": res.get('score', 0)}}})
|
|
|
prompt = (f"ToolResults:\nSentiment: {res.get('sentiment')} (score={res.get('score')})\n\n" + prompt)
|
|
|
except Exception as e:
|
|
|
logger.info(f"sentiment tool error: {e}")
|
|
|
except Exception as e:
|
|
|
logger.info(f"Tool processing error: {e}")
|
|
|
elif request.web_search or web_search_enabled:
|
|
|
try:
|
|
|
from utils import web_search
|
|
|
|
|
|
search_q = request.prompt if request.prompt else cleanMessages(request.messages or [])
|
|
|
search_res = web_search(search_q, int(request.search_top_k or 3))
|
|
|
if search_res:
|
|
|
search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}}
|
|
|
executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct})
|
|
|
prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
logger.info(f"[REQ] {completionId} - context\n```{prompt}```")
|
|
|
|
|
|
|
|
|
if request.state_name:
|
|
|
state_key = (request.model, request.state_name)
|
|
|
if state_key in STATE_STORE:
|
|
|
stored = STATE_STORE[state_key]
|
|
|
model_state = stored.get('state', None)
|
|
|
model_tokens = stored.get('model_tokens', [0])
|
|
|
if model_state is None:
|
|
|
|
|
|
out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens)
|
|
|
else:
|
|
|
out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :])
|
|
|
else:
|
|
|
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
|
|
else:
|
|
|
out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state)
|
|
|
|
|
|
prefillTime = time.time()
|
|
|
promptTokenCount = len(model_tokens)
|
|
|
|
|
|
completionTokenCount = 0
|
|
|
finishReason = None
|
|
|
|
|
|
model_initiated_tool_calls = 0
|
|
|
MODEL_MAX_TOOL_CALLS = 3
|
|
|
|
|
|
response = ChatCompletionChunk(
|
|
|
id=completionId,
|
|
|
created=createTimestamp,
|
|
|
model=request.model,
|
|
|
usage=(
|
|
|
Usage(
|
|
|
prompt_tokens=promptTokenCount,
|
|
|
completion_tokens=completionTokenCount,
|
|
|
total_tokens=promptTokenCount + completionTokenCount,
|
|
|
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
|
|
)
|
|
|
if request.include_usage
|
|
|
else None
|
|
|
),
|
|
|
choices=[
|
|
|
ChatCompletionChoice(
|
|
|
index=0,
|
|
|
delta=ChatCompletionMessage(
|
|
|
role="Assistant",
|
|
|
content="",
|
|
|
reasoning_content="" if enableReasoning else None,
|
|
|
tool_calls=None,
|
|
|
),
|
|
|
logprobs=None,
|
|
|
finish_reason=finishReason,
|
|
|
)
|
|
|
],
|
|
|
)
|
|
|
if response.choices and response.choices[0].delta is None:
|
|
|
response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
|
|
|
r_dict = response.model_dump()
|
|
|
r_dict['state_name'] = request.state_name
|
|
|
|
|
|
if executed_tool_calls:
|
|
|
r_dict['tool_calls'] = executed_tool_calls
|
|
|
try:
|
|
|
if r_dict.get('choices') and len(r_dict['choices']) > 0 and r_dict['choices'][0].get('delta') is not None:
|
|
|
r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls
|
|
|
except Exception:
|
|
|
pass
|
|
|
yield f"data: {r_dict}\n\n"
|
|
|
|
|
|
buffer = []
|
|
|
|
|
|
if enableReasoning:
|
|
|
buffer.append("<think")
|
|
|
|
|
|
streamConfig = {
|
|
|
"isChecking": False,
|
|
|
"fullTextCursor": 0,
|
|
|
"in_think": False,
|
|
|
"cacheStr": "",
|
|
|
}
|
|
|
|
|
|
max_gen_tokens = (
|
|
|
getattr(CONFIG, 'MAX_GENERATION_TOKENS_LIMIT', 64000)
|
|
|
if "max_tokens" not in request.model_fields_set and enableReasoning
|
|
|
else (request.max_tokens or 2048)
|
|
|
)
|
|
|
max_tokens_limit = getattr(CONFIG, 'MAX_TOKENS_PER_REQUEST', None)
|
|
|
if max_tokens_limit:
|
|
|
max_gen_tokens = min(max_gen_tokens, max_tokens_limit)
|
|
|
for chunk in generate(request, out, model_tokens, model_state, max_tokens=max_gen_tokens):
|
|
|
completionTokenCount += 1
|
|
|
|
|
|
chunkContent: str = chunk["content"]
|
|
|
buffer.append(chunkContent)
|
|
|
|
|
|
fullText = "".join(buffer)
|
|
|
|
|
|
if chunk["finish_reason"]:
|
|
|
finishReason = chunk["finish_reason"]
|
|
|
|
|
|
response = ChatCompletionChunk(
|
|
|
id=completionId,
|
|
|
created=createTimestamp,
|
|
|
model=request.model,
|
|
|
usage=(
|
|
|
Usage(
|
|
|
prompt_tokens=promptTokenCount,
|
|
|
completion_tokens=completionTokenCount,
|
|
|
total_tokens=promptTokenCount + completionTokenCount,
|
|
|
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
|
|
)
|
|
|
if request.include_usage
|
|
|
else None
|
|
|
),
|
|
|
choices=[
|
|
|
ChatCompletionChoice(
|
|
|
index=0,
|
|
|
delta=ChatCompletionMessage(
|
|
|
role="Assistant",
|
|
|
content=None,
|
|
|
reasoning_content=None,
|
|
|
tool_calls=None,
|
|
|
),
|
|
|
logprobs=None,
|
|
|
finish_reason=finishReason,
|
|
|
)
|
|
|
],
|
|
|
)
|
|
|
if response.choices and response.choices[0].delta is None:
|
|
|
response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
|
|
|
markStart = fullText.find("<", streamConfig["fullTextCursor"])
|
|
|
if not streamConfig["isChecking"] and markStart != -1:
|
|
|
streamConfig["isChecking"] = True
|
|
|
|
|
|
if streamConfig["in_think"]:
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.reasoning_content = fullText[streamConfig["fullTextCursor"] : markStart]
|
|
|
else:
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.content = fullText[streamConfig["fullTextCursor"] : markStart]
|
|
|
|
|
|
streamConfig["cacheStr"] = ""
|
|
|
streamConfig["fullTextCursor"] = markStart
|
|
|
|
|
|
if streamConfig["isChecking"]:
|
|
|
streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :]
|
|
|
else:
|
|
|
if streamConfig["in_think"]:
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.reasoning_content = chunkContent
|
|
|
else:
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.content = chunkContent
|
|
|
streamConfig["fullTextCursor"] = len(fullText)
|
|
|
|
|
|
markEnd = fullText.find(">", streamConfig["fullTextCursor"])
|
|
|
if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None:
|
|
|
streamConfig["isChecking"] = False
|
|
|
|
|
|
if (
|
|
|
not streamConfig["in_think"]
|
|
|
and streamConfig["cacheStr"].find("<think>") != -1
|
|
|
):
|
|
|
streamConfig["in_think"] = True
|
|
|
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.reasoning_content = (
|
|
|
delta.reasoning_content
|
|
|
if delta.reasoning_content != None
|
|
|
else "" + streamConfig["cacheStr"].replace("<think>", "")
|
|
|
)
|
|
|
|
|
|
elif (
|
|
|
streamConfig["in_think"]
|
|
|
and streamConfig["cacheStr"].find("</think>") != -1
|
|
|
):
|
|
|
streamConfig["in_think"] = False
|
|
|
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.content = (
|
|
|
delta.content
|
|
|
if delta.content != None
|
|
|
else "" + streamConfig["cacheStr"].replace("</think>", "")
|
|
|
)
|
|
|
else:
|
|
|
if streamConfig["in_think"]:
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.reasoning_content = (
|
|
|
delta.reasoning_content
|
|
|
if delta.reasoning_content != None
|
|
|
else "" + streamConfig["cacheStr"]
|
|
|
)
|
|
|
else:
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
delta.content = (
|
|
|
delta.content
|
|
|
if delta.content != None
|
|
|
else "" + streamConfig["cacheStr"]
|
|
|
)
|
|
|
streamConfig["fullTextCursor"] = len(fullText)
|
|
|
|
|
|
delta = response.choices[0].delta
|
|
|
if delta is None:
|
|
|
delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None)
|
|
|
response.choices[0].delta = delta
|
|
|
if delta.content != None or delta.reasoning_content != None:
|
|
|
|
|
|
try:
|
|
|
if request.state_name:
|
|
|
STATE_STORE[(request.model, request.state_name)] = {
|
|
|
'state': model_state,
|
|
|
'model_tokens': model_tokens,
|
|
|
}
|
|
|
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
|
|
|
try:
|
|
|
_save_state_store_to_disk(force=True)
|
|
|
except Exception:
|
|
|
pass
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
|
|
|
m = TOOL_CALL_RE.search(fullText)
|
|
|
if m:
|
|
|
try:
|
|
|
payload_raw = m.group(1)
|
|
|
import json
|
|
|
|
|
|
payload = json.loads(payload_raw)
|
|
|
tool_name = payload.get('name')
|
|
|
tool_args = payload.get('args', {})
|
|
|
tool_res = None
|
|
|
if tool_name == 'web_search':
|
|
|
from utils import web_search
|
|
|
|
|
|
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
|
|
k = int(tool_args.get('top_k') or request.search_top_k or 3)
|
|
|
tool_res = web_search(q, k)
|
|
|
elif tool_name in ('calc', 'calculator'):
|
|
|
from utils import calc
|
|
|
|
|
|
expr = tool_args.get('expression')
|
|
|
if expr:
|
|
|
tool_res = calc(expr)
|
|
|
else:
|
|
|
try:
|
|
|
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
|
|
except Exception:
|
|
|
tool_res = None
|
|
|
|
|
|
if tool_res:
|
|
|
|
|
|
if not isinstance(tool_res, dict):
|
|
|
if tool_name in ('calc', 'calculator'):
|
|
|
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
|
|
|
elif tool_name == 'web_search':
|
|
|
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
|
|
|
else:
|
|
|
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
|
|
|
else:
|
|
|
tool_res_struct = tool_res
|
|
|
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
|
|
|
executed_tool_calls.append(exec_entry)
|
|
|
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
|
|
|
prompt = delta_text + prompt
|
|
|
fullText = TOOL_CALL_RE.sub('', fullText)
|
|
|
buffer = [fullText]
|
|
|
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
|
|
|
model_initiated_tool_calls += 1
|
|
|
should_restart = True
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.info(f"Model-initiated tool handling error: {e}")
|
|
|
yield f"data: {response.model_dump_json()}\n\n"
|
|
|
|
|
|
for stop_words in request.stop or []:
|
|
|
if stop_words in ''.join(buffer):
|
|
|
finishReason = f"stop:words:{stop_words}"
|
|
|
return
|
|
|
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
del streamConfig
|
|
|
else:
|
|
|
should_restart = True
|
|
|
while should_restart:
|
|
|
should_restart = False
|
|
|
gen = generate(request, out, model_tokens, model_state)
|
|
|
for chunk in gen:
|
|
|
completionTokenCount += 1
|
|
|
buffer.append(chunk["content"])
|
|
|
|
|
|
if chunk["finish_reason"]:
|
|
|
finishReason = chunk["finish_reason"]
|
|
|
|
|
|
|
|
|
try:
|
|
|
if request.state_name:
|
|
|
STATE_STORE[(request.model, request.state_name)] = {
|
|
|
'state': model_state,
|
|
|
'model_tokens': model_tokens,
|
|
|
}
|
|
|
if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False):
|
|
|
try:
|
|
|
_save_state_store_to_disk(force=True)
|
|
|
except Exception:
|
|
|
pass
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS:
|
|
|
fullText = ''.join(buffer)
|
|
|
m = TOOL_CALL_RE.search(fullText)
|
|
|
if m:
|
|
|
try:
|
|
|
payload_raw = m.group(1)
|
|
|
import json
|
|
|
|
|
|
payload = json.loads(payload_raw)
|
|
|
tool_name = payload.get('name')
|
|
|
tool_args = payload.get('args', {})
|
|
|
tool_res = None
|
|
|
if tool_name == 'web_search':
|
|
|
from utils import web_search
|
|
|
|
|
|
q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or []))
|
|
|
k = int(tool_args.get('top_k') or request.search_top_k or 3)
|
|
|
tool_res = web_search(q, k)
|
|
|
elif tool_name in ('calc', 'calculator'):
|
|
|
from utils import calc
|
|
|
|
|
|
expr = tool_args.get('expression')
|
|
|
if expr:
|
|
|
tool_res = calc(expr)
|
|
|
else:
|
|
|
try:
|
|
|
tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled))
|
|
|
except Exception:
|
|
|
tool_res = None
|
|
|
|
|
|
if tool_res:
|
|
|
if not isinstance(tool_res, dict):
|
|
|
if tool_name in ('calc', 'calculator'):
|
|
|
tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}}
|
|
|
elif tool_name == 'web_search':
|
|
|
tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}}
|
|
|
else:
|
|
|
tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}}
|
|
|
else:
|
|
|
tool_res_struct = tool_res
|
|
|
exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True}
|
|
|
executed_tool_calls.append(exec_entry)
|
|
|
delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n"
|
|
|
prompt = delta_text + prompt
|
|
|
fullText = TOOL_CALL_RE.sub('', fullText)
|
|
|
buffer = [fullText]
|
|
|
out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state)
|
|
|
|
|
|
try:
|
|
|
meta_resp = ChatCompletionChunk(
|
|
|
id=completionId,
|
|
|
created=createTimestamp,
|
|
|
model=request.model,
|
|
|
usage=(
|
|
|
Usage(
|
|
|
prompt_tokens=promptTokenCount,
|
|
|
completion_tokens=completionTokenCount,
|
|
|
total_tokens=promptTokenCount + completionTokenCount,
|
|
|
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
|
|
)
|
|
|
if request.include_usage
|
|
|
else None
|
|
|
),
|
|
|
choices=[
|
|
|
ChatCompletionChoice(
|
|
|
index=0,
|
|
|
delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls),
|
|
|
logprobs=None,
|
|
|
finish_reason=None,
|
|
|
)
|
|
|
],
|
|
|
)
|
|
|
yield f"data: {meta_resp.model_dump_json()}\n\n"
|
|
|
except Exception:
|
|
|
pass
|
|
|
model_initiated_tool_calls += 1
|
|
|
should_restart = True
|
|
|
break
|
|
|
except Exception as e:
|
|
|
logger.info(f"Model-initiated tool handling error: {e}")
|
|
|
|
|
|
response = ChatCompletionChunk(
|
|
|
id=completionId,
|
|
|
created=createTimestamp,
|
|
|
model=request.model,
|
|
|
usage=(
|
|
|
Usage(
|
|
|
prompt_tokens=promptTokenCount,
|
|
|
completion_tokens=completionTokenCount,
|
|
|
total_tokens=promptTokenCount + completionTokenCount,
|
|
|
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
|
|
)
|
|
|
if request.include_usage
|
|
|
else None
|
|
|
),
|
|
|
choices=[
|
|
|
ChatCompletionChoice(
|
|
|
index=0,
|
|
|
delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None),
|
|
|
logprobs=None,
|
|
|
finish_reason=finishReason,
|
|
|
)
|
|
|
],
|
|
|
)
|
|
|
yield f"data: {response.model_dump_json()}\n\n"
|
|
|
await asyncio.sleep(0)
|
|
|
|
|
|
genenrateTime = time.time()
|
|
|
|
|
|
responseLog = {
|
|
|
"content": "".join(buffer),
|
|
|
"finish": finishReason,
|
|
|
"prefill_len": promptTokenCount,
|
|
|
"prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2),
|
|
|
"gen_len": completionTokenCount,
|
|
|
"gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2),
|
|
|
}
|
|
|
logger.info(f"[RES] {completionId} - {responseLog}")
|
|
|
|
|
|
try:
|
|
|
from utils import bias_mitigation
|
|
|
|
|
|
content_for_mitigation = responseLog.get('content')
|
|
|
if content_for_mitigation is None:
|
|
|
content_for_mitigation = ""
|
|
|
mitigation = bias_mitigation(content_for_mitigation)
|
|
|
if mitigation and isinstance(mitigation, dict) and mitigation.get('suppressed'):
|
|
|
executed_tool_calls.append({"name": "safety_mitigation", "args": {}, "result": {"action": "safety", "result": mitigation.get('sanitized'), "metadata": {"reason": mitigation.get('reason')}}})
|
|
|
responseLog['content'] = mitigation.get('sanitized')
|
|
|
except Exception:
|
|
|
pass
|
|
|
if request.messages is None:
|
|
|
request.messages = []
|
|
|
|
|
|
content_str = responseLog["content"] if responseLog["content"] is not None else ""
|
|
|
request.messages.append(ChatMessage(role="Assistant", content=content_str))
|
|
|
log(
|
|
|
{
|
|
|
**request.model_dump(),
|
|
|
**responseLog,
|
|
|
"completionId": completionId,
|
|
|
"machineLabel": os.environ.get("MACHINE_LABEL"),
|
|
|
}
|
|
|
)
|
|
|
|
|
|
del buffer
|
|
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
|
|
|
@app.post("/api/v1/chat/completions")
|
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
|
completionId = str(next(CompletionIdGenerator))
|
|
|
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
apply_model_tags_to_request(request)
|
|
|
modelName = request.model
|
|
|
|
|
|
if request.model == "rwkv-latest":
|
|
|
|
|
|
|
|
|
if DEFALUT_MODEL_NAME == None:
|
|
|
raise HTTPException(404, "DEFALUT_MODEL_NAME not set")
|
|
|
ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME)
|
|
|
if not ms_def or not ms_def.MODEL_CONFIG:
|
|
|
raise HTTPException(500, "Default sampler config missing for default model")
|
|
|
defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER
|
|
|
request.model = DEFALUT_MODEL_NAME
|
|
|
|
|
|
elif modelName in MODEL_STORAGE:
|
|
|
ms_sel = MODEL_STORAGE.get(modelName)
|
|
|
if not ms_sel or not ms_sel.MODEL_CONFIG:
|
|
|
raise HTTPException(500, f"Default sampler config missing for model {modelName}")
|
|
|
defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER
|
|
|
request.model = modelName
|
|
|
else:
|
|
|
raise HTTPException(404, f"Can not find `{modelName}`")
|
|
|
|
|
|
|
|
|
enableReasoning = bool(request.enable_reasoning) if request.enable_reasoning is not None else False
|
|
|
|
|
|
async def chatResponseStreamDisconnect():
|
|
|
logGPUState()
|
|
|
|
|
|
|
|
|
model_state = None
|
|
|
model_tokens_for_resume = [0]
|
|
|
state_name = request.state_name
|
|
|
if state_name is None:
|
|
|
state_name = str(uuid.uuid4())
|
|
|
request.state_name = state_name
|
|
|
state_key = (request.model, state_name)
|
|
|
if state_key in STATE_STORE:
|
|
|
stored = STATE_STORE[state_key]
|
|
|
model_state = stored.get('state', None)
|
|
|
model_tokens_for_resume = stored.get('model_tokens', [0])
|
|
|
request_dict = request.model_dump()
|
|
|
|
|
|
|
|
|
|
|
|
sampler_overrides = request_dict.get('sampler') or {}
|
|
|
for k, v in defaultSamplerConfig.model_dump().items():
|
|
|
|
|
|
if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None:
|
|
|
request_dict[k] = sampler_overrides.get(k)
|
|
|
continue
|
|
|
if k in request_dict and request_dict[k] is None:
|
|
|
request_dict[k] = v
|
|
|
realRequest = ChatCompletionRequest(**request_dict)
|
|
|
|
|
|
if realRequest.stream is None:
|
|
|
realRequest.stream = CONFIG.DEFAULT_STREAM
|
|
|
|
|
|
|
|
|
try:
|
|
|
max_tokens_limit = getattr(CONFIG, 'MAX_TOKENS_PER_REQUEST', None)
|
|
|
if realRequest.max_tokens is not None and max_tokens_limit is not None:
|
|
|
realRequest.max_tokens = min(realRequest.max_tokens, max_tokens_limit)
|
|
|
if realRequest.temperature is not None:
|
|
|
realRequest.temperature = min(max(realRequest.temperature, getattr(CONFIG, 'MIN_TEMPERATURE', 0.0)), getattr(CONFIG, 'MAX_TEMPERATURE', 2.0))
|
|
|
if realRequest.top_p is not None:
|
|
|
realRequest.top_p = min(max(realRequest.top_p, getattr(CONFIG, 'MIN_TOP_P', 0.0)), getattr(CONFIG, 'MAX_TOP_P', 1.0))
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}")
|
|
|
|
|
|
if realRequest.stream:
|
|
|
r = StreamingResponse(
|
|
|
chatResponseStream(realRequest, model_state, completionId, enableReasoning),
|
|
|
media_type="text/event-stream",
|
|
|
background=BackgroundTask(chatResponseStreamDisconnect),
|
|
|
)
|
|
|
else:
|
|
|
r = await chatResponse(realRequest, model_state, completionId, enableReasoning)
|
|
|
|
|
|
try:
|
|
|
import json
|
|
|
|
|
|
if isinstance(r, ChatCompletion):
|
|
|
d = r.model_dump()
|
|
|
d['state_name'] = state_name
|
|
|
return d
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.isdir("dist-frontend"):
|
|
|
logger.info("Static frontend mount enabled: serving dist-frontend at /")
|
|
|
app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static")
|
|
|
else:
|
|
|
logger.info("Static frontend mount not enabled; `dist-frontend` directory not found")
|
|
|
|
|
|
|
|
|
@app.get('/api/v1/models')
|
|
|
def list_models():
|
|
|
"""Return model configuration summary for clients/UI.
|
|
|
|
|
|
This endpoint returns configured models, their default sampler values, and
|
|
|
ALLOW_* flags so UI clients can build a controls surface based on server
|
|
|
capabilities (web search, tools, reasoning).
|
|
|
"""
|
|
|
out = []
|
|
|
root_defaults = {
|
|
|
'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
|
|
|
'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True),
|
|
|
'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True),
|
|
|
'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True),
|
|
|
'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True),
|
|
|
'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True),
|
|
|
'UPLOAD_URL': '/api/v1/files' if getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False) else None,
|
|
|
'ALLOW_PUBLIC_UPLOADS': getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False),
|
|
|
}
|
|
|
for m in CONFIG.MODELS:
|
|
|
out.append(
|
|
|
{
|
|
|
'SERVICE_NAME': m.SERVICE_NAME,
|
|
|
'DEFAULT_CHAT': m.DEFAULT_CHAT,
|
|
|
'DEFAULT_REASONING': m.DEFAULT_REASONING,
|
|
|
'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True),
|
|
|
'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True),
|
|
|
'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True),
|
|
|
'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True),
|
|
|
'ALLOW_FETCH_URL': getattr(m, 'ALLOW_FETCH_URL', True),
|
|
|
'ALLOW_SUMMARIZE': getattr(m, 'ALLOW_SUMMARIZE', True),
|
|
|
'ALLOW_KEYWORDS': getattr(m, 'ALLOW_KEYWORDS', True),
|
|
|
'ALLOW_SENTIMENT': getattr(m, 'ALLOW_SENTIMENT', True),
|
|
|
'ALLOW_TRANSLATE': getattr(m, 'ALLOW_TRANSLATE', True),
|
|
|
'ALLOW_SPELL_CHECK': getattr(m, 'ALLOW_SPELL_CHECK', True),
|
|
|
'ALLOW_FORMAT_CODE': getattr(m, 'ALLOW_FORMAT_CODE', True),
|
|
|
'ALLOW_EXPLAIN_CODE': getattr(m, 'ALLOW_EXPLAIN_CODE', True),
|
|
|
'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True),
|
|
|
'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True),
|
|
|
'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True),
|
|
|
'SHOW_FETCH_URL_BUTTON': getattr(m, 'SHOW_FETCH_URL_BUTTON', True),
|
|
|
'SHOW_SUMMARIZE_BUTTON': getattr(m, 'SHOW_SUMMARIZE_BUTTON', True),
|
|
|
'SHOW_KEYWORDS_BUTTON': getattr(m, 'SHOW_KEYWORDS_BUTTON', True),
|
|
|
'SHOW_SENTIMENT_BUTTON': getattr(m, 'SHOW_SENTIMENT_BUTTON', True),
|
|
|
'SHOW_TRANSLATE_BUTTON': getattr(m, 'SHOW_TRANSLATE_BUTTON', True),
|
|
|
'SHOW_SPELL_CHECK_BUTTON': getattr(m, 'SHOW_SPELL_CHECK_BUTTON', True),
|
|
|
'SHOW_FORMAT_CODE_BUTTON': getattr(m, 'SHOW_FORMAT_CODE_BUTTON', True),
|
|
|
'SHOW_EXPLAIN_CODE_BUTTON': getattr(m, 'SHOW_EXPLAIN_CODE_BUTTON', True),
|
|
|
'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None,
|
|
|
|
|
|
'UPLOAD_URL': '/api/v1/files' if getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False) else None,
|
|
|
'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True),
|
|
|
}
|
|
|
)
|
|
|
return {'root_defaults': root_defaults, 'models': out}
|
|
|
|
|
|
|
|
|
def upload_file_internal(bytes_data: bytes, filename: Optional[str] = None, model: Optional[str] = None) -> dict:
|
|
|
"""Internal function to register uploaded file into the service memory.
|
|
|
|
|
|
This is NOT exposed as a public HTTP endpoint by default; it's intended to be used by
|
|
|
the server's `universal_tool` or admin-only utilities. Returns saved metadata or raises Exception.
|
|
|
"""
|
|
|
if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True):
|
|
|
raise HTTPException(403, 'File uploads are disabled by server configuration')
|
|
|
if model:
|
|
|
if model not in MODEL_STORAGE:
|
|
|
raise HTTPException(404, f"Model {model} not found")
|
|
|
ms = MODEL_STORAGE[model]
|
|
|
if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True):
|
|
|
raise HTTPException(403, f"Model {model} does not allow file uploads")
|
|
|
from utils import save_bytes_to_upload
|
|
|
meta = save_bytes_to_upload(filename or 'uploaded_file', bytes_data)
|
|
|
if meta.get('error'):
|
|
|
raise Exception(meta.get('error'))
|
|
|
UPLOADED_FILES[meta['file_id']] = meta
|
|
|
return meta
|
|
|
|
|
|
|
|
|
def list_files_internal():
|
|
|
return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()]
|
|
|
|
|
|
|
|
|
def get_file_internal(file_id: str, download: bool = False):
|
|
|
if file_id not in UPLOADED_FILES:
|
|
|
raise HTTPException(404, 'File not found')
|
|
|
meta = UPLOADED_FILES[file_id]
|
|
|
if download:
|
|
|
try:
|
|
|
with open(meta['path'], 'rb') as f:
|
|
|
return StreamingResponse(f, media_type='application/octet-stream')
|
|
|
except Exception as e:
|
|
|
raise HTTPException(500, str(e))
|
|
|
return UploadedFile(**meta)
|
|
|
|
|
|
|
|
|
def delete_file_internal(file_id: str):
|
|
|
if file_id not in UPLOADED_FILES:
|
|
|
raise HTTPException(404, 'File not found')
|
|
|
meta = UPLOADED_FILES.pop(file_id)
|
|
|
try:
|
|
|
if os.path.exists(meta['path']):
|
|
|
os.remove(meta['path'])
|
|
|
except Exception:
|
|
|
pass
|
|
|
return {'success': True}
|
|
|
|
|
|
|
|
|
|
|
|
if getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False):
|
|
|
@app.post('/api/v1/files', response_model=FileUploadResponse)
|
|
|
async def upload_file_public(file: UploadFile = File(...), model: Optional[str] = None):
|
|
|
try:
|
|
|
content = await file.read()
|
|
|
|
|
|
max_upload_size = getattr(CONFIG, 'MAX_UPLOAD_SIZE_BYTES', None)
|
|
|
if max_upload_size is not None and len(content) > max_upload_size:
|
|
|
raise HTTPException(413, 'File too large')
|
|
|
fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file'
|
|
|
meta = upload_file_internal(content, filename=fname, model=model)
|
|
|
return FileUploadResponse(success=True, file=UploadedFile(**meta))
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
raise HTTPException(500, str(e))
|
|
|
|
|
|
@app.get('/api/v1/files')
|
|
|
def list_files_public():
|
|
|
return list_files_internal()
|
|
|
|
|
|
@app.get('/api/v1/files/{file_id}')
|
|
|
def get_file_public(file_id: str, download: bool = False):
|
|
|
return get_file_internal(file_id, download=download)
|
|
|
|
|
|
@app.delete('/api/v1/files/{file_id}')
|
|
|
def delete_file_public(file_id: str):
|
|
|
return delete_file_internal(file_id)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
|
|
|
host = CONFIG.HOST or "127.0.0.1"
|
|
|
port = CONFIG.PORT or 7860
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|