diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,1716 +1,1599 @@ -import os - -if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": - from modelscope import patch_hub - - patch_hub() - -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" - - -from config import CONFIG, ModelConfig -from utils import ( - cleanMessages, - parse_think_response, - remove_nested_think_tags_stack, - format_bytes, - log, - detect_tools_and_reasoning, - universal_tool, -) - -import copy, types, gc, sys, re, time, collections, asyncio -from huggingface_hub import hf_hub_download -from loguru import logger -from rich import print - -from snowflake import SnowflakeGenerator - -CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) - -from typing import List, Optional, Union, Any, Dict -import uuid -from pydantic import BaseModel, Field, model_validator -from pydantic_settings import BaseSettings - - -import numpy as np -import torch - - -if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): - logger.info(f"CUDA not found, fall back to cpu") - CONFIG.STRATEGY = "cpu fp16" -# Normalize STRATEGY to include precision if missing (e.g., 'cpu' -> 'cpu fp16') -_s = CONFIG.STRATEGY.lower() -if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s): - logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`") - CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16" - - -try: - from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo -except Exception: - nvmlInit = None - nvmlDeviceGetHandleByIndex = None - nvmlDeviceGetMemoryInfo = None - -if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None: - nvmlInit() - gpu_h = nvmlDeviceGetHandleByIndex(0) - - -def logGPUState(): - if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None: - gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) - logger.info( - f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}" - ) - - -torch.backends.cudnn.benchmark = True -torch.backends.cudnn.allow_tf32 = True -torch.backends.cuda.matmul.allow_tf32 = True -os.environ["RWKV_V7_ON"] = "1" # enable this for rwkv-7 models -os.environ["RWKV_JIT_ON"] = "1" -os.environ["RWKV_CUDA_ON"] = ( - "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0" -) - -from rwkv.model import RWKV -from rwkv.utils import PIPELINE, PIPELINE_ARGS - -from fastapi import FastAPI, HTTPException, UploadFile, File -from starlette.background import BackgroundTask -from fastapi.responses import StreamingResponse -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from fastapi.middleware.gzip import GZipMiddleware - - -from api_types import ( - ChatMessage, - ChatCompletion, - ChatCompletionChunk, - Usage, - PromptTokensDetails, - ChatCompletionChoice, - ChatCompletionMessage, - SamplerConfig, - UploadedFile, - FileUploadResponse, -) - - -class ModelStorage: - MODEL_CONFIG: Optional[ModelConfig] = None - model: Optional[RWKV] = None - pipeline: Optional[PIPELINE] = None - - -MODEL_STORAGE: Dict[str, ModelStorage] = {} - -DEFALUT_MODEL_NAME = None -DEFAULT_REASONING_MODEL_NAME = None - -# In-memory model state store to support streaming continuation/resume per state_name. -# Keys: (model_name, state_name) -> dict with 'state' and 'model_tokens' -STATE_STORE: Dict[tuple, Any] = {} - -# Serialized state store file path and flush interval defined in CONFIG -_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json') -_LAST_STATE_STORE_WRITE = 0 - -# sentinel for model-initiated tool calls: {json} -TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S) - -# File uploads: simple in-memory index (persisted on disk via the files themselves) -UPLOADED_FILES: Dict[str, dict] = {} - - -def _serialize_state_store() -> dict: - # Save only model_tokens to disk; model_state (torch objects) are not serializable - serial = {} - for (model_name, state_name), entry in STATE_STORE.items(): - try: - mt = entry.get('model_tokens') if isinstance(entry, dict) else None - if mt is None: - # if entry is a raw model_state, skip - continue - serial[f"{model_name}|{state_name}"] = { - 'model': model_name, - 'state_name': state_name, - 'model_tokens': mt, - } - except Exception: - continue - return serial - - -def _load_state_store_from_disk(): - global STATE_STORE - try: - if os.path.exists(_STATE_STORE_PATH): - import json - - with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f: - data = json.load(f) - for k, v in data.items(): - model = v.get('model') - state_name = v.get('state_name') - model_tokens = v.get('model_tokens') - if model and state_name and isinstance(model_tokens, list): - STATE_STORE[(model, state_name)] = { - 'state': None, - 'model_tokens': model_tokens, - } - logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}") - except Exception as e: - logger.info(f"Failed to load state store from disk: {e}") - - -def _save_state_store_to_disk(force=False): - global _LAST_STATE_STORE_WRITE - now = time.time() - if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5): - return - try: - serial = _serialize_state_store() - if not serial: - return - import json - tmp = _STATE_STORE_PATH + ".tmp" - with open(tmp, 'w', encoding='utf-8') as f: - json.dump(serial, f) - os.replace(tmp, _STATE_STORE_PATH) - _LAST_STATE_STORE_WRITE = now - except Exception as e: - logger.info(f"Write state store to disk failed: {e}") - - -def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]): - """ - Recompute the `out` logits and `model_state` by forwarding through tokens in chunks. - Returns a tuple (out, model_state). - """ - ms = MODEL_STORAGE.get(model_name) - if not ms or not ms.model: - return None, None - model_state = None - out = None - tokens = list(model_tokens) if isinstance(model_tokens, list) else [0] - while len(tokens) > 0: - out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) - tokens = tokens[CONFIG.CHUNK_LEN :] - return out, model_state - -logger.info(f"STRATEGY - {CONFIG.STRATEGY}") - -logGPUState() - -# Keep any configured models intact; do not force selection by name/size. -# The previous policy enforced a single '0.1b' model which hid additional configs; use the full list. -logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config") - -for model_config in CONFIG.MODELS: - logger.info(f"Load Model - {model_config.SERVICE_NAME}") - - if model_config.MODEL_FILE_PATH == None: - model_config.MODEL_FILE_PATH = hf_hub_download( - repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID), - filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME), - local_dir=str(model_config.DOWNLOAD_MODEL_DIR), - ) - logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}") - - if model_config.DEFAULT_CHAT: - if DEFALUT_MODEL_NAME != None: - logger.info( - f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" - ) - DEFALUT_MODEL_NAME = model_config.SERVICE_NAME - - if model_config.DEFAULT_REASONING: - if DEFAULT_REASONING_MODEL_NAME != None: - logger.info( - f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" - ) - DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME - - logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`") - print(model_config.DEFAULT_SAMPLER) - - MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() - MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config - MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( - model=model_config.MODEL_FILE_PATH.replace(".pth", ""), - strategy=CONFIG.STRATEGY, - ) - MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( - MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB - ) - if "cuda" in CONFIG.STRATEGY: - torch.cuda.empty_cache() - gc.collect() - logGPUState() - - -logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`") -logger.info( - f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`" -) - -if len(MODEL_STORAGE) == 1: - single_name = list(MODEL_STORAGE.keys())[0] - if DEFALUT_MODEL_NAME != single_name: - DEFALUT_MODEL_NAME = single_name - logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`") - if DEFAULT_REASONING_MODEL_NAME != single_name: - DEFAULT_REASONING_MODEL_NAME = single_name - logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`") - - -class ChatCompletionRequest(BaseModel): - model: str = Field( - default="rwkv-latest", - description="Add `:thinking` suffix to the model name to enable reasoning. Example: `rwkv-latest:thinking`", - ) - messages: Optional[List[ChatMessage]] = Field(default=None) - prompt: Optional[str] = Field(default=None) - max_tokens: Optional[int] = Field(default=None) - temperature: Optional[float] = Field(default=None) - top_p: Optional[float] = Field(default=None) - presence_penalty: Optional[float] = Field(default=None) - count_penalty: Optional[float] = Field(default=None) - penalty_decay: Optional[float] = Field(default=None) - stream: Optional[bool] = Field(default=True, description="Whether to stream token-by-token responses. If None, uses CONFIG.DEFAULT_STREAM") - state_name: Optional[str] = Field(default=None) - include_usage: Optional[bool] = Field(default=False) - stop: Optional[list[str]] = Field(["\n\n"]) - stop_tokens: Optional[list[int]] = Field([0]) - web_search: Optional[bool] = Field(default=True, description="Whether to perform a web search and append results to the prompt") - enable_web_search: Optional[bool] = Field(default=True, description="Explicitly enable web search (overrides auto/web_search) if set") - auto_web_search: Optional[bool] = Field(default=True, description="Whether to enable web_search based on auto-detected intent") - enable_tools: Optional[bool] = Field(default=None, description="Explicitly enable tools (overrides auto detection)") - auto_tools: Optional[bool] = Field(default=True, description="Whether to enable tools based on auto-detected intent") - enable_reasoning: Optional[bool] = Field(default=True, description="Explicitly override reasoning enablement") - auto_reasoning: Optional[bool] = Field(default=True, description="Whether to enable reasoning based on auto detection") - enable_universal: Optional[bool] = Field(default=None, description="Explicitly enable the universal tool execution") - auto_universal: Optional[bool] = Field(default=True, description="Whether to auto enable universal tool execution") - search_top_k: Optional[int] = Field(default=3, description="Number of web search results to retrieve") - tools: Optional[List[Dict[str, Any]]] = Field(default=None, description="List of tools to execute server-side (e.g., {'name':'web_search','args':{'query':'x'}})") - # Per-request sampler overrides for ALLOW_* flags. These let the user - # disable server-side features for this particular request if needed. - sampler_allow_web_search: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing web_search") - sampler_allow_tools: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing tools") - sampler_allow_reasoning: Optional[bool] = Field(default=None, description="Per-request (sampler) override allowing reasoning") - # Per-request sampler config object; if provided, these settings will - # override the model defaults for this request. - sampler: Optional[SamplerConfig] = Field(default=None, description="Per-request sampler settings (overrides model default)") - # File uploads: allow referencing uploaded files in the request - file_ids: Optional[List[str]] = Field(default=None, description="List of uploaded file IDs that the model may use for this request") - enable_file_tool: Optional[bool] = Field(default=None, description="Explicitly enable file-based tools for this request") - auto_file_tool: Optional[bool] = Field(default=None, description="Auto-detect whether file-based tools are needed") - sampler_allow_file_tool: Optional[bool] = Field(default=None, description="Per-request sampler override allowing file tools") - - @model_validator(mode="before") - @classmethod - def validate_mutual_exclusivity(cls, data: Any) -> Any: - if not isinstance(data, dict): - return data - - messages_provided = "messages" in data and data["messages"] != None - prompt_provided = "prompt" in data and data["prompt"] != None - - if messages_provided and prompt_provided: - raise ValueError("messages and prompt cannot coexist. Choose one.") - if not messages_provided and not prompt_provided: - raise ValueError("Either messages or prompt must be provided.") - return data - - -app = FastAPI(title="RWKV OpenAI-Compatible API") - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) - - -@app.on_event("startup") -async def _startup_state_load_and_persist_loop(): - # Load previous persisted state (tokens only) at startup - _load_state_store_from_disk() - - async def _persist_loop(): - while True: - try: - _save_state_store_to_disk(force=False) - except Exception: - pass - await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5)) - - # Spawn background flush task - try: - asyncio.create_task(_persist_loop()) - except Exception: - pass - - -async def runPrefill( - request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state -): - ctx = ctx.replace("\r\n", "\n") - out = None - - ms = MODEL_STORAGE.get(request.model) - if not ms or not ms.pipeline or not ms.model: - raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing") - tokens = ms.pipeline.encode(ctx) - tokens = [int(x) for x in tokens] - model_tokens += tokens - - while len(tokens) > 0: - out, model_state = ms.model.forward( - tokens[: CONFIG.CHUNK_LEN], model_state - ) - tokens = tokens[CONFIG.CHUNK_LEN :] - await asyncio.sleep(0) - - return out, model_tokens, model_state - - -def generate( - request: ChatCompletionRequest, - out, - model_tokens: List[int], - model_state, - max_tokens=2048, -): - ms = MODEL_STORAGE.get(request.model) - if not ms or not ms.pipeline or not ms.model: - raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing") - - temperature = request.temperature if request.temperature is not None else 0.2 - top_p = request.top_p if request.top_p is not None else 0.9 - alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0 - alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0 - penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5 - - args = PIPELINE_ARGS( - temperature=max(0.2, temperature), - top_p=top_p, - alpha_frequency=alpha_frequency, - alpha_presence=alpha_presence, - token_ban=[], # ban the generation of some tokens - token_stop=[0], - ) # stop generation whenever you see any token here - - occurrence = {} - out_tokens: List[int] = [] - out_last = 0 - - # Stream token-by-token; each chunk contains a single decoded token string. - - for i in range(max_tokens): - for n in occurrence: - out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency - # out[0] -= 1e10 # disable END_OF_TEXT - - token = ms.pipeline.sample_logits( - out, temperature=args.temperature, top_p=args.top_p - ) - - if token == 0 and request.stop_tokens and token in request.stop_tokens: - yield { - "content": "", - "tokens": out_tokens[out_last:], - "finish_reason": "stop:token:0", - "state": model_state, - } - - del out - gc.collect() - return - - out, model_state = ms.model.forward([token], model_state) - model_tokens.append(token) - out_tokens.append(token) - - if request.stop_tokens and token in request.stop_tokens: - yield { - "content": "", - "tokens": out_tokens[out_last:], - "finish_reason": f"stop:token:{token}", - "state": model_state, - } - - del out - gc.collect() - return - - for xxx in list(occurrence.keys()): - occurrence[xxx] *= penalty_decay - occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) - - # Decode token to text and yield it as a single-token chunk - decoded = ms.pipeline.decode([token]) - # filter out replacement characters - if "\ufffd" in decoded: - continue - - yield { - "content": decoded, - "tokens": [token], - "finish_reason": None, - "state": model_state, - } - out_last = i + 1 - - else: - yield { - "content": "", - "tokens": [], - "finish_reason": "length", - } - - -async def chatResponse( - request: ChatCompletionRequest, - model_state: Any, - completionId: str, - enableReasoning: bool, -) -> ChatCompletion: - createTimestamp = time.time() - - # Build raw prompt for detection (prefer explicit request.prompt, else messages) - raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or []) - # Intent detection: analyze raw_prompt or messages to auto-activate tools/web-search/reasoning - detection = detect_tools_and_reasoning(raw_prompt) - # After computing auto flags, build the actual prompt string to include if needed - prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' 0) or (auto_file_flag and request.file_ids)) - # Respect root-level defaults - if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None: - file_tool_enabled = False - # Per-request sampler overrides - try: - if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None: - file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL) - elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None: - file_tool_enabled = bool(request.sampler_allow_file_tool) - else: - ms = MODEL_STORAGE.get(request.model) - if ms and ms.MODEL_CONFIG: - if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None: - file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL) - elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL: - file_tool_enabled = False - except Exception: - pass - - # Decide whether tools should be used - if request.enable_tools is not None: - tools_enabled = bool(request.enable_tools) - else: - # if explicit tools provided, or enable by default config, or auto detection suggests - auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS - tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search'))) - # Respect sampler-level override (request.sampler.ALLOW_TOOLS), then - # request.sampler_allow_tools, then sampler default and finally model-level allow - try: - if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None: - tools_enabled = bool(request.sampler.ALLOW_TOOLS) - elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None: - tools_enabled = bool(request.sampler_allow_tools) - else: - ms = MODEL_STORAGE.get(request.model) - if ms and ms.MODEL_CONFIG: - if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None: - if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS: - tools_enabled = False - elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS: - tools_enabled = False - except Exception: - pass - - # Decide whether reasoning should be enabled (in addition to :thinking or explicit) - reasoning_enabled = bool( - True - if (request.enable_reasoning is not None and request.enable_reasoning) - else ( - bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning')))) - ) - ) - # If the root config sets reasoning to disabled by default and no explicit request to enable, disable it - if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None: - reasoning_enabled = False - # Respect sampler-level override for reasoning: request.sampler.ALLOW_REASONING -> sampler_allow_reasoning -> sampler.default -> model - try: - if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None: - reasoning_enabled = bool(request.sampler.ALLOW_REASONING) - elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None: - reasoning_enabled = bool(request.sampler_allow_reasoning) - else: - ms = MODEL_STORAGE.get(request.model) - if ms and ms.MODEL_CONFIG: - if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None: - if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING: - reasoning_enabled = False - elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING: - reasoning_enabled = False - except Exception: - pass - - # Keep the local boolean for generating content - enableReasoning = reasoning_enabled - try: - ms = MODEL_STORAGE.get(request.model) - if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING: - enableReasoning = False - except Exception: - pass - - # Ensure web_search property mirrors computed web_search_enabled if not explicitly provided - if request.enable_web_search is None: - request.web_search = web_search_enabled - # If tools should be automatically enabled, add detected ones - if tools_enabled and not request.tools: - if detection.get('detected_tools'): - request.tools = detection.get('detected_tools') - # If universal is needed and not explicitly requested, add universal tool - if (request.enable_universal is True) or ( - request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal')) - ): - if not request.tools: - request.tools = [{"name": "universal", "args": {"query": raw_prompt}}] - - executed_tool_calls = [] - # If file tools are enabled and files are attached, inject them into the prompt (for streaming) - if file_tool_enabled and request.file_ids: - for fid in request.file_ids: - try: - if fid not in UPLOADED_FILES: - continue - meta = UPLOADED_FILES.get(fid) - if not meta: - continue - from utils import file_read_from_path - fpath = meta.get('path') - if not fpath or not os.path.exists(fpath): - continue - file_content = file_read_from_path(fpath, 200000) - if file_content: - exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}} - executed_tool_calls.append(exec_entry) - prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt) - except Exception as e: - logger.info(f"File injection error: {e}") - # If file tools are enabled and files are attached, inject them into the prompt - if file_tool_enabled and request.file_ids: - for fid in request.file_ids: - try: - if fid not in UPLOADED_FILES: - continue - meta = UPLOADED_FILES.get(fid) - if not meta: - continue - from utils import file_read_from_path - fpath = meta.get('path') - if not fpath or not os.path.exists(fpath): - continue - file_content = file_read_from_path(fpath, 200000) - if file_content: - exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}} - executed_tool_calls.append(exec_entry) - prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt) - except Exception as e: - logger.info(f"File injection error: {e}") - if request.tools: - try: - for tool in request.tools: - name = tool.get('name') - args = tool.get('args', {}) - if name == 'web_search': - from utils import web_search - - search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) - search_top_k = int(args.get('top_k') or request.search_top_k or 3) - search_str = web_search(search_q, search_top_k) - if search_str: - search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}} - executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct}) - prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt) - elif name == 'calc' or name == 'calculator': - from utils import calc - - expr = args.get('expression') - if expr: - calc_res = calc(expr) - # Wrap result into a structured dict - calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}} - executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct}) - prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt) - elif name == 'universal': - try: - res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) - # If universal_tool returns a dict, extract text result for prompt injection - if isinstance(res, dict): - result_text = res.get('result') if res.get('result') is not None else '' - else: - result_text = str(res) - executed_tool_calls.append({"name": "universal", "args": args, "result": res}) - prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt) - except Exception as e: - logger.info(f"Universal tool execution error: {e}") - else: - # Unsupported tool - ignore or log - logger.info(f"Unsupported tool requested: {name}") - if name == 'file_read': - # read an uploaded file by id/path - try: - fid = args.get('file_id') or args.get('id') or (request.file_ids[0] if request.file_ids else None) - if not fid: - continue - if fid not in UPLOADED_FILES: - continue - meta = UPLOADED_FILES.get(fid) - if not meta: - continue - from utils import file_read_from_path - fpath = meta.get('path') - if not fpath or not os.path.exists(fpath): - continue - file_content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000)) - exec_entry = {"name": "file_read", "args": {"file_id": fid, "max_bytes": int(args.get('max_bytes') or 100000)}, "result": {"action": "file_read", "result": file_content, "metadata": {"file_id": fid, "filename": meta.get('filename')}}} - executed_tool_calls.append(exec_entry) - _res = exec_entry.get('result') if isinstance(exec_entry, dict) else None - _res_text = '' - if isinstance(_res, dict): - _res_text = _res.get('result') or '' - elif _res is not None: - _res_text = str(_res) - prompt = (f"ToolResults:\n{_res_text}\n\nUse these file contents to answer the prompt.\n\n" + prompt) - except Exception as e: - logger.info(f"file_read tool error: {e}") - except Exception as e: - logger.info(f"Tool processing error: {e}") - elif request.web_search or web_search_enabled: - try: - from utils import web_search - - search_q = request.prompt if request.prompt else cleanMessages(request.messages or []) - search_res = web_search(search_q, int(request.search_top_k or 3)) - if search_res: - search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}} - executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct}) - prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt - except Exception: - pass - logger.info(f"[REQ] {completionId} - prompt - {prompt}") - - # Resume or prefill tokens/state - if request.state_name: - state_key = (request.model, request.state_name) - if state_key in STATE_STORE: - stored = STATE_STORE[state_key] - model_state = stored.get('state', None) - model_tokens = stored.get('model_tokens', [0]) - if model_state is None: - # Recompute out and model_state from tokens since we did not persist the torch state - out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens) - else: - # If we have a model_state, we still need out logits. Compute from last window of tokens - out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :]) - else: - out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state) - else: - out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state) - - prefillTime = time.time() - promptTokenCount = len(model_tokens) - - fullResponse = " 0) or (auto_file_flag and request.file_ids)) - if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None: - file_tool_enabled = False - try: - if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None: - file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL) - elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None: - file_tool_enabled = bool(request.sampler_allow_file_tool) - else: - ms2 = MODEL_STORAGE.get(request.model) - if ms2 and ms2.MODEL_CONFIG: - if hasattr(ms2.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms2.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None: - file_tool_enabled = bool(ms2.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL) - elif hasattr(ms2.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms2.MODEL_CONFIG.ALLOW_FILE_TOOL: - file_tool_enabled = False - except Exception: - pass - # Build final prompt after deciding enableReasoning - prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' 0 and r_dict['choices'][0].get('delta') is not None: - r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls - except Exception: - pass - yield f"data: {r_dict}\n\n" - - buffer = [] - - if enableReasoning: - buffer.append(" tag - "fullTextCursor": 0, - "in_think": False, - "cacheStr": "", - } - - for chunk in generate( - request, - out, - model_tokens, - model_state, - max_tokens=( - 64000 - if "max_tokens" not in request.model_fields_set and enableReasoning - else (request.max_tokens or 2048) - ), - ): - completionTokenCount += 1 - # Each token stream is delivered as a decoded character/bytes (maybe 1 or more chars) - chunkContent: str = chunk["content"] - buffer.append(chunkContent) - - fullText = "".join(buffer) - - if chunk["finish_reason"]: - finishReason = chunk["finish_reason"] - - response = ChatCompletionChunk( - id=completionId, - created=createTimestamp, - model=request.model, - usage=( - Usage( - prompt_tokens=promptTokenCount, - completion_tokens=completionTokenCount, - total_tokens=promptTokenCount + completionTokenCount, - prompt_tokens_details=PromptTokensDetails(cached_tokens=0), - ) - if request.include_usage - else None - ), - choices=[ - ChatCompletionChoice( - index=0, - delta=ChatCompletionMessage( - role="Assistant", - content=None, - reasoning_content=None, - tool_calls=None, - ), - logprobs=None, - finish_reason=finishReason, - ) - ], - ) - if response.choices and response.choices[0].delta is None: - response.choices[0].delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - - markStart = fullText.find("<", streamConfig["fullTextCursor"]) - if not streamConfig["isChecking"] and markStart != -1: - streamConfig["isChecking"] = True - - if streamConfig["in_think"]: - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.reasoning_content = fullText[streamConfig["fullTextCursor"] : markStart] - else: - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.content = fullText[streamConfig["fullTextCursor"] : markStart] - - streamConfig["cacheStr"] = "" - streamConfig["fullTextCursor"] = markStart - - if streamConfig["isChecking"]: - streamConfig["cacheStr"] = fullText[streamConfig["fullTextCursor"] :] - else: - if streamConfig["in_think"]: - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.reasoning_content = chunkContent - else: - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.content = chunkContent - streamConfig["fullTextCursor"] = len(fullText) - - markEnd = fullText.find(">", streamConfig["fullTextCursor"]) - if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None: - streamConfig["isChecking"] = False - - if ( - not streamConfig["in_think"] - and streamConfig["cacheStr"].find("") != -1 - ): - streamConfig["in_think"] = True - - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.reasoning_content = ( - delta.reasoning_content - if delta.reasoning_content != None - else "" + streamConfig["cacheStr"].replace("", "") - ) - - elif ( - streamConfig["in_think"] - and streamConfig["cacheStr"].find("") != -1 - ): - streamConfig["in_think"] = False - - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.content = ( - delta.content - if delta.content != None - else "" + streamConfig["cacheStr"].replace("", "") - ) - else: - if streamConfig["in_think"]: - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.reasoning_content = ( - delta.reasoning_content - if delta.reasoning_content != None - else "" + streamConfig["cacheStr"] - ) - else: - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - delta.content = ( - delta.content - if delta.content != None - else "" + streamConfig["cacheStr"] - ) - streamConfig["fullTextCursor"] = len(fullText) - - delta = response.choices[0].delta - if delta is None: - delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) - response.choices[0].delta = delta - if delta.content != None or delta.reasoning_content != None: - # Save model state frequently (after each token) to allow resuming - try: - if request.state_name: - STATE_STORE[(request.model, request.state_name)] = { - 'state': model_state, - 'model_tokens': model_tokens, - } - if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False): - try: - _save_state_store_to_disk(force=True) - except Exception: - pass - except Exception: - pass - # model-initiated tool call detection - if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS: - m = TOOL_CALL_RE.search(fullText) - if m: - try: - payload_raw = m.group(1) - import json - - payload = json.loads(payload_raw) - tool_name = payload.get('name') - tool_args = payload.get('args', {}) - tool_res = None - if tool_name == 'web_search': - from utils import web_search - - q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) - k = int(tool_args.get('top_k') or request.search_top_k or 3) - tool_res = web_search(q, k) - elif tool_name in ('calc', 'calculator'): - from utils import calc - - expr = tool_args.get('expression') - if expr: - tool_res = calc(expr) - else: - try: - tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) - except Exception: - tool_res = None - - if tool_res: - # Normalize tool_res into a structured dict if needed - if not isinstance(tool_res, dict): - if tool_name in ('calc', 'calculator'): - tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}} - elif tool_name == 'web_search': - tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}} - else: - tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}} - else: - tool_res_struct = tool_res - exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True} - executed_tool_calls.append(exec_entry) - delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n" - prompt = delta_text + prompt - fullText = TOOL_CALL_RE.sub('', fullText) - buffer = [fullText] - out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state) - model_initiated_tool_calls += 1 - should_restart = True - break - except Exception as e: - logger.info(f"Model-initiated tool handling error: {e}") - yield f"data: {response.model_dump_json()}\n\n" - # check stop sequences and stop streaming if we see them - for stop_words in request.stop or []: - if stop_words in ''.join(buffer): - finishReason = f"stop:words:{stop_words}" - return - - await asyncio.sleep(0) - - del streamConfig - else: - should_restart = True - while should_restart: - should_restart = False - gen = generate(request, out, model_tokens, model_state) - for chunk in gen: - completionTokenCount += 1 - buffer.append(chunk["content"]) - - if chunk["finish_reason"]: - finishReason = chunk["finish_reason"] - - # Save model state frequently (after each token) to allow resuming - try: - if request.state_name: - STATE_STORE[(request.model, request.state_name)] = { - 'state': model_state, - 'model_tokens': model_tokens, - } - if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False): - try: - _save_state_store_to_disk(force=True) - except Exception: - pass - except Exception: - pass - - # Detect model-initiated tool calls - if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS: - fullText = ''.join(buffer) - m = TOOL_CALL_RE.search(fullText) - if m: - try: - payload_raw = m.group(1) - import json - - payload = json.loads(payload_raw) - tool_name = payload.get('name') - tool_args = payload.get('args', {}) - tool_res = None - if tool_name == 'web_search': - from utils import web_search - - q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) - k = int(tool_args.get('top_k') or request.search_top_k or 3) - tool_res = web_search(q, k) - elif tool_name in ('calc', 'calculator'): - from utils import calc - - expr = tool_args.get('expression') - if expr: - tool_res = calc(expr) - else: - try: - tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) - except Exception: - tool_res = None - - if tool_res: - if not isinstance(tool_res, dict): - if tool_name in ('calc', 'calculator'): - tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}} - elif tool_name == 'web_search': - tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}} - else: - tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}} - else: - tool_res_struct = tool_res - exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True} - executed_tool_calls.append(exec_entry) - delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n" - prompt = delta_text + prompt - fullText = TOOL_CALL_RE.sub('', fullText) - buffer = [fullText] - out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state) - # Notify client that a tool was called mid-stream (metadata-only chunk) - try: - meta_resp = ChatCompletionChunk( - id=completionId, - created=createTimestamp, - model=request.model, - usage=( - Usage( - prompt_tokens=promptTokenCount, - completion_tokens=completionTokenCount, - total_tokens=promptTokenCount + completionTokenCount, - prompt_tokens_details=PromptTokensDetails(cached_tokens=0), - ) - if request.include_usage - else None - ), - choices=[ - ChatCompletionChoice( - index=0, - delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls), - logprobs=None, - finish_reason=None, - ) - ], - ) - yield f"data: {meta_resp.model_dump_json()}\n\n" - except Exception: - pass - model_initiated_tool_calls += 1 - should_restart = True - break - except Exception as e: - logger.info(f"Model-initiated tool handling error: {e}") - - response = ChatCompletionChunk( - id=completionId, - created=createTimestamp, - model=request.model, - usage=( - Usage( - prompt_tokens=promptTokenCount, - completion_tokens=completionTokenCount, - total_tokens=promptTokenCount + completionTokenCount, - prompt_tokens_details=PromptTokensDetails(cached_tokens=0), - ) - if request.include_usage - else None - ), - choices=[ - ChatCompletionChoice( - index=0, - delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None), - logprobs=None, - finish_reason=finishReason, - ) - ], - ) - yield f"data: {response.model_dump_json()}\n\n" - await asyncio.sleep(0) - - genenrateTime = time.time() - - responseLog = { - "content": "".join(buffer), - "finish": finishReason, - "prefill_len": promptTokenCount, - "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), - "gen_len": completionTokenCount, - "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), - } - logger.info(f"[RES] {completionId} - {responseLog}") - if request.messages is None: - request.messages = [] - request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"])) - log( - { - **request.model_dump(), - **responseLog, - "completionId": completionId, - "machineLabel": os.environ.get("MACHINE_LABEL"), - } - ) - - del buffer - - yield "data: [DONE]\n\n" - - -@app.post("/api/v1/chat/completions") -async def chat_completions(request: ChatCompletionRequest): - completionId = str(next(CompletionIdGenerator)) - logger.info(f"[REQ] {completionId} - {request.model_dump()}") - - # Support model suffixes like ':thinking' for reasoning or ':web' to request - # web search by default for this request. E.g., 'rwkv-latest:web' will enable web_search. - modelName = request.model.split(":")[0] - if ":web" in request.model: - request.enable_web_search = True - if ":file" in request.model: - request.enable_file_tool = True - enableReasoning = ":thinking" in request.model - - if "rwkv-latest" in request.model: - # Map to the default chat model in all cases. Do not redirect to a separate - # reasoning model when ':thinking' is used. The same model will be used - # and reasoning handled in-process by setting enableReasoning=True. - if DEFALUT_MODEL_NAME == None: - raise HTTPException(404, "DEFALUT_MODEL_NAME not set") - ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME) - if not ms_def or not ms_def.MODEL_CONFIG: - raise HTTPException(500, "Default sampler config missing for default model") - defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER - request.model = DEFALUT_MODEL_NAME - - elif modelName in MODEL_STORAGE: - ms_sel = MODEL_STORAGE.get(modelName) - if not ms_sel or not ms_sel.MODEL_CONFIG: - raise HTTPException(500, f"Default sampler config missing for model {modelName}") - defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER - request.model = modelName - else: - raise HTTPException(404, f"Can not find `{modelName}`") - - async def chatResponseStreamDisconnect(): - logGPUState() - - # Load or initialize model_state and tokens based on state_name - model_state = None - model_tokens_for_resume = [0] - state_name = request.state_name - if state_name is None: - state_name = str(uuid.uuid4()) - request.state_name = state_name - state_key = (request.model, state_name) - if state_key in STATE_STORE: - stored = STATE_STORE[state_key] - model_state = stored.get('state', None) - model_tokens_for_resume = stored.get('model_tokens', [0]) - request_dict = request.model_dump() - - # Apply defaults from model's DEFAULT_SAMPLER, optionally overridden by the - # per-request `sampler` object (or legacy sampler_allow_* booleans). - sampler_overrides = request_dict.get('sampler') or {} - for k, v in defaultSamplerConfig.model_dump().items(): - # If the request provided a sampler override for this field, use it - if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None: - request_dict[k] = sampler_overrides.get(k) - continue - if k in request_dict and request_dict[k] is None: - request_dict[k] = v - realRequest = ChatCompletionRequest(**request_dict) - # Ensure stream defaults to configuration value when not explicitly provided - if realRequest.stream is None: - realRequest.stream = CONFIG.DEFAULT_STREAM - - logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}") - - if realRequest.stream: - r = StreamingResponse( - chatResponseStream(realRequest, model_state, completionId, enableReasoning), - media_type="text/event-stream", - background=BackgroundTask(chatResponseStreamDisconnect), - ) - else: - r = await chatResponse(realRequest, model_state, completionId, enableReasoning) - # Attach state_name to non-streaming response as additional metadata - try: - import json - - if isinstance(r, ChatCompletion): - d = r.model_dump() - d['state_name'] = state_name - return d - except Exception: - pass - - return r - - -# We keep the service API-only; remove static mount for demo frontend to -# avoid serving HTML files by default and keep the repository Python-only. -logger.info("Static frontend mount removed for Python-only deploy; use API endpoints for integration") - - -@app.get('/api/v1/models') -def list_models(): - """Return model configuration summary for clients/UI. - - This endpoint returns configured models, their default sampler values, and - ALLOW_* flags so UI clients can build a controls surface based on server - capabilities (web search, tools, reasoning). - """ - out = [] - root_defaults = { - 'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True), - 'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True), - 'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True), - 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True), - 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True), - 'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True), - 'UPLOAD_URL': '/api/v1/files', - } - for m in CONFIG.MODELS: - out.append( - { - 'SERVICE_NAME': m.SERVICE_NAME, - 'DEFAULT_CHAT': m.DEFAULT_CHAT, - 'DEFAULT_REASONING': m.DEFAULT_REASONING, - 'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True), - 'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True), - 'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True), - 'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True), - 'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True), - 'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True), - 'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True), - 'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None, - # Convenience info for clients: upload endpoint and root defaults - 'UPLOAD_URL': '/api/v1/files', - 'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True), - } - ) - return {'root_defaults': root_defaults, 'models': out} - - -@app.post('/api/v1/files', response_model=FileUploadResponse) -async def upload_file(file: UploadFile = File(...), model: Optional[str] = None): - """Save uploaded file to CONFIG.UPLOAD_DIR and return metadata.""" - try: - # Respect root-level upload toggle - if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True): - raise HTTPException(403, 'File uploads are disabled by server configuration') - # If a model is provided, verify the model allows file tools - if model: - if model not in MODEL_STORAGE: - raise HTTPException(404, f"Model {model} not found") - ms = MODEL_STORAGE[model] - if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True): - raise HTTPException(403, f"Model {model} does not allow file uploads") - from utils import save_bytes_to_upload - - content = await file.read() - fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file' - meta = save_bytes_to_upload(fname, content) - if meta.get('error'): - raise HTTPException(500, f"Could not save file: {meta.get('error')}") - UPLOADED_FILES[meta['file_id']] = meta - return FileUploadResponse(success=True, file=UploadedFile(**meta)) - except Exception as e: - raise HTTPException(500, str(e)) - - -@app.get('/api/v1/files') -def list_files(): - return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()] - - -@app.get('/api/v1/files/{file_id}') -def get_file(file_id: str, download: bool = False): - if file_id not in UPLOADED_FILES: - raise HTTPException(404, 'File not found') - meta = UPLOADED_FILES[file_id] - if download: - # return file contents - try: - with open(meta['path'], 'rb') as f: - return StreamingResponse(f, media_type='application/octet-stream') - except Exception as e: - raise HTTPException(500, str(e)) - return UploadedFile(**meta) - - -@app.delete('/api/v1/files/{file_id}') -def delete_file(file_id: str): - if file_id not in UPLOADED_FILES: - raise HTTPException(404, 'File not found') - meta = UPLOADED_FILES.pop(file_id) - try: - if os.path.exists(meta['path']): - os.remove(meta['path']) - except Exception: - pass - return {'success': True} - -if __name__ == "__main__": - import uvicorn - - host = CONFIG.HOST or "127.0.0.1" - port = CONFIG.PORT or 7860 - uvicorn.run(app, host=host, port=port) +import os + +if os.environ.get("MODELSCOPE_ENVIRONMENT") == "studio": + from modelscope import patch_hub + patch_hub() + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" + +from config import CONFIG, ModelConfig +from utils import ( + cleanMessages, + parse_think_response, + remove_nested_think_tags_stack, + format_bytes, + log, + detect_tools_and_reasoning, + universal_tool, +) + +import copy, types, gc, sys, re, time, collections, asyncio +from huggingface_hub import hf_hub_download +from loguru import logger +from rich import print + +from snowflake import SnowflakeGenerator + +CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595) + +from typing import List, Optional, Union, Any, Dict +import uuid +from pydantic import BaseModel, Field, model_validator +from pydantic_settings import BaseSettings + +import numpy as np +import torch + +if "cuda" in CONFIG.STRATEGY.lower() and not torch.cuda.is_available(): + logger.info(f"CUDA not found, fall back to cpu") + CONFIG.STRATEGY = "cpu fp16" + +_s = CONFIG.STRATEGY.lower() +if ("cpu" in _s or "cuda" in _s) and not ("fp16" in _s or "fp32" in _s): + logger.info(f"STRATEGY missing precision, appending 'fp16' to `{CONFIG.STRATEGY}`") + CONFIG.STRATEGY = CONFIG.STRATEGY + " fp16" + +try: + from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo +except Exception: + nvmlInit = None + nvmlDeviceGetHandleByIndex = None + nvmlDeviceGetMemoryInfo = None + +if "cuda" in CONFIG.STRATEGY.lower() and nvmlInit is not None and nvmlDeviceGetHandleByIndex is not None: + nvmlInit() + gpu_h = nvmlDeviceGetHandleByIndex(0) + +def logGPUState(): + if "cuda" in CONFIG.STRATEGY and nvmlDeviceGetMemoryInfo is not None: + gpu_info = nvmlDeviceGetMemoryInfo(gpu_h) + logger.info( + f"[STATUS] Torch - {format_bytes(torch.cuda.memory_allocated())} - NVML - vram {format_bytes(gpu_info.total)} used {format_bytes(gpu_info.used)} free {format_bytes(gpu_info.free)}" + ) + +torch.backends.cudnn.benchmark = True +torch.backends.cudnn.allow_tf32 = True +torch.backends.cuda.matmul.allow_tf32 = True +os.environ["RWKV_V7_ON"] = "1" +os.environ["RWKV_JIT_ON"] = "1" +os.environ["RWKV_CUDA_ON"] = ( + "1" if CONFIG.RWKV_CUDA_ON and "cuda" in CONFIG.STRATEGY.lower() else "0" +) + +from rwkv.model import RWKV +from rwkv.utils import PIPELINE, PIPELINE_ARGS + +from fastapi import FastAPI, HTTPException, UploadFile, File +from starlette.background import BackgroundTask +from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from fastapi.middleware.gzip import GZipMiddleware + +from api_types import ( + ChatMessage, + ChatCompletion, + ChatCompletionChunk, + Usage, + PromptTokensDetails, + ChatCompletionChoice, + ChatCompletionMessage, + SamplerConfig, + UploadedFile, + FileUploadResponse, +) + +class ModelStorage: + MODEL_CONFIG: Optional[ModelConfig] = None + model: Optional[RWKV] = None + pipeline: Optional[PIPELINE] = None + +MODEL_STORAGE: Dict[str, ModelStorage] = {} + +DEFALUT_MODEL_NAME = None +DEFAULT_REASONING_MODEL_NAME = None + +STATE_STORE: Dict[tuple, Any] = {} + +_STATE_STORE_PATH = getattr(CONFIG, 'STATE_STORE_PATH', './state_store.json') +_LAST_STATE_STORE_WRITE = 0 + +TOOL_CALL_RE = re.compile(r"\s*(\{.*?\})\s*", re.S) + +UPLOADED_FILES: Dict[str, dict] = {} + +def _serialize_state_store() -> dict: + serial = {} + for (model_name, state_name), entry in STATE_STORE.items(): + try: + mt = entry.get('model_tokens') if isinstance(entry, dict) else None + if mt is None: + continue + serial[f"{model_name}|{state_name}"] = { + 'model': model_name, + 'state_name': state_name, + 'model_tokens': mt, + } + except Exception: + continue + return serial + +def _load_state_store_from_disk(): + global STATE_STORE + try: + if os.path.exists(_STATE_STORE_PATH): + import json + + with open(_STATE_STORE_PATH, 'r', encoding='utf-8') as f: + data = json.load(f) + for k, v in data.items(): + model = v.get('model') + state_name = v.get('state_name') + model_tokens = v.get('model_tokens') + if model and state_name and isinstance(model_tokens, list): + STATE_STORE[(model, state_name)] = { + 'state': None, + 'model_tokens': model_tokens, + } + logger.info(f"Loaded {len(STATE_STORE)} entries from state store file {_STATE_STORE_PATH}") + except Exception as e: + logger.info(f"Failed to load state store from disk: {e}") + +def _save_state_store_to_disk(force=False): + global _LAST_STATE_STORE_WRITE + now = time.time() + if not force and now - _LAST_STATE_STORE_WRITE < getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5): + return + try: + serial = _serialize_state_store() + if not serial: + return + import json + tmp = _STATE_STORE_PATH + ".tmp" + with open(tmp, 'w', encoding='utf-8') as f: + json.dump(serial, f) + os.replace(tmp, _STATE_STORE_PATH) + _LAST_STATE_STORE_WRITE = now + except Exception as e: + logger.info(f"Write state store to disk failed: {e}") + +def _recompute_out_and_state_from_tokens(model_name: str, model_tokens: List[int]): + ms = MODEL_STORAGE.get(model_name) + if not ms or not ms.model: + return None, None + model_state = None + out = None + tokens = list(model_tokens) if isinstance(model_tokens, list) else [0] + while len(tokens) > 0: + out, model_state = ms.model.forward(tokens[: CONFIG.CHUNK_LEN], model_state) + tokens = tokens[CONFIG.CHUNK_LEN :] + return out, model_state + +logger.info(f"STRATEGY - {CONFIG.STRATEGY}") + +logGPUState() + +logger.info(f"Configured {len(CONFIG.MODELS)} model(s) in ROOT config") + +for model_config in CONFIG.MODELS: + logger.info(f"Load Model - {model_config.SERVICE_NAME}") + + if model_config.MODEL_FILE_PATH == None: + model_config.MODEL_FILE_PATH = hf_hub_download( + repo_id=str(model_config.DOWNLOAD_MODEL_REPO_ID), + filename=str(model_config.DOWNLOAD_MODEL_FILE_NAME), + local_dir=str(model_config.DOWNLOAD_MODEL_DIR), + ) + logger.info(f"Load Model - Path - {model_config.MODEL_FILE_PATH}") + + if model_config.DEFAULT_CHAT: + if DEFALUT_MODEL_NAME != None: + logger.info( + f"Load Model - Replace `DEFALUT_MODEL_NAME` from `{DEFALUT_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" + ) + DEFALUT_MODEL_NAME = model_config.SERVICE_NAME + + if model_config.DEFAULT_REASONING: + if DEFAULT_REASONING_MODEL_NAME != None: + logger.info( + f"Load Model - Replace `DEFAULT_REASONING_MODEL_NAME` from `{DEFAULT_REASONING_MODEL_NAME}` to `{model_config.SERVICE_NAME}`" + ) + DEFAULT_REASONING_MODEL_NAME = model_config.SERVICE_NAME + + logger.info(f"Load Model - Loading `{model_config.SERVICE_NAME}`") + print(model_config.DEFAULT_SAMPLER) + + MODEL_STORAGE[model_config.SERVICE_NAME] = ModelStorage() + MODEL_STORAGE[model_config.SERVICE_NAME].MODEL_CONFIG = model_config + MODEL_STORAGE[model_config.SERVICE_NAME].model = RWKV( + model=model_config.MODEL_FILE_PATH.replace(".pth", ""), + strategy=CONFIG.STRATEGY, + ) + MODEL_STORAGE[model_config.SERVICE_NAME].pipeline = PIPELINE( + MODEL_STORAGE[model_config.SERVICE_NAME].model, model_config.VOCAB + ) + if "cuda" in CONFIG.STRATEGY: + torch.cuda.empty_cache() + gc.collect() + logGPUState() + +logger.info(f"Load Model - DEFALUT_MODEL_NAME is `{DEFALUT_MODEL_NAME}`") +logger.info( + f"Load Model - DEFAULT_REASONING_MODEL_NAME is `{DEFAULT_REASONING_MODEL_NAME}`" +) + +if len(MODEL_STORAGE) == 1: + single_name = list(MODEL_STORAGE.keys())[0] + if DEFALUT_MODEL_NAME != single_name: + DEFALUT_MODEL_NAME = single_name + logger.info(f"Load Model - Only one model present; DEFALUT_MODEL_NAME set to `{DEFALUT_MODEL_NAME}`") + if DEFAULT_REASONING_MODEL_NAME != single_name: + DEFAULT_REASONING_MODEL_NAME = single_name + logger.info(f"Load Model - Only one model present; DEFAULT_REASONING_MODEL_NAME set to `{DEFAULT_REASONING_MODEL_NAME}`") + +class ChatCompletionRequest(BaseModel): + model: str = Field(default="rwkv-latest") + messages: Optional[List[ChatMessage]] = Field(default=None) + prompt: Optional[str] = Field(default=None) + max_tokens: Optional[int] = Field(default=None) + temperature: Optional[float] = Field(default=None) + top_p: Optional[float] = Field(default=None) + presence_penalty: Optional[float] = Field(default=None) + count_penalty: Optional[float] = Field(default=None) + penalty_decay: Optional[float] = Field(default=None) + stream: Optional[bool] = Field(default=True) + state_name: Optional[str] = Field(default=None) + include_usage: Optional[bool] = Field(default=False) + stop: Optional[list[str]] = Field(["\n\n"]) + stop_tokens: Optional[list[int]] = Field([0]) + web_search: Optional[bool] = Field(default=True) + enable_web_search: Optional[bool] = Field(default=True) + auto_web_search: Optional[bool] = Field(default=True) + enable_tools: Optional[bool] = Field(default=None) + auto_tools: Optional[bool] = Field(default=True) + enable_reasoning: Optional[bool] = Field(default=True) + auto_reasoning: Optional[bool] = Field(default=True) + enable_universal: Optional[bool] = Field(default=None) + auto_universal: Optional[bool] = Field(default=True) + search_top_k: Optional[int] = Field(default=3) + tools: Optional[List[Dict[str, Any]]] = Field(default=None) + sampler_allow_web_search: Optional[bool] = Field(default=None) + sampler_allow_tools: Optional[bool] = Field(default=None) + sampler_allow_reasoning: Optional[bool] = Field(default=None) + sampler: Optional[SamplerConfig] = Field(default=None) + file_ids: Optional[List[str]] = Field(default=None) + enable_file_tool: Optional[bool] = Field(default=None) + auto_file_tool: Optional[bool] = Field(default=None) + sampler_allow_file_tool: Optional[bool] = Field(default=None) + + @model_validator(mode="before") + @classmethod + def validate_mutual_exclusivity(cls, data: Any) -> Any: + if not isinstance(data, dict): + return data + + messages_provided = "messages" in data and data["messages"] != None + prompt_provided = "prompt" in data and data["prompt"] != None + + if messages_provided and prompt_provided: + raise ValueError("messages and prompt cannot coexist. Choose one.") + if not messages_provided and not prompt_provided: + raise ValueError("Either messages or prompt must be provided.") + return data + +app = FastAPI(title="RWKV OpenAI-Compatible API") + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=5) + +@app.on_event("startup") +async def _startup_state_load_and_persist_loop(): + _load_state_store_from_disk() + + async def _persist_loop(): + while True: + try: + _save_state_store_to_disk(force=False) + except Exception: + pass + await asyncio.sleep(getattr(CONFIG, 'STATE_STORE_FLUSH_INTERVAL', 5)) + + try: + asyncio.create_task(_persist_loop()) + except Exception: + pass + +async def runPrefill( + request: ChatCompletionRequest, ctx: str, model_tokens: List[int], model_state +): + ctx = ctx.replace("\r\n", "\n") + out = None + + ms = MODEL_STORAGE.get(request.model) + if not ms or not ms.pipeline or not ms.model: + raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing") + tokens = ms.pipeline.encode(ctx) + tokens = [int(x) for x in tokens] + model_tokens += tokens + + while len(tokens) > 0: + out, model_state = ms.model.forward( + tokens[: CONFIG.CHUNK_LEN], model_state + ) + tokens = tokens[CONFIG.CHUNK_LEN :] + await asyncio.sleep(0) + + return out, model_tokens, model_state + +def generate( + request: ChatCompletionRequest, + out, + model_tokens: List[int], + model_state, + max_tokens=2048, +): + ms = MODEL_STORAGE.get(request.model) + if not ms or not ms.pipeline or not ms.model: + raise HTTPException(500, f"Model {request.model} not loaded or pipeline missing") + + temperature = request.temperature if request.temperature is not None else 0.2 + top_p = request.top_p if request.top_p is not None else 0.9 + alpha_frequency = request.count_penalty if request.count_penalty is not None else 0.0 + alpha_presence = request.presence_penalty if request.presence_penalty is not None else 0.0 + penalty_decay = request.penalty_decay if request.penalty_decay is not None else 0.5 + + args = PIPELINE_ARGS( + temperature=max(0.2, temperature), + top_p=top_p, + alpha_frequency=alpha_frequency, + alpha_presence=alpha_presence, + token_ban=[], + token_stop=[0], + ) + + occurrence = {} + out_tokens: List[int] = [] + out_last = 0 + + for i in range(max_tokens): + for n in occurrence: + out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency + + token = ms.pipeline.sample_logits( + out, temperature=args.temperature, top_p=args.top_p + ) + + if token == 0 and request.stop_tokens and token in request.stop_tokens: + yield { + "content": "", + "tokens": out_tokens[out_last:], + "finish_reason": "stop:token:0", + "state": model_state, + } + + del out + gc.collect() + return + + out, model_state = ms.model.forward([token], model_state) + model_tokens.append(token) + out_tokens.append(token) + + if request.stop_tokens and token in request.stop_tokens: + yield { + "content": "", + "tokens": out_tokens[out_last:], + "finish_reason": f"stop:token:{token}", + "state": model_state, + } + + del out + gc.collect() + return + + for xxx in list(occurrence.keys()): + occurrence[xxx] *= penalty_decay + occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) + + decoded = ms.pipeline.decode([token]) + if "\ufffd" in decoded: + continue + + yield { + "content": decoded, + "tokens": [token], + "finish_reason": None, + "state": model_state, + } + out_last = i + 1 + + else: + yield { + "content": "", + "tokens": [], + "finish_reason": "length", + } + +async def chatResponse( + request: ChatCompletionRequest, + model_state: Any, + completionId: str, + enableReasoning: bool, +) -> ChatCompletion: + createTimestamp = time.time() + + raw_prompt = request.prompt.strip() if request.prompt is not None else cleanMessages(request.messages or []) + detection = detect_tools_and_reasoning(raw_prompt) + prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [])}\n\nAssistant:{' 0) or (auto_file_flag and request.file_ids)) + if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None: + file_tool_enabled = False + try: + if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None: + file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL) + elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None: + file_tool_enabled = bool(request.sampler_allow_file_tool) + else: + ms = MODEL_STORAGE.get(request.model) + if ms and ms.MODEL_CONFIG: + if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None: + file_tool_enabled = bool(ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL) + elif hasattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms.MODEL_CONFIG.ALLOW_FILE_TOOL: + file_tool_enabled = False + except Exception: + pass + + if request.enable_tools is not None: + tools_enabled = bool(request.enable_tools) + else: + auto_tools_flag = request.auto_tools if request.auto_tools is not None else CONFIG.AUTO_ENABLE_TOOLS + tools_enabled = bool(request.tools) or CONFIG.ENABLE_TOOLS_BY_DEFAULT or (auto_tools_flag and (detection.get('need_calc') or detection.get('need_web_search'))) + try: + if request.sampler and getattr(request.sampler, 'ALLOW_TOOLS', None) is not None: + tools_enabled = bool(request.sampler.ALLOW_TOOLS) + elif hasattr(request, 'sampler_allow_tools') and request.sampler_allow_tools is not None: + tools_enabled = bool(request.sampler_allow_tools) + else: + ms = MODEL_STORAGE.get(request.model) + if ms and ms.MODEL_CONFIG: + if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_TOOLS', None) is not None: + if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_TOOLS: + tools_enabled = False + elif hasattr(ms.MODEL_CONFIG, 'ALLOW_TOOLS') and not ms.MODEL_CONFIG.ALLOW_TOOLS: + tools_enabled = False + except Exception: + pass + + reasoning_enabled = bool( + True + if (request.enable_reasoning is not None and request.enable_reasoning) + else ( + bool(enableReasoning) or bool(request.auto_reasoning if request.auto_reasoning is not None else (CONFIG.AUTO_ENABLE_REASONING and bool(detection.get('need_reasoning')))) + ) + ) + if not getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True) and request.enable_reasoning is None: + reasoning_enabled = False + try: + if request.sampler and getattr(request.sampler, 'ALLOW_REASONING', None) is not None: + reasoning_enabled = bool(request.sampler.ALLOW_REASONING) + elif hasattr(request, 'sampler_allow_reasoning') and request.sampler_allow_reasoning is not None: + reasoning_enabled = bool(request.sampler_allow_reasoning) + else: + ms = MODEL_STORAGE.get(request.model) + if ms and ms.MODEL_CONFIG: + if hasattr(ms.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_REASONING', None) is not None: + if not ms.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_REASONING: + reasoning_enabled = False + elif hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING: + reasoning_enabled = False + except Exception: + pass + + enableReasoning = reasoning_enabled + try: + ms = MODEL_STORAGE.get(request.model) + if ms and ms.MODEL_CONFIG and hasattr(ms.MODEL_CONFIG, 'ALLOW_REASONING') and not ms.MODEL_CONFIG.ALLOW_REASONING: + enableReasoning = False + except Exception: + pass + + if request.enable_web_search is None: + request.web_search = web_search_enabled + if tools_enabled and not request.tools: + if detection.get('detected_tools'): + request.tools = detection.get('detected_tools') + if (request.enable_universal is True) or ( + request.enable_universal is None and (request.auto_universal if request.auto_universal is not None else CONFIG.AUTO_ENABLE_TOOLS and detection.get('need_universal')) + ): + if not request.tools: + request.tools = [{"name": "universal", "args": {"query": raw_prompt}}] + + executed_tool_calls = [] + if file_tool_enabled and request.file_ids: + for fid in request.file_ids: + try: + if fid not in UPLOADED_FILES: + continue + meta = UPLOADED_FILES.get(fid) + if not meta: + continue + from utils import file_read_from_path + fpath = meta.get('path') + if not fpath or not os.path.exists(fpath): + continue + file_content = file_read_from_path(fpath, 200000) + if file_content: + exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}} + executed_tool_calls.append(exec_entry) + prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt) + except Exception as e: + logger.info(f"File injection error: {e}") + if file_tool_enabled and request.file_ids: + for fid in request.file_ids: + try: + if fid not in UPLOADED_FILES: + continue + meta = UPLOADED_FILES.get(fid) + if not meta: + continue + from utils import file_read_from_path + fpath = meta.get('path') + if not fpath or not os.path.exists(fpath): + continue + file_content = file_read_from_path(fpath, 200000) + if file_content: + exec_entry = {"name": "file_inject", "args": {"file_id": fid}, "result": {"action": "file_inject", "result": "injected", "metadata": {"file_id": fid, "filename": meta.get('filename')}}} + executed_tool_calls.append(exec_entry) + prompt = (f"AttachedFile: {meta.get('filename')} (id:{fid})\n{file_content}\n\n" + prompt) + except Exception as e: + logger.info(f"File injection error: {e}") + if request.tools: + try: + for tool in request.tools: + name = tool.get('name') + args = tool.get('args', {}) + if name == 'web_search': + from utils import web_search + + search_q = args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) + search_top_k = int(args.get('top_k') or request.search_top_k or 3) + search_str = web_search(search_q, search_top_k) + if search_str: + search_res_struct = {"action": "web_search", "result": str(search_str), "metadata": {"query": search_q, "top_k": search_top_k, "confidence": 0.9}} + executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": search_top_k}, "result": search_res_struct}) + prompt = (f"ToolResults:\n{search_res_struct.get('result')}\n\nUse these results to answer the prompt.\n\n" + prompt) + elif name == 'calc' or name == 'calculator': + from utils import calc + + expr = args.get('expression') + if expr: + calc_res = calc(expr) + calc_res_struct = {"action": "calc", "result": str(calc_res), "metadata": {"expression": expr, "confidence": 0.98}} + executed_tool_calls.append({"name": "calc", "args": {"expression": expr}, "result": calc_res_struct}) + prompt = (f"ToolResults:\nCalcResult:{expr} = {calc_res_struct.get('result')}\n\nUse this result to answer the prompt.\n\n" + prompt) + elif name == 'universal': + try: + res = universal_tool(args or {"query": raw_prompt}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) + if isinstance(res, dict): + result_text = res.get('result') if res.get('result') is not None else '' + else: + result_text = str(res) + executed_tool_calls.append({"name": "universal", "args": args, "result": res}) + prompt = (f"ToolResults:\n{result_text}\n\nUse this result to answer the prompt.\n\n" + prompt) + except Exception as e: + logger.info(f"Universal tool execution error: {e}") + else: + logger.info(f"Unsupported tool requested: {name}") + if name == 'file_read': + try: + fid = args.get('file_id') or args.get('id') or (request.file_ids[0] if request.file_ids else None) + if not fid: + continue + if fid not in UPLOADED_FILES: + continue + meta = UPLOADED_FILES.get(fid) + if not meta: + continue + from utils import file_read_from_path + fpath = meta.get('path') + if not fpath or not os.path.exists(fpath): + continue + file_content = file_read_from_path(fpath, int(args.get('max_bytes') or 100000)) + exec_entry = {"name": "file_read", "args": {"file_id": fid, "max_bytes": int(args.get('max_bytes') or 100000)}, "result": {"action": "file_read", "result": file_content, "metadata": {"file_id": fid, "filename": meta.get('filename')}}} + executed_tool_calls.append(exec_entry) + _res = exec_entry.get('result') if isinstance(exec_entry, dict) else None + _res_text = '' + if isinstance(_res, dict): + _res_text = _res.get('result') or '' + elif _res is not None: + _res_text = str(_res) + prompt = (f"ToolResults:\n{_res_text}\n\nUse these file contents to answer the prompt.\n\n" + prompt) + except Exception as e: + logger.info(f"file_read tool error: {e}") + except Exception as e: + logger.info(f"Tool processing error: {e}") + elif request.web_search or web_search_enabled: + try: + from utils import web_search + + search_q = request.prompt if request.prompt else cleanMessages(request.messages or []) + search_res = web_search(search_q, int(request.search_top_k or 3)) + if search_res: + search_res_struct = {"action": "web_search", "result": str(search_res), "metadata": {"query": search_q, "top_k": int(request.search_top_k or 3), "confidence": 0.9}} + executed_tool_calls.append({"name": "web_search", "args": {"query": search_q, "top_k": int(request.search_top_k or 3)}, "result": search_res_struct}) + prompt = f"WebSearchResults:\n{search_res_struct.get('result')}\n\n" + prompt + except Exception: + pass + logger.info(f"[REQ] {completionId} - prompt - {prompt}") + + if request.state_name: + state_key = (request.model, request.state_name) + if state_key in STATE_STORE: + stored = STATE_STORE[state_key] + model_state = stored.get('state', None) + model_tokens = stored.get('model_tokens', [0]) + if model_state is None: + out, model_state = _recompute_out_and_state_from_tokens(request.model, model_tokens) + else: + out, _ = _recompute_out_and_state_from_tokens(request.model, model_tokens[-CONFIG.CHUNK_LEN :]) + else: + out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state) + else: + out, model_tokens, model_state = await runPrefill(request, prompt, [0], model_state) + + prefillTime = time.time() + promptTokenCount = len(model_tokens) + + fullResponse = " 0) or (auto_file_flag and request.file_ids)) + if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True) and request.enable_file_tool is None: + file_tool_enabled = False + try: + if request.sampler and getattr(request.sampler, 'ALLOW_FILE_TOOL', None) is not None: + file_tool_enabled = bool(request.sampler.ALLOW_FILE_TOOL) + elif hasattr(request, 'sampler_allow_file_tool') and request.sampler_allow_file_tool is not None: + file_tool_enabled = bool(request.sampler_allow_file_tool) + else: + ms2 = MODEL_STORAGE.get(request.model) + if ms2 and ms2.MODEL_CONFIG: + if hasattr(ms2.MODEL_CONFIG, 'DEFAULT_SAMPLER') and getattr(ms2.MODEL_CONFIG.DEFAULT_SAMPLER, 'ALLOW_FILE_TOOL', None) is not None: + file_tool_enabled = bool(ms2.MODEL_CONFIG.DEFAULT_SAMPLER.ALLOW_FILE_TOOL) + elif hasattr(ms2.MODEL_CONFIG, 'ALLOW_FILE_TOOL') and not ms2.MODEL_CONFIG.ALLOW_FILE_TOOL: + file_tool_enabled = False + except Exception: + pass + prompt = raw_prompt if request.prompt is not None else f"{cleanMessages(request.messages or [], enableReasoning)}\n\nAssistant:{' 0 and r_dict['choices'][0].get('delta') is not None: + r_dict['choices'][0]['delta']['tool_calls'] = executed_tool_calls + except Exception: + pass + yield f"data: {r_dict}\n\n" + + buffer = [] + + if enableReasoning: + buffer.append("", streamConfig["fullTextCursor"]) + if (streamConfig["isChecking"] and markEnd != -1) or finishReason != None: + streamConfig["isChecking"] = False + + if ( + not streamConfig["in_think"] + and streamConfig["cacheStr"].find("") != -1 + ): + streamConfig["in_think"] = True + + delta = response.choices[0].delta + if delta is None: + delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) + response.choices[0].delta = delta + delta.reasoning_content = ( + delta.reasoning_content + if delta.reasoning_content != None + else "" + streamConfig["cacheStr"].replace("", "") + ) + + elif ( + streamConfig["in_think"] + and streamConfig["cacheStr"].find("") != -1 + ): + streamConfig["in_think"] = False + + delta = response.choices[0].delta + if delta is None: + delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) + response.choices[0].delta = delta + delta.content = ( + delta.content + if delta.content != None + else "" + streamConfig["cacheStr"].replace("", "") + ) + else: + if streamConfig["in_think"]: + delta = response.choices[0].delta + if delta is None: + delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) + response.choices[0].delta = delta + delta.reasoning_content = ( + delta.reasoning_content + if delta.reasoning_content != None + else "" + streamConfig["cacheStr"] + ) + else: + delta = response.choices[0].delta + if delta is None: + delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) + response.choices[0].delta = delta + delta.content = ( + delta.content + if delta.content != None + else "" + streamConfig["cacheStr"] + ) + streamConfig["fullTextCursor"] = len(fullText) + + delta = response.choices[0].delta + if delta is None: + delta = ChatCompletionMessage(role="Assistant", content="", reasoning_content=None, tool_calls=None) + response.choices[0].delta = delta + if delta.content != None or delta.reasoning_content != None: + try: + if request.state_name: + STATE_STORE[(request.model, request.state_name)] = { + 'state': model_state, + 'model_tokens': model_tokens, + } + if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False): + try: + _save_state_store_to_disk(force=True) + except Exception: + pass + except Exception: + pass + if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS: + m = TOOL_CALL_RE.search(fullText) + if m: + try: + payload_raw = m.group(1) + import json + + payload = json.loads(payload_raw) + tool_name = payload.get('name') + tool_args = payload.get('args', {}) + tool_res = None + if tool_name == 'web_search': + from utils import web_search + + q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) + k = int(tool_args.get('top_k') or request.search_top_k or 3) + tool_res = web_search(q, k) + elif tool_name in ('calc', 'calculator'): + from utils import calc + + expr = tool_args.get('expression') + if expr: + tool_res = calc(expr) + else: + try: + tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) + except Exception: + tool_res = None + + if tool_res: + if not isinstance(tool_res, dict): + if tool_name in ('calc', 'calculator'): + tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}} + elif tool_name == 'web_search': + tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}} + else: + tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}} + else: + tool_res_struct = tool_res + exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True} + executed_tool_calls.append(exec_entry) + delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n" + prompt = delta_text + prompt + fullText = TOOL_CALL_RE.sub('', fullText) + buffer = [fullText] + out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state) + try: + meta_resp = ChatCompletionChunk( + id=completionId, + created=createTimestamp, + model=request.model, + usage=( + Usage( + prompt_tokens=promptTokenCount, + completion_tokens=completionTokenCount, + total_tokens=promptTokenCount + completionTokenCount, + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ) + if request.include_usage + else None + ), + choices=[ + ChatCompletionChoice( + index=0, + delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls), + logprobs=None, + finish_reason=None, + ) + ], + ) + yield f"data: {meta_resp.model_dump_json()}\n\n" + except Exception: + pass + model_initiated_tool_calls += 1 + should_restart = True + break + except Exception as e: + logger.info(f"Model-initiated tool handling error: {e}") + yield f"data: {response.model_dump_json()}\n\n" + for stop_words in request.stop or []: + if stop_words in ''.join(buffer): + finishReason = f"stop:words:{stop_words}" + return + + await asyncio.sleep(0) + + del streamConfig + else: + should_restart = True + while should_restart: + should_restart = False + gen = generate(request, out, model_tokens, model_state) + for chunk in gen: + completionTokenCount += 1 + buffer.append(chunk["content"]) + + if chunk["finish_reason"]: + finishReason = chunk["finish_reason"] + + try: + if request.state_name: + STATE_STORE[(request.model, request.state_name)] = { + 'state': model_state, + 'model_tokens': model_tokens, + } + if getattr(CONFIG, 'STATE_STORE_SAVE_ON_UPDATE', False): + try: + _save_state_store_to_disk(force=True) + except Exception: + pass + except Exception: + pass + + if model_initiated_tool_calls < MODEL_MAX_TOOL_CALLS: + fullText = ''.join(buffer) + m = TOOL_CALL_RE.search(fullText) + if m: + try: + payload_raw = m.group(1) + import json + + payload = json.loads(payload_raw) + tool_name = payload.get('name') + tool_args = payload.get('args', {}) + tool_res = None + if tool_name == 'web_search': + from utils import web_search + + q = tool_args.get('query') or (request.prompt if request.prompt else cleanMessages(request.messages or [])) + k = int(tool_args.get('top_k') or request.search_top_k or 3) + tool_res = web_search(q, k) + elif tool_name in ('calc', 'calculator'): + from utils import calc + + expr = tool_args.get('expression') + if expr: + tool_res = calc(expr) + else: + try: + tool_res = universal_tool({'query': tool_args.get('query') or payload.get('query') or ''}, allow_web_search=bool(web_search_enabled), allow_tools=bool(tools_enabled), allow_file_tool=bool(file_tool_enabled)) + except Exception: + tool_res = None + + if tool_res: + if not isinstance(tool_res, dict): + if tool_name in ('calc', 'calculator'): + tool_res_struct = {"action": "calc", "result": str(tool_res), "metadata": {"expression": tool_args.get('expression'), "confidence": 0.98}} + elif tool_name == 'web_search': + tool_res_struct = {"action": "web_search", "result": str(tool_res), "metadata": {"query": tool_args.get('query'), "top_k": tool_args.get('top_k') or request.search_top_k or 3, "confidence": 0.9}} + else: + tool_res_struct = {"action": tool_name, "result": str(tool_res), "metadata": {"confidence": 0.6}} + else: + tool_res_struct = tool_res + exec_entry = {"name": tool_name, "args": tool_args, "result": tool_res_struct, 'initiated_by_model': True} + executed_tool_calls.append(exec_entry) + delta_text = f"ToolResults:\n{tool_res_struct.get('result')}\n\n" + prompt = delta_text + prompt + fullText = TOOL_CALL_RE.sub('', fullText) + buffer = [fullText] + out, model_tokens, model_state = await runPrefill(request, delta_text, model_tokens, model_state) + try: + meta_resp = ChatCompletionChunk( + id=completionId, + created=createTimestamp, + model=request.model, + usage=( + Usage( + prompt_tokens=promptTokenCount, + completion_tokens=completionTokenCount, + total_tokens=promptTokenCount + completionTokenCount, + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ) + if request.include_usage + else None + ), + choices=[ + ChatCompletionChoice( + index=0, + delta=ChatCompletionMessage(role="Assistant", content=None, reasoning_content=None, tool_calls=executed_tool_calls), + logprobs=None, + finish_reason=None, + ) + ], + ) + yield f"data: {meta_resp.model_dump_json()}\n\n" + except Exception: + pass + model_initiated_tool_calls += 1 + should_restart = True + break + except Exception as e: + logger.info(f"Model-initiated tool handling error: {e}") + + response = ChatCompletionChunk( + id=completionId, + created=createTimestamp, + model=request.model, + usage=( + Usage( + prompt_tokens=promptTokenCount, + completion_tokens=completionTokenCount, + total_tokens=promptTokenCount + completionTokenCount, + prompt_tokens_details=PromptTokensDetails(cached_tokens=0), + ) + if request.include_usage + else None + ), + choices=[ + ChatCompletionChoice( + index=0, + delta=ChatCompletionMessage(role="Assistant", content=chunk["content"], reasoning_content=None, tool_calls=None), + logprobs=None, + finish_reason=finishReason, + ) + ], + ) + yield f"data: {response.model_dump_json()}\n\n" + await asyncio.sleep(0) + + genenrateTime = time.time() + + responseLog = { + "content": "".join(buffer), + "finish": finishReason, + "prefill_len": promptTokenCount, + "prefill_tps": round(promptTokenCount / (prefillTime - createTimestamp), 2), + "gen_len": completionTokenCount, + "gen_tps": round(completionTokenCount / (genenrateTime - prefillTime), 2), + } + logger.info(f"[RES] {completionId} - {responseLog}") + if request.messages is None: + request.messages = [] + request.messages.append(ChatMessage(role="Assistant", content=responseLog["content"])) + log( + { + **request.model_dump(), + **responseLog, + "completionId": completionId, + "machineLabel": os.environ.get("MACHINE_LABEL"), + } + ) + + del buffer + + yield "data: [DONE]\n\n" + +@app.post("/api/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest): + completionId = str(next(CompletionIdGenerator)) + logger.info(f"[REQ] {completionId} - {request.model_dump()}") + + if "rwkv-latest" in request.model: + if DEFALUT_MODEL_NAME == None: + raise HTTPException(404, "DEFALUT_MODEL_NAME not set") + ms_def = MODEL_STORAGE.get(DEFALUT_MODEL_NAME) + if not ms_def or not ms_def.MODEL_CONFIG: + raise HTTPException(500, "Default sampler config missing for default model") + defaultSamplerConfig = ms_def.MODEL_CONFIG.DEFAULT_SAMPLER + request.model = DEFALUT_MODEL_NAME + elif request.model in MODEL_STORAGE: + ms_sel = MODEL_STORAGE.get(request.model) + if not ms_sel or not ms_sel.MODEL_CONFIG: + raise HTTPException(500, f"Default sampler config missing for model {request.model}") + defaultSamplerConfig = ms_sel.MODEL_CONFIG.DEFAULT_SAMPLER + else: + raise HTTPException(404, f"Can not find `{request.model}`") + + enableReasoning = request.enable_reasoning if request.enable_reasoning is not None else False + + async def chatResponseStreamDisconnect(): + logGPUState() + + model_state = None + model_tokens_for_resume = [0] + state_name = request.state_name + if state_name is None: + state_name = str(uuid.uuid4()) + request.state_name = state_name + state_key = (request.model, state_name) + if state_key in STATE_STORE: + stored = STATE_STORE[state_key] + model_state = stored.get('state', None) + model_tokens_for_resume = stored.get('model_tokens', [0]) + request_dict = request.model_dump() + + sampler_overrides = request_dict.get('sampler') or {} + for k, v in defaultSamplerConfig.model_dump().items(): + if sampler_overrides and k in sampler_overrides and sampler_overrides.get(k) is not None: + request_dict[k] = sampler_overrides.get(k) + continue + if k in request_dict and request_dict[k] is None: + request_dict[k] = v + realRequest = ChatCompletionRequest(**request_dict) + if realRequest.stream is None: + realRequest.stream = CONFIG.DEFAULT_STREAM + + logger.info(f"[REQ] {completionId} - Real - {request.model_dump()}") + + if realRequest.stream: + r = StreamingResponse( + chatResponseStream(realRequest, model_state, completionId, enableReasoning), + media_type="text/event-stream", + background=BackgroundTask(chatResponseStreamDisconnect), + ) + else: + r = await chatResponse(realRequest, model_state, completionId, enableReasoning) + try: + import json + + if isinstance(r, ChatCompletion): + d = r.model_dump() + d['state_name'] = state_name + return d + except Exception: + pass + + return r + +logger.info("Static frontend mount removed for Python-only deploy; use API endpoints for integration") + +@app.get('/api/v1/models') +def list_models(): + out = [] + root_defaults = { + 'ALLOW_FILE_TOOL_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True), + 'ENABLE_WEB_SEARCH_BY_DEFAULT': getattr(CONFIG, 'ENABLE_WEB_SEARCH_BY_DEFAULT', True), + 'ENABLE_REASONING_BY_DEFAULT': getattr(CONFIG, 'ENABLE_REASONING_BY_DEFAULT', True), + 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_WEB_SEARCH_BUTTON_BY_DEFAULT', True), + 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT': getattr(CONFIG, 'SHOW_FILE_UPLOAD_BUTTON_BY_DEFAULT', True), + 'SHOW_REASONING_TOGGLE_BY_DEFAULT': getattr(CONFIG, 'SHOW_REASONING_TOGGLE_BY_DEFAULT', True), + 'UPLOAD_URL': '/api/v1/files', + } + for m in CONFIG.MODELS: + out.append( + { + 'SERVICE_NAME': m.SERVICE_NAME, + 'DEFAULT_CHAT': m.DEFAULT_CHAT, + 'DEFAULT_REASONING': m.DEFAULT_REASONING, + 'ALLOW_WEB_SEARCH': getattr(m, 'ALLOW_WEB_SEARCH', True), + 'ALLOW_TOOLS': getattr(m, 'ALLOW_TOOLS', True), + 'ALLOW_REASONING': getattr(m, 'ALLOW_REASONING', True), + 'ALLOW_FILE_TOOL': getattr(m, 'ALLOW_FILE_TOOL', True), + 'SHOW_WEB_SEARCH_BUTTON': getattr(m, 'SHOW_WEB_SEARCH_BUTTON', True), + 'SHOW_FILE_UPLOAD_BUTTON': getattr(m, 'SHOW_FILE_UPLOAD_BUTTON', True), + 'SHOW_REASONING_TOGGLE': getattr(m, 'SHOW_REASONING_TOGGLE', True), + 'DEFAULT_SAMPLER': m.DEFAULT_SAMPLER.model_dump() if hasattr(m, 'DEFAULT_SAMPLER') else None, + 'UPLOAD_URL': '/api/v1/files', + 'UPLOAD_ALLOWED_BY_DEFAULT': getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True), + } + ) + return {'root_defaults': root_defaults, 'models': out} + +@app.post('/api/v1/files', response_model=FileUploadResponse) +async def upload_file(file: UploadFile = File(...), model: Optional[str] = None): + try: + if not getattr(CONFIG, 'ALLOW_FILE_TOOL_BY_DEFAULT', True): + raise HTTPException(403, 'File uploads are disabled by server configuration') + if model: + if model not in MODEL_STORAGE: + raise HTTPException(404, f"Model {model} not found") + ms = MODEL_STORAGE[model] + if ms and ms.MODEL_CONFIG and not getattr(ms.MODEL_CONFIG, 'ALLOW_FILE_TOOL', True): + raise HTTPException(403, f"Model {model} does not allow file uploads") + from utils import save_bytes_to_upload + + content = await file.read() + fname = file.filename if getattr(file, 'filename', None) else 'uploaded_file' + meta = save_bytes_to_upload(fname, content) + if meta.get('error'): + raise HTTPException(500, f"Could not save file: {meta.get('error')}") + UPLOADED_FILES[meta['file_id']] = meta + return FileUploadResponse(success=True, file=UploadedFile(**meta)) + except Exception as e: + raise HTTPException(500, str(e)) + +@app.get('/api/v1/files') +def list_files(): + return [UploadedFile(**v).model_dump() for v in UPLOADED_FILES.values()] + +@app.get('/api/v1/files/{file_id}') +def get_file(file_id: str, download: bool = False): + if file_id not in UPLOADED_FILES: + raise HTTPException(404, 'File not found') + meta = UPLOADED_FILES[file_id] + if download: + try: + with open(meta['path'], 'rb') as f: + return StreamingResponse(f, media_type='application/octet-stream') + except Exception as e: + raise HTTPException(500, str(e)) + return UploadedFile(**meta) + +@app.delete('/api/v1/files/{file_id}') +def delete_file(file_id: str): + if file_id not in UPLOADED_FILES: + raise HTTPException(404, 'File not found') + meta = UPLOADED_FILES.pop(file_id) + try: + if os.path.exists(meta['path']): + os.remove(meta['path']) + except Exception: + pass + return {'success': True} + +if __name__ == "__main__": + import uvicorn + + host = CONFIG.HOST or "127.0.0.1" + port = CONFIG.PORT or 7860 + uvicorn.run(app, host=host, port=port) \ No newline at end of file