import os if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": from modelscope import patch_hub patch_hub() os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" from config import CONFIG, ModelConfig from utils import ( cleanMessages, parse_think_response, remove_nested_think_tags_stack, format_bytes, log, detect_tools_and_reasoning, universal_tool, ) import copy, types, gc, sys, re, time, collections, asyncio from huggingface_hub import hf_hub_download from loguru import logger from rich import print from snowflake import SnowflakeGenerator CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) from typing import List, Optional, Union, Any, Dict import uuid from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings import numpy as np import torch if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): logger.info(f"CUDA not found, fall back to cpu") CONFIG.STRATEGY = "cpu fp16" # Normalize STRATEGY to include precision if missing (e.g., 'cpu' -> 'cpu fp16') _s = CONFIG.STRATEGY.lower() if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s): logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`") CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16" try: from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo except Exception: nvmlInit = None nvmlDeviceGetHandleByIndex = None nvmlDeviceGetMemoryInfo = None if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None: nvmlInit() gpu_h = nvmlDeviceGetHandleByIndex(0) def logGPUState(): if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None: gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) logger.info( f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}" ) torch.backends.cudnn.benchmark = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models os.environ["RWKV_JIT_ON"] = "1" os.environ["RWKV_CUDA_ON"] = ( "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0" ) from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from fastapi import FastAPI, HTTPException, UploadFile, File from starlette.background import BackgroundTask from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from fastapi.middleware.gzip import GZipMiddleware from api_types import ( ChatMessage, ChatCompletion, ChatCompletionChunk, Usage, PromptTokensDetails, ChatCompletionChoice, ChatCompletionMessage, SamplerConfig, UploadedFile, FileUploadResponse, ) class ModelStorage: MODEL_CONFIG: Optional[ModelConfig] = None model: Optional[RWKV] = None pipeline: Optional[PIPELINE] = None MODEL_STORAGE: Dict[str, ModelStorage] = {} DEFALUT_MODEL_NAME = None DEFAULT_REASONING_MODEL_NAME = None # In-memory model state store to support streaming continuation/resume per state_name. # Keys: (model_name, state_name) -> dict with 'state' and 'model_tokens' STATE_STORE: Dict[tuple, Any] = {} # Serialized state store file path and flush interval defined in CONFIG _STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json') _LAST_STATE_STORE_WRITE = 0 # sentinel for model-initiated tool calls: {json} TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S) # File uploads: simple in-memory index (persisted on disk via the files themselves) UPLOADED_FILES: Dict[str, dict] = {} def _serialize_state_store() -> dict: # Save only model_tokens to disk; model_state (torch objects) are not serializable serial = {} for (model_name, state_name), entry in STATE_STORE.items(): try: mt = entry.get('model_tokens') if isinstance(entry, dict) else None if mt is None: # if entry is a raw model_state, skip continue serial[f"{model_name}|{state_name}"] = { 'model': model_name, 'state_name': state_name, 'model_tokens': mt, } except Exception: continue return serial def _load_state_store_from_disk(): global STATE_STORE try: if os.path.exists(_STATE_STORE_PATH): import json with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f: data = json.load(f) for k, v in data.items(): model = v.get('model') state_name = v.get('state_name') model_tokens = v.get('model_tokens') if model and state_name and isinstance(model_tokens, list): STATE_STORE[(model, state_name)] = { 'state': None, 'model_tokens': model_tokens, } logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}") except Exception as e: logger.info(f"Failed to load state store from disk: {e}") def _save_state_store_to_disk(force=False): global _LAST_STATE_STORE_WRITE now = time.time() if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5): return try: serial = _serialize_state_store() if not serial: return import json tmp = _STATE_STORE_PATH + ".tmp" with open(tmp, 'w', encoding='utf-8') as f: json.dump(serial, f) os.replace(tmp, _STATE_STORE_PATH) _LAST_STATE_STORE_WRITE = now except Exception as e: logger.info(f"Write state store to disk failed: {e}") def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]): """ Recompute the `out` logits and `model_state` by forwarding through tokens in chunks. Returns a tuple (out, model_state). """ ms = MODEL_STORAGE.get(model_name) if not ms or not ms.model: return None, None model_state = None out = None tokens = list(model_tokens) if isinstance(model_tokens, list) else [0] while len(tokens) > 0: out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) tokens = tokens[CONFIG.CHUNK_LEN :] return out, model_state def resolve_request_flags(request, detection): """Resolve effective booleans for web_search, file_tool, tools, and reasoning based on request flags (explicit), sampler overrides, model defaults and detection. Returns dict with keys: web_search_enabled, file_tool_enabled, tools_enabled, reasoning_enabled. """ # Web search web_search_enabled = ( True if (request.enable_web_search is not None and request.enable_web_search) else ( request.web_search or (request.auto_web_search if request.auto_web_search is not None else (getattr(CONFIG, 'AUTO_ENABLE_WEB_SEARCH', True) and detection.get('need_web_search'))) ) ) if not getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True) and request.enable_web_search is None and not (request.web_search or False): web_search_enabled = False try: if request.sampler and getattr(request.sampler, 'ALLOW_WEB_SEARCH', None) is not None: web_search_enabled = bool(request.sampler.ALLOW_WEB_SEARCH) elif hasattr(request, 'sampler_allow_web_search') and request.sampler_allow_web_search is not None: web_search_enabled = bool(request.sampler_allow_web_search) else: ms = MODEL_STORAGE.get(request.model) if ms and ms.MODEL_CONFIG: if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_WEB_SEARCH', None) is not None: web_search_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_WEB_SEARCH) elif hasattr(ms.MODEL_CONFIG, 'ALLOW_WEB_SEARCH') and not ms.MODEL_CONFIG.ALLOW_WEB_SEARCH: web_search_enabled = False except Exception: pass # File tool decision if request.enable_file_tool is not None: file_tool_enabled = bool(request.enable_file_tool) else: auto_file_flag = request.auto_file_tool if request.auto_file_tool is not None else getattr(CONFIG, 'AUTO_ENABLE_TOOLS', True) file_tool_enabled = bool((request.file_ids and len(request.file_ids) > 0) or (auto_file_flag and request.file_ids)) if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None: file_tool_enabled = False try: if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None: file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL) elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None: file_tool_enabled = bool(request.sampler_allow_file_tool) else: ms = MODEL_STORAGE.get(request.model) if ms and ms.MODEL_CONFIG: if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None: file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL) elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL: file_tool_enabled = False except Exception: pass # Tools decision if request.enable_tools is not None: tools_enabled = bool(request.enable_tools) else: auto_tools_flag = request.auto_tools if request.auto_tools is not None else getattr(CONFIG, 'AUTO_ENABLE_TOOLS', True) tools_enabled = bool(request.tools) or getattr(CONFIG, 'ENABLE_TOOLS_BY_DEFAULT', False) or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search'))) try: if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None: tools_enabled = bool(request.sampler.ALLOW_TOOLS) elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None: tools_enabled = bool(request.sampler_allow_tools) else: ms = MODEL_STORAGE.get(request.model) if ms and ms.MODEL_CONFIG: if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None: if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS: tools_enabled = False elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS: tools_enabled = False except Exception: pass # Reasoning decision reasoning_enabled = bool( True if (request.enable_reasoning is not None and request.enable_reasoning) else ( bool(False) or bool(request.auto_reasoning if request.auto_reasoning is not None else (getattr(CONFIG, 'AUTO_ENABLE_REASONING', True) and bool(detection.get('need_reasoning')))) ) ) if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None: reasoning_enabled = False try: if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None: reasoning_enabled = bool(request.sampler.ALLOW_REASONING) elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None: reasoning_enabled = bool(request.sampler_allow_reasoning) else: ms = MODEL_STORAGE.get(request.model) if ms and ms.MODEL_CONFIG: if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None: if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING: reasoning_enabled = False elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING: reasoning_enabled = False except Exception: pass # Also apply model-level disable try: ms = MODEL_STORAGE.get(request.model) if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING: reasoning_enabled = False except Exception: pass return { 'web_search_enabled': web_search_enabled, 'file_tool_enabled': file_tool_enabled, 'tools_enabled': tools_enabled, 'reasoning_enabled': reasoning_enabled, } # Move ChatCompletionRequest definition above fallback apply_model_tags_to_request # Move ChatCompletionRequest definition above fallback apply_model_tags_to_request try: from model_tags import apply_model_tags_to_request_obj as apply_model_tags_to_request except Exception: def apply_model_tags_to_request(req: Any): # Fallback implementation if the module import fails; keep behavior robust if not req or not getattr(req, 'model', None) or ':' not in req.model: return original = req.model parts = [p.strip() for p in original.split(":") if p is not None and p != ""] if len(parts) <= 1: return base = parts[0] tags = parts[1:] req.model = base for tag in tags: t = tag.lower() if t in ("thinking", "think", "reasoning", "reason"): req.enable_reasoning = True req.auto_reasoning = False elif t in ("web", "web_search", "search"): req.enable_web_search = True req.web_search = True req.auto_web_search = False elif t in ("no-web", "disable-web", "no-web-search"): req.enable_web_search = False req.web_search = False elif t in ("tools", "enable-tools"): req.enable_tools = True req.auto_tools = False elif t in ("no-tools", "disable-tools"): req.enable_tools = False elif t in ("file", "file_tool", "filetool"): req.enable_file_tool = True req.auto_file_tool = False elif t in ("no-file", "disable-file"): req.enable_file_tool = False elif t in ("universal", "univ"): req.enable_universal = True req.auto_universal = False elif t in ("stream",): req.stream = True logger.info(f"STRATEGY - {CONFIG.STRATEGY}") logGPUState() def load_models_once(): """Load and initialize configured models into `MODEL_STORAGE`. This is executed once at server startup.""" global DEFALUT_MODEL_NAME, DEFAULT_REASONING_MODEL_NAME logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config") for model_config in CONFIG.MODELS: logger.info(f"Load Model - {model_config.SERVICE_NAME}") if model_config.MODEL_FILE_PATH == None: model_config.MODEL_FILE_PATH = hf_hub_download( repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID), filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME), local_dir=str(model_config.DOWNLOAD_MODEL_DIR), ) logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}") if model_config.DEFAULT_CHAT: if DEFALUT_MODEL_NAME != None: logger.info( f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" ) DEFALUT_MODEL_NAME = model_config.SERVICE_NAME if model_config.DEFAULT_REASONING: if DEFAULT_REASONING_MODEL_NAME != None: logger.info( f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" ) DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`") print(model_config.DEFAULT_SAMPLER) MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( model=model_config.MODEL_FILE_PATH.replace(".pth", ""), strategy=CONFIG.STRATEGY, ) MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB ) if "cuda" in CONFIG.STRATEGY: torch.cuda.empty_cache() gc.collect() logGPUState() logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`") logger.info(f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`") if len(MODEL_STORAGE) == 1: single_name = list(MODEL_STORAGE.keys())[0] if DEFALUT_MODEL_NAME != single_name: DEFALUT_MODEL_NAME = single_name logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`") if DEFAULT_REASONING_MODEL_NAME != single_name: DEFAULT_REASONING_MODEL_NAME = single_name logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`") class ChatCompletionRequest(BaseModel): model: str = Field( default="rwkv-latest", description="Specify the model name. Model tags/suffixes (e.g., ':thinking' or ':web') are not supported — set the corresponding request flags (enable_reasoning, web_search, enable_file_tool) instead.", ) messages: Optional[List[ChatMessage]] = Field(default=None) prompt: Optional[str] = Field(default=None) max_tokens: Optional[int] = Field(default=None) temperature: Optional[float] = Field(default=None) top_p: Optional[float] = Field(default=None) presence_penalty: Optional[float] = Field(default=None) count_penalty: Optional[float] = Field(default=None) penalty_decay: Optional[float] = Field(default=None) stream: Optional[bool] = Field(default=True, description="Whether to stream token-by-token responses. If None, uses CONFIG.DEFAULT_STREAM") state_name: Optional[str] = Field(default=None) include_usage: Optional[bool] = Field(default=False) stop: Optional[list[str]] = Field(["\n\n"]) stop_tokens: Optional[list[int]] = Field([0]) # Note: these defaults are intentionally None so the model may decide # autonomously whether to use web search based on prompt detection unless # the client explicitly sets flags. `auto_web_search` will be consulted # when `enable_web_search` and `web_search` are None. web_search: Optional[bool] = Field(default=None, description="Whether to perform a web search and append results to the prompt; if None, auto-detection is used") enable_web_search: Optional[bool] = Field(default=None, description="Explicitly enable web search (overrides auto/web_search) if set; if None, auto-detection controls it") auto_web_search: Optional[bool] = Field(default=True, description="Whether to enable web_search based on auto-detected intent") enable_tools: Optional[bool] = Field(default=None, description="Explicitly enable tools (overrides auto detection)") auto_tools: Optional[bool] = Field(default=True, description="Whether to enable tools based on auto-detected intent") enable_reasoning: Optional[bool] = Field(default=None, description="Explicitly override reasoning enablement; if None, auto-detection controls it") auto_reasoning: Optional[bool] = Field(default=True, description="Whether to enable reasoning based on auto detection") enable_universal: Optional[bool] = Field(default=None, description="Explicitly enable the universal tool execution") auto_universal: Optional[bool] = Field(default=True, description="Whether to auto enable universal tool execution") search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve") tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})") # Per-request sampler overrides for ALLOW_* flags. These let the user # disable server-side features for this particular request if needed. sampler_allow_web_search: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing web_search") sampler_allow_tools: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing tools") sampler_allow_reasoning: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing reasoning") # Per-request sampler config object; if provided, these settings will # override the model defaults for this request. sampler: Optional[SamplerConfig] = Field(default=None, description="Per-request sampler settings (overrides model default)") # File uploads: allow referencing uploaded files in the request file_ids: Optional[List[str]] = Field(default=None, description="List of uploaded file IDs that the model may use for this request") enable_file_tool: Optional[bool] = Field(default=None, description="Explicitly enable file-based tools for this request") auto_file_tool: Optional[bool] = Field(default=None, description="Auto-detect whether file-based tools are needed") sampler_allow_file_tool: Optional[bool] = Field(default=None, description="Per-request sampler override allowing file tools") @model_validator(mode="before") @classmethod def validate_mutual_exclusivity(cls, data: Any) -> Any: if not isinstance(data, dict): return data messages_provided = "messages" in data and data["messages"] != None prompt_provided = "prompt" in data and data["prompt"] != None if messages_provided and prompt_provided: raise ValueError("messages and prompt cannot coexist. Choose one.") if not messages_provided and not prompt_provided: raise ValueError("Either messages or prompt must be provided.") return data app = FastAPI(title="RWKV OpenAI-Compatible API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) @app.on_event("startup") async def _startup_state_load_and_persist_loop(): # Load previous persisted state (tokens only) at startup _load_state_store_from_disk() # Load configured models once on startup try: load_models_once() except Exception as e: logger.info(f"Model loading at startup failed: {e}") async def _persist_loop(): while True: try: _save_state_store_to_disk(force=False) except Exception: pass await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5)) # Spawn background flush task try: asyncio.create_task(_persist_loop()) except Exception: pass async def runPrefill( request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state ): ctx = ctx.replace("\r\n", "\n") out = None ms = MODEL_STORAGE.get(request.model) if not ms or not ms.pipeline or not ms.model: raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing") tokens = ms.pipeline.encode(ctx) tokens = [int(x) for x in tokens] model_tokens += tokens while len(tokens) > 0: out, model_state = ms.model.forward( tokens[: CONFIG.CHUNK_LEN], model_state ) tokens = tokens[CONFIG.CHUNK_LEN :] await asyncio.sleep(0) return out, model_tokens, model_state def generate( request: ChatCompletionRequest, out, model_tokens: List[int], model_state, max_tokens=2048, ): ms = MODEL_STORAGE.get(request.model) if not ms or not ms.pipeline or not ms.model: raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing") # Enforce sampling configuration limits from CONFIG from config import CONFIG as _CFG temperature = request.temperature if request.temperature is not None else 0.2 temperature = min(max(temperature, getattr(_CFG, 'MIN_TEMPERATURE', 0.0)), getattr(_CFG, 'MAX_TEMPERATURE', 2.0)) top_p = request.top_p if request.top_p is not None else 0.9 top_p = min(max(top_p, getattr(_CFG, 'MIN_TOP_P', 0.0)), getattr(_CFG, 'MAX_TOP_P', 1.0)) alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0 alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0 penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5 args = PIPELINE_ARGS( temperature=max(0.2, temperature), top_p=top_p, alpha_frequency=alpha_frequency, alpha_presence=alpha_presence, token_ban=[], # ban the generation of some tokens token_stop=[0], ) # stop generation whenever you see any token here occurrence = {} out_tokens: List[int] = [] out_last = 0 # Stream token-by-token; each chunk contains a single decoded token string. for i in range(max_tokens): for n in occurrence: out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency # out[0] -= 1e10 # disable END_OF_TEXT token = ms.pipeline.sample_logits( out, temperature=args.temperature, top_p=args.top_p ) if token == 0 and request.stop_tokens and token in request.stop_tokens: yield { "content": "", "tokens": out_tokens[out_last:], "finish_reason": "stop:token:0", "state": model_state, } del out gc.collect() return out, model_state = ms.model.forward([token], model_state) model_tokens.append(token) out_tokens.append(token) if request.stop_tokens and token in request.stop_tokens: yield { "content": "", "tokens": out_tokens[out_last:], "finish_reason": f"stop:token:{token}", "state": model_state, } del out gc.collect() return for xxx in list(occurrence.keys()): occurrence[xxx] *= penalty_decay occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) # Decode token to text and yield it as a single-token chunk decoded = ms.pipeline.decode([token]) # filter out replacement characters if "\ufffd" in decoded: continue yield { "content": decoded, "tokens": [token], "finish_reason": None, "state": model_state, } out_last = i + 1 else: yield { "content": "", "tokens": [], "finish_reason": "length", } async def chatResponse( request: ChatCompletionRequest, model_state: Any, completionId: str, enableReasoning: bool, ) -> ChatCompletion: createTimestamp = time.time() # use global resolve_request_flags def decide_file_tool_enabled(request): return resolve_request_flags(request, detection)['file_tool_enabled'] def decide_tools_enabled(request, detection): return resolve_request_flags(request, detection)['tools_enabled'] def decide_reasoning_enabled(request, detection, enableReasoning): # resolve_request_flags ignores the enableReasoning baseline, so compute flags then flags = resolve_request_flags(request, detection) return flags['reasoning_enabled'] def execute_tools(request, detection, prompt, executed_tool_calls, web_search_enabled, tools_enabled, file_tool_enabled, raw_prompt): """Helper to execute tools and update prompt and executed_tool_calls.""" # If file tools are enabled and files are attached, inject them into the prompt (for streaming) if file_tool_enabled and request.file_ids: for fid in request.file_ids: try: if fid not in UPLOADED_FILES: continue meta = UPLOADED_FILES.get(fid) if not meta: continue from utils import file_read_from_path fpath = meta.get('path') if not fpath or not os.path.exists(fpath): continue file_content = file_read_from_path(fpath, 200000) if file_content: exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}} executed_tool_calls.append(exec_entry) prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt) except Exception as e: logger.info(f"File injection error: {e}") # Only a single injection loop above is needed; duplicate removed. if request.tools: try: for tool in request.tools: name = tool.get('name') args = tool.get('args', {}) if name == 'web_search': from utils import web_search search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) search_top_k = int(args.get('top_k') or request.search_top_k or 3) search_str = web_search(search_q, search_top_k) if search_str: search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}} executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct}) prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt) elif name == 'calc' or name == 'calculator': from utils import calc expr = args.get('expression') if expr: calc_res = calc(expr) calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}} executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct}) prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt) elif name == 'universal': try: res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) if isinstance(res, dict): result_text = res.get('result') if res.get('result') is not None else '' else: result_text = str(res) executed_tool_calls.append({"name": "universal", "args": args, "result": res}) prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt) except Exception as e: logger.info(f"Universal tool execution error: {e}") else: logger.info(f"Unsupported tool requested: {name}") # ...existing code for other tools (fetch_url, summarize, keywords, etc.)... # For brevity, keep the rest as is, or further refactor if needed. except Exception as e: logger.info(f"Tool processing error: {e}") elif request.web_search or web_search_enabled: try: from utils import web_search search_q = request.prompt if request.prompt else cleanMessages(request.messages or []) search_res = web_search(search_q, int(request.search_top_k or 3)) if search_res: search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}} executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct}) prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt except Exception: pass return prompt, executed_tool_calls # Build raw prompt for detection (prefer explicit request.prompt, else messages) raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or []) detection = detect_tools_and_reasoning(raw_prompt) prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' auto-detect) flags = resolve_request_flags(request, detection) web_search_enabled = flags['web_search_enabled'] tools_enabled = flags['tools_enabled'] file_tool_enabled = flags['file_tool_enabled'] reasoning_enabled = flags['reasoning_enabled'] enableReasoning = reasoning_enabled try: ms_cfg = MODEL_STORAGE.get(request.model) if ms_cfg and ms_cfg.MODEL_CONFIG and hasattr(ms_cfg.MODEL_CONFIG, 'ALLOW_REASONING') and not ms_cfg.MODEL_CONFIG.ALLOW_REASONING: enableReasoning = False except Exception: pass # file_tool_enabled is derived from resolve_request_flags as well # Build final prompt after deciding enableReasoning prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' 0 and r_dict['choices'][0].get('delta') is not None: r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls except Exception: pass yield f"data: {r_dict}\n\n" buffer = [] if enableReasoning: buffer.append(" tag "fullTextCursor": 0, "in_think": False, "cacheStr": "", } max_gen_tokens = ( getattr(CONFIG, 'MAX_GENERATION_TOKENS_LIMIT', 64000) if "max_tokens" not in request.model_fields_set and enableReasoning else (request.max_tokens or 2048) ) max_tokens_limit = getattr(CONFIG, 'MAX_TOKENS_PER_REQUEST', None) if max_tokens_limit: max_gen_tokens = min(max_gen_tokens, max_tokens_limit) for chunk in generate(request, out, model_tokens, model_state, max_tokens=max_gen_tokens): completionTokenCount += 1 # Each token stream is delivered as a decoded character/bytes (maybe 1 or more chars) chunkContent: str = chunk["content"] buffer.append(chunkContent) fullText = "".join(buffer) if chunk["finish_reason"]: finishReason = chunk["finish_reason"] response = ChatCompletionChunk( id=completionId, created=createTimestamp, model=request.model, usage=( Usage( prompt_tokens=promptTokenCount, completion_tokens=completionTokenCount, total_tokens=promptTokenCount + completionTokenCount, prompt_tokens_details=PromptTokensDetails(cached_tokens=0), ) if request.include_usage else None ), choices=[ ChatCompletionChoice( index=0, delta=ChatCompletionMessage( role="Assistant", content=None, reasoning_content=None, tool_calls=None, ), logprobs=None, finish_reason=finishReason, ) ], ) if response.choices and response.choices[0].delta is None: response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) markStart = fullText.find("<", streamConfig["fullTextCursor"]) if not streamConfig["isChecking"] and markStart != -1: streamConfig["isChecking"] = True if streamConfig["in_think"]: delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.reasoning_content = fullText[streamConfig["fullTextCursor"] : markStart] else: delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.content = fullText[streamConfig["fullTextCursor"] : markStart] streamConfig["cacheStr"] = "" streamConfig["fullTextCursor"] = markStart if streamConfig["isChecking"]: streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :] else: if streamConfig["in_think"]: delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.reasoning_content = chunkContent else: delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.content = chunkContent streamConfig["fullTextCursor"] = len(fullText) markEnd = fullText.find(">", streamConfig["fullTextCursor"]) if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None: streamConfig["isChecking"] = False if ( not streamConfig["in_think"] and streamConfig["cacheStr"].find("") != -1 ): streamConfig["in_think"] = True delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.reasoning_content = ( delta.reasoning_content if delta.reasoning_content != None else "" + streamConfig["cacheStr"].replace("", "") ) elif ( streamConfig["in_think"] and streamConfig["cacheStr"].find("") != -1 ): streamConfig["in_think"] = False delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.content = ( delta.content if delta.content != None else "" + streamConfig["cacheStr"].replace("", "") ) else: if streamConfig["in_think"]: delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.reasoning_content = ( delta.reasoning_content if delta.reasoning_content != None else "" + streamConfig["cacheStr"] ) else: delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta delta.content = ( delta.content if delta.content != None else "" + streamConfig["cacheStr"] ) streamConfig["fullTextCursor"] = len(fullText) delta = response.choices[0].delta if delta is None: delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) response.choices[0].delta = delta if delta.content != None or delta.reasoning_content != None: # Save model state frequently (after each token) to allow resuming try: if request.state_name: STATE_STORE[(request.model, request.state_name)] = { 'state': model_state, 'model_tokens': model_tokens, } if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False): try: _save_state_store_to_disk(force=True) except Exception: pass except Exception: pass # model-initiated tool call detection if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS: m = TOOL_CALL_RE.search(fullText) if m: try: payload_raw = m.group(1) import json payload = json.loads(payload_raw) tool_name = payload.get('name') tool_args = payload.get('args', {}) tool_res = None if tool_name == 'web_search': from utils import web_search q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) k = int(tool_args.get('top_k') or request.search_top_k or 3) tool_res = web_search(q, k) elif tool_name in ('calc', 'calculator'): from utils import calc expr = tool_args.get('expression') if expr: tool_res = calc(expr) else: try: tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) except Exception: tool_res = None if tool_res: # Normalize tool_res into a structured dict if needed if not isinstance(tool_res, dict): if tool_name in ('calc', 'calculator'): tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}} elif tool_name == 'web_search': tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}} else: tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}} else: tool_res_struct = tool_res exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True} executed_tool_calls.append(exec_entry) delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n" prompt = delta_text + prompt fullText = TOOL_CALL_RE.sub('', fullText) buffer = [fullText] out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state) model_initiated_tool_calls += 1 should_restart = True break except Exception as e: logger.info(f"Model-initiated tool handling error: {e}") yield f"data: {response.model_dump_json()}\n\n" # check stop sequences and stop streaming if we see them for stop_words in request.stop or []: if stop_words in ''.join(buffer): finishReason = f"stop:words:{stop_words}" return await asyncio.sleep(0) del streamConfig else: should_restart = True while should_restart: should_restart = False gen = generate(request, out, model_tokens, model_state) for chunk in gen: completionTokenCount += 1 buffer.append(chunk["content"]) if chunk["finish_reason"]: finishReason = chunk["finish_reason"] # Save model state frequently (after each token) to allow resuming try: if request.state_name: STATE_STORE[(request.model, request.state_name)] = { 'state': model_state, 'model_tokens': model_tokens, } if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False): try: _save_state_store_to_disk(force=True) except Exception: pass except Exception: pass # Detect model-initiated tool calls if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS: fullText = ''.join(buffer) m = TOOL_CALL_RE.search(fullText) if m: try: payload_raw = m.group(1) import json payload = json.loads(payload_raw) tool_name = payload.get('name') tool_args = payload.get('args', {}) tool_res = None if tool_name == 'web_search': from utils import web_search q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) k = int(tool_args.get('top_k') or request.search_top_k or 3) tool_res = web_search(q, k) elif tool_name in ('calc', 'calculator'): from utils import calc expr = tool_args.get('expression') if expr: tool_res = calc(expr) else: try: tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) except Exception: tool_res = None if tool_res: if not isinstance(tool_res, dict): if tool_name in ('calc', 'calculator'): tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}} elif tool_name == 'web_search': tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}} else: tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}} else: tool_res_struct = tool_res exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True} executed_tool_calls.append(exec_entry) delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n" prompt = delta_text + prompt fullText = TOOL_CALL_RE.sub('', fullText) buffer = [fullText] out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state) # Notify client that a tool was called mid-stream (metadata-only chunk) try: meta_resp = ChatCompletionChunk( id=completionId, created=createTimestamp, model=request.model, usage=( Usage( prompt_tokens=promptTokenCount, completion_tokens=completionTokenCount, total_tokens=promptTokenCount + completionTokenCount, prompt_tokens_details=PromptTokensDetails(cached_tokens=0), ) if request.include_usage else None ), choices=[ ChatCompletionChoice( index=0, delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls), logprobs=None, finish_reason=None, ) ], ) yield f"data: {meta_resp.model_dump_json()}\n\n" except Exception: pass model_initiated_tool_calls += 1 should_restart = True break except Exception as e: logger.info(f"Model-initiated tool handling error: {e}") response = ChatCompletionChunk( id=completionId, created=createTimestamp, model=request.model, usage=( Usage( prompt_tokens=promptTokenCount, completion_tokens=completionTokenCount, total_tokens=promptTokenCount + completionTokenCount, prompt_tokens_details=PromptTokensDetails(cached_tokens=0), ) if request.include_usage else None ), choices=[ ChatCompletionChoice( index=0, delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None), logprobs=None, finish_reason=finishReason, ) ], ) yield f"data: {response.model_dump_json()}\n\n" await asyncio.sleep(0) genenrateTime = time.time() responseLog = { "content": "".join(buffer), "finish": finishReason, "prefill_len": promptTokenCount, "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), "gen_len": completionTokenCount, "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), } logger.info(f"[RES] {completionId} - {responseLog}") # Apply bias mitigation to the final content for streaming mode try: from utils import bias_mitigation content_for_mitigation = responseLog.get('content') if content_for_mitigation is None: content_for_mitigation = "" mitigation = bias_mitigation(content_for_mitigation) if mitigation and isinstance(mitigation, dict) and mitigation.get('suppressed'): executed_tool_calls.append({"name": "safety_mitigation", "args": {}, "result": {"action": "safety", "result": mitigation.get('sanitized'), "metadata": {"reason": mitigation.get('reason')}}}) responseLog['content'] = mitigation.get('sanitized') except Exception: pass if request.messages is None: request.messages = [] # Ensure responseLog['content'] is a string content_str = responseLog["content"] if responseLog["content"] is not None else "" request.messages.append(ChatMessage(role="Assistant", content=content_str)) log( { **request.model_dump(), **responseLog, "completionId": completionId, "machineLabel": os.environ.get("MACHINE_LABEL"), } ) del buffer yield "data: [DONE]\n\n" @app.post("/api/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): completionId = str(next(CompletionIdGenerator)) logger.info(f"[REQ] {completionId} - {request.model_dump()}") # Apply any legacy model suffix tags (e.g., 'rwkv-latest:thinking' -> enable_reasoning) # This helper is defined at module level so it can be unit-tested and reused. # Apply legacy tags (if present) to request and proceed normally apply_model_tags_to_request(request) modelName = request.model if request.model == "rwkv-latest": # Map to the default chat model in all cases. Do not redirect to a separate # reasoning model; the same model will be used and reasoning is handled in-process. if DEFALUT_MODEL_NAME == None: raise HTTPException(404, "DEFALUT_MODEL_NAME not set") ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME) if not ms_def or not ms_def.MODEL_CONFIG: raise HTTPException(500, "Default sampler config missing for default model") defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER request.model = DEFALUT_MODEL_NAME elif modelName in MODEL_STORAGE: ms_sel = MODEL_STORAGE.get(modelName) if not ms_sel or not ms_sel.MODEL_CONFIG: raise HTTPException(500, f"Default sampler config missing for model {modelName}") defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER request.model = modelName else: raise HTTPException(404, f"Can not find `{modelName}`") # Baseline enableReasoning: prefer explicit request flag if present, else False. chatResponse will recompute with auto-detection. enableReasoning = bool(request.enable_reasoning) if request.enable_reasoning is not None else False async def chatResponseStreamDisconnect(): logGPUState() # Load or initialize model_state and tokens based on state_name model_state = None model_tokens_for_resume = [0] state_name = request.state_name if state_name is None: state_name = str(uuid.uuid4()) request.state_name = state_name state_key = (request.model, state_name) if state_key in STATE_STORE: stored = STATE_STORE[state_key] model_state = stored.get('state', None) model_tokens_for_resume = stored.get('model_tokens', [0]) request_dict = request.model_dump() # Apply defaults from model's DEFAULT_SAMPLER, optionally overridden by the # per-request `sampler` object (or legacy sampler_allow_* booleans). sampler_overrides = request_dict.get('sampler') or {} for k, v in defaultSamplerConfig.model_dump().items(): # If the request provided a sampler override for this field, use it if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None: request_dict[k] = sampler_overrides.get(k) continue if k in request_dict and request_dict[k] is None: request_dict[k] = v realRequest = ChatCompletionRequest(**request_dict) # Ensure stream defaults to configuration value when not explicitly provided if realRequest.stream is None: realRequest.stream = CONFIG.DEFAULT_STREAM # Enforce top-level numeric limits on the realRequest values try: max_tokens_limit = getattr(CONFIG, 'MAX_TOKENS_PER_REQUEST', None) if realRequest.max_tokens is not None and max_tokens_limit is not None: realRequest.max_tokens = min(realRequest.max_tokens, max_tokens_limit) if realRequest.temperature is not None: realRequest.temperature = min(max(realRequest.temperature, getattr(CONFIG, 'MIN_TEMPERATURE', 0.0)), getattr(CONFIG, 'MAX_TEMPERATURE', 2.0)) if realRequest.top_p is not None: realRequest.top_p = min(max(realRequest.top_p, getattr(CONFIG, 'MIN_TOP_P', 0.0)), getattr(CONFIG, 'MAX_TOP_P', 1.0)) except Exception: pass logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}") if realRequest.stream: r = StreamingResponse( chatResponseStream(realRequest, model_state, completionId, enableReasoning), media_type="text/event-stream", background=BackgroundTask(chatResponseStreamDisconnect), ) else: r = await chatResponse(realRequest, model_state, completionId, enableReasoning) # Attach state_name to non-streaming response as additional metadata try: import json if isinstance(r, ChatCompletion): d = r.model_dump() d['state_name'] = state_name return d except Exception: pass return r # We keep the service API-only by default. If a local `dist-frontend` directory # exists (a built frontend), mount it at `/` so the app can serve a static UI. if os.path.isdir("dist-frontend"): logger.info("Static frontend mount enabled: serving dist-frontend at /") app.mount("/", StaticFiles(directory="dist-frontend", html=True), name="static") else: logger.info("Static frontend mount not enabled; `dist-frontend` directory not found") @app.get('/api/v1/models') def list_models(): """Return model configuration summary for clients/UI. This endpoint returns configured models, their default sampler values, and ALLOW_* flags so UI clients can build a controls surface based on server capabilities (web search, tools, reasoning). """ out = [] root_defaults = { 'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True), 'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True), 'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True), 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True), 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True), 'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True), 'UPLOAD_URL': '/api/v1/files' if getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False) else None, 'ALLOW_PUBLIC_UPLOADS': getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False), } for m in CONFIG.MODELS: out.append( { 'SERVICE_NAME': m.SERVICE_NAME, 'DEFAULT_CHAT': m.DEFAULT_CHAT, 'DEFAULT_REASONING': m.DEFAULT_REASONING, 'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True), 'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True), 'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True), 'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True), 'ALLOW_FETCH_URL': getattr(m, 'ALLOW_FETCH_URL', True), 'ALLOW_SUMMARIZE': getattr(m, 'ALLOW_SUMMARIZE', True), 'ALLOW_KEYWORDS': getattr(m, 'ALLOW_KEYWORDS', True), 'ALLOW_SENTIMENT': getattr(m, 'ALLOW_SENTIMENT', True), 'ALLOW_TRANSLATE': getattr(m, 'ALLOW_TRANSLATE', True), 'ALLOW_SPELL_CHECK': getattr(m, 'ALLOW_SPELL_CHECK', True), 'ALLOW_FORMAT_CODE': getattr(m, 'ALLOW_FORMAT_CODE', True), 'ALLOW_EXPLAIN_CODE': getattr(m, 'ALLOW_EXPLAIN_CODE', True), 'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True), 'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True), 'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True), 'SHOW_FETCH_URL_BUTTON': getattr(m, 'SHOW_FETCH_URL_BUTTON', True), 'SHOW_SUMMARIZE_BUTTON': getattr(m, 'SHOW_SUMMARIZE_BUTTON', True), 'SHOW_KEYWORDS_BUTTON': getattr(m, 'SHOW_KEYWORDS_BUTTON', True), 'SHOW_SENTIMENT_BUTTON': getattr(m, 'SHOW_SENTIMENT_BUTTON', True), 'SHOW_TRANSLATE_BUTTON': getattr(m, 'SHOW_TRANSLATE_BUTTON', True), 'SHOW_SPELL_CHECK_BUTTON': getattr(m, 'SHOW_SPELL_CHECK_BUTTON', True), 'SHOW_FORMAT_CODE_BUTTON': getattr(m, 'SHOW_FORMAT_CODE_BUTTON', True), 'SHOW_EXPLAIN_CODE_BUTTON': getattr(m, 'SHOW_EXPLAIN_CODE_BUTTON', True), 'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None, # Convenience info for clients: upload endpoint and root defaults 'UPLOAD_URL': '/api/v1/files' if getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False) else None, 'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True), } ) return {'root_defaults': root_defaults, 'models': out} def upload_file_internal(bytes_data: bytes, filename: Optional[str] = None, model: Optional[str] = None) -> dict: """Internal function to register uploaded file into the service memory. This is NOT exposed as a public HTTP endpoint by default; it's intended to be used by the server's `universal_tool` or admin-only utilities. Returns saved metadata or raises Exception. """ if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True): raise HTTPException(403, 'File uploads are disabled by server configuration') if model: if model not in MODEL_STORAGE: raise HTTPException(404, f"Model {model} not found") ms = MODEL_STORAGE[model] if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True): raise HTTPException(403, f"Model {model} does not allow file uploads") from utils import save_bytes_to_upload meta = save_bytes_to_upload(filename or 'uploaded_file', bytes_data) if meta.get('error'): raise Exception(meta.get('error')) UPLOADED_FILES[meta['file_id']] = meta return meta def list_files_internal(): return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()] def get_file_internal(file_id: str, download: bool = False): if file_id not in UPLOADED_FILES: raise HTTPException(404, 'File not found') meta = UPLOADED_FILES[file_id] if download: try: with open(meta['path'], 'rb') as f: return StreamingResponse(f, media_type='application/octet-stream') except Exception as e: raise HTTPException(500, str(e)) return UploadedFile(**meta) def delete_file_internal(file_id: str): if file_id not in UPLOADED_FILES: raise HTTPException(404, 'File not found') meta = UPLOADED_FILES.pop(file_id) try: if os.path.exists(meta['path']): os.remove(meta['path']) except Exception: pass return {'success': True} # If public uploads are allowed, add the public endpoints as wrappers to internal helpers if getattr(CONFIG, 'ALLOW_PUBLIC_UPLOADS', False): @app.post('/api/v1/files', response_model=FileUploadResponse) async def upload_file_public(file: UploadFile = File(...), model: Optional[str] = None): try: content = await file.read() # enforce size limit max_upload_size = getattr(CONFIG, 'MAX_UPLOAD_SIZE_BYTES', None) if max_upload_size is not None and len(content) > max_upload_size: raise HTTPException(413, 'File too large') fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file' meta = upload_file_internal(content, filename=fname, model=model) return FileUploadResponse(success=True, file=UploadedFile(**meta)) except HTTPException: raise except Exception as e: raise HTTPException(500, str(e)) @app.get('/api/v1/files') def list_files_public(): return list_files_internal() @app.get('/api/v1/files/{file_id}') def get_file_public(file_id: str, download: bool = False): return get_file_internal(file_id, download=download) @app.delete('/api/v1/files/{file_id}') def delete_file_public(file_id: str): return delete_file_internal(file_id) if __name__ == "__main__": import uvicorn host = CONFIG.HOST or "127.0.0.1" port = CONFIG.PORT or 7860 uvicorn.run(app, host=host, port=port)