Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| from datetime import timedelta | |
| from typing import Optional, Any | |
| from langchain.agents import create_agent | |
| from langchain_openai import ChatOpenAI | |
| from langchain_mcp_adapters.callbacks import Callbacks | |
| from langchain_mcp_adapters.client import MultiServerMCPClient | |
| from open_storyline.config import Settings | |
| from open_storyline.storage.agent_memory import ArtifactStore | |
| from open_storyline.nodes.node_manager import NodeManager | |
| from open_storyline.mcp.hooks.chat_middleware import handle_tool_errors, on_progress, log_tool_request | |
| from open_storyline.mcp.sampling_handler import make_sampling_callback | |
| from open_storyline.skills.skills_io import load_skills | |
| class ClientContext: | |
| cfg: Settings | |
| session_id: str | |
| media_dir: str | |
| bgm_dir: str | |
| outputs_dir: str | |
| node_manager: NodeManager | |
| chat_model_key: str # Chat model key | |
| vlm_model_key: str = "" # VLM model key | |
| pexels_api_key: Optional[str] = None | |
| tts_config: Optional[dict] = None # TTS config at runtime | |
| llm_pool: dict[tuple[str, bool], ChatOpenAI] = field(default_factory=dict) | |
| lang: str = "zh" # Default language: Chinese | |
| async def build_agent( | |
| cfg: Settings, | |
| session_id: str, | |
| store: ArtifactStore, | |
| tool_interceptors=None, | |
| *, | |
| llm_override: Optional[dict] = None, | |
| vlm_override: Optional[dict] = None, | |
| ): | |
| def _get(override: Optional[dict], key: str, default: Any) -> Any: | |
| return (override.get(key) if isinstance(override, dict) and key in override else default) | |
| def _norm_url(u: str) -> str: | |
| u = (u or "").strip() | |
| return u.rstrip("/") if u else u | |
| # 1) LLM: use user input from form first, fall back to config.toml | |
| llm_model = _get(llm_override, "model", cfg.llm.model) | |
| llm_base_url = _norm_url(_get(llm_override, "base_url", cfg.llm.base_url)) | |
| llm_api_key = _get(llm_override, "api_key", cfg.llm.api_key) | |
| llm_timeout = _get(llm_override, "timeout", cfg.llm.timeout) | |
| llm_temperature = _get(llm_override, "temperature", cfg.llm.temperature) | |
| llm_max_retries = _get(llm_override, "max_retries", cfg.llm.max_retries) | |
| llm = ChatOpenAI( | |
| model=llm_model, | |
| base_url=llm_base_url, | |
| api_key=llm_api_key, | |
| default_headers={ | |
| "api-key": llm_api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| timeout=llm_timeout, | |
| temperature=llm_temperature, | |
| streaming=True, | |
| max_retries=llm_max_retries, | |
| ) | |
| # 2) VLM: same priority as above | |
| vlm_model = _get(vlm_override, "model", cfg.vlm.model) | |
| vlm_base_url = _norm_url(_get(vlm_override, "base_url", cfg.vlm.base_url)) | |
| vlm_api_key = _get(vlm_override, "api_key", cfg.vlm.api_key) | |
| vlm_timeout = _get(vlm_override, "timeout", cfg.vlm.timeout) | |
| vlm_temperature = _get(vlm_override, "temperature", cfg.vlm.temperature) | |
| vlm_max_retries = _get(vlm_override, "max_retries", cfg.vlm.max_retries) | |
| vlm = ChatOpenAI( | |
| model=vlm_model, | |
| base_url=vlm_base_url, | |
| api_key=vlm_api_key, | |
| default_headers={ | |
| "api-key": vlm_api_key, | |
| "Content-Type": "application/json", | |
| }, | |
| timeout=vlm_timeout, | |
| temperature=vlm_temperature, | |
| max_retries=vlm_max_retries, | |
| ) | |
| sampling_callback = make_sampling_callback(llm, vlm) | |
| connections = { | |
| cfg.local_mcp_server.server_name: { | |
| "transport": cfg.local_mcp_server.server_transport, | |
| "url": cfg.local_mcp_server.url, | |
| "timeout": timedelta(seconds=cfg.local_mcp_server.timeout), | |
| "sse_read_timeout": timedelta(minutes=30), | |
| "headers": {"X-Storyline-Session-Id": session_id}, | |
| "session_kwargs": {"sampling_callback": sampling_callback}, | |
| }, | |
| } | |
| client = MultiServerMCPClient( | |
| connections=connections, | |
| tool_interceptors=tool_interceptors, | |
| callbacks=Callbacks(on_progress=on_progress), | |
| tool_name_prefix=True, | |
| ) | |
| tools = await client.get_tools() | |
| skills = await load_skills(cfg.skills.skill_dir) # Load skills | |
| node_manager = NodeManager(tools) | |
| # 4) Use LangChain's agent runtime to handle the multi-turn tool calling loop | |
| agent = create_agent( | |
| model=llm, | |
| tools=tools+skills, | |
| middleware=[log_tool_request, handle_tool_errors], | |
| store=store, | |
| context_schema=ClientContext, | |
| ) | |
| return agent, node_manager |