xusijie
Clean branch for HF push
06ba7ea
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Optional, Any
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from langchain_mcp_adapters.callbacks import Callbacks
from langchain_mcp_adapters.client import MultiServerMCPClient
from open_storyline.config import Settings
from open_storyline.storage.agent_memory import ArtifactStore
from open_storyline.nodes.node_manager import NodeManager
from open_storyline.mcp.hooks.chat_middleware import handle_tool_errors, on_progress, log_tool_request
from open_storyline.mcp.sampling_handler import make_sampling_callback
from open_storyline.skills.skills_io import load_skills
@dataclass
class ClientContext:
cfg: Settings
session_id: str
media_dir: str
bgm_dir: str
outputs_dir: str
node_manager: NodeManager
chat_model_key: str # Chat model key
vlm_model_key: str = "" # VLM model key
pexels_api_key: Optional[str] = None
tts_config: Optional[dict] = None # TTS config at runtime
llm_pool: dict[tuple[str, bool], ChatOpenAI] = field(default_factory=dict)
lang: str = "zh" # Default language: Chinese
async def build_agent(
cfg: Settings,
session_id: str,
store: ArtifactStore,
tool_interceptors=None,
*,
llm_override: Optional[dict] = None,
vlm_override: Optional[dict] = None,
):
def _get(override: Optional[dict], key: str, default: Any) -> Any:
return (override.get(key) if isinstance(override, dict) and key in override else default)
def _norm_url(u: str) -> str:
u = (u or "").strip()
return u.rstrip("/") if u else u
# 1) LLM: use user input from form first, fall back to config.toml
llm_model = _get(llm_override, "model", cfg.llm.model)
llm_base_url = _norm_url(_get(llm_override, "base_url", cfg.llm.base_url))
llm_api_key = _get(llm_override, "api_key", cfg.llm.api_key)
llm_timeout = _get(llm_override, "timeout", cfg.llm.timeout)
llm_temperature = _get(llm_override, "temperature", cfg.llm.temperature)
llm_max_retries = _get(llm_override, "max_retries", cfg.llm.max_retries)
llm = ChatOpenAI(
model=llm_model,
base_url=llm_base_url,
api_key=llm_api_key,
default_headers={
"api-key": llm_api_key,
"Content-Type": "application/json",
},
timeout=llm_timeout,
temperature=llm_temperature,
streaming=True,
max_retries=llm_max_retries,
)
# 2) VLM: same priority as above
vlm_model = _get(vlm_override, "model", cfg.vlm.model)
vlm_base_url = _norm_url(_get(vlm_override, "base_url", cfg.vlm.base_url))
vlm_api_key = _get(vlm_override, "api_key", cfg.vlm.api_key)
vlm_timeout = _get(vlm_override, "timeout", cfg.vlm.timeout)
vlm_temperature = _get(vlm_override, "temperature", cfg.vlm.temperature)
vlm_max_retries = _get(vlm_override, "max_retries", cfg.vlm.max_retries)
vlm = ChatOpenAI(
model=vlm_model,
base_url=vlm_base_url,
api_key=vlm_api_key,
default_headers={
"api-key": vlm_api_key,
"Content-Type": "application/json",
},
timeout=vlm_timeout,
temperature=vlm_temperature,
max_retries=vlm_max_retries,
)
sampling_callback = make_sampling_callback(llm, vlm)
connections = {
cfg.local_mcp_server.server_name: {
"transport": cfg.local_mcp_server.server_transport,
"url": cfg.local_mcp_server.url,
"timeout": timedelta(seconds=cfg.local_mcp_server.timeout),
"sse_read_timeout": timedelta(minutes=30),
"headers": {"X-Storyline-Session-Id": session_id},
"session_kwargs": {"sampling_callback": sampling_callback},
},
}
client = MultiServerMCPClient(
connections=connections,
tool_interceptors=tool_interceptors,
callbacks=Callbacks(on_progress=on_progress),
tool_name_prefix=True,
)
tools = await client.get_tools()
skills = await load_skills(cfg.skills.skill_dir) # Load skills
node_manager = NodeManager(tools)
# 4) Use LangChain's agent runtime to handle the multi-turn tool calling loop
agent = create_agent(
model=llm,
tools=tools+skills,
middleware=[log_tool_request, handle_tool_errors],
store=store,
context_schema=ClientContext,
)
return agent, node_manager