| | from dataclasses import dataclass, field |
| | from datetime import timedelta |
| | from typing import Optional, Any |
| |
|
| |
|
| | from langchain.agents import create_agent |
| | from langchain_openai import ChatOpenAI |
| |
|
| | from langchain_mcp_adapters.callbacks import Callbacks |
| | from langchain_mcp_adapters.client import MultiServerMCPClient |
| |
|
| | from open_storyline.config import Settings |
| | from open_storyline.storage.agent_memory import ArtifactStore |
| | from open_storyline.nodes.node_manager import NodeManager |
| | from open_storyline.mcp.hooks.chat_middleware import handle_tool_errors, on_progress, log_tool_request |
| | from open_storyline.mcp.sampling_handler import make_sampling_callback |
| | from open_storyline.skills.skills_io import load_skills |
| |
|
| | @dataclass |
| | class ClientContext: |
| | cfg: Settings |
| | session_id: str |
| | media_dir: str |
| | bgm_dir: str |
| | outputs_dir: str |
| | node_manager: NodeManager |
| | chat_model_key: str |
| | vlm_model_key: str = "" |
| | pexels_api_key: Optional[str] = None |
| | tts_config: Optional[dict] = None |
| | llm_pool: dict[tuple[str, bool], ChatOpenAI] = field(default_factory=dict) |
| | lang: str = "zh" |
| |
|
| |
|
| | async def build_agent( |
| | cfg: Settings, |
| | session_id: str, |
| | store: ArtifactStore, |
| | tool_interceptors=None, |
| | *, |
| | llm_override: Optional[dict] = None, |
| | vlm_override: Optional[dict] = None, |
| | ): |
| | def _get(override: Optional[dict], key: str, default: Any) -> Any: |
| | return (override.get(key) if isinstance(override, dict) and key in override else default) |
| |
|
| | def _norm_url(u: str) -> str: |
| | u = (u or "").strip() |
| | return u.rstrip("/") if u else u |
| | |
| | |
| | llm_model = _get(llm_override, "model", cfg.llm.model) |
| | llm_base_url = _norm_url(_get(llm_override, "base_url", cfg.llm.base_url)) |
| | llm_api_key = _get(llm_override, "api_key", cfg.llm.api_key) |
| | llm_timeout = _get(llm_override, "timeout", cfg.llm.timeout) |
| | llm_temperature = _get(llm_override, "temperature", cfg.llm.temperature) |
| | llm_max_retries = _get(llm_override, "max_retries", cfg.llm.max_retries) |
| |
|
| | llm = ChatOpenAI( |
| | model=llm_model, |
| | base_url=llm_base_url, |
| | api_key=llm_api_key, |
| | default_headers={ |
| | "api-key": llm_api_key, |
| | "Content-Type": "application/json", |
| | }, |
| | timeout=llm_timeout, |
| | temperature=llm_temperature, |
| | streaming=True, |
| | max_retries=llm_max_retries, |
| | ) |
| |
|
| | |
| | vlm_model = _get(vlm_override, "model", cfg.vlm.model) |
| | vlm_base_url = _norm_url(_get(vlm_override, "base_url", cfg.vlm.base_url)) |
| | vlm_api_key = _get(vlm_override, "api_key", cfg.vlm.api_key) |
| | vlm_timeout = _get(vlm_override, "timeout", cfg.vlm.timeout) |
| | vlm_temperature = _get(vlm_override, "temperature", cfg.vlm.temperature) |
| | vlm_max_retries = _get(vlm_override, "max_retries", cfg.vlm.max_retries) |
| |
|
| | vlm = ChatOpenAI( |
| | model=vlm_model, |
| | base_url=vlm_base_url, |
| | api_key=vlm_api_key, |
| | default_headers={ |
| | "api-key": vlm_api_key, |
| | "Content-Type": "application/json", |
| | }, |
| | timeout=vlm_timeout, |
| | temperature=vlm_temperature, |
| | max_retries=vlm_max_retries, |
| | ) |
| |
|
| | sampling_callback = make_sampling_callback(llm, vlm) |
| |
|
| | connections = { |
| | cfg.local_mcp_server.server_name: { |
| | "transport": cfg.local_mcp_server.server_transport, |
| | "url": cfg.local_mcp_server.url, |
| | "timeout": timedelta(seconds=cfg.local_mcp_server.timeout), |
| | "sse_read_timeout": timedelta(minutes=30), |
| | "headers": {"X-Storyline-Session-Id": session_id}, |
| | "session_kwargs": {"sampling_callback": sampling_callback}, |
| | }, |
| | } |
| |
|
| | client = MultiServerMCPClient( |
| | connections=connections, |
| | tool_interceptors=tool_interceptors, |
| | callbacks=Callbacks(on_progress=on_progress), |
| | tool_name_prefix=True, |
| | ) |
| |
|
| | tools = await client.get_tools() |
| | skills = await load_skills(cfg.skills.skill_dir) |
| | node_manager = NodeManager(tools) |
| |
|
| | |
| | agent = create_agent( |
| | model=llm, |
| | tools=tools+skills, |
| | middleware=[log_tool_request, handle_tool_errors], |
| | store=store, |
| | context_schema=ClientContext, |
| | ) |
| | return agent, node_manager |