diff --git "a/agent_fastapi.py" "b/agent_fastapi.py" new file mode 100644--- /dev/null +++ "b/agent_fastapi.py" @@ -0,0 +1,2826 @@ +# agent_fastapi.py +from __future__ import annotations + +import asyncio +import mimetypes +import os +import sys +import json +import re +import time +import uuid +import math +import logging +import shutil +from pathlib import Path +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Set +from contextlib import asynccontextmanager +from starlette.websockets import WebSocketState, WebSocketDisconnect +try: + import tomllib # Python 3.11+ # type: ignore +except ModuleNotFoundError: + import tomli as tomllib # Python <= 3.10 +import traceback + +try: + from uvicorn.protocols.utils import ClientDisconnected +except Exception: + ClientDisconnected = None + + +logger = logging.getLogger(__name__) + +import anyio +from fastapi import FastAPI, APIRouter, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect, Request +from fastapi.responses import FileResponse, JSONResponse, Response +from fastapi.staticfiles import StaticFiles + +from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage, AIMessage, ToolMessage + +# ---- 确保 src 可导入(避免环境差异导致找不到模块)---- +ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) +SRC_DIR = os.path.join(ROOT_DIR, "src") +if SRC_DIR not in sys.path: + sys.path.insert(0, SRC_DIR) + +from open_storyline.agent import build_agent, ClientContext +from open_storyline.utils.prompts import get_prompt +from open_storyline.utils.media_handler import scan_media_dir +from open_storyline.config import load_settings, default_config_path +from open_storyline.config import Settings +from open_storyline.storage.agent_memory import ArtifactStore +from open_storyline.mcp.hooks.node_interceptors import ToolInterceptor +from open_storyline.mcp.hooks.chat_middleware import set_mcp_log_sink, reset_mcp_log_sink + +WEB_DIR = os.path.join(ROOT_DIR, "web") +STATIC_DIR = os.path.join(WEB_DIR, "static") +INDEX_HTML = os.path.join(WEB_DIR, "index.html") +NODE_MAP_HTML = os.path.join(WEB_DIR, "node_map/node_map.html") +NODE_MAP_DIR = os.path.join(WEB_DIR, "node_map") + +SERVER_CACHE_DIR = os.path.join(ROOT_DIR, '.storyline' , ".server_cache") + +CHUNK_SIZE = 1024 * 1024 # 1MB + +# 是否根据session_id隔离用户 +USE_SESSION_SUBDIR = True + +CUSTOM_MODEL_KEY = "__custom__" + +# Load keys +DEFAULT_LLM_API_KEY = os.getenv("DEEPSEEK_API_KEY") +DEFAULT_LLM_API_URL = os.getenv("DEEPSEEK_API_URL") +DEFAULT_LLM_API_NAME = os.getenv("DEEPSEEK_API_NAME", "deepseek-chat") +DEFAULT_VLM_API_KEY = os.getenv("GLM_V4_6_API_KEY") +DEFAULT_VLM_API_URL = os.getenv("GLM_V4_6_API_URL") +DEFAULT_VLM_API_NAME = os.getenv("GLM_V4_6_API_NAME", "qwen3-vl-8b-instruct") +print("DEEPSEEK_API_KEY exists:", bool(os.getenv("DEEPSEEK_API_KEY"))) +print("QWEN3_VL_8B_API_KEY exists:", bool(os.getenv("QWEN3_VL_8B_API_KEY"))) +print("DEEPSEEK_API_URL:", repr(os.getenv("DEEPSEEK_API_URL"))) +print("QWEN3_VL_8B_API_URL:", repr(os.getenv("QWEN3_VL_8B_API_URL"))) + +def debug_traceback_print(cfg: Settings): + if cfg.developer.developer_mode: + traceback.print_exc() + +def _s(x: Any) -> str: + return str(x or "").strip() + +def _norm_url(u: Any) -> str: + u = _s(u) + return u.rstrip("/") if u else "" + +def _env_fallback_for_model(model_name: str) -> Tuple[str, str]: + """ + - deepseek* -> DEEPSEEK_API_URL / DEEPSEEK_API_KEY + - qwen3* -> QWEN3_VL_8B_API_URL / QWEN3_VL_8B_API_KEY + """ + m = _s(model_name).lower() + if "deepseek" in m: + return (_s(os.getenv("DEEPSEEK_API_URL")), _s(os.getenv("DEEPSEEK_API_KEY"))) + if m.startswith("qwen3-vl-8b-instruct") or "qwen3-vl-8b-instruct" in m: + return (_s(os.getenv("QWEN3_VL_8B_API_URL")), _s(os.getenv("QWEN3_VL_8B_API_KEY"))) + return ("", "") + +def _resolve_default_model_override(cfg: Settings, model_name: str) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: + """ + 1. get config from [developer.chat_models_config.""] + 2. rollback to env + """ + model_name = _s(model_name) + if not model_name: + return None, "default model name is empty" + + model_cfg: Dict[str, Any] = {} + try: + model_cfg = (cfg.developer.chat_models_config.get(model_name) or {}) if getattr(cfg, "developer", None) else {} + except Exception: + model_cfg = {} + + if not isinstance(model_cfg, dict): + model_cfg = {} + + base_url = _norm_url(model_cfg.get("base_url")) + api_key = _s(model_cfg.get("api_key")) + + if not base_url or not api_key: + env_url, env_key = _env_fallback_for_model(model_name) + if not base_url: + base_url = _norm_url(env_url) + if not api_key: + api_key = _s(env_key) + + override: Dict[str, Any] = {"model": model_name} + if base_url: + override["base_url"] = base_url + if api_key: + override["api_key"] = api_key + + for k in ("timeout", "temperature", "max_retries", "top_p", "max_tokens"): + if k in model_cfg and model_cfg.get(k) not in (None, ""): + override[k] = model_cfg.get(k) + + if not override.get("base_url") or not override.get("api_key"): + return None, ( + f"cannot find base_url/api_key of default model: {model_name}. " + f"please fill in base_url/api_key of [developer.chat_models_config.\"{model_name}\" in config.toml]" + f"or set environment variables(DEEPSEEK_API_URL/DEEPSEEK_API_KEY / QWEN3_VL_8B_API_URL/QWEN3_VL_8B_API_KEY)。" + ) + + return override, None + +def _stable_dict_key(d: Optional[Dict[str, Any]]) -> str: + try: + return json.dumps(d or {}, sort_keys=True, ensure_ascii=False) + except Exception: + return str(d or {}) + +def _parse_service_config(service_cfg: Any) -> Tuple[ + Optional[Dict[str, Any]], + Optional[Dict[str, Any]], + Dict[str, Any], + Dict[str, Any], + Optional[str]]: + """ + 返回 (custom_llm, custom_vlm, tts_cfg, pexels, err) + - custom_llm/custom_vlm: {"model","base_url","api_key"} 或 None(允许只传 llm 或只传 vlm) + - tts_cfg: dict(可能为空) + """ + if not isinstance(service_cfg, dict): + return None, None, {}, {}, None + + # ---- custom models ---- + custom_llm = None + custom_vlm = None + custom_models = service_cfg.get("custom_models") + + if custom_models is not None: + if not isinstance(custom_models, dict): + return None, None, {}, {}, "service_config.custom_models 必须是对象" + + def _pick(m: Any, label: str) -> Tuple[Optional[Dict[str, str]], Optional[str]]: + if m is None: + return None, None + if not isinstance(m, dict): + return None, f"service_config.custom_models.{label} 必须是对象" + + model = _s(m.get("model")) + base_url = _norm_url(m.get("base_url")) + api_key = _s(m.get("api_key")) + + if not (model and base_url and api_key): + return None, f"自定义 {label.upper()} 配置不完整:请填写 model/base_url/api_key" + if not (base_url.startswith("http://") or base_url.startswith("https://")): + return None, f"自定义 {label.upper()} 的 base_url 必须以 http(s) 开头" + return {"model": model, "base_url": base_url, "api_key": api_key}, None + + custom_llm, err1 = _pick(custom_models.get("llm"), "llm") + if err1: + return None, None, {}, {}, err1 + + custom_vlm, err2 = _pick(custom_models.get("vlm"), "vlm") + if err2: + return None, None, {}, {}, err2 + + # ---- tts ---- + tts_cfg: Dict[str, Any] = {} + tts = service_cfg.get("tts") + if isinstance(tts, dict): + provider = (tts.get("provider") or "").strip().lower() + if provider: + provider_block = tts.get(provider) + tts_cfg = {"provider": provider, provider: provider_block} + + # ---- pexels ---- + pexels_cfg: Dict[str, Any] = {} + search_media = service_cfg.get("search_media") + if isinstance(search_media, dict): + # 支持两种格式: + # 1) {search_media:{pexels:{mode, api_key}}} + # 2) {search_media:{mode, pexel_api_key}} + p = search_media.get("pexels") or search_media.get("pexels") + if isinstance(p, dict): + mode = _s(p.get("mode")).lower() + if mode not in ("default", "custom"): + mode = "default" + api_key = _s(p.get("api_key") or p.get("pexels_api_key") or p.get("pexels_api_key")) + pexels_cfg = {"mode": mode, "api_key": api_key} + else: + mode = _s(search_media.get("mode") or search_media.get("pexels_mode") or search_media.get("pexels_mode")).lower() + if mode not in ("default", "custom"): + mode = "default" + api_key = _s(search_media.get("pexels_api_key") or search_media.get("pexels_api_key")) + pexels_cfg = {"mode": mode, "api_key": api_key} + + return custom_llm, custom_vlm, tts_cfg, pexels_cfg, None + +def is_developer_mode(cfg: Settings) -> bool: + try: + return bool(cfg.developer.developer_mode) + except Exception: + return False + +def _abs(p: str) -> str: + return os.path.abspath(os.path.expanduser(p)) + + +def resolve_media_dir(cfg_media_dir: str, session_id: str) -> str: + root = _abs(cfg_media_dir).rstrip("/\\") + if not USE_SESSION_SUBDIR: + return root + project_dir = os.path.dirname(root) + leaf = os.path.basename(root) + return os.path.join(project_dir, session_id, leaf) + + +def sanitize_filename(name: str) -> str: + name = os.path.basename(name or "") + name = name.replace("\x00", "") + return name or "unnamed" + + +def detect_media_kind(filename: str) -> str: + ext = os.path.splitext(filename)[1].lower() + if ext in {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp"}: + return "image" + if ext in {".mp4", ".mov", ".avi", ".mkv", ".webm"}: + return "video" + return "unknown" + +_MEDIA_RE = re.compile(r"^media_(\d+)", re.IGNORECASE) + +def make_media_store_filename(seq: int, ext: str) -> str: + ext = (ext or "").lower() + if ext and not ext.startswith("."): + ext = "." + ext + return f"{MEDIA_PREFIX}{seq:0{MEDIA_SEQ_WIDTH}d}{ext}" + +def parse_media_seq(filename: str) -> Optional[int]: + m = _MEDIA_RE.match(os.path.basename(filename or "")) + if not m: + return None + try: + return int(m.group(1)) + except Exception: + return None + +def safe_save_path_no_overwrite(media_dir: str, filename: str) -> str: + filename = sanitize_filename(filename) + stem, ext = os.path.splitext(filename) + path = os.path.join(media_dir, filename) + if not os.path.exists(path): + return path + i = 2 + while True: + p2 = os.path.join(media_dir, f"{stem} ({i}){ext}") + if not os.path.exists(p2): + return p2 + i += 1 + + +def ensure_thumbs_dir(media_dir: str) -> str: + d = os.path.join(media_dir, ".thumbs") + os.makedirs(d, exist_ok=True) + return d + +def ensure_uploads_dir(media_dir: str) -> str: + d = os.path.join(media_dir, ".uploads") + os.makedirs(d, exist_ok=True) + return d + +def guess_media_type(path: str) -> str: + mt, _ = mimetypes.guess_type(path) + return mt or "application/octet-stream" + + +def _is_under_dir(path: str, root: str) -> bool: + try: + path = os.path.abspath(path) + root = os.path.abspath(root) + return os.path.commonpath([path, root]) == root + except Exception: + return False + + +def video_placeholder_svg_bytes() -> bytes: + svg = """ + + + + + + + + + +""" + return svg.encode("utf-8") + + +def make_image_thumbnail_sync(src_path: str, dst_path: str, max_size: Tuple[int, int] = (320, 320)) -> bool: + try: + from PIL import Image + img = Image.open(src_path).convert("RGB") + img.thumbnail(max_size) + img.save(dst_path, format="JPEG", quality=85) + return True + except Exception: + return False + +async def make_video_thumbnail_async( + src_video: str, + dst_path: str, + *, + max_size: Tuple[int, int] = (320, 320), + seek_sec: float = 0.5, + timeout_sec: float = 20.0, +) -> bool: + ffmpeg = os.environ.get("FFMPEG_BIN") or shutil.which("ffmpeg") + if not ffmpeg: + logger.warning("ffmpeg not found (PATH/FFMPEG_BIN). skip video thumbnail. src=%s", src_video) + return False + + src_video = os.path.abspath(src_video) + dst_path = os.path.abspath(dst_path) + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + + tmp_path = dst_path + ".tmp.jpg" + + vf = ( + f"scale={max_size[0]}:{max_size[1]}:force_original_aspect_ratio=decrease" + f",pad={max_size[0]}:{max_size[1]}:(ow-iw)/2:(oh-ih)/2" + ) + + async def _run(args: list[str]) -> tuple[bool, str]: + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.DEVNULL, + stderr=asyncio.subprocess.PIPE, + ) + try: + _, err = await asyncio.wait_for(proc.communicate(), timeout=timeout_sec) + except asyncio.TimeoutError: + try: + proc.kill() + except Exception: + pass + await proc.wait() + return False, f"timeout after {timeout_sec}s" + err_text = (err or b"").decode("utf-8", "ignore").strip() + return (proc.returncode == 0), err_text + + # 两种策略:1) -ss 在 -i 前(快,但有些文件/关键帧会失败) + # 2) -ss 在 -i 后(慢,但更稳定) + common_tail = [ + "-an", + "-frames:v", "1", + "-vf", vf, + "-vcodec", "mjpeg", + "-q:v", "3", + "-f", "image2", + tmp_path, + ] + + attempts = [ + # fast seek + [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-ss", f"{seek_sec}", "-i", src_video] + common_tail, + # accurate seek + [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-i", src_video, "-ss", f"{seek_sec}"] + common_tail, + # fallback:如果 seek 太靠前导致失败,再试试 1s + [ffmpeg, "-hide_banner", "-loglevel", "error", "-y", "-ss", "1.0", "-i", src_video] + common_tail, + ] + + last_err: Optional[str] = None + try: + for args in attempts: + ok, err = await _run(args) + if ok and os.path.exists(tmp_path) and os.path.getsize(tmp_path) > 0: + os.replace(tmp_path, dst_path) + return True + last_err = err or last_err + # 清理无效临时文���,避免下次误判 + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + + logger.warning("ffmpeg thumbnail failed. src=%s dst=%s err=%s", src_video, dst_path, last_err) + return False + finally: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + +def _env_int(name: str, default: int) -> int: + try: + return int(os.environ.get(name, str(default))) + except Exception: + return default + +def _env_float(name: str, default: float) -> float: + try: + return float(os.environ.get(name, str(default))) + except Exception: + return float(default) + +def _rpm_to_rps(rpm: float) -> float: + return float(rpm) / 60.0 + + +# 是否信任反向代理头(X-Forwarded-For / X-Real-IP) +RATE_LIMIT_TRUST_PROXY_HEADERS = os.environ.get("RATE_LIMIT_TRUST_PROXY_HEADERS", "0") == "1" + +@dataclass +class _RateBucket: + tokens: float + last_ts: float # monotonic + last_seen: float # monotonic (for TTL cleanup) + +class TokenBucketRateLimiter: + """ + 内存令牌桶 + 防爆内存: + - max_buckets: 限制内部桶表最大条目数(防止海量 IP 导致字典膨胀) + - evict_batch: 超过上限后每次驱逐多少条(按插入顺序驱逐最早创建的桶) + """ + def __init__( + self, + ttl_sec: int = 900, + cleanup_interval_sec: int = 60, + *, + max_buckets: int = 100000, + evict_batch: int = 2000, + ): + self.ttl_sec = int(ttl_sec) + self.cleanup_interval_sec = int(cleanup_interval_sec) + self.max_buckets = int(max(1, max_buckets)) + self.evict_batch = int(max(1, evict_batch)) + + self._buckets: Dict[str, _RateBucket] = {} + self._lock = asyncio.Lock() + self._last_cleanup = time.monotonic() + + async def allow( + self, + key: str, + *, + capacity: float, + refill_rate: float, + cost: float = 1.0, + ) -> Tuple[bool, float, float]: + """ + 返回: (allowed, retry_after_sec, remaining_tokens) + """ + now = time.monotonic() + capacity = float(max(0.0, capacity)) + refill_rate = float(max(0.0, refill_rate)) + cost = float(max(0.0, cost)) + + async with self._lock: + b = self._buckets.get(key) + + if b is None: + # 先做一次周期清理 + if now - self._last_cleanup > self.cleanup_interval_sec: + self._cleanup_locked(now) + self._last_cleanup = now + + # 桶表满了:先清 TTL,再做批量驱逐;仍然满 -> 不再创建新桶,直接拒绝 + if len(self._buckets) >= self.max_buckets: + self._cleanup_locked(now) + + if len(self._buckets) >= self.max_buckets: + self._evict_locked() + + if len(self._buckets) >= self.max_buckets: + # 不存任何新 key,避免内存继续涨 + # retry_after 给一个很短的值即可(客户端会重试) + return False, 1.0, 0.0 + + b = _RateBucket(tokens=capacity, last_ts=now, last_seen=now) + self._buckets[key] = b + else: + b.last_seen = now + + # refill + elapsed = max(0.0, now - b.last_ts) + if refill_rate > 0: + b.tokens = min(capacity, b.tokens + elapsed * refill_rate) + else: + b.tokens = min(capacity, b.tokens) + b.last_ts = now + + if b.tokens >= cost: + b.tokens -= cost + return True, 0.0, float(max(0.0, b.tokens)) + + # not enough + if refill_rate <= 0: + retry_after = float(self.ttl_sec) + else: + need = cost - b.tokens + retry_after = need / refill_rate + return False, float(retry_after), float(max(0.0, b.tokens)) + + def _cleanup_locked(self, now: float) -> None: + ttl = float(self.ttl_sec) + dead = [k for k, b in self._buckets.items() if (now - b.last_seen) > ttl] + for k in dead: + self._buckets.pop(k, None) + + def _evict_locked(self) -> None: + # 按 dict 插入顺序驱逐最早的一批 bucket(不排序,避免在高压下额外 CPU 开销) + n = min(self.evict_batch, len(self._buckets)) + for _ in range(n): + try: + k = next(iter(self._buckets)) + except StopIteration: + break + self._buckets.pop(k, None) + +def _headers_to_dict(scope_headers: List[Tuple[bytes, bytes]]) -> Dict[str, str]: + d: Dict[str, str] = {} + for k, v in scope_headers or []: + try: + dk = k.decode("latin1").lower() + dv = v.decode("latin1") + except Exception: + continue + d[dk] = dv + return d + +def _client_ip_from_http_scope(scope: dict, trust_proxy_headers: bool) -> str: + headers = _headers_to_dict(scope.get("headers") or []) + if trust_proxy_headers: + xff = headers.get("x-forwarded-for") + if xff: + # "client, proxy1, proxy2" -> client + return xff.split(",")[0].strip() or "unknown" + xri = headers.get("x-real-ip") + if xri: + return xri.strip() or "unknown" + + client = scope.get("client") + if client and isinstance(client, (list, tuple)) and len(client) >= 1: + return str(client[0] or "unknown") + return "unknown" + +def _client_ip_from_ws(ws: WebSocket, trust_proxy_headers: bool) -> str: + try: + if trust_proxy_headers: + xff = ws.headers.get("x-forwarded-for") + if xff: + return xff.split(",")[0].strip() or "unknown" + xri = ws.headers.get("x-real-ip") + if xri: + return xri.strip() or "unknown" + except Exception: + pass + + try: + if ws.client: + return str(ws.client.host or "unknown") + except Exception: + pass + + return "unknown" + +# 分片上传(绕开网关对单次请求体/单文件的限制) +UPLOAD_RESUMABLE_CHUNK_BYTES = _env_int("UPLOAD_RESUMABLE_CHUNK_BYTES", 8 * 1024 * 1024) + +# 未完成的分片上传状态保留多久(超时自动清理临时文件) +RESUMABLE_UPLOAD_TTL_SEC = _env_int("RESUMABLE_UPLOAD_TTL_SEC", 3600) # 1 hour + +MEDIA_SEQ_WIDTH = 4 # media_0001 +MEDIA_PREFIX = "media_" + + +# -------- 注意:在服务器上,所有用户的ip可能是相同的---- + +# 每个 IP 的总体请求速率(包括 /static、/api、/ 等) +HTTP_GLOBAL_RPM = _env_int("RATE_LIMIT_HTTP_GLOBAL_RPM", 3000) +HTTP_GLOBAL_BURST = _env_int("RATE_LIMIT_HTTP_GLOBAL_BURST", 600) + +# 创建 session:防止刷 session 导致内存爆 +HTTP_CREATE_SESSION_RPM = _env_int("RATE_LIMIT_CREATE_SESSION_RPM", 3000) +HTTP_CREATE_SESSION_BURST = _env_int("RATE_LIMIT_CREATE_SESSION_BURST", 50) + +# 上传素材:最容易被滥用(大文件 + 频率) +HTTP_UPLOAD_MEDIA_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_RPM", 12000) +HTTP_UPLOAD_MEDIA_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_BURST", 300) + +# 上传“成本”换算:content-length 每多少字节算 1 个 token(越大越费 token) +UPLOAD_COST_BYTES = _env_int("RATE_LIMIT_UPLOAD_COST_BYTES", 10 * 1024 * 1024) # 默认 10MB = 1 token + +# 素材个数控制:会话内上线+上传上限 +MAX_UPLOAD_FILES_PER_REQUEST = _env_int("MAX_UPLOAD_FILES_PER_REQUEST", 30) # 单次请求最多文件数 +MAX_MEDIA_PER_SESSION = _env_int("MAX_MEDIA_PER_SESSION", 30) # 每个 session 总素材上限(pending + 已用) +MAX_PENDING_MEDIA_PER_SESSION = _env_int("MAX_PENDING_MEDIA_PER_SESSION", 30) # 每个 session pending 素材上限(UI 友好) + +HTTP_UPLOAD_MEDIA_COUNT_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_RPM", 50000) +HTTP_UPLOAD_MEDIA_COUNT_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_BURST", 1000) + +# 下载/缩略图:适中限制(防刷资源) +HTTP_MEDIA_GET_RPM = _env_int("RATE_LIMIT_MEDIA_GET_RPM", 2400) +HTTP_MEDIA_GET_BURST = _env_int("RATE_LIMIT_MEDIA_GET_BURST", 60) + +# 清空会话:避免频繁清空扰动 +HTTP_CLEAR_RPM = _env_int("RATE_LIMIT_CLEAR_SESSION_RPM", 3000) +HTTP_CLEAR_BURST = _env_int("RATE_LIMIT_CLEAR_SESSION_BURST", 50) + +# 其它 API 默认:比 global 更细一点(可选) +HTTP_API_RPM = _env_int("RATE_LIMIT_API_RPM", 2400) +HTTP_API_BURST = _env_int("RATE_LIMIT_API_BURST", 120) + +# WebSocket:连接创建频率 +WS_CONNECT_RPM = _env_int("RATE_LIMIT_WS_CONNECT_RPM", 600) +WS_CONNECT_BURST = _env_int("RATE_LIMIT_WS_CONNECT_BURST", 50) + +# WebSocket:chat.send(真正触发 LLM 成本) +WS_CHAT_SEND_RPM = _env_int("RATE_LIMIT_WS_CHAT_SEND_RPM", 300) +WS_CHAT_SEND_BURST = _env_int("RATE_LIMIT_WS_CHAT_SEND_BURST", 20) + +# ---- 全局(所有 IP 合并)限流:抵御多 IP 同时访问 ---- +HTTP_ALL_RPM = _env_int("RATE_LIMIT_HTTP_ALL_RPM", 1200) # 全站 HTTP 总量:1200/min ~= 20 rps +HTTP_ALL_BURST = _env_int("RATE_LIMIT_HTTP_ALL_BURST", 200) + +CREATE_SESSION_ALL_RPM = _env_int("RATE_LIMIT_CREATE_SESSION_ALL_RPM", 120) +CREATE_SESSION_ALL_BURST = _env_int("RATE_LIMIT_CREATE_SESSION_ALL_BURST", 20) + +UPLOAD_MEDIA_ALL_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_ALL_RPM", 6000) +UPLOAD_MEDIA_ALL_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_ALL_BURST", 2000) + +# “素材个数”限流:默认复用 upload_media 的 rpm/burst +UPLOAD_MEDIA_COUNT_ALL_RPM = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_ALL_RPM", UPLOAD_MEDIA_ALL_RPM) +UPLOAD_MEDIA_COUNT_ALL_BURST = _env_int("RATE_LIMIT_UPLOAD_MEDIA_COUNT_ALL_BURST", UPLOAD_MEDIA_ALL_BURST) + +MEDIA_GET_ALL_RPM = _env_int("RATE_LIMIT_MEDIA_GET_ALL_RPM", 600) +MEDIA_GET_ALL_BURST = _env_int("RATE_LIMIT_MEDIA_GET_ALL_BURST", 120) + +WS_CONNECT_ALL_RPM = _env_int("RATE_LIMIT_WS_CONNECT_ALL_RPM", 60000) +WS_CONNECT_ALL_BURST = _env_int("RATE_LIMIT_WS_CONNECT_ALL_BURST", 2000) + +WS_CHAT_SEND_ALL_RPM = _env_int("RATE_LIMIT_WS_CHAT_SEND_ALL_RPM", 500) +WS_CHAT_SEND_ALL_BURST = _env_int("RATE_LIMIT_WS_CHAT_SEND_ALL_BURST", 30) + +# ---- 全局并发上限:抵御“很多 IP 同时连/同时触发 LLM/同时上传” ---- +WS_MAX_CONNECTIONS = _env_int("RATE_LIMIT_WS_MAX_CONNECTIONS", 500) # 同时在线 WS 连接数上限 +CHAT_MAX_CONCURRENCY = _env_int("RATE_LIMIT_CHAT_MAX_CONCURRENCY", 80) # 同时跑的 LLM turn 上限 +UPLOAD_MAX_CONCURRENCY = _env_int("RATE_LIMIT_UPLOAD_MAX_CONCURRENCY", 100) # 同时处理上传(含缩略图)上限 + +WS_CONN_SEM = asyncio.Semaphore(WS_MAX_CONNECTIONS) +CHAT_TURN_SEM = asyncio.Semaphore(CHAT_MAX_CONCURRENCY) +UPLOAD_SEM = asyncio.Semaphore(UPLOAD_MAX_CONCURRENCY) + +def _global_http_rule_limit(rule_name: str) -> Optional[Tuple[int, int]]: + if rule_name == "create_session": + return CREATE_SESSION_ALL_BURST, CREATE_SESSION_ALL_RPM + if rule_name == "upload_media": + return UPLOAD_MEDIA_ALL_BURST, UPLOAD_MEDIA_ALL_RPM + if rule_name == "media_get": + return MEDIA_GET_ALL_BURST, MEDIA_GET_ALL_RPM + return None + + +def _get_content_length(scope: dict) -> Optional[int]: + try: + headers = _headers_to_dict(scope.get("headers") or []) + v = headers.get("content-length") + if v is None: + return None + n = int(v) + if n < 0: + return None + return n + except Exception: + return None + +def _match_http_rule(method: str, path: str) -> Tuple[str, int, int, float]: + """ + 返回 (rule_name, burst, rpm, cost) + cost 默认为 1;上传接口会按 content-length 动态计算 cost(在 middleware 内处理)。 + """ + method = (method or "").upper() + path = path or "" + + # 精确接口优先 + if method == "POST" and path == "/api/sessions": + return ("create_session", HTTP_CREATE_SESSION_BURST, HTTP_CREATE_SESSION_RPM, 1.0) + + # 上传素材(含分片接口) + if method == "POST" and path.startswith("/api/sessions/"): + if path.endswith("/media") or path.endswith("/media/init"): + return ("upload_media", HTTP_UPLOAD_MEDIA_BURST, HTTP_UPLOAD_MEDIA_RPM, 1.0) + if "/media/" in path and (path.endswith("/chunk") or path.endswith("/complete") or path.endswith("/cancel")): + return ("upload_media", HTTP_UPLOAD_MEDIA_BURST, HTTP_UPLOAD_MEDIA_RPM, 1.0) + + if method == "GET" and path.startswith("/api/sessions/") and (path.endswith("/thumb") or path.endswith("/file")): + return ("media_get", HTTP_MEDIA_GET_BURST, HTTP_MEDIA_GET_RPM, 1.0) + + if method == "POST" and path.startswith("/api/sessions/") and path.endswith("/clear"): + return ("clear_session", HTTP_CLEAR_BURST, HTTP_CLEAR_RPM, 1.0) + + # 其它 API + if path.startswith("/api/"): + return ("api_general", HTTP_API_BURST, HTTP_API_RPM, 1.0) + + # 非 /api 的其他请求:只走 global + return ("", 0, 0, 1.0) + +class HttpRateLimitMiddleware: + """ + ASGI middleware:对 HTTP 请求做限流(WebSocket 不在这里处理)。 + """ + def __init__(self, app: Any, limiter: TokenBucketRateLimiter, trust_proxy_headers: bool = False): + self.app = app + self.limiter = limiter + self.trust_proxy_headers = bool(trust_proxy_headers) + + async def __call__(self, scope: dict, receive: Any, send: Any): + if scope.get("type") != "http": + return await self.app(scope, receive, send) + + method = scope.get("method", "GET") + path = scope.get("path", "/") + ip = _client_ip_from_http_scope(scope, self.trust_proxy_headers) + + # 0) 全局总量桶(所有 IP 合并) + ok, retry_after, _ = await self.limiter.allow( + key="http:all", + capacity=float(HTTP_ALL_BURST), + refill_rate=_rpm_to_rps(float(HTTP_ALL_RPM)), + cost=1.0, + ) + if not ok: + return await self._reject(send, retry_after) + + # 1) 单 IP 全局桶(防单点) + ok, retry_after, _ = await self.limiter.allow( + key=f"http:global:{ip}", + capacity=float(HTTP_GLOBAL_BURST), + refill_rate=_rpm_to_rps(float(HTTP_GLOBAL_RPM)), + cost=1.0, + ) + if not ok: + return await self._reject(send, retry_after) + + # 2) 规则桶 + rule_name, burst, rpm, cost = _match_http_rule(method, path) + + # 上传接口:按 content-length 动态增加 cost(越大越费 token) + if rule_name == "upload_media": + cl = _get_content_length(scope) + if cl and cl > 0 and UPLOAD_COST_BYTES > 0: + cost = max(1.0, float(math.ceil(cl / float(UPLOAD_COST_BYTES)))) + + if rule_name: + # 2.1 规则的“全局桶”(跨 IP) + g = _global_http_rule_limit(rule_name) + if g: + g_burst, g_rpm = g + okg, rag, _ = await self.limiter.allow( + key=f"http:{rule_name}:all", + capacity=float(g_burst), + refill_rate=_rpm_to_rps(float(g_rpm)), + cost=float(cost), + ) + if not okg: + return await self._reject(send, rag) + + # 2.2 规则的“单 IP 桶” + ok2, retry_after2, _ = await self.limiter.allow( + key=f"http:{rule_name}:{ip}", + capacity=float(burst), + refill_rate=_rpm_to_rps(float(rpm)), + cost=float(cost), + ) + if not ok2: + return await self._reject(send, retry_after2) + + return await self.app(scope, receive, send) + + + async def _reject(self, send: Any, retry_after: float): + ra = int(math.ceil(float(retry_after or 0.0))) + body = json.dumps( + {"detail": "Too Many Requests", "retry_after": ra}, + ensure_ascii=False + ).encode("utf-8") + + headers = [ + (b"content-type", b"application/json; charset=utf-8"), + (b"retry-after", str(ra).encode("ascii")), + ] + + await send({"type": "http.response.start", "status": 429, "headers": headers}) + await send({"type": "http.response.body", "body": body, "more_body": False}) + +RATE_LIMITER = TokenBucketRateLimiter( + ttl_sec=_env_int("RATE_LIMIT_TTL_SEC", 900), # 默认 15 分钟:多 IP 攻击时更快释放桶表 + cleanup_interval_sec=_env_int("RATE_LIMIT_CLEANUP_INTERVAL_SEC", 60), + max_buckets=_env_int("RATE_LIMIT_MAX_BUCKETS", 100000), + evict_batch=_env_int("RATE_LIMIT_EVICT_BATCH", 2000), +) + + +@dataclass +class MediaMeta: + id: str + name: str + kind: str + path: str + thumb_path: Optional[str] + ts: float + +@dataclass +class ResumableUpload: + upload_id: str + filename: str # 素材原名(用于UI展示) + store_filename: str # 落盘名 media_0001.mp4 + size: int + chunk_size: int + total_chunks: int + tmp_path: str + kind: str + created_ts: float + last_ts: float + received: Set[int] = field(default_factory=set) + closed: bool = False + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + +class MediaStore: + """ + 专注文件系统层: + - 保存上传文件(async chunk) + - 生成缩略图(图片:线程;视频:异步子进程) + - 删除文件(只删 media_dir 下的文件) + """ + def __init__(self, media_dir: str): + self.media_dir = os.path.abspath(media_dir) + os.makedirs(self.media_dir, exist_ok=True) + self.thumbs_dir = ensure_thumbs_dir(self.media_dir) + + async def save_upload(self, uf: UploadFile, *, store_filename: str, display_name: str) -> MediaMeta: + media_id = uuid.uuid4().hex[:10] + + display_name = sanitize_filename(display_name or uf.filename or "unnamed") + store_filename = sanitize_filename(store_filename) + + kind = detect_media_kind(display_name) + + save_path = os.path.join(self.media_dir, store_filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + if os.path.exists(save_path): + raise HTTPException(status_code=409, detail=f"media filename exists: {store_filename}") + + # async chunk 写盘(不一次性读入内存) + async with await anyio.open_file(save_path, "wb") as out: + while True: + chunk = await uf.read(CHUNK_SIZE) + if not chunk: + break + await out.write(chunk) + + try: + await uf.close() + except Exception: + pass + + thumb_path: Optional[str] = None + if kind in ("image", "video"): + thumb_path = os.path.join(self.thumbs_dir, f"{media_id}.jpg") + + if kind == "image": + ok = await anyio.to_thread.run_sync(make_image_thumbnail_sync, save_path, thumb_path) + else: + ok = await make_video_thumbnail_async(save_path, thumb_path) + + if not ok: + # 图片缩略图失败 -> 用原图;视频失败 -> 置空(thumb endpoint 返回占位 SVG) + thumb_path = save_path if kind == "image" else None + + return MediaMeta( + id=media_id, + name=os.path.basename(display_name), + kind=kind, + path=os.path.abspath(save_path), + thumb_path=os.path.abspath(thumb_path) if thumb_path else None, + ts=time.time(), + ) + + async def save_from_path( + self, + src_path: str, + *, + store_filename: str, + display_name: str, + ) -> MediaMeta: + """ + 将分片上传产生的临时文件移动到 media_dir 下的最终文件。 + - display_name: UI 展示名(原始文件名) + - store_filename: 落盘名(media_0001.mp4),用于记录顺序 + """ + media_id = uuid.uuid4().hex[:10] + + display_name = sanitize_filename(display_name or "unnamed") + store_filename = sanitize_filename(store_filename or "unnamed") + + kind = detect_media_kind(display_name) + + src_path = os.path.abspath(src_path) + if not os.path.exists(src_path): + raise HTTPException(status_code=400, detail="upload temp file missing") + + save_path = os.path.abspath(os.path.join(self.media_dir, store_filename)) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + if os.path.exists(save_path): + raise HTTPException(status_code=409, detail=f"media already exists: {store_filename}") + + # move tmp -> final + os.replace(src_path, save_path) + + thumb_path: Optional[str] = None + if kind in ("image", "video"): + thumb_path = os.path.join(self.thumbs_dir, f"{media_id}.jpg") + + if kind == "image": + ok = await anyio.to_thread.run_sync(make_image_thumbnail_sync, save_path, thumb_path) + else: + ok = await make_video_thumbnail_async(save_path, thumb_path) + + if not ok: + thumb_path = save_path if kind == "image" else None + + return MediaMeta( + id=media_id, + name=os.path.basename(display_name), # ★ UI 显示原文件名 + kind=kind, + path=os.path.abspath(save_path), # ★ 磁盘文件名 media_0001.ext + thumb_path=os.path.abspath(thumb_path) if thumb_path else None, + ts=time.time(), + ) + + async def delete_files(self, meta: MediaMeta) -> None: + root = self.media_dir + for p in {meta.path, meta.thumb_path}: + if not p: + continue + ap = os.path.abspath(p) + if not _is_under_dir(ap, root): + continue + if os.path.isdir(ap): + continue + if os.path.exists(ap): + try: + os.remove(ap) + except Exception: + pass + + +class ChatSession: + """ + 一个 session 的全部状态: + - agent / lc_messages(LangChain上下文) + - history(给前端回放) + - load_media / pending_media(staging) + - tool trace 索引(支持 tool 事件“就地更新”) + """ + def __init__(self, session_id: str, cfg: Settings): + self.session_id = session_id + self.cfg = cfg + self.lang = "zh" + + default_llm = _s(getattr(getattr(cfg, "developer", None), "default_llm", "")) or "deepseek-chat" + default_vlm = _s(getattr(getattr(cfg, "developer", None), "default_vlm", "")) or "qwen3-vl-8b-instruct" + + self.chat_models = [default_llm, CUSTOM_MODEL_KEY] + self.chat_model_key = default_llm + + self.vlm_models = [default_vlm, CUSTOM_MODEL_KEY] + self.vlm_model_key = default_vlm + + self.developer_mode = is_developer_mode(cfg) + + self.media_dir = resolve_media_dir(cfg.project.media_dir, session_id) + self.media_store = MediaStore(self.media_dir) + # 分片上传临时目录 + in-flight 状态 + self.uploads_dir = ensure_uploads_dir(self.media_dir) + self.resumable_uploads: Dict[str, ResumableUpload] = {} + + # 直传(multipart 多文件)时的“预占位”,避免并发竞争导致超过上限 + self._direct_upload_reservations = 0 + + self.agent: Any = None + self.node_manager = None + self.client_context = None + + # 锁分离:避免“流式输出”阻塞上传/删除 pending + self.chat_lock = asyncio.Lock() + self.media_lock = asyncio.Lock() + + self.sent_media_total: int = 0 + self._attach_stats_msg_idx = 1 + + self.lc_messages: List[BaseMessage] = [ + SystemMessage(content=get_prompt("instruction.system", lang=self.lang)), + SystemMessage(content="【User media upload status】{}"), + ] + self.history: List[Dict[str, Any]] = [] + + self.load_media: Dict[str, MediaMeta] = {} + self.pending_media_ids: List[str] = [] + + self._tool_history_index: Dict[str, int] = {} # tool_call_id -> history index + + self.cancel_event = asyncio.Event() # 打断信号 + + # 服务相关配置 + self.custom_llm_config: Optional[Dict[str, Any]] = None + self.custom_vlm_config: Optional[Dict[str, Any]] = None + self.tts_config: Dict[str, Any] = {} + self._agent_build_key: Optional[Tuple[Any, ...]] = None + + self.pexels_key_mode: str = "default" # "default" | "custom" + self.pexels_custom_key: str = "" + + self._media_seq_inited = False + self._media_seq_next = 1 + + def _ensure_system_prompt(self) -> None: + sys = (get_prompt("instruction.system", lang=self.lang) or "").strip() + if not sys: + return + + for m in self.lc_messages: + if isinstance(m, SystemMessage) and (getattr(m, "content", "") or "").strip() == sys: + return + + self.lc_messages.insert(0, SystemMessage(content=sys)) + + def _init_media_seq_locked(self) -> None: + """ + 初始化 self._media_seq_next: + - 允许 clear chat 后继续编号,不覆盖旧文件 + """ + if self._media_seq_inited: + return + + max_seq = 0 + + # 1) 已落盘文件 + try: + for fn in os.listdir(self.media_dir): + s = parse_media_seq(fn) + if s is not None: + max_seq = max(max_seq, s) + except Exception: + pass + + # 2) 内存里已有 load_media(保险) + for meta in (self.load_media or {}).values(): + s = parse_media_seq(os.path.basename(meta.path or "")) + if s is not None: + max_seq = max(max_seq, s) + + # 3) in-flight resumable(保险) + for u in (self.resumable_uploads or {}).values(): + s = parse_media_seq(getattr(u, "store_filename", "") or "") + if s is not None: + max_seq = max(max_seq, s) + + self._media_seq_next = max_seq + 1 + self._media_seq_inited = True + + + def _reserve_store_filenames_locked(self, display_filenames: List[str]) -> List[str]: + """ + 按传入顺序生成一组 store 文件名(media_0001.ext ...) + 注意:这里的“顺序”就是你要固化的上传顺序。 + """ + self._init_media_seq_locked() + + out: List[str] = [] + seq = int(self._media_seq_next) + + for disp in display_filenames: + disp = sanitize_filename(disp or "unnamed") + ext = os.path.splitext(disp)[1].lower() + + # 不复用旧号;仅在极端情况下跳过已存在文件(防撞) + while True: + store = make_media_store_filename(seq, ext) + if not os.path.exists(os.path.join(self.media_dir, store)): + break + seq += 1 + + out.append(store) + seq += 1 + + self._media_seq_next = seq + return out + + + def apply_service_config(self, service_cfg: Any) -> Tuple[bool, Optional[str]]: + llm, vlm, tts, pexels, err = _parse_service_config(service_cfg) + if err: + return False, err + + if llm is not None: + self.custom_llm_config = llm + if vlm is not None: + self.custom_vlm_config = vlm + + # tts 允许为空;非空才覆盖 + if isinstance(tts, dict) and tts: + self.tts_config = tts + + # ---- pexels ---- + if isinstance(pexels, dict) and pexels: + mode = _s(pexels.get("mode")).lower() + if mode == "custom": + self.pexels_key_mode = "custom" + self.pexels_custom_key = _s(pexels.get("api_key")) + else: + self.pexels_key_mode = "default" + self.pexels_custom_key = "" + + return True, None + + async def ensure_agent(self) -> None: + # 1) resolve LLM override + if self.chat_model_key == CUSTOM_MODEL_KEY: + if not isinstance(self.custom_llm_config, dict): + raise RuntimeError("please fill in model/base_url/api_key of custom LLM") + llm_override = self.custom_llm_config + else: + llm_override, err = _resolve_default_model_override(self.cfg, self.chat_model_key) + if err: + raise RuntimeError(err) + + # 2) resolve VLM override + if self.vlm_model_key == CUSTOM_MODEL_KEY: + if not isinstance(self.custom_vlm_config, dict): + raise RuntimeError("please fill in model/base_url/api_key of custom VLM") + vlm_override = self.custom_vlm_config + else: + vlm_override, err = _resolve_default_model_override(self.cfg, self.vlm_model_key) + if err: + raise RuntimeError(err) + + agent_build_key: Tuple[Any, ...] = ( + "models", + _stable_dict_key(llm_override), + _stable_dict_key(vlm_override), + ) + + if self.agent is None or self._agent_build_key != agent_build_key: + artifact_store = ArtifactStore(self.cfg.project.outputs_dir, session_id=self.session_id) + self.agent, self.node_manager = await build_agent( + cfg=self.cfg, + session_id=self.session_id, + store=artifact_store, + tool_interceptors=[ + ToolInterceptor.inject_media_content_before, + ToolInterceptor.save_media_content_after, + ToolInterceptor.inject_tts_config, + ToolInterceptor.inject_pexels_api_key, + ], + llm_override=llm_override, + vlm_override=vlm_override, + ) + self._agent_build_key = agent_build_key + + if self.client_context is None: + self.client_context = ClientContext( + cfg=self.cfg, + session_id=self.session_id, + media_dir=self.media_dir, + bgm_dir=self.cfg.project.bgm_dir, + outputs_dir=self.cfg.project.outputs_dir, + node_manager=self.node_manager, + chat_model_key=self.chat_model_key, + vlm_model_key=self.vlm_model_key, + tts_config=(self.tts_config or None), + pexels_api_key=None, + lang=self.lang, + ) + else: + self.client_context.chat_model_key = self.chat_model_key + self.client_context.vlm_model_key = self.vlm_model_key + self.client_context.tts_config = (self.tts_config or None) + self.client_context.lang = self.lang + + # ---- resolve pexels_api_key for runtime context ---- + pexels_api_key = "" + if (self.pexels_key_mode or "").lower() == "custom": + pexels_api_key = _s(self.pexels_custom_key) + else: + pexels_api_key = _get_default_pexels_api_key(self.cfg) # from config.toml + + self.client_context.pexels_api_key = (pexels_api_key or None) + + # ---- DTO / public mapping ---- + def public_media(self, meta: MediaMeta) -> Dict[str, Any]: + return { + "id": meta.id, + "name": meta.name, + "kind": meta.kind, + "thumb_url": f"/api/sessions/{self.session_id}/media/{meta.id}/thumb", + "file_url": f"/api/sessions/{self.session_id}/media/{meta.id}/file", + } + + def public_pending_media(self) -> List[Dict[str, Any]]: + out: List[Dict[str, Any]] = [] + for aid in self.pending_media_ids: + meta = self.load_media.get(aid) + if meta: + out.append(self.public_media(meta)) + return out + + def snapshot(self) -> Dict[str, Any]: + return { + "session_id": self.session_id, + "developer_mode": self.developer_mode, + "pending_media": self.public_pending_media(), + "history": self.history, + "limits": { + "max_upload_files_per_request": MAX_UPLOAD_FILES_PER_REQUEST, + "max_media_per_session": MAX_MEDIA_PER_SESSION, + "max_pending_media_per_session": MAX_PENDING_MEDIA_PER_SESSION, + "upload_chunk_bytes": UPLOAD_RESUMABLE_CHUNK_BYTES, + }, + "stats": { + "media_count": len(self.load_media), + "pending_count": len(self.pending_media_ids), + "inflight_uploads": len(self.resumable_uploads), + }, + "chat_model_key": self.chat_model_key, + "chat_models": self.chat_models, + "llm_model_key": self.chat_model_key, + "llm_models": self.chat_models, + "vlm_model_key": self.vlm_model_key, + "vlm_models": self.vlm_models, + "lang": self.lang, + } + + # ---- media operations ---- + def _cleanup_stale_uploads_locked(self, now: Optional[float] = None) -> None: + now = float(now or time.time()) + ttl = float(RESUMABLE_UPLOAD_TTL_SEC) + dead = [uid for uid, u in self.resumable_uploads.items() if (now - u.last_ts) > ttl] + for uid in dead: + u = self.resumable_uploads.pop(uid, None) + if not u: + continue + try: + if u.tmp_path and os.path.exists(u.tmp_path): + os.remove(u.tmp_path) + except Exception: + pass + + def _check_media_caps_locked(self, add: int = 0) -> None: + add = int(max(0, add)) + total = len(self.load_media) + len(self.resumable_uploads) + int(self._direct_upload_reservations) + pending = len(self.pending_media_ids) + len(self.resumable_uploads) + int(self._direct_upload_reservations) + + if MAX_MEDIA_PER_SESSION > 0 and (total + add) > MAX_MEDIA_PER_SESSION: + raise HTTPException( + status_code=400, + detail=f"会话素材总数已达上限:{total}/{MAX_MEDIA_PER_SESSION}", + ) + + if MAX_PENDING_MEDIA_PER_SESSION > 0 and (pending + add) > MAX_PENDING_MEDIA_PER_SESSION: + raise HTTPException( + status_code=400, + detail=f"待发送素材数量已达上限:{pending}/{MAX_PENDING_MEDIA_PER_SESSION}", + ) + + async def add_uploads(self, files: List[UploadFile], store_filenames: List[str]) -> List[MediaMeta]: + if len(store_filenames) != len(files): + raise HTTPException(status_code=500, detail="store_filenames mismatch") + + metas: List[MediaMeta] = [] + for uf, store_fn in zip(files, store_filenames): + display_name = sanitize_filename(uf.filename or "unnamed") + metas.append(await self.media_store.save_upload( + uf, + store_filename=store_fn, + display_name=display_name, + )) + + async with self.media_lock: + for m in metas: + self.load_media[m.id] = m + self.pending_media_ids.append(m.id) + + self.pending_media_ids.sort( + key=lambda aid: os.path.basename(self.load_media[aid].path or "") + if aid in self.load_media else "" + ) + + return metas + + async def delete_pending_media(self, media_id: str) -> None: + async with self.media_lock: + if media_id not in self.pending_media_ids: + raise HTTPException(status_code=400, detail="media is not pending (refuse physical delete)") + self.pending_media_ids = [x for x in self.pending_media_ids if x != media_id] + meta = self.load_media.pop(media_id, None) + + if meta: + await self.media_store.delete_files(meta) + + async def take_pending_media_for_message(self, attachment_ids: Optional[List[str]]) -> List[MediaMeta]: + async with self.media_lock: + if attachment_ids: + pick = [aid for aid in attachment_ids if aid in self.pending_media_ids] + else: + pick = list(self.pending_media_ids) + + pick_set = set(pick) + self.pending_media_ids = [aid for aid in self.pending_media_ids if aid not in pick_set] + metas = [self.load_media[aid] for aid in pick if aid in self.load_media] + return metas + + # ---- tool trace handling ---- + def _ensure_tool_record(self, tcid: str, server: str, name: str, args: Any) -> Dict[str, Any]: + idx = self._tool_history_index.get(tcid) + if idx is None: + rec = { + "id": f"tool_{tcid}", + "role": "tool", + "tool_call_id": tcid, + "server": server, + "name": name, + "args": args, + "state": "running", + "progress": 0.0, + "message": "", + "summary": None, + "ts": time.time(), + } + self.history.append(rec) + self._tool_history_index[tcid] = len(self.history) - 1 + return rec + return self.history[idx] + + def apply_tool_event(self, raw: Dict[str, Any]) -> Optional[Dict[str, Any]]: + et = raw.get("type") + tcid = raw.get("tool_call_id") + if et not in ("tool_start", "tool_progress", "tool_end") or not tcid: + return None + + server = raw.get("server") or "" + name = raw.get("name") or "" + args = raw.get("args") or {} + + rec = self._ensure_tool_record(tcid, server, name, args) + + if et == "tool_start": + rec.update({ + "server": server, + "name": name, + "args": args, + "state": "running", + "progress": 0.0, + "message": "Starting...", + "summary": None, + }) + + elif et == "tool_progress": + progress = float(raw.get("progress", 0.0)) + total = raw.get("total") + if total and float(total) > 0: + p = progress / float(total) + else: + p = progress / 100.0 if progress > 1 else progress + p = max(0.0, min(1.0, p)) + rec.update({ + "state": "running", + "progress": p, + "message": raw.get("message") or "", + }) + + elif et == "tool_end": + is_error = bool(raw.get("is_error")) + + summary = raw.get("summary") + try: + json.dumps(summary, ensure_ascii=False) + except Exception: + summary = str(summary) if summary is not None else None + rec.update({ + "state": "error" if is_error else "complete", + "progress": 1.0, + "summary": summary, + "message": raw.get("message") or rec.get("message") or "", + }) + + return rec + + +class SessionStore: + def __init__(self, cfg: Settings): + self.cfg = cfg + self._lock = asyncio.Lock() + self._sessions: Dict[str, ChatSession] = {} + + async def create(self) -> ChatSession: + sid = uuid.uuid4().hex + sess = ChatSession(sid, self.cfg) + async with self._lock: + self._sessions[sid] = sess + return sess + + async def get(self, sid: str) -> Optional[ChatSession]: + async with self._lock: + return self._sessions.get(sid) + + async def get_or_404(self, sid: str) -> ChatSession: + sess = await self.get(sid) + if not sess: + raise HTTPException(status_code=404, detail="session not found") + return sess + + +@asynccontextmanager +async def lifespan(app: FastAPI): + cfg = load_settings(default_config_path()) + app.state.cfg = cfg + app.state.developer_mode = is_developer_mode(cfg) + app.state.sessions = SessionStore(cfg) + yield + + +app = FastAPI(title="OpenStoryline Web", version="1.0.0", lifespan=lifespan) + +app.add_middleware( + HttpRateLimitMiddleware, + limiter=RATE_LIMITER, + trust_proxy_headers=RATE_LIMIT_TRUST_PROXY_HEADERS, +) + +if os.path.isdir(STATIC_DIR): + app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + +if os.path.isdir(NODE_MAP_DIR): + app.mount("/node_map", StaticFiles(directory=NODE_MAP_DIR), name="node_map") + +api = APIRouter(prefix="/api") + +def _rate_limit_reject_json(retry_after: float) -> JSONResponse: + ra = int(math.ceil(float(retry_after or 0.0))) + return JSONResponse( + {"detail": "Too Many Requests", "retry_after": ra}, + status_code=429, + headers={"Retry-After": str(ra)}, + ) + +async def _enforce_upload_media_count_limit(request: Request, cost: float) -> Optional[JSONResponse]: + ip = _client_ip_from_http_scope(request.scope, RATE_LIMIT_TRUST_PROXY_HEADERS) + cost = float(max(0.0, cost)) + + ok, ra, _ = await RATE_LIMITER.allow( + key="http:upload_media_count:all", + capacity=float(UPLOAD_MEDIA_COUNT_ALL_BURST), + refill_rate=_rpm_to_rps(float(UPLOAD_MEDIA_COUNT_ALL_RPM)), + cost=cost, + ) + if not ok: + return _rate_limit_reject_json(ra) + + ok2, ra2, _ = await RATE_LIMITER.allow( + key=f"http:upload_media_count:{ip}", + capacity=float(HTTP_UPLOAD_MEDIA_COUNT_BURST), + refill_rate=_rpm_to_rps(float(HTTP_UPLOAD_MEDIA_COUNT_RPM)), + cost=cost, + ) + if not ok2: + return _rate_limit_reject_json(ra2) + + return None + +_TTS_UI_SECRET_KEYS = { + "api_key", + "access_token", + "authorization", + "token", + "password", + "secret", + "x-api-key", + "apikey", + "access_key", + "accesskey", +} + +def _is_secret_field_name(k: str) -> bool: + if str(k or "").strip().lower() in _TTS_UI_SECRET_KEYS: + return True + return False + +def _read_config_toml(path: str) -> dict: + if tomllib is None: + return {} + try: + p = Path(path) + with p.open("rb") as f: + return tomllib.load(f) or {} + except Exception: + return {} + +def _get_default_pexels_api_key(cfg: Settings) -> str: + # 1) try Settings.search_media.pexels_api_key + try: + search_media = getattr(cfg, "search_media", None) + pexels_api_key = _s(getattr(search_media, "pexels_api_key", None) if search_media else None) + if pexels_api_key: + return pexels_api_key + else: + return "" + except Exception: + return "" + +def _normalize_field_item(item) -> dict | None: + """ + item 支持: + - "uid" + - { key="uid", label="UID", required=true, secret=false, placeholder="..." } + """ + if isinstance(item, str): + key = item.strip() + if not key: + return None + return { + "key": key, + "secret": _is_secret_field_name(key), + } + return None + +def _build_provider_schema(provider: str, label: str | None, fields: list[dict]) -> dict: + seen = set() + out = [] + for f in fields: + k = str(f.get("key") or "").strip() + if not k or k in seen: + continue + seen.add(k) + out.append({ + "key": k, + "label": f.get("label") or k, + "placeholder": f.get("placeholder") or f.get("label") or k, + "required": bool(f.get("required", False)), + "secret": bool(f.get("secret", False)), + }) + return {"provider": provider, "label": label or provider, "fields": out} + +def _build_tts_ui_schema_from_config(config_path: str) -> dict: + """ + 返回: + { + "providers": [ + {"provider":"bytedance","label":"字节跳动","fields":[{"key":"uid",...}, ...]}, + ... + ] + } + """ + cfg = _read_config_toml(config_path) + tts = cfg.get("generate_voiceover", {}) + + providers_out: list[dict] = [] + + # 格式:[tts.providers.] + providers = tts.get("providers") + if isinstance(providers, dict): + for provider, provider_cfg in providers.items(): + fields: list[dict] = [] + label = str(provider_cfg.get("label") or provider_cfg.get("name") or provider) + for key in provider_cfg.keys(): + f = _normalize_field_item(str(key)) + if f: + fields.append(f) + + providers_out.append(_build_provider_schema(provider, label, fields)) + + return {"providers": providers_out} + +@app.get("/") +async def index(): + if not os.path.exists(INDEX_HTML): + return Response("index.html not found. Put it under ./web/index.html", media_type="text/plain", status_code=404) + return FileResponse(INDEX_HTML, media_type="text/html") + +@app.get("/node-map") +async def node_map(): + if not os.path.exists(NODE_MAP_HTML): + return Response( + "node_map.html not found. Put it under ./web/node_map/node_map.html", + media_type="text/plain", + status_code=404, + ) + return FileResponse(NODE_MAP_HTML, media_type="text/html") + +@api.get("/meta/tts") +async def get_tts_ui_schema(): + schema = _build_tts_ui_schema_from_config(default_config_path()) + return JSONResponse(schema) + +# ------------------------- +# Sessions (REST) +# ------------------------- +@api.post("/sessions") +async def create_session(): + store: SessionStore = app.state.sessions + sess = await store.create() + return JSONResponse(sess.snapshot()) + + +@api.get("/sessions/{session_id}") +async def get_session(session_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + return JSONResponse(sess.snapshot()) + + +@api.post("/sessions/{session_id}/clear") +async def clear_session_chat(session_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + async with sess.chat_lock: + sess.sent_media_total = 0 + sess._attach_stats_msg_idx = 1 + sess.lc_messages = [ + SystemMessage(content=get_prompt("instruction.system", lang=sess.lang)), + SystemMessage(content="【User media upload status】{}"), + ] + sess._attach_stats_msg_idx = 1 + + sess.history = [] + sess._tool_history_index = {} + return JSONResponse({"ok": True}) + +@api.post("/sessions/{session_id}/cancel") +async def cancel_session_turn(session_id: str): + """ + 打断当前正在进行的 LLM turn(流式回复/工具调用)。 + - 不清空 history / lc_messages + - 仅设置 cancel_event,由 WS 侧在流式循环中感知并安全收尾 + """ + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + sess.cancel_event.set() + return JSONResponse({"ok": True}) + +# ------------------------- +# media (REST, session-scoped) +# ------------------------- +@api.post("/sessions/{session_id}/media") +async def upload_media(session_id: str, request: Request, files: List[UploadFile] = File(...)): + if not isinstance(files, list) or not files: + raise HTTPException(status_code=400, detail="no files") + + if MAX_UPLOAD_FILES_PER_REQUEST > 0 and len(files) > MAX_UPLOAD_FILES_PER_REQUEST: + raise HTTPException(status_code=400, detail=f"单次上传最多 {MAX_UPLOAD_FILES_PER_REQUEST} 个文件") + + # 按素材个数限流(cost = 文件数) + rej = await _enforce_upload_media_count_limit(request, cost=float(len(files))) + if rej: + return rej + + if UPLOAD_SEM.locked(): + raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") + await UPLOAD_SEM.acquire() + + n = len(files) + try: + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + # session cap 检查 + 预占位(避免并发竞争) + async with sess.media_lock: + sess._cleanup_stale_uploads_locked() + sess._check_media_caps_locked(add=n) + sess._direct_upload_reservations += n + + display_names = [sanitize_filename(uf.filename or "unnamed") for uf in files] + store_filenames = sess._reserve_store_filenames_locked(display_names) + + try: + metas = await sess.add_uploads(files, store_filenames=store_filenames) + + finally: + async with sess.media_lock: + sess._direct_upload_reservations = max(0, sess._direct_upload_reservations - n) + + return JSONResponse({ + "media": [sess.public_media(m) for m in metas], + "pending_media": sess.public_pending_media(), + }) + finally: + try: + UPLOAD_SEM.release() + except Exception: + pass + +@api.post("/sessions/{session_id}/media/init") +async def init_resumable_media_upload(session_id: str, request: Request): + try: + data = await request.json() + if not isinstance(data, dict): + data = {} + except Exception: + data = {} + + filename = sanitize_filename((data.get("filename") or data.get("name") or "unnamed")) + size = int(data.get("size") or 0) + if size <= 0: + raise HTTPException(status_code=400, detail="invalid size") + + # 按素材个数限流:init 视为“新增 1 个素材” + rej = await _enforce_upload_media_count_limit(request, cost=1.0) + if rej: + return rej + + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + async with sess.media_lock: + sess._cleanup_stale_uploads_locked() + sess._check_media_caps_locked(add=1) + + store_filename = sess._reserve_store_filenames_locked([filename])[0] + + upload_id = uuid.uuid4().hex + chunk_size = int(max(1, UPLOAD_RESUMABLE_CHUNK_BYTES)) + total_chunks = int(math.ceil(size / float(chunk_size))) + + tmp_path = os.path.join(sess.uploads_dir, f"{upload_id}.part") + os.makedirs(os.path.dirname(tmp_path), exist_ok=True) + try: + with open(tmp_path, "wb"): + pass + except Exception as e: + raise HTTPException(status_code=500, detail=f"cannot create temp file: {e}") + + u = ResumableUpload( + upload_id=upload_id, + filename=filename, + store_filename=store_filename, + size=size, + chunk_size=chunk_size, + total_chunks=total_chunks, + tmp_path=os.path.abspath(tmp_path), + kind=detect_media_kind(filename), + created_ts=time.time(), + last_ts=time.time(), + ) + sess.resumable_uploads[upload_id] = u + + return JSONResponse({ + "upload_id": upload_id, + "chunk_size": chunk_size, + "total_chunks": total_chunks, + "filename": filename, + }) + + +@api.post("/sessions/{session_id}/media/{upload_id}/chunk") +async def upload_resumable_media_chunk( + session_id: str, + upload_id: str, + index: int = Form(...), + chunk: UploadFile = File(...), +): + if UPLOAD_SEM.locked(): + raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") + await UPLOAD_SEM.acquire() + try: + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + async with sess.media_lock: + sess._cleanup_stale_uploads_locked() + u = sess.resumable_uploads.get(upload_id) + + if not u: + raise HTTPException(status_code=404, detail="upload_id not found or expired") + + idx = int(index) + if idx < 0 or idx >= u.total_chunks: + raise HTTPException(status_code=400, detail="invalid chunk index") + + # 期望长度(最后一片可能小于 chunk_size) + expected_len = u.size - idx * u.chunk_size + if expected_len <= 0: + raise HTTPException(status_code=400, detail="invalid chunk index") + expected_len = min(u.chunk_size, expected_len) + + written = 0 + async with u.lock: + if u.closed: + raise HTTPException(status_code=400, detail="upload already closed") + + async with await anyio.open_file(u.tmp_path, "r+b") as out: + await out.seek(idx * u.chunk_size) + while True: + buf = await chunk.read(CHUNK_SIZE) + if not buf: + break + written += len(buf) + if written > expected_len: + raise HTTPException(status_code=400, detail="chunk too large") + await out.write(buf) + + try: + await chunk.close() + except Exception: + pass + + if written != expected_len: + raise HTTPException(status_code=400, detail=f"chunk size mismatch: {written} != {expected_len}") + + u.received.add(idx) + u.last_ts = time.time() + + return JSONResponse({ + "ok": True, + "received_chunks": len(u.received), + "total_chunks": u.total_chunks, + }) + finally: + try: + UPLOAD_SEM.release() + except Exception: + pass + + +@api.post("/sessions/{session_id}/media/{upload_id}/complete") +async def complete_resumable_media_upload(session_id: str, upload_id: str): + if UPLOAD_SEM.locked(): + raise HTTPException(status_code=429, detail="上传并发过高,请稍后重试") + await UPLOAD_SEM.acquire() + try: + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + async with sess.media_lock: + sess._cleanup_stale_uploads_locked() + u = sess.resumable_uploads.get(upload_id) + + if not u: + raise HTTPException(status_code=404, detail="upload_id not found or expired") + + # 锁住此 upload,防止 chunk 并发写 + async with u.lock: + u.closed = True + if len(u.received) != u.total_chunks: + missing = u.total_chunks - len(u.received) + raise HTTPException(status_code=400, detail=f"chunks missing: {missing}") + + # 从索引移除(释放会话额度) + async with sess.media_lock: + u2 = sess.resumable_uploads.pop(upload_id, None) + + if not u2: + raise HTTPException(status_code=404, detail="upload_id not found") + + meta = await sess.media_store.save_from_path( + u2.tmp_path, + store_filename=u2.store_filename, + display_name=u2.filename, + ) + + async with sess.media_lock: + sess.load_media[meta.id] = meta + sess.pending_media_ids.append(meta.id) + + return JSONResponse({ + "media": sess.public_media(meta), + "pending_media": sess.public_pending_media(), + }) + finally: + try: + UPLOAD_SEM.release() + except Exception: + pass + + +@api.post("/sessions/{session_id}/media/{upload_id}/cancel") +async def cancel_resumable_media_upload(session_id: str, upload_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + async with sess.media_lock: + u = sess.resumable_uploads.pop(upload_id, None) + + if not u: + return JSONResponse({"ok": True}) + + async with u.lock: + u.closed = True + try: + if u.tmp_path and os.path.exists(u.tmp_path): + os.remove(u.tmp_path) + except Exception: + pass + + return JSONResponse({"ok": True}) + +@api.get("/sessions/{session_id}/media/pending") +async def get_pending_media(session_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + return JSONResponse({"pending_media": sess.public_pending_media()}) + + +@api.delete("/sessions/{session_id}/media/pending/{media_id}") +async def delete_pending_media(session_id: str, media_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + await sess.delete_pending_media(media_id) + return JSONResponse({"ok": True, "pending_media": sess.public_pending_media()}) + + +@api.get("/sessions/{session_id}/media/{media_id}/thumb") +async def get_media_thumb(session_id: str, media_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + meta = sess.load_media.get(media_id) + if not meta: + raise HTTPException(status_code=404, detail="media not found") + + # thumb 存在优先 + if meta.thumb_path and os.path.exists(meta.thumb_path): + return FileResponse(meta.thumb_path, media_type="image/jpeg") + + # video 无 thumb => placeholder + if meta.kind == "video": + return Response(content=video_placeholder_svg_bytes(), media_type="image/svg+xml") + + # image thumb 失败 => 用原图 + if meta.path and os.path.exists(meta.path): + return FileResponse(meta.path, media_type=guess_media_type(meta.path)) + + raise HTTPException(status_code=404, detail="thumb not available") + + +@api.get("/sessions/{session_id}/media/{media_id}/file") +async def get_media_file(session_id: str, media_id: str): + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + meta = sess.load_media.get(media_id) + if not meta: + raise HTTPException(status_code=404, detail="media not found") + if not meta.path or (not os.path.exists(meta.path)): + raise HTTPException(status_code=404, detail="file not found") + + # 安全:只允许 media_dir 下 + if not _is_under_dir(meta.path, sess.media_store.media_dir): + raise HTTPException(status_code=403, detail="forbidden") + + return FileResponse( + meta.path, + media_type=guess_media_type(meta.path), + filename=meta.name, + ) + +@api.get("/sessions/{session_id}/preview") +async def preview_local_file(session_id: str, path: str): + """ + 把 summary.preview_urls 里的“服务器本地路径”安全地转成可访问 URL。 + 只允许访问:media_dir / outputs_dir / outputs_dir / bgm_dir / .server_cache 这些根目录下的文件。 + """ + store: SessionStore = app.state.sessions + sess = await store.get_or_404(session_id) + + p = (path or "").strip() + if not p: + raise HTTPException(status_code=400, detail="empty path") + if "\x00" in p: + raise HTTPException(status_code=400, detail="bad path") + + # 兼容 file:// 前缀(如果未来有) + if p.startswith("file://"): + p = p[len("file://"):] + + # 相对路径:默认相对 ROOT_DIR + if os.path.isabs(p): + ap = os.path.abspath(p) + else: + ap = os.path.abspath(os.path.join(ROOT_DIR, p)) + + allowed_roots = [ + os.path.abspath(sess.media_dir), + os.path.abspath(app.state.cfg.project.outputs_dir), + os.path.abspath(app.state.cfg.project.outputs_dir), + os.path.abspath(app.state.cfg.project.bgm_dir), + os.path.abspath(SERVER_CACHE_DIR), + ] + + if not any(_is_under_dir(ap, r) for r in allowed_roots): + raise HTTPException(status_code=403, detail="forbidden") + + if (not os.path.exists(ap)) or os.path.isdir(ap): + raise HTTPException(status_code=404, detail="file not found") + + # 对 cache 文件强缓存 + headers = {"Cache-Control": "public, max-age=31536000, immutable"} if _is_under_dir(ap, SERVER_CACHE_DIR) else None + + return FileResponse( + ap, + media_type=guess_media_type(ap), + filename=os.path.basename(ap), + headers=headers, + ) + +app.include_router(api) + + +# ------------------------- +# WebSocket: session-scoped chat stream +# ------------------------- +def extract_text_delta(msg_chunk: Any) -> str: + # 兼容 content_blocks (qwen3 常见) + blocks = getattr(msg_chunk, "content_blocks", None) or [] + if blocks: + out = "" + for b in blocks: + if isinstance(b, dict) and b.get("type") == "text": + out += b.get("text", "") + return out + c = getattr(msg_chunk, "content", "") + return c if isinstance(c, str) else "" + + +async def ws_send(ws: WebSocket, type_: str, data: Any = None): + if getattr(ws, "client_state", None) != WebSocketState.CONNECTED: + return False + try: + await ws.send_json({"type": type_, "data": data}) + return True + except WebSocketDisconnect: + return False + except RuntimeError: + return False + except Exception as e: + if ClientDisconnected is not None and isinstance(e, ClientDisconnected): + return False + logger.exception("ws_send failed: type=%s err=%r", type_, e) + return False + +@asynccontextmanager +async def mcp_sink_context(sink_func): + token = set_mcp_log_sink(sink_func) + try: + yield + finally: + reset_mcp_log_sink(token) + + +@app.websocket("/ws/sessions/{session_id}/chat") +async def ws_chat(ws: WebSocket, session_id: str): + client_ip = _client_ip_from_ws(ws, RATE_LIMIT_TRUST_PROXY_HEADERS) + + ok, retry_after, _ = await RATE_LIMITER.allow( + key=f"ws:connect:{client_ip}", + capacity=float(WS_CONNECT_BURST), + refill_rate=_rpm_to_rps(float(WS_CONNECT_RPM)), + cost=1.0, + ) + if not ok: + try: + await ws.close(code=1013, reason=f"rate_limited, retry after {int(math.ceil(retry_after))}s") + except Exception: + debug_traceback_print(app.state.cfg) + pass + return + + if WS_CONN_SEM.locked(): + try: + await ws.close(code=1013, reason="Server busy (websocket connections limit)") + except Exception: + debug_traceback_print(app.state.cfg) + pass + return + + await WS_CONN_SEM.acquire() + + try: + await ws.accept() + + store: SessionStore = app.state.sessions + sess = await store.get(session_id) + if not sess: + await ws.close(code=4404, reason="session not found") + return + sess = await store.get_or_404(session_id) + + await ws_send(ws, "session.snapshot", sess.snapshot()) + + try: + while True: + req = await ws.receive_json() + if not isinstance(req, dict): + continue + + t = req.get("type") + if t == "ping": + await ws_send(ws, "pong", {"ts": time.time()}) + continue + + if t == "session.set_lang": + data = (req.get("data") or {}) + lang = (data.get("lang") or "").strip().lower() + if lang not in ("zh", "en"): + lang = "zh" + + sess.lang = lang + if sess.client_context: + sess.client_context.lang = lang + + await ws_send(ws, "session.lang", {"lang": lang}) + continue + + if t == "chat.clear": + async with sess.chat_lock: + sess.sent_media_total = 0 + sess._attach_stats_msg_idx = 1 + sess.lc_messages = [ + SystemMessage(content=get_prompt("instruction.system", lang=sess.lang)), + SystemMessage(content="【User media upload status】{}"), + ] + sess._attach_stats_msg_idx = 1 + sess.history = [] + sess._tool_history_index = {} + await ws_send(ws, "chat.cleared", {"ok": True}) + continue + + if t != "chat.send": + await ws_send(ws, "error", {"message": f"unknown type: {t}"}) + continue + + # ---- WebSocket message rate limit: only limit expensive "chat.send" ---- + if sess.chat_lock.locked(): + await ws_send(ws, "error", {"message": "上一条消息尚未完成,请稍后再发送"}) + continue + + ok, retry_after, _ = await RATE_LIMITER.allow( + key="ws:chat_send:all", + capacity=float(WS_CHAT_SEND_ALL_BURST), + refill_rate=_rpm_to_rps(float(WS_CHAT_SEND_ALL_RPM)), + cost=1.0, + ) + if not ok: + await ws_send(ws, "error", { + "message": f"触发全局限流:请 {int(math.ceil(retry_after))} 秒后再试", + "retry_after": int(math.ceil(retry_after)), + }) + continue + + ok, retry_after, _ = await RATE_LIMITER.allow( + key=f"ws:chat_send:{client_ip}", + capacity=float(WS_CHAT_SEND_BURST), + refill_rate=_rpm_to_rps(float(WS_CHAT_SEND_RPM)), + cost=1.0, + ) + if not ok: + await ws_send(ws, "error", { + "message": f"触发限流:请 {int(math.ceil(retry_after))} 秒后再试", + "retry_after": int(math.ceil(retry_after)), + }) + continue + + if CHAT_TURN_SEM.locked(): + await ws_send(ws, "error", {"message": "服务器繁忙(模型并发已满),请稍后再试"}) + continue + + await CHAT_TURN_SEM.acquire() + try: + # 再次确认(期间有 await,锁状态可能变化) + if sess.chat_lock.locked(): + await ws_send(ws, "error", {"message": "上一条消息尚未完成,请稍后再发送"}) + continue + + data = (req.get("data", {}) or {}) + + prompt = data.get("text", "") + prompt = (prompt or "").strip() + if not prompt: + continue + + requested_llm = data.get("llm_model") + requested_vlm = data.get("vlm_model") + + attachment_ids = data.get("attachment_ids") + if not isinstance(attachment_ids, list): + attachment_ids = None + + async with sess.chat_lock: + # 新 turn 开始:清掉上一次残留的 cancel 信号 + sess.cancel_event.clear() + # 0.0) 应用 service_config(自定义模型 / TTS) + ok_cfg, err_cfg = sess.apply_service_config(data.get("service_config")) + if not ok_cfg: + await ws_send(ws, "error", {"message": err_cfg or "service_config invalid"}) + continue + + # 0) 如果前端传了 model,则更新会话当前对话模型 + if isinstance(requested_llm, str): + m = requested_llm.strip() + if m: + sess.chat_model_key = m + if sess.client_context: + sess.client_context.chat_model_key = m + + if isinstance(requested_vlm, str): + m2 = requested_vlm.strip() + if m2: + sess.vlm_model_key = m2 + if sess.client_context: + sess.client_context.vlm_model_key = m2 + + requested_lang = data.get("lang") + if isinstance(requested_lang, str): + lang = requested_lang.strip().lower() + if lang in ("zh", "en"): + sess.lang = lang + # 0.1) 可能需要重建 agent(比如切换到 __custom__ 或者自定义配置变化) + try: + await sess.ensure_agent() + except Exception as e: + await ws_send(ws, "error", {"message": f"{type(e).__name__}: {e}"}) + continue + + sess._ensure_system_prompt() + + if sess.client_context: + sess.client_context.lang = sess.lang + + # 1) 从 pending 里拿本次要发送的附件 + attachments = await sess.take_pending_media_for_message(attachment_ids) + attachments_public = [sess.public_media(m) for m in attachments] + + # 统计本轮和累计发送了几个素材 + turn_attached_count = len(attachments) + sess.sent_media_total = int(getattr(sess, "sent_media_total", 0)) + turn_attached_count + + stats = { + "Number of media carried in this message sent by the user": turn_attached_count, + "Total number of media sent by the user in all conversations": sess.sent_media_total, + "Total number of media in user's media library": scan_media_dir(resolve_media_dir(app.state.cfg.project.media_dir, session_id=session_id)), + } + + idx = int(getattr(sess, "_attach_stats_msg_idx", 1)) + if len(sess.lc_messages) <= idx: + while len(sess.lc_messages) <= idx: + sess.lc_messages.append(SystemMessage(content="")) + + sess.lc_messages[idx] = SystemMessage( + content="【User media upload status】The following fields are used to determine the nature of the media provided by the user: \n" + + json.dumps(stats, ensure_ascii=False) + ) + + + # 2.1 写入 history + lc context + user_msg = { + "id": uuid.uuid4().hex[:12], + "role": "user", + "content": prompt, + "attachments": attachments_public, + "ts": time.time(), + } + sess.history.append(user_msg) + sess.lc_messages.append(HumanMessage(content=prompt)) + + # if app.state.cfg.developer.developer_mode: + # print("[LLM_CTX]", session_id, sess.lc_messages) + + # 2.2 ack:让前端更新 pending + 插入 user 消息(前端也可本地先插入) + await ws_send(ws, "chat.user", { + "text": prompt, + "attachments": attachments_public, + "pending_media": sess.public_pending_media(), + "llm_model_key": sess.chat_model_key, + "vlm_model_key": sess.vlm_model_key, + }) + + # 2.3 建立“单通道事件队列”,确保 ws.send_json 不会并发冲突 + loop = asyncio.get_running_loop() + out_q: asyncio.Queue[Tuple[str, Any]] = asyncio.Queue() + + def sink(ev: Any): + # MCP interceptor 可能 emit 非 dict;这里只收 dict + if isinstance(ev, dict): + loop.call_soon_threadsafe(out_q.put_nowait, ("mcp", ev)) + + new_messages: List[BaseMessage] = [] + + async def pump_agent(): + nonlocal new_messages + try: + stream = sess.agent.astream( + {"messages": sess.lc_messages}, + context=sess.client_context, + stream_mode=["messages", "updates"], + ) + async for mode, chunk in stream: + if mode == "messages": + msg_chunk, meta = chunk + if meta.get("langgraph_node") == "model": + delta = extract_text_delta(msg_chunk) + if delta: + await out_q.put(("assistant.delta", delta)) + + elif mode == "updates": + if isinstance(chunk, dict): + for _step, data in chunk.items(): + msgs = (data or {}).get("messages") or [] + new_messages.extend(msgs) + + await out_q.put(("agent.done", None)) + except asyncio.CancelledError: + # 被用户打断 / 连接关闭导致的取消,不属于“真正异常” + # 不要发 agent.error;给主循环一个 cancelled 信号即可 + try: + out_q.put_nowait(("agent.cancelled", None)) + except Exception: + debug_traceback_print(app.state.cfg) + pass + raise # 让任务保持 cancelled 状态,finally 里 await 时会抛 CancelledError + + except Exception as e: + # 关键:异常也要让主循环“可结束”,否则 UI 卡死 + await out_q.put(("agent.error", f"{type(e).__name__}: {e}")) + + + async def safe_send(type_: str, data: Any = None) -> bool: + try: + await ws_send(ws, type_, data) + return True + except WebSocketDisconnect: + return False + except RuntimeError as e: + # starlette: Cannot call "send" once a close message has been sent. + if 'Cannot call "send" once a close message has been sent.' in str(e): + return False + raise + except Exception as e: + # uvicorn: ClientDisconnected(不同版本类路径不稳定,用类名兜底) + if e.__class__.__name__ == "ClientDisconnected": + return False + raise + # turn 开始(前端可禁用发送按钮/显示占位) + if not await ws_send(ws, "assistant.start", {}): + return + + # 当前 assistant 分段缓冲:用于在 tool_start 到来前“封口” + seg_text = "" + seg_ts: Optional[float] = None + + async def flush_segment(send_flush_event: bool): + """ + - send_flush_event=True:告诉前端立刻结束当前 assistant 气泡(不结束整个 turn) + - 若 seg_text 有内容:写入 history(用于刷新/回放) + """ + nonlocal seg_text, seg_ts + + if send_flush_event: + if not await ws_send(ws, "assistant.flush", {}): + return + + text = (seg_text or "").strip() + if text: + sess.history.append({ + "id": uuid.uuid4().hex[:12], + "role": "assistant", + "content": text, + "ts": seg_ts or time.time(), + }) + + seg_text = "" + seg_ts = None + + pump_task: Optional[asyncio.Task] = None + + # helper: 从 AIMessage 提取 tool_call_id(兼容不同 provider 的结构) + def _tool_call_ids_from_ai_message(m: BaseMessage) -> set[str]: + ids: set[str] = set() + + tc = getattr(m, "tool_calls", None) or [] + for c in tc: + _id = None + if isinstance(c, dict): + _id = c.get("id") or c.get("tool_call_id") + else: + _id = getattr(c, "id", None) or getattr(c, "tool_call_id", None) + if _id: + ids.add(str(_id)) + + ak = getattr(m, "additional_kwargs", None) or {} + tc2 = ak.get("tool_calls") or [] + for c in tc2: + if isinstance(c, dict): + _id = c.get("id") or c.get("tool_call_id") + if _id: + ids.add(str(_id)) + + return ids + + # helper: new_messages 里有哪些 tool_call_id + def _tool_call_ids_in_msgs(msgs: List[BaseMessage]) -> set[str]: + ids: set[str] = set() + for m in msgs: + if isinstance(m, AIMessage): + ids |= _tool_call_ids_from_ai_message(m) + return ids + + # helper: new_messages 里哪些 tool_call_id 已经有 ToolMessage 结果了 + def _tool_result_ids_in_msgs(msgs: List[BaseMessage]) -> set[str]: + ids: set[str] = set() + for m in msgs: + if isinstance(m, ToolMessage): + tcid = getattr(m, "tool_call_id", None) + if tcid: + ids.add(str(tcid)) + return ids + + # helper: 把“已存在的 ToolMessage”强制替换成 cancelled(避免工具其实返回了但用户打断没看到,导致上下文和 UI 不一致) + def _force_cancelled_tool_results(msgs: List[BaseMessage], cancel_ids: set[str]) -> List[BaseMessage]: + if not cancel_ids: + return msgs + cancelled_content = json.dumps({"cancelled": True}, ensure_ascii=False) + out: List[BaseMessage] = [] + for m in msgs: + if isinstance(m, ToolMessage): + tcid = getattr(m, "tool_call_id", None) + if tcid and str(tcid) in cancel_ids: + out.append(ToolMessage(content=cancelled_content, tool_call_id=str(tcid))) + continue + out.append(m) + return out + + def _inject_cancelled_tool_messages(msgs: List[BaseMessage], tool_call_ids: List[str]) -> List[BaseMessage]: + if not tool_call_ids: + return msgs + + out = list(msgs) + + existing = set() + for m in out: + if isinstance(m, ToolMessage): + tcid = getattr(m, "tool_call_id", None) + if tcid: + existing.add(str(tcid)) + + cancelled_content = json.dumps({"cancelled": True}, ensure_ascii=False) + + for tcid in tool_call_ids: + tcid = str(tcid) + if tcid in existing: + continue + + insert_at = None + for i in range(len(out) - 1, -1, -1): + m = out[i] + if isinstance(m, AIMessage) and (tcid in _tool_call_ids_from_ai_message(m)): + insert_at = i + 1 + break + + if insert_at is None: + continue + + out.insert(insert_at, ToolMessage(content=cancelled_content, tool_call_id=tcid)) + existing.add(tcid) + + return out + + def _sanitize_new_messages_on_cancel( + new_messages: List[BaseMessage], + *, + interrupted_text: str, + cancelled_tool_ids_from_ui: List[str], + ) -> List[BaseMessage]: + """ + 返回:应该写回 sess.lc_messages 的消息序列(只包含“用户可见/认可”的那部分) + - 工具:对未返回的 tool_call 补 ToolMessage({"cancelled": true}) + - 回复:用 interrupted_text 替换末尾 final AIMessage,避免把完整回复泄漏进上下文 + """ + msgs = list(new_messages or []) + interrupted_text = (interrupted_text or "").strip() + + # 1) 工具:找出“AI 发起了 tool_call 但没有 ToolMessage 结果”的那些 id + ai_tool_ids = _tool_call_ids_in_msgs(msgs) + tool_result_ids = _tool_result_ids_in_msgs(msgs) + pending_tool_ids = ai_tool_ids - tool_result_ids + + # UI 认为被取消的 tool(running -> cancelled) + ui_cancel_ids = {str(x) for x in (cancelled_tool_ids_from_ui or [])} + + # 统一要取消的集合: + # - UI 侧 running 的(用户按下打断时看见的) + # - 以及 messages 里缺结果的(防止漏标) + cancel_ids = set(ui_cancel_ids) | set(pending_tool_ids) + + # 2) 如果 new_messages 里已经有 ToolMessage(真实结果) 但用户打断了, + # 为了“UI/上下文一致”,强制替换成 cancelled + msgs = _force_cancelled_tool_results(msgs, cancel_ids) + + # 3) 注入缺失的 ToolMessage(cancelled) + msgs = _inject_cancelled_tool_messages(msgs, list(cancel_ids)) + + # 4) 处理 assistant 最终文本(避免把完整 answer 写回) + # - 如果 interrupted_text 非空:用它替换最后一个“非 tool_call 的 AIMessage” + # - 如果 interrupted_text 为空:只在“末尾存在一个 non-toolcall AIMessage(且它后面没有 tool_call)”时移除它 + def _is_toolcall_ai(m: BaseMessage) -> bool: + return isinstance(m, AIMessage) and bool(_tool_call_ids_from_ai_message(m)) + + def _is_text_ai(m: BaseMessage) -> bool: + if not isinstance(m, AIMessage): + return False + if _tool_call_ids_from_ai_message(m): + return False + c = getattr(m, "content", None) + return isinstance(c, str) and bool(c.strip()) + + # 找最后一个“文本 AIMessage(非 tool_call)” + last_text_ai_idx = None + for i in range(len(msgs) - 1, -1, -1): + if _is_text_ai(msgs[i]): + last_text_ai_idx = i + break + + if interrupted_text: + if last_text_ai_idx is None: + msgs.append(AIMessage(content=interrupted_text)) + else: + # 用用户看见的部分替换,且丢弃后面所有消息(防止泄漏) + msgs = msgs[:last_text_ai_idx] + [AIMessage(content=interrupted_text)] + return msgs + + # interrupted_text 为空:用户没看见任何本段 token + # 只移除“末尾的 final answer AIMessage”,避免把 unseen answer 写进上下文; + # 但如果该 AIMessage 后面还有 tool_call(说明它是 pre-tool 文本),就不要删 + if last_text_ai_idx is not None: + has_toolcall_after = any(_is_toolcall_ai(m) for m in msgs[last_text_ai_idx + 1 :]) + if not has_toolcall_after: + msgs = msgs[:last_text_ai_idx] + + return msgs + + pump_task: Optional[asyncio.Task] = None + cancel_wait_task: Optional[asyncio.Task] = None + + was_interrupted = False # 本 turn 是否已经走了“打断收尾” + + try: + async with mcp_sink_context(sink): + pump_task = asyncio.create_task(pump_agent()) + cancel_wait_task = asyncio.create_task(sess.cancel_event.wait()) + + while True: + # 同时等:queue 出事件 或 cancel_event + get_task = asyncio.create_task(out_q.get()) + done, _ = await asyncio.wait( + {get_task, cancel_wait_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + # 优先处理队列事件(避免 done/flush 已经在队列里时被 cancel 抢占) + if get_task in done: + kind, payload = get_task.result() + else: + # cancel_event 触发:不再等 queue + try: + get_task.cancel() + await get_task + except asyncio.CancelledError: + debug_traceback_print(app.state.cfg) + pass + except Exception: + debug_traceback_print(app.state.cfg) + pass + + kind, payload = ("agent.cancelled", None) + + # ------------------------ + # 1) 处理打断 + # ------------------------ + if kind == "agent.cancelled": + # 防止重复触发(cancel_event + pump_agent cancelled 都可能来一次) + if was_interrupted: + break + was_interrupted = True + # 1.1 cancel agent 流(停止继续产出 token/工具) + if pump_task and (not pump_task.done()): + pump_task.cancel() + + # 1.2 将所有 running 的工具卡片标记为 error + cancelled_tool_recs: List[Dict[str, Any]] = [] + for tcid, idx in list(sess._tool_history_index.items()): + rec = sess.history[idx] + if rec.get("role") == "tool" and rec.get("state") == "running": + rec.update({ + "state": "error", + "progress": 1.0, + "message": "Cancelled by user", + "summary": {"cancelled": True}, + }) + cancelled_tool_recs.append(rec) + + # 推送 tool.end,确保前端停止 spinner + for rec in cancelled_tool_recs: + await ws_send(ws, "tool.end", { + "tool_call_id": rec["tool_call_id"], + "server": rec["server"], + "name": rec["name"], + "is_error": True, + "summary": rec.get("summary"), + }) + # 1.3 把已输出的 seg_text 写入 history(UI 看到的内容) + interrupted_text = (seg_text or "").strip() + if interrupted_text: + sess.history.append({ + "id": uuid.uuid4().hex[:12], + "role": "assistant", + "content": interrupted_text, + "ts": seg_ts or time.time(), + }) + + # 1.4 上下文:只写回“用户真实看到/认可”的消息序列 + cancelled_tool_ids = [rec["tool_call_id"] for rec in cancelled_tool_recs] + + commit_msgs = _sanitize_new_messages_on_cancel( + new_messages, + interrupted_text=interrupted_text, + cancelled_tool_ids_from_ui=cancelled_tool_ids, + ) + + if commit_msgs: + sess.lc_messages.extend(commit_msgs) + elif interrupted_text: + # 极端情况:updates 没来得及给任何消息,但用户已看到 token + sess.lc_messages.append(AIMessage(content=interrupted_text)) + + + # ★打断:只发 assistant.end,带 interrupted=true + await ws_send(ws, "assistant.end", {"text": interrupted_text, "interrupted": True}) + + sess.cancel_event.clear() + break + + # ------------------------ + # 2) 事件处理 + # ------------------------ + if kind == "assistant.delta": + delta = payload or "" + if delta: + if seg_ts is None: + seg_ts = time.time() + seg_text += delta + if not await ws_send(ws, "assistant.delta", {"delta": delta}): + raise WebSocketDisconnect() + continue + + if kind == "mcp": + raw = payload + + if raw.get("type") == "tool_start": + await flush_segment(send_flush_event=True) + + rec = sess.apply_tool_event(raw) + if rec: + if raw["type"] == "tool_start": + await ws_send(ws, "tool.start", { + "tool_call_id": rec["tool_call_id"], + "server": rec["server"], + "name": rec["name"], + "args": rec["args"], + }) + elif raw["type"] == "tool_progress": + await ws_send(ws, "tool.progress", { + "tool_call_id": rec["tool_call_id"], + "server": rec["server"], + "name": rec["name"], + "progress": rec["progress"], + "message": rec["message"], + }) + elif raw["type"] == "tool_end": + await ws_send(ws, "tool.end", { + "tool_call_id": rec["tool_call_id"], + "server": rec["server"], + "name": rec["name"], + "is_error": rec["state"] == "error", + "summary": rec["summary"], + }) + continue + + if kind == "agent.done": + final_text = (seg_text or "").strip() + + if final_text: + sess.history.append({ + "id": uuid.uuid4().hex[:12], + "role": "assistant", + "content": final_text, + "ts": seg_ts or time.time(), + }) + + if new_messages: + sess.lc_messages.extend(new_messages) + + if not await ws_send(ws, "assistant.end", {"text": final_text}): + return + break + + if kind == "agent.error": + err_text = str(payload or "unknown error") + partial = (seg_text or "").strip() + + # 把已输出部分落盘/落上下文(避免丢上下文) + if partial: + sess.history.append({ + "id": uuid.uuid4().hex[:12], + "role": "assistant", + "content": partial, + "ts": seg_ts or time.time(), + }) + sess.lc_messages.append(AIMessage(content=partial)) + + if new_messages: + sess.lc_messages.extend(new_messages) + + # ★ 真异常:只发 error(并带 partial_text 让前端结束当前气泡) + await ws_send(ws, "error", {"message": err_text, "partial_text": partial}) + break + + except WebSocketDisconnect: + return + except asyncio.CancelledError: + # 连接关闭/任务取消:不当作 error + return + except Exception as e: + # 如果已经走了打断收尾,别再发 error(避免“打断=报错”) + if was_interrupted: + return + await ws_send(ws, "error", {"message": f"{type(e).__name__}: {e}", "partial_text": (seg_text or "").strip()}) + return + finally: + # 结束 cancel_wait_task + if cancel_wait_task and (not cancel_wait_task.done()): + cancel_wait_task.cancel() + + # pump_task 取消/收尾:避免 await 卡死,加一个短超时保护 + if pump_task and (not pump_task.done()): + pump_task.cancel() + if pump_task: + try: + await asyncio.wait_for(pump_task, timeout=2.0) + except asyncio.TimeoutError: + debug_traceback_print(app.state.cfg) + pass + except asyncio.CancelledError: + debug_traceback_print(app.state.cfg) + pass + except Exception: + debug_traceback_print(app.state.cfg) + pass + finally: + try: + CHAT_TURN_SEM.release() + except Exception: + debug_traceback_print(app.state.cfg) + pass + + except WebSocketDisconnect: + return + finally: + try: + WS_CONN_SEM.release() + except: + pass