Spaces:
Sleeping
Sleeping
| import asyncio | |
| import time | |
| import uuid | |
| import os,sys | |
| import json | |
| import contextvars | |
| from typing import Callable, Optional, Union, Any, List | |
| from dataclasses import dataclass | |
| from datetime import timedelta | |
| from langchain.agents import create_agent | |
| from langchain.agents.middleware import wrap_tool_call, wrap_model_call | |
| from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from langchain_mcp_adapters.interceptors import MCPToolCallRequest, MCPToolCallResult | |
| from langchain_mcp_adapters.callbacks import Callbacks, CallbackContext | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.types import Command | |
| # 把 src 目录加入 Python 的模块搜索路径 | |
| ROOT_DIR = os.path.dirname(__file__) | |
| SRC_DIR = os.path.join(ROOT_DIR, "src") | |
| if SRC_DIR not in sys.path: | |
| sys.path.insert(0, SRC_DIR) | |
| from open_storyline.skills.skills_io import load_skills | |
| from open_storyline.utils.prompts import get_prompt | |
| from open_storyline.mcp.sampling_handler import make_sampling_callback | |
| from open_storyline.config import load_settings, default_config_path | |
| from open_storyline.config import Settings | |
| from open_storyline.storage.agent_memory import ArtifactStore | |
| from open_storyline.nodes.node_manager import NodeManager | |
| from open_storyline.mcp.hooks.node_interceptors import ToolInterceptor | |
| class ClientContext: | |
| session_id: str | |
| assets_dir: str | |
| bgm_dir: str | |
| outputs_dir: str | |
| node_manager: NodeManager | |
| chat_model_key: str # 当前聊天的模型 | |
| tts_config: Optional[dict] = None # 运行时 TTS 配置 | |
| # GUI 日志输出通道 | |
| _MCP_LOG_SINK = contextvars.ContextVar("mcp_log_sink", default=None) | |
| _MCP_ACTIVE_TOOL_CALL_ID = contextvars.ContextVar("mcp_active_tool_call_id", default=None) | |
| def set_mcp_log_sink(sink: Optional[Callable[[dict], None]]): | |
| return _MCP_LOG_SINK.set(sink) | |
| def reset_mcp_log_sink(token): | |
| _MCP_LOG_SINK.reset(token) | |
| # ------ key 脱敏 + 注入 tts 凭据,先在这里放着,等改完框架看看放到哪里合适--- | |
| CUSTOM_MODEL_KEY = "__custom__" | |
| _SENSITIVE_KEYS = { | |
| "api_key", | |
| "access_token", | |
| "authorization", | |
| "token", | |
| "password", | |
| "secret", | |
| "x-api-key", | |
| "apikey", | |
| } | |
| def _mask_secrets(obj: Any) -> Any: | |
| """递归脱敏:避免 key/token 被打印到控制台、日志、tool trace、ToolMessage等各种地方""" | |
| try: | |
| if isinstance(obj, dict): | |
| out = {} | |
| for k, v in obj.items(): | |
| if str(k).lower() in _SENSITIVE_KEYS: | |
| out[k] = "***" | |
| else: | |
| out[k] = _mask_secrets(v) | |
| return out | |
| if isinstance(obj, list): | |
| return [_mask_secrets(x) for x in obj] | |
| if isinstance(obj, tuple): | |
| return tuple(_mask_secrets(x) for x in obj) | |
| return obj | |
| except Exception: | |
| return "***" | |
| async def inject_tts_config(request: MCPToolCallRequest, handler): | |
| """ | |
| 拦截器:在调用 TTS 工具前注入 key。 | |
| """ | |
| try: | |
| runtime = getattr(request, "runtime", None) | |
| ctx = getattr(runtime, "context", None) if runtime else None | |
| tts_cfg = getattr(ctx, "tts_config", None) if ctx else None | |
| if not (isinstance(tts_cfg, dict) and isinstance(getattr(request, "args", None), dict)): | |
| return await handler(request) | |
| tool_name = str(getattr(request, "name", "") or "") | |
| if "voiceover" not in tool_name: | |
| return await handler(request) | |
| vendor = str(tts_cfg.get("vendor") or "").lower().strip() | |
| if vendor != "bytedance": | |
| return await handler(request) | |
| bd = tts_cfg.get("bytedance") or {} | |
| if not isinstance(bd, dict): | |
| return await handler(request) | |
| uid = bd.get("uid") | |
| appid = bd.get("appid") | |
| access_token = bd.get("access_token") | |
| if uid and appid and access_token: | |
| request.args.setdefault("uid", uid) | |
| request.args.setdefault("appid", appid) | |
| request.args.setdefault("access_token", access_token) | |
| except Exception: | |
| pass | |
| return await handler(request) | |
| async def log_mcp_request(request: MCPToolCallRequest, handler): | |
| sink = _MCP_LOG_SINK.get() | |
| def emit_event(x: str | dict): | |
| if sink: | |
| sink(x) | |
| runtime = request.runtime | |
| context = runtime.context | |
| meta_collector = context.node_manager | |
| exclude = set(meta_collector.kind_to_node_ids.keys()) | { | |
| "inputs", "artifacts_dir", "artifact_id", "blobs_dir", "meta_path", | |
| "assets_dir", "bgm_dir", "outputs_dir", "debug_dir", | |
| } | |
| extracted_args = {} | |
| for arg in (request.args.keys() if isinstance(request.args, dict) else []): | |
| if arg not in exclude: | |
| extracted_args[arg] = request.args[arg] | |
| extracted_args = _mask_secrets(extracted_args) | |
| tool_call_id = getattr(runtime, "tool_call_id", None) | |
| if not tool_call_id: | |
| tool_call_id = f"mcp_{uuid.uuid4().hex[:8]}" | |
| if runtime is not None: | |
| setattr(runtime, "tool_call_id", tool_call_id) | |
| active_tok = _MCP_ACTIVE_TOOL_CALL_ID.set(tool_call_id) | |
| try: | |
| emit_event({ | |
| "type": "tool_start", | |
| "tool_call_id": tool_call_id, | |
| "server": request.server_name, | |
| "name": request.name, | |
| "args": extracted_args, | |
| }) | |
| print(f"[MCP 工具开始] {request.server_name}.{request.name} args={extracted_args}\n") | |
| out = await handler(request) | |
| try: | |
| isError = json.loads(out.content[0].text).get('isError', False) | |
| except Exception: | |
| isError = True | |
| finally: | |
| _MCP_ACTIVE_TOOL_CALL_ID.reset(active_tok) | |
| if out.isError or isError: | |
| print(f"[MCP 工具出错] result:{out.content}\n\n") | |
| emit_event({ | |
| "type": "tool_end", | |
| "tool_call_id": tool_call_id, | |
| "server": request.server_name, | |
| "name": request.name, | |
| "is_error": True, | |
| "summary": _mask_secrets(out.content), | |
| }) | |
| else: | |
| payload = json.loads(out.content[0].text) | |
| summary = payload.get("summary") | |
| print(f"[MCP 工具完成] result:{summary}\n\n") | |
| emit_event({ | |
| "type": "tool_end", | |
| "tool_call_id": tool_call_id, | |
| "server": request.server_name, | |
| "name": request.name, | |
| "is_error": False, | |
| "summary": _mask_secrets(summary), | |
| }) | |
| return out | |
| async def handle_tool_errors(request, handler): | |
| try: | |
| out = await handler(request) | |
| if isinstance(out, Command): | |
| return out.update.get('messages')[0] | |
| elif isinstance(out, MCPToolCallResult) and not isinstance(out.content, str): | |
| return ToolMessage(content=out.content[0].get("text", ""), tool_call_id=out.tool_call_id, name=out.name) # 解决deepseek-chat兼容性问题 | |
| return out | |
| except Exception as e: | |
| tc = request.tool_call | |
| safe_args = _mask_secrets(tc.get("args") or {}) | |
| return ToolMessage( | |
| content=( | |
| "工具调用失败\n" | |
| f"工具名:{tc.get('name')}\n" | |
| f"工具参数:{safe_args}\n" | |
| f"错误信息:{type(e).__name__}: {e}\n" | |
| "如果是参数问题,请修正参数重新调用;如果是因为缺少前置依赖,请先调用前置节点;如果你认为是偶发错误,请尝试重新调用;如果你认为已经无法继续,请向用户解释原因。" | |
| ), | |
| tool_call_id=tc["id"], | |
| ) | |
| async def on_progress(progress: float, total: float | None, message: str| None, context: CallbackContext): | |
| sink = _MCP_LOG_SINK.get() | |
| if sink: | |
| sink({ | |
| "type": "tool_progress", | |
| "tool_call_id": _MCP_ACTIVE_TOOL_CALL_ID.get(), | |
| "server": context.server_name, | |
| "name": context.tool_name, | |
| "progress": progress, | |
| "total": total, | |
| "message": message, | |
| }) | |
| async def build_agent( | |
| cfg: Settings, | |
| session_id: str, | |
| store: ArtifactStore, | |
| tool_interceptors=None, | |
| *, | |
| llm_override: Optional[dict] = None, | |
| vlm_override: Optional[dict] = None, | |
| ): | |
| def _get(override: Optional[dict], key: str, default: Any) -> Any: | |
| return (override.get(key) if isinstance(override, dict) and key in override else default) | |
| def _norm_url(u: str) -> str: | |
| u = (u or "").strip() | |
| return u.rstrip("/") if u else u | |
| # 1) LLM 优先用输入框里用户填写的,其次是 config.toml 里的 | |
| llm_model = _get(llm_override, "model", cfg.llm.model) | |
| llm_base_url = _norm_url(_get(llm_override, "base_url", cfg.llm.base_url)) | |
| llm_api_key = _get(llm_override, "api_key", cfg.llm.api_key) | |
| llm_timeout = _get(llm_override, "timeout", cfg.llm.timeout) | |
| llm_temperature = _get(llm_override, "temperature", cfg.llm.temperature) | |
| llm_max_retries = _get(llm_override, "max_retries", cfg.llm.max_retries) | |
| llm = ChatOpenAI( | |
| model=llm_model, | |
| base_url=llm_base_url, | |
| api_key=llm_api_key, | |
| default_headers={ | |
| "api-key": llm_api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| timeout=llm_timeout, | |
| temperature=llm_temperature, | |
| streaming=True, | |
| max_retries=llm_max_retries, | |
| ) | |
| # 2) VLM 优先级同上 | |
| vlm_model = _get(vlm_override, "model", cfg.vlm.model) | |
| vlm_base_url = _norm_url(_get(vlm_override, "base_url", cfg.vlm.base_url)) | |
| vlm_api_key = _get(vlm_override, "api_key", cfg.vlm.api_key) | |
| vlm_timeout = _get(vlm_override, "timeout", cfg.vlm.timeout) | |
| vlm_temperature = _get(vlm_override, "temperature", cfg.vlm.temperature) | |
| vlm_max_retries = _get(vlm_override, "max_retries", cfg.vlm.max_retries) | |
| vlm = ChatOpenAI( | |
| model=vlm_model, | |
| base_url=vlm_base_url, | |
| api_key=vlm_api_key, | |
| default_headers={ | |
| "api-key": vlm_api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| timeout=vlm_timeout, | |
| temperature=vlm_temperature, | |
| max_retries=vlm_max_retries, | |
| ) | |
| sampling_callback = make_sampling_callback(llm, vlm) | |
| # ---- 开发者模型 allowlist,只用于站内预置模型切换 ---- | |
| allowed = set(cfg.developer.chat_models or []) | |
| llm_pool: dict[tuple[str, bool], ChatOpenAI] = {} | |
| def _make_chat_llm(model_name: str, streaming: bool) -> ChatOpenAI: | |
| model_config = (cfg.developer.chat_models_config.get(model_name) or {}) | |
| base_url = _norm_url(model_config.get("base_url") or "") | |
| api_key = model_config.get("api_key") | |
| return ChatOpenAI( | |
| model=model_name, | |
| base_url=base_url, | |
| api_key=api_key, | |
| default_headers={ | |
| "api-key": api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| timeout=cfg.llm.timeout, | |
| temperature=model_config.get("temperature", cfg.llm.temperature), | |
| streaming=streaming, | |
| ) | |
| def _get_llm(model_name: str, streaming: bool) -> ChatOpenAI: | |
| hit = llm_pool.get((model_name, streaming)) | |
| if hit: | |
| return hit | |
| new_llm = _make_chat_llm(model_name, streaming=streaming) | |
| llm_pool[(model_name, streaming)] = new_llm | |
| return new_llm | |
| async def switch_chat_model(request, handler): | |
| runtime = getattr(request, "runtime", None) | |
| ctx = getattr(runtime, "context", None) if runtime else None | |
| model_key = getattr(ctx, "chat_model_key", None) if ctx else None | |
| model_key = model_key.strip() if isinstance(model_key, str) else "" | |
| if not model_key: | |
| return await handler(request) | |
| if model_key == CUSTOM_MODEL_KEY: | |
| return await handler(request) | |
| if model_key not in allowed: | |
| return await handler(request) | |
| # 优先流式,失败再非流式 | |
| for _ in range(2): | |
| try: | |
| llm_stream = _get_llm(model_key, streaming=True) | |
| if cfg.developer.print_context: | |
| print(request) | |
| return await handler(request.override(model=llm_stream)) | |
| except Exception as e1: | |
| print(f"流式 llm 出错:{e1}") | |
| await asyncio.sleep(2) | |
| for _ in range(2): | |
| try: | |
| llm_nostream = _get_llm(model_key, streaming=False) | |
| if cfg.developer.print_context: | |
| print(request) | |
| return await handler(request.override(model=llm_nostream)) | |
| except Exception as e2: | |
| print(f"非流式 llm 出错:{e2}") | |
| await asyncio.sleep(2) | |
| fresh = _make_chat_llm(model_key, streaming=False) | |
| llm_pool[(model_key, False)] = fresh | |
| if cfg.developer.print_context: | |
| print(request) | |
| return await handler(request.override(model=fresh)) | |
| connections = { | |
| cfg.local_mcp_server.server_name: { | |
| "transport": cfg.local_mcp_server.server_transport, | |
| "url": cfg.local_mcp_server.url, | |
| "timeout": timedelta(seconds=cfg.local_mcp_server.timeout), | |
| "headers": {"X-Storyline-Session-Id": session_id}, | |
| "session_kwargs": {"sampling_callback": sampling_callback}, | |
| }, | |
| } | |
| client = MultiServerMCPClient( | |
| connections=connections, | |
| tool_interceptors=tool_interceptors or [log_mcp_request], | |
| callbacks=Callbacks(on_progress=on_progress), | |
| tool_name_prefix=True, | |
| ) | |
| tools = await client.get_tools() | |
| skills = await load_skills() # 加载skills | |
| node_manager = NodeManager(tools) | |
| # 4) 用 LangChain 的 agent runtime 接管“多轮工具调用 loop” | |
| agent = create_agent( | |
| model=llm, | |
| tools=tools+skills, | |
| middleware=[switch_chat_model, handle_tool_errors], | |
| store=store, | |
| context_schema=ClientContext, | |
| ) | |
| return agent, node_manager | |
| async def main(): | |
| session_id = f"run_{int(time.time())}_{uuid.uuid4().hex[:8]}" | |
| cfg = load_settings(default_config_path()) | |
| artifact_store = ArtifactStore(cfg.project.outputs_dir, session_id=session_id) | |
| agent, node_manager = await build_agent(cfg=cfg, session_id=session_id, store=artifact_store, tool_interceptors=[ToolInterceptor.inject_media_content_before, ToolInterceptor.save_media_content_after, log_mcp_request]) | |
| context = ClientContext( | |
| session_id=session_id, | |
| assets_dir=cfg.project.assets_dir, | |
| bgm_dir=cfg.project.bgm_dir, | |
| outputs_dir=cfg.project.outputs_dir, | |
| node_manager=node_manager, | |
| chat_model_key=cfg.llm.model, | |
| ) | |
| # 用 messages 列表做最简单的短期记忆(替代手写 history_for_llm 字符串) | |
| messages: List[BaseMessage] = [SystemMessage(content=get_prompt("instruction.system"))] | |
| print("智能剪辑Agent(LangChain + MCP) v 0.2.0") | |
| print("请描述你的剪辑需求,输入 /exit 退出") | |
| while True: | |
| try: | |
| user_input = input("你:").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| print("\n再见~") | |
| break | |
| if not user_input: | |
| continue | |
| if user_input in ("/exit", "/quit"): | |
| print("\n再见~") | |
| break | |
| messages.append(HumanMessage(content=user_input)) | |
| print("正在调用LLM中,请耐心等待...") | |
| # 这里一次 ainvoke,就会在内部完成:推理 -> 多次调用工具 -> 汇总 -> 最终回答 | |
| result = await agent.ainvoke( | |
| {"messages": messages}, | |
| context=context | |
| ) | |
| messages = result["messages"] | |
| final_text = None | |
| for m in reversed(messages): | |
| if isinstance(m, AIMessage): | |
| final_text = m.content | |
| break | |
| print(f"\n助手:{final_text or '(未生成最终答复)'}\n") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |