Spaces:
Running
Running
| """ | |
| MACPRunner — executor of Multi-Agent Communication Protocol. | |
| Supports both simple sequential execution and adaptive mode | |
| with conditional edges, pruning, fallback, and parallel execution. | |
| Also supports: | |
| - Stratified memory (working/long-term) | |
| - Hidden channels (hidden_state, embeddings) | |
| - Exchange protocol separating visible/hidden data | |
| - Typed errors and handling policies | |
| - Token/request budgets at graph and node levels | |
| - Structured execution event logging | |
| - **Multi-model**: each agent can use its own LLM | |
| Multi-model support: | |
| MACPRunner supports three ways to specify LLMs: | |
| 1. A single llm_caller for all agents (legacy) | |
| 2. A dict of llm_callers: dict[agent_id, Callable] for different models | |
| 3. LLMCallerFactory that creates callers based on AgentLLMConfig | |
| Priority: agent-specific caller > factory > default caller | |
| """ | |
| import asyncio | |
| import time | |
| import uuid | |
| from collections import deque | |
| from collections.abc import AsyncIterator, Awaitable, Callable, Iterator | |
| from datetime import UTC, datetime | |
| from typing import Any, NamedTuple | |
| import torch | |
| from pydantic import BaseModel, ConfigDict, Field | |
| from callbacks import ( | |
| CallbackManager, | |
| Handler, | |
| get_callback_manager, | |
| ) | |
| from core.agent import AgentLLMConfig | |
| from utils.memory import AgentMemory, MemoryConfig, SharedMemoryPool | |
| from .budget import BudgetConfig, BudgetTracker | |
| from .errors import ( | |
| ErrorPolicy, | |
| ExecutionError, | |
| ExecutionMetrics, | |
| ) | |
| from .scheduler import ( | |
| AdaptiveScheduler, | |
| ConditionContext, | |
| ExecutionPlan, | |
| PruningConfig, | |
| RoutingPolicy, | |
| StepResult, | |
| build_execution_order, | |
| extract_agent_adjacency, | |
| filter_reachable_agents, | |
| get_incoming_agents, | |
| get_parallel_groups, | |
| ) | |
| from .streaming import ( | |
| AgentErrorEvent, | |
| AgentOutputEvent, | |
| AgentStartEvent, | |
| AnyStreamEvent, | |
| AsyncStreamCallback, | |
| FallbackEvent, | |
| ParallelEndEvent, | |
| ParallelStartEvent, | |
| PruneEvent, | |
| RunEndEvent, | |
| RunStartEvent, | |
| StreamCallback, | |
| StreamEvent, | |
| StreamEventType, | |
| TokenEvent, | |
| TopologyChangedEvent, | |
| ) | |
| # Tools support (optional import) | |
| try: | |
| from tools import ToolRegistry | |
| TOOLS_AVAILABLE = True | |
| except ImportError: | |
| TOOLS_AVAILABLE = False | |
| # --------------------------------------------------------------------------- | |
| # Module-level constants | |
| # --------------------------------------------------------------------------- | |
| #: Minimum edge weight below which the edge is considered absent (BFS). | |
| _MIN_EDGE_WEIGHT: float = 1e-6 | |
| __all__ = [ | |
| "AgentMemory", | |
| "AsyncStreamCallback", | |
| "AsyncStructuredLLMCallerProtocol", | |
| "AsyncTopologyHook", | |
| "BudgetConfig", | |
| "EarlyStopCondition", | |
| "ErrorPolicy", | |
| "ExecutionMetrics", | |
| "HiddenState", | |
| # Multi-model support | |
| "LLMCallerFactory", | |
| "LLMCallerProtocol", | |
| "MACPResult", | |
| "MACPRunner", | |
| "MemoryConfig", | |
| "RunnerConfig", | |
| "SharedMemoryPool", | |
| # Dynamic topology | |
| "StepContext", | |
| "StreamCallback", | |
| # Streaming types | |
| "StreamEvent", | |
| "StreamEventType", | |
| # Structured prompt support | |
| "StructuredLLMCallerProtocol", | |
| "StructuredPrompt", | |
| "ToolRegistry", | |
| "TopologyAction", | |
| "TopologyHook", | |
| "create_openai_async_structured_caller", | |
| "create_openai_caller", | |
| "create_openai_structured_caller", | |
| ] | |
| # Type aliases for LLM callers | |
| LLMCallerProtocol = Callable[[str], str] | |
| AsyncLLMCallerProtocol = Callable[[str], Awaitable[str]] | |
| # Structured prompt: allows callers to receive proper system/user roles | |
| # instead of a flat string. This is the modern way to call chat LLMs. | |
| StructuredLLMCallerProtocol = Callable[[list[dict[str, str]]], str] | |
| AsyncStructuredLLMCallerProtocol = Callable[[list[dict[str, str]]], Awaitable[str]] | |
| class StructuredPrompt: | |
| """ | |
| Prompt that carries both a flat string (legacy) and structured messages. | |
| When a ``structured_llm_caller`` is available the runner sends | |
| ``messages`` (a list of ``{"role": ..., "content": ...}`` dicts) | |
| so the LLM receives proper system / user roles. Otherwise the | |
| runner falls back to the flat ``text`` representation — full | |
| backward compatibility with ``Callable[[str], str]`` callers. | |
| """ | |
| __slots__ = ("messages", "text") | |
| def __init__(self, text: str, messages: list[dict[str, str]]) -> None: | |
| self.text = text | |
| self.messages = messages | |
| # Allow `len()`, logging, and anywhere a plain str was expected | |
| def __str__(self) -> str: | |
| return self.text | |
| def __repr__(self) -> str: | |
| return f"StructuredPrompt(text='{self.text[:80]}…', messages={len(self.messages)})" | |
| class LLMCallerFactory: | |
| """ | |
| Factory for creating LLM callers based on agent configuration. | |
| Supports: | |
| - OpenAI-compatible APIs (OpenAI, Azure, Ollama, vLLM, LiteLLM) | |
| - Caching created callers by (base_url, api_key, model_name) | |
| - Fallback to default caller if configuration is not set | |
| Example: | |
| factory = LLMCallerFactory( | |
| default_caller=my_default_caller, | |
| default_config=LLMConfig(model_name="gpt-4", base_url="...") | |
| ) | |
| # Automatically creates a caller for the agent | |
| caller = factory.get_caller(agent.get_llm_config()) | |
| response = caller(prompt) | |
| # Or with OpenAI | |
| factory = LLMCallerFactory.create_openai_factory( | |
| default_api_key="sk-...", | |
| default_model="gpt-4" | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| default_caller: LLMCallerProtocol | None = None, | |
| default_async_caller: AsyncLLMCallerProtocol | None = None, | |
| default_config: AgentLLMConfig | None = None, | |
| caller_builder: Callable[[AgentLLMConfig], LLMCallerProtocol] | None = None, | |
| async_caller_builder: Callable[[AgentLLMConfig], AsyncLLMCallerProtocol] | None = None, | |
| ): | |
| """ | |
| Create an LLM caller factory. | |
| Args: | |
| default_caller: Default caller (for agents without custom configuration). | |
| default_async_caller: Default async caller. | |
| default_config: Default configuration (merged with agent config). | |
| caller_builder: Function for creating a sync caller from configuration. | |
| async_caller_builder: Function for creating an async caller from configuration. | |
| """ | |
| self.default_caller = default_caller | |
| self.default_async_caller = default_async_caller | |
| self.default_config = default_config | |
| self.caller_builder = caller_builder | |
| self.async_caller_builder = async_caller_builder | |
| # Cache callers by config hash | |
| self._caller_cache: dict[str, LLMCallerProtocol] = {} | |
| self._async_caller_cache: dict[str, AsyncLLMCallerProtocol] = {} | |
| def _config_key(self, config: AgentLLMConfig) -> str: | |
| """Create a cache key for the configuration.""" | |
| return f"{config.base_url}|{config.model_name}|{config.api_key}" | |
| def _merge_config(self, config: AgentLLMConfig) -> AgentLLMConfig: | |
| """Merge the agent configuration with the default configuration.""" | |
| if not self.default_config: | |
| return config | |
| return AgentLLMConfig( | |
| model_name=config.model_name or self.default_config.model_name, | |
| base_url=config.base_url or self.default_config.base_url, | |
| api_key=config.api_key or self.default_config.api_key, | |
| max_tokens=config.max_tokens if config.max_tokens is not None else self.default_config.max_tokens, | |
| temperature=config.temperature if config.temperature is not None else self.default_config.temperature, | |
| timeout=config.timeout if config.timeout is not None else self.default_config.timeout, | |
| top_p=config.top_p if config.top_p is not None else self.default_config.top_p, | |
| stop_sequences=config.stop_sequences or self.default_config.stop_sequences, | |
| extra_params={**self.default_config.extra_params, **config.extra_params}, | |
| ) | |
| def get_caller( | |
| self, | |
| config: AgentLLMConfig | None = None, | |
| _agent_id: str | None = None, | |
| ) -> LLMCallerProtocol | None: | |
| """ | |
| Get sync caller for the given configuration. | |
| Args: | |
| config: Agent LLM configuration. | |
| _agent_id: Agent ID (reserved for future use). | |
| Returns: | |
| LLM caller or None if creation failed. | |
| """ | |
| if config is None or not config.is_configured(): | |
| return self.default_caller | |
| config = self._merge_config(config) | |
| cache_key = self._config_key(config) | |
| if cache_key in self._caller_cache: | |
| return self._caller_cache[cache_key] | |
| if self.caller_builder: | |
| caller = self.caller_builder(config) | |
| self._caller_cache[cache_key] = caller | |
| return caller | |
| return self.default_caller | |
| def get_async_caller( | |
| self, | |
| config: AgentLLMConfig | None = None, | |
| _agent_id: str | None = None, | |
| ) -> AsyncLLMCallerProtocol | None: | |
| """Get async caller for the given configuration.""" | |
| if config is None or not config.is_configured(): | |
| return self.default_async_caller | |
| config = self._merge_config(config) | |
| cache_key = self._config_key(config) | |
| if cache_key in self._async_caller_cache: | |
| return self._async_caller_cache[cache_key] | |
| if self.async_caller_builder: | |
| caller = self.async_caller_builder(config) | |
| self._async_caller_cache[cache_key] = caller | |
| return caller | |
| return self.default_async_caller | |
| def create_openai_factory( | |
| cls, | |
| default_api_key: str | None = None, | |
| default_model: str = "gpt-4", | |
| default_base_url: str = "https://api.openai.com/v1", | |
| default_temperature: float = 0.7, | |
| default_max_tokens: int = 2000, | |
| ) -> "LLMCallerFactory": | |
| """ | |
| Create a factory for OpenAI-compatible APIs. | |
| Automatically creates callers using the openai library. | |
| Args: | |
| default_api_key: Default API key (or $OPENAI_API_KEY). | |
| default_model: Default model name. | |
| default_base_url: Default API URL. | |
| default_temperature: Default temperature. | |
| default_max_tokens: Default maximum tokens. | |
| Example: | |
| factory = LLMCallerFactory.create_openai_factory( | |
| default_api_key="$OPENAI_API_KEY" | |
| ) | |
| runner = MACPRunner(llm_factory=factory) | |
| """ | |
| import os | |
| # Resolve default API key | |
| if default_api_key and default_api_key.startswith("$"): | |
| default_api_key = os.environ.get(default_api_key[1:]) | |
| default_config = AgentLLMConfig( | |
| model_name=default_model, | |
| base_url=default_base_url, | |
| api_key=default_api_key, | |
| temperature=default_temperature, | |
| max_tokens=default_max_tokens, | |
| ) | |
| def build_sync_caller(config: AgentLLMConfig) -> LLMCallerProtocol: | |
| return _create_openai_caller_from_config(config) | |
| def build_async_caller(config: AgentLLMConfig) -> AsyncLLMCallerProtocol: | |
| return _create_async_openai_caller_from_config(config) | |
| return cls( | |
| default_config=default_config, | |
| caller_builder=build_sync_caller, | |
| async_caller_builder=build_async_caller, | |
| ) | |
| def _create_openai_caller_from_config(config: AgentLLMConfig) -> LLMCallerProtocol: | |
| """Create an OpenAI-compatible sync caller from configuration.""" | |
| try: | |
| from openai import OpenAI | |
| except ImportError as e: | |
| msg = "openai package required. Install with: pip install openai" | |
| raise ImportError(msg) from e | |
| api_key = config.resolve_api_key() | |
| client = OpenAI( | |
| api_key=api_key, | |
| base_url=config.base_url, | |
| timeout=config.timeout or 60.0, | |
| ) | |
| gen_params = config.to_generation_params() | |
| model = config.model_name or "gpt-4" | |
| def caller(prompt: str) -> str: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| **gen_params, | |
| ) | |
| return response.choices[0].message.content or "" | |
| return caller | |
| def _create_async_openai_caller_from_config(config: AgentLLMConfig) -> AsyncLLMCallerProtocol: | |
| """Create an OpenAI-compatible async caller from configuration.""" | |
| try: | |
| from openai import AsyncOpenAI | |
| except ImportError as e: | |
| msg = "openai package required. Install with: pip install openai" | |
| raise ImportError(msg) from e | |
| api_key = config.resolve_api_key() | |
| client = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=config.base_url, | |
| timeout=config.timeout or 60.0, | |
| ) | |
| gen_params = config.to_generation_params() | |
| model = config.model_name or "gpt-4" | |
| async def caller(prompt: str) -> str: | |
| response = await client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| **gen_params, | |
| ) | |
| return response.choices[0].message.content or "" | |
| return caller | |
| def create_openai_caller( | |
| api_key: str | None = None, | |
| model: str = "gpt-4", | |
| base_url: str = "https://api.openai.com/v1", | |
| temperature: float = 0.7, | |
| max_tokens: int = 2000, | |
| ) -> LLMCallerProtocol: | |
| """ | |
| Create a simple OpenAI caller (convenience function). | |
| Returns a ``llm_caller``-compatible callable: ``(str) -> str``. | |
| Example: | |
| caller = create_openai_caller(api_key="sk-...", model="gpt-4") | |
| runner = MACPRunner(llm_caller=caller) | |
| """ | |
| config = AgentLLMConfig( | |
| model_name=model, | |
| base_url=base_url, | |
| api_key=api_key, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return _create_openai_caller_from_config(config) | |
| def create_openai_structured_caller( | |
| api_key: str | None = None, | |
| model: str = "gpt-4", | |
| base_url: str = "https://api.openai.com/v1", | |
| temperature: float = 0.7, | |
| max_tokens: int = 2000, | |
| ) -> StructuredLLMCallerProtocol: | |
| """ | |
| Create an OpenAI structured caller (recommended for chat LLMs). | |
| Returns a ``structured_llm_caller``-compatible callable that receives | |
| a list of ``{"role": ..., "content": ...}`` dicts and returns a string. | |
| This gives the LLM proper system/user role separation — shorter, | |
| more focused responses and fewer tokens in long chains. | |
| Example: | |
| caller = create_openai_structured_caller(api_key="sk-...", model="gpt-4o") | |
| runner = MACPRunner(structured_llm_caller=caller) | |
| """ | |
| try: | |
| from openai import OpenAI | |
| except ImportError as e: | |
| msg = "openai package required. Install with: pip install openai" | |
| raise ImportError(msg) from e | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| def caller(messages: list[dict[str, str]]) -> str: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return response.choices[0].message.content or "" | |
| return caller | |
| def create_openai_async_structured_caller( | |
| api_key: str | None = None, | |
| model: str = "gpt-4", | |
| base_url: str = "https://api.openai.com/v1", | |
| temperature: float = 0.7, | |
| max_tokens: int = 2000, | |
| ) -> AsyncStructuredLLMCallerProtocol: | |
| """ | |
| Create an async OpenAI structured caller. | |
| Returns an ``async_structured_llm_caller``-compatible callable. | |
| Required for parallel execution via ``astream()`` with | |
| ``enable_parallel=True``. | |
| Example: | |
| sync_caller = create_openai_structured_caller(api_key="sk-...") | |
| async_caller = create_openai_async_structured_caller(api_key="sk-...") | |
| runner = MACPRunner( | |
| structured_llm_caller=sync_caller, | |
| async_structured_llm_caller=async_caller, | |
| config=RunnerConfig(enable_parallel=True), | |
| ) | |
| # Sequential topologies — use stream() | |
| for event in runner.stream(graph): | |
| ... | |
| # Parallel topologies — use astream() | |
| async for event in runner.astream(graph): | |
| ... | |
| """ | |
| try: | |
| from openai import AsyncOpenAI | |
| except ImportError as e: | |
| msg = "openai package required. Install with: pip install openai" | |
| raise ImportError(msg) from e | |
| client = AsyncOpenAI(api_key=api_key, base_url=base_url) | |
| async def caller(messages: list[dict[str, str]]) -> str: | |
| response = await client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return response.choices[0].message.content or "" | |
| return caller | |
| class HiddenState(BaseModel): | |
| """Agent hidden state/embeddings passed via hidden channels.""" | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| tensor: torch.Tensor | None = None | |
| embedding: torch.Tensor | None = None | |
| metadata: dict[str, Any] = Field(default_factory=dict) | |
| # ============================================================================= | |
| # DYNAMIC TOPOLOGY (runtime graph modification) | |
| # ============================================================================= | |
| class StepContext(BaseModel): | |
| """ | |
| Context of the current execution step for making graph modification decisions. | |
| Passed to hooks for dynamic topology modification at runtime. | |
| Attributes: | |
| agent_id: ID of the current agent. | |
| response: Agent response (if already received). | |
| messages: All agent responses so far. | |
| step_result: Result of the current step. | |
| execution_order: Execution order up to the current moment. | |
| remaining_agents: Agents that have not yet been executed. | |
| query: Original query. | |
| total_tokens: Tokens used. | |
| metadata: Arbitrary metadata. | |
| """ | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| agent_id: str | |
| response: str | None = None | |
| messages: dict[str, str] = Field(default_factory=dict) | |
| step_result: StepResult | None = None | |
| execution_order: list[str] = Field(default_factory=list) | |
| remaining_agents: list[str] = Field(default_factory=list) | |
| query: str = "" | |
| total_tokens: int = 0 | |
| metadata: dict[str, Any] = Field(default_factory=dict) | |
| class TopologyAction(BaseModel): | |
| """ | |
| Action for modifying the graph/plan topology at runtime. | |
| Returned from hooks for dynamic graph and ExecutionPlan modification. | |
| Supports two levels of modification: | |
| - Graph (add_edges, remove_edges) — for non-adaptive mode. | |
| - Plan (skip_agents, force_agents, condition_skip/unskip) — for adaptive mode. | |
| """ | |
| # Early stopping | |
| early_stop: bool = False | |
| early_stop_reason: str | None = None | |
| # Adding/removing edges (graph modification) | |
| add_edges: list[tuple[str, str, float]] = Field(default_factory=list) # (src, tgt, weight) | |
| remove_edges: list[tuple[str, str]] = Field(default_factory=list) | |
| # Skip agents (skipped — permanently excluded from the plan) | |
| skip_agents: list[str] = Field(default_factory=list) | |
| # Force agent execution (even if not in the plan — added to the plan) | |
| force_agents: list[str] = Field(default_factory=list) | |
| # Conditional skip (condition_skipped — can be restored when conditions change) | |
| condition_skip_agents: list[str] = Field(default_factory=list) | |
| condition_unskip_agents: list[str] = Field(default_factory=list) | |
| # Add agents to the plan with their unconditional chains | |
| insert_chains: list[tuple[str, str]] = Field(default_factory=list) | |
| # Change the end agent | |
| new_end_agent: str | None = None | |
| # Reserved for future use | |
| trigger_rebuild: bool = False | |
| # Type aliases for hooks | |
| TopologyHook = Callable[[StepContext, Any], TopologyAction | None] | |
| AsyncTopologyHook = Callable[[StepContext, Any], Awaitable[TopologyAction | None]] | |
| class EarlyStopCondition: | |
| """ | |
| Condition for early stopping of execution. | |
| Allows stopping graph execution based on an arbitrary condition. | |
| The condition is any function (ctx: StepContext) -> bool. | |
| Attributes: | |
| condition: Condition evaluation function. | |
| reason: Stop reason (for logging/debugging). | |
| after_agents: List of agents after which to check the condition. | |
| min_agents_executed: Minimum number of agents that must execute before checking. | |
| Example: | |
| # Arbitrary condition — any logic | |
| stop_condition = EarlyStopCondition( | |
| condition=lambda ctx: my_complex_check(ctx.messages, ctx.metadata), | |
| reason="Custom condition met" | |
| ) | |
| # Condition based on metrics | |
| stop_condition = EarlyStopCondition( | |
| condition=lambda ctx: ctx.metadata.get("quality_score", 0) > 0.9, | |
| reason="Quality threshold reached" | |
| ) | |
| # Stop after reaching a token limit | |
| stop_condition = EarlyStopCondition( | |
| condition=lambda ctx: ctx.total_tokens > 5000, | |
| reason="Token limit reached" | |
| ) | |
| runner = MACPRunner( | |
| llm_caller=my_llm, | |
| config=RunnerConfig(early_stop_conditions=[stop_condition]) | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| condition: Callable[[StepContext], bool], | |
| reason: str = "Early stop condition met", | |
| after_agents: list[str] | None = None, | |
| min_agents_executed: int = 0, | |
| ): | |
| """ | |
| Create an early stop condition. | |
| Args: | |
| condition: Arbitrary condition function (ctx -> bool). | |
| reason: Stop reason (for logging). | |
| after_agents: List of agents after which to check the condition | |
| (if None — check after every agent). | |
| min_agents_executed: Minimum number of agents that must execute | |
| before checking the condition. | |
| """ | |
| self.condition = condition | |
| self.reason = reason | |
| self.after_agents = after_agents | |
| self.min_agents_executed = min_agents_executed | |
| def should_stop(self, ctx: StepContext) -> tuple[bool, str]: | |
| """Check whether execution should stop.""" | |
| # Check minimum number of executed agents | |
| if len(ctx.execution_order) < self.min_agents_executed: | |
| return False, "" | |
| # Check whether the current agent matches | |
| if self.after_agents and ctx.agent_id not in self.after_agents: | |
| return False, "" | |
| try: | |
| if self.condition(ctx): | |
| return True, self.reason | |
| except (ValueError, TypeError, KeyError, AttributeError, RuntimeError): | |
| # Condition evaluation failed, treat as not met | |
| return False, "" | |
| return False, "" | |
| # ========================================================================= | |
| # FACTORY METHODS for common conditions | |
| # ========================================================================= | |
| def on_keyword( | |
| cls, | |
| keyword: str, | |
| reason: str | None = None, | |
| *, | |
| case_sensitive: bool = False, | |
| in_last_response: bool = True, | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Stop if the response contains a keyword. | |
| Args: | |
| keyword: Keyword to search for. | |
| reason: Stop reason. | |
| case_sensitive: Whether to perform a case-sensitive match. | |
| in_last_response: Search only in the last response (otherwise in all). | |
| Example: | |
| stop = EarlyStopCondition.on_keyword("FINAL ANSWER") | |
| """ | |
| def check(ctx: StepContext) -> bool: | |
| text = ctx.response or "" if in_last_response else " ".join(ctx.messages.values()) | |
| if case_sensitive: | |
| return keyword in text | |
| return keyword.lower() in text.lower() | |
| return cls(condition=check, reason=reason or f"Keyword '{keyword}' found") | |
| def on_token_limit( | |
| cls, | |
| max_tokens: int, | |
| reason: str | None = None, | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Stop when the token limit is reached. | |
| Args: | |
| max_tokens: Maximum number of tokens. | |
| reason: Stop reason. | |
| Example: | |
| stop = EarlyStopCondition.on_token_limit(5000) | |
| """ | |
| return cls( | |
| condition=lambda ctx: ctx.total_tokens >= max_tokens, | |
| reason=reason or f"Token limit {max_tokens} reached", | |
| ) | |
| def on_agent_count( | |
| cls, | |
| max_agents: int, | |
| reason: str | None = None, | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Stop after N agents have been executed. | |
| Args: | |
| max_agents: Maximum number of agents. | |
| reason: Stop reason. | |
| Example: | |
| stop = EarlyStopCondition.on_agent_count(3) | |
| """ | |
| return cls( | |
| condition=lambda ctx: len(ctx.execution_order) >= max_agents, | |
| reason=reason or f"Agent count limit {max_agents} reached", | |
| ) | |
| def on_metadata( | |
| cls, | |
| key: str, | |
| value: Any = None, | |
| comparator: Callable[[Any, Any], bool] | None = None, | |
| reason: str | None = None, | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Stop based on a value in metadata. | |
| Args: | |
| key: Key in metadata. | |
| value: Expected value (if None — only presence is checked). | |
| comparator: Comparison function (default: ==). | |
| reason: Stop reason. | |
| Example: | |
| # Stop if quality > 0.9 | |
| stop = EarlyStopCondition.on_metadata( | |
| "quality", 0.9, | |
| comparator=lambda v, threshold: v > threshold | |
| ) | |
| """ | |
| def check(ctx: StepContext) -> bool: | |
| if key not in ctx.metadata: | |
| return False | |
| actual = ctx.metadata[key] | |
| if value is None: | |
| return True | |
| if comparator: | |
| return comparator(actual, value) | |
| return actual == value | |
| return cls(condition=check, reason=reason or f"Metadata condition met: {key}") | |
| def on_custom( | |
| cls, | |
| condition: Callable[[StepContext], bool], | |
| reason: str = "Custom condition met", | |
| **kwargs, | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Create a condition with an arbitrary function (constructor alias). | |
| Args: | |
| condition: Arbitrary evaluation function. | |
| reason: Stop reason. | |
| **kwargs: Additional parameters (after_agents, min_agents_executed). | |
| Example: | |
| stop = EarlyStopCondition.on_custom( | |
| lambda ctx: my_rl_agent.should_stop(ctx.messages), | |
| reason="RL agent decided to stop" | |
| ) | |
| """ | |
| return cls(condition=condition, reason=reason, **kwargs) | |
| def combine_any( | |
| cls, | |
| conditions: list["EarlyStopCondition"], | |
| reason: str = "One of conditions met", | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Combine conditions with OR (stop if at least one is met). | |
| Args: | |
| conditions: List of conditions. | |
| reason: Stop reason. | |
| Example: | |
| stop = EarlyStopCondition.combine_any([ | |
| EarlyStopCondition.on_keyword("DONE"), | |
| EarlyStopCondition.on_token_limit(10000), | |
| ]) | |
| """ | |
| def check(ctx: StepContext) -> bool: | |
| for cond in conditions: | |
| should_stop, _ = cond.should_stop(ctx) | |
| if should_stop: | |
| return True | |
| return False | |
| return cls(condition=check, reason=reason) | |
| def combine_all( | |
| cls, | |
| conditions: list["EarlyStopCondition"], | |
| reason: str = "All conditions met", | |
| ) -> "EarlyStopCondition": | |
| """ | |
| Combine conditions with AND (stop if all are met). | |
| Args: | |
| conditions: List of conditions. | |
| reason: Stop reason. | |
| Example: | |
| stop = EarlyStopCondition.combine_all([ | |
| EarlyStopCondition.on_keyword("answer"), | |
| EarlyStopCondition.on_metadata("confidence", 0.8, lambda v, t: v > t), | |
| ]) | |
| """ | |
| def check(ctx: StepContext) -> bool: | |
| for cond in conditions: | |
| should_stop, _ = cond.should_stop(ctx) | |
| if not should_stop: | |
| return False | |
| return True | |
| return cls(condition=check, reason=reason) | |
| class MACPResult(NamedTuple): | |
| """MACP execution result with messages, metrics, and states.""" | |
| messages: dict[str, str] | |
| final_answer: str | |
| final_agent_id: str | |
| execution_order: list[str] | |
| agent_states: dict[str, list[dict[str, Any]]] | None = None | |
| step_results: dict[str, StepResult] | None = None | |
| total_tokens: int = 0 | |
| total_time: float = 0.0 | |
| topology_changed_count: int = 0 | |
| fallback_count: int = 0 | |
| pruned_agents: list[str] | None = None | |
| errors: list[ExecutionError] | None = None | |
| hidden_states: dict[str, HiddenState] | None = None | |
| metrics: ExecutionMetrics | None = None | |
| budget_summary: dict[str, Any] | None = None | |
| # Dynamic topology | |
| early_stopped: bool = False | |
| early_stop_reason: str | None = None | |
| topology_modifications: int = 0 # Number of topology modifications | |
| class ExecutionContext(NamedTuple): | |
| """ | |
| Common initialization context for all execution methods. | |
| Assembled by ``MACPRunner._prepare_base_context`` and contains | |
| parsed graph data: indices, agent identifiers, and lookups. | |
| """ | |
| task_idx: int | |
| a_agents: Any # torch.Tensor adjacency matrix | |
| agent_ids: list[str] | |
| query: str | |
| agent_lookup: dict[str, Any] | |
| agent_names: dict[str, str] | |
| class RunnerConfig(BaseModel): | |
| """Runner configuration: timeouts, adaptivity, parallelism, budgets, and logging.""" | |
| model_config = ConfigDict(arbitrary_types_allowed=True) | |
| timeout: float = 60.0 | |
| adaptive: bool = False | |
| enable_parallel: bool = True | |
| max_parallel_size: int = 5 | |
| max_retries: int = 2 | |
| retry_delay: float = 1.0 | |
| retry_backoff: float = 2.0 | |
| update_states: bool = True | |
| routing_policy: RoutingPolicy = RoutingPolicy.TOPOLOGICAL | |
| pruning_config: PruningConfig | None = None | |
| enable_hidden_channels: bool = False | |
| hidden_combine_strategy: str = "mean" | |
| pass_embeddings: bool = True | |
| error_policy: ErrorPolicy = Field(default_factory=ErrorPolicy) | |
| budget_config: BudgetConfig | None = None | |
| callbacks: list[Handler] = Field(default_factory=list) | |
| # Memory integration | |
| enable_memory: bool = False | |
| memory_config: MemoryConfig | None = None | |
| memory_context_limit: int = 5 # number of memory entries to include in the prompt | |
| # Streaming configuration | |
| enable_token_streaming: bool = False # Enable token-level streaming if LLM supports it | |
| stream_callbacks: list[StreamCallback] = Field(default_factory=list) | |
| async_stream_callbacks: list[AsyncStreamCallback] = Field(default_factory=list) | |
| prompt_preview_length: int = 100 # How many chars of prompt to include in events | |
| # Task query broadcast mode | |
| broadcast_task_to_all: bool = True # True: task query is broadcast to all agents | |
| # False: only to agents directly connected to the task node | |
| # Dynamic topology (graph modification at runtime) | |
| enable_dynamic_topology: bool = False # Enable dynamic topology support | |
| topology_hooks: list[Any] = Field(default_factory=list) # TopologyHook callbacks | |
| async_topology_hooks: list[Any] = Field(default_factory=list) # AsyncTopologyHook callbacks | |
| early_stop_conditions: list[Any] = Field(default_factory=list) # EarlyStopCondition list | |
| # Tools support - if an agent has tools, they are used AUTOMATICALLY | |
| max_tool_iterations: int = 3 # Maximum tool calling iterations per agent | |
| tool_registry: Any | None = None # ToolRegistry (optional, can use global registry) | |
| class MACPRunner: | |
| """ | |
| MACP protocol executor for RoleGraph with sync/async and adaptive mode. | |
| Supports three execution modes: | |
| - Batch: run_round() / arun_round() - returns complete result | |
| - Streaming: stream() / astream() - yields events during execution | |
| Multi-model support: | |
| Each agent can use its own LLM. Three ways to specify models: | |
| 1. A single llm_caller for all agents (legacy) | |
| 2. llm_callers: dict[agent_id, Callable] — different models for different agents | |
| 3. llm_factory: LLMCallerFactory — dynamic caller creation | |
| Priority: llm_callers[agent_id] > factory > default llm_caller | |
| Streaming Example: | |
| runner = MACPRunner(llm_caller=my_llm) | |
| # Sync streaming | |
| for event in runner.stream(graph): | |
| if event.event_type == StreamEventType.AGENT_OUTPUT: | |
| print(f"{event.agent_id}: {event.content}") | |
| # Async streaming | |
| async for event in runner.astream(graph): | |
| print(event) | |
| Multi-model Example: | |
| # Option 1: Different callers for different agents | |
| runner = MACPRunner( | |
| llm_caller=default_caller, | |
| llm_callers={ | |
| "solver": gpt4_caller, | |
| "reviewer": claude_caller, | |
| "analyzer": local_llama_caller, | |
| } | |
| ) | |
| # Option 2: Factory (automatically creates callers from agent LLM configurations) | |
| factory = LLMCallerFactory.create_openai_factory(default_api_key="...") | |
| runner = MACPRunner(llm_factory=factory) | |
| # Agents with llm_config will automatically receive their callers: | |
| builder.add_agent("solver", llm_backbone="gpt-4", temperature=0.7) | |
| builder.add_agent("analyzer", llm_backbone="gpt-4o-mini", temperature=0.0) | |
| Structured Prompt (recommended for modern chat LLMs): | |
| Instead of the legacy ``llm_caller(str) -> str`` interface, use | |
| ``structured_llm_caller`` to send proper system/user roles to the LLM. | |
| This avoids the "prompt as one string" problem where the model has to | |
| re-parse a flat blob, and typically results in shorter, more focused | |
| responses with fewer tokens — especially in long agent chains. | |
| from openai import OpenAI | |
| client = OpenAI(api_key="sk-...") | |
| def my_structured_caller(messages: list[dict[str, str]]) -> str: | |
| resp = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=messages, | |
| max_tokens=1024, | |
| ) | |
| return resp.choices[0].message.content or "" | |
| runner = MACPRunner(structured_llm_caller=my_structured_caller) | |
| result = runner.run_round(graph) | |
| # The runner automatically builds StructuredPrompt with: | |
| # {"role": "system", "content": "<persona + description + tools hint + output_schema>"} | |
| # {"role": "assistant", "content": "<state[0] assistant turn>"} # replayed from agent.state | |
| # {"role": "user", "content": "<state[1] user turn>"} # replayed from agent.state | |
| # ... | |
| # {"role": "user", "content": "<task + input_schema hint + memory + incoming messages>"} | |
| # and passes prompt.messages to your structured_llm_caller. | |
| """ | |
| def __init__( | |
| self, | |
| llm_caller: Callable[[str], str] | None = None, | |
| async_llm_caller: Callable[[str], Awaitable[str]] | None = None, | |
| streaming_llm_caller: Callable[[str], Iterator[str]] | None = None, | |
| async_streaming_llm_caller: Callable[[str], AsyncIterator[str]] | None = None, | |
| token_counter: Callable[[str], int] | None = None, | |
| config: RunnerConfig | None = None, | |
| timeout: int = 60, | |
| memory_pool: SharedMemoryPool | None = None, | |
| # Multi-model support | |
| llm_callers: dict[str, Callable[[str], str]] | None = None, | |
| async_llm_callers: dict[str, Callable[[str], Awaitable[str]]] | None = None, | |
| llm_factory: LLMCallerFactory | None = None, | |
| # Tools support | |
| tool_registry: Any | None = None, | |
| # Structured prompt support (modern chat LLMs) | |
| structured_llm_caller: StructuredLLMCallerProtocol | None = None, | |
| async_structured_llm_caller: AsyncStructuredLLMCallerProtocol | None = None, | |
| ): | |
| """ | |
| Create a MACP runner with multi-model and tools support. | |
| Args: | |
| llm_caller: Default synchronous LLM call (returns full response). | |
| async_llm_caller: Default asynchronous LLM call (returns full response). | |
| streaming_llm_caller: Synchronous streaming LLM (yields tokens). | |
| async_streaming_llm_caller: Asynchronous streaming LLM (async yields tokens). | |
| token_counter: Function to estimate tokens in text. | |
| config: Runner configuration (otherwise created with the given timeout). | |
| timeout: Default timeout (seconds), used if not specified in config. | |
| memory_pool: External SharedMemoryPool (if None — created automatically). | |
| # Multi-model support: | |
| llm_callers: Dict agent_id -> sync caller. Has highest priority. | |
| async_llm_callers: Dict agent_id -> async caller. | |
| llm_factory: Factory for creating callers based on agent LLM configurations. | |
| Used if no explicit caller is set in llm_callers for the agent. | |
| # Tools support: | |
| tool_registry: Tool registry (ToolRegistry). Optional. | |
| If an agent has tools, they are used automatically. | |
| # Structured prompt support: | |
| structured_llm_caller: Callable that receives a list of | |
| ``{"role": "system"|"user", "content": "..."}`` dicts | |
| and returns a string. When provided the runner sends | |
| proper system/user roles to the LLM instead of a flat | |
| string — this typically produces shorter, more focused | |
| responses and saves tokens in long chains. | |
| async_structured_llm_caller: Async version of the above. | |
| Example: | |
| # Multi-model via caller dictionary | |
| runner = MACPRunner( | |
| llm_caller=default_gpt4_caller, | |
| llm_callers={ | |
| "analyzer": create_openai_caller(model="gpt-4o-mini"), | |
| "expert": create_openai_caller(model="gpt-4-turbo"), | |
| } | |
| ) | |
| # Multi-model via factory | |
| factory = LLMCallerFactory.create_openai_factory() | |
| runner = MACPRunner(llm_factory=factory) | |
| # Structured prompt (modern chat LLMs) | |
| def my_chat(messages: list[dict[str, str]]) -> str: | |
| return openai_client.chat.completions.create( | |
| model="gpt-4", messages=messages | |
| ).choices[0].message.content | |
| runner = MACPRunner(structured_llm_caller=my_chat) | |
| """ | |
| self.structured_llm_caller = structured_llm_caller | |
| self.async_structured_llm_caller = async_structured_llm_caller | |
| # When only a structured caller is provided, create a thin | |
| # str->str wrapper so all existing code paths that check | |
| # ``self.llm_caller is not None`` keep working. | |
| if llm_caller is None and structured_llm_caller is not None: | |
| _sc = structured_llm_caller # capture for closure | |
| def _str_wrapper(prompt: str) -> str: | |
| return _sc([{"role": "user", "content": prompt}]) | |
| llm_caller = _str_wrapper | |
| self.llm_caller = llm_caller | |
| self.async_llm_caller = async_llm_caller | |
| self.streaming_llm_caller = streaming_llm_caller | |
| self.async_streaming_llm_caller = async_streaming_llm_caller | |
| self.token_counter = token_counter or self._default_token_counter | |
| self.config = config or RunnerConfig(timeout=float(timeout)) | |
| # Multi-model support | |
| self.llm_callers = llm_callers or {} | |
| self.async_llm_callers = async_llm_callers or {} | |
| self.llm_factory = llm_factory | |
| # Tools support | |
| self.tool_registry = tool_registry | |
| self._scheduler = ( | |
| AdaptiveScheduler( | |
| policy=self.config.routing_policy, | |
| pruning_config=self.config.pruning_config, | |
| ) | |
| if self.config.adaptive | |
| else None | |
| ) | |
| self._callback_manager: CallbackManager | None = None | |
| self._budget_tracker: BudgetTracker | None = None | |
| self._metrics: ExecutionMetrics | None = None | |
| # Memory integration | |
| self._memory_pool: SharedMemoryPool | None = memory_pool | |
| self._agent_memories: dict[str, AgentMemory] = {} | |
| def _init_run( | |
| self, | |
| graph_name: str | None = None, # noqa: ARG002 | |
| num_agents: int = 0, | |
| query: str = "", | |
| execution_order: list[str] | None = None, | |
| callbacks: list[Handler] | None = None, | |
| ) -> uuid.UUID: | |
| """Initialize callbacks, budgets and metrics before running. Returns run_id.""" | |
| # Merge config callbacks with per-run callbacks and context callbacks | |
| all_callbacks = list(self.config.callbacks) | |
| if callbacks: | |
| all_callbacks.extend(callbacks) | |
| # Check for context callback manager | |
| context_manager = get_callback_manager() | |
| if context_manager: | |
| all_callbacks.extend(context_manager.handlers) | |
| self._callback_manager = CallbackManager.configure(handlers=all_callbacks) | |
| if self.config.budget_config: | |
| self._budget_tracker = BudgetTracker(self.config.budget_config) | |
| self._budget_tracker.start() | |
| else: | |
| self._budget_tracker = None | |
| self._metrics = ExecutionMetrics( | |
| start_time=datetime.now(tz=UTC), | |
| total_agents=num_agents, | |
| ) | |
| return self._callback_manager.on_run_start( | |
| query=query, | |
| num_agents=num_agents, | |
| execution_order=execution_order or [], | |
| ) | |
| def _prepare_base_context(self, role_graph: Any) -> "ExecutionContext | None": | |
| """ | |
| Assemble the common initialization context from role_graph. | |
| Extracts task_idx, adjacency matrix, agent_ids list, query, | |
| and lookup dictionaries. Does not call ``_init_memory`` — each execution | |
| method does so itself (accounting for possible agent filtering). | |
| Returns: | |
| ``ExecutionContext`` with graph data, or ``None`` if there are no agents. | |
| """ | |
| task_idx = self._get_task_index(role_graph) | |
| a_agents = extract_agent_adjacency(role_graph.A_com, task_idx) | |
| agent_ids, _ = self._get_agent_ids(role_graph, task_idx) | |
| if not agent_ids: | |
| return None | |
| query = role_graph.query or "" | |
| agent_lookup = {a.agent_id: a for a in role_graph.agents} | |
| agent_names = self._build_agent_names(role_graph) | |
| return ExecutionContext( | |
| task_idx=task_idx, | |
| a_agents=a_agents, | |
| agent_ids=agent_ids, | |
| query=query, | |
| agent_lookup=agent_lookup, | |
| agent_names=agent_names, | |
| ) | |
| def _init_memory(self, agent_ids: list[str]) -> None: | |
| """Initialize memory for agents before execution.""" | |
| if not self.config.enable_memory: | |
| return | |
| if self._memory_pool is None: | |
| self._memory_pool = SharedMemoryPool() | |
| mem_config = self.config.memory_config or MemoryConfig() | |
| for agent_id in agent_ids: | |
| if agent_id not in self._agent_memories: | |
| memory = AgentMemory(agent_id, mem_config) | |
| self._agent_memories[agent_id] = memory | |
| self._memory_pool.register(memory) | |
| def _get_memory_context(self, agent_id: str) -> list[dict[str, Any]]: | |
| """Get the latest entries from the agent's memory for context.""" | |
| if not self.config.enable_memory or agent_id not in self._agent_memories: | |
| return [] | |
| memory = self._agent_memories[agent_id] | |
| return memory.get_messages(limit=self.config.memory_context_limit) | |
| def _save_to_memory( | |
| self, | |
| agent_id: str, | |
| response: str, | |
| incoming_ids: list[str] | None = None, | |
| ) -> None: | |
| """Save the agent's response to its memory and share with neighbors.""" | |
| if not self.config.enable_memory or agent_id not in self._agent_memories: | |
| return | |
| memory = self._agent_memories[agent_id] | |
| entry = memory.add_message(role="assistant", content=response) | |
| # Share with incoming agents (graph neighbors) | |
| if self._memory_pool and incoming_ids: | |
| self._memory_pool.share(agent_id, entry, to_agents=incoming_ids) | |
| def get_agent_memory(self, agent_id: str) -> AgentMemory | None: | |
| """Get the agent's memory by id (for external access).""" | |
| return self._agent_memories.get(agent_id) | |
| def memory_pool(self) -> SharedMemoryPool | None: | |
| """Access to the SharedMemoryPool.""" | |
| return self._memory_pool | |
| # ========================================================================= | |
| # DYNAMIC TOPOLOGY METHODS | |
| # ========================================================================= | |
| def _check_early_stop( | |
| self, | |
| agent_id: str, | |
| response: str | None, | |
| messages: dict[str, str], | |
| execution_order: list[str], | |
| remaining_agents: list[str], | |
| query: str, | |
| total_tokens: int, | |
| ) -> tuple[bool, str]: | |
| """ | |
| Check early stop conditions. | |
| Returns: | |
| (should_stop, reason) | |
| """ | |
| if not self.config.early_stop_conditions: | |
| return False, "" | |
| ctx = StepContext( | |
| agent_id=agent_id, | |
| response=response, | |
| messages=messages, | |
| execution_order=execution_order, | |
| remaining_agents=remaining_agents, | |
| query=query, | |
| total_tokens=total_tokens, | |
| ) | |
| for condition in self.config.early_stop_conditions: | |
| if isinstance(condition, EarlyStopCondition): | |
| should_stop, reason = condition.should_stop(ctx) | |
| if should_stop: | |
| return True, reason | |
| return False, "" | |
| def _apply_topology_hooks( | |
| self, | |
| agent_id: str, | |
| response: str | None, | |
| step_result: StepResult | None, | |
| messages: dict[str, str], | |
| execution_order: list[str], | |
| remaining_agents: list[str], | |
| query: str, | |
| total_tokens: int, | |
| role_graph: Any, | |
| ) -> TopologyAction | None: | |
| """ | |
| Apply sync topology hooks and collect actions. | |
| Returns: | |
| Combined TopologyAction or None. | |
| """ | |
| if not self.config.enable_dynamic_topology or not self.config.topology_hooks: | |
| return None | |
| ctx = StepContext( | |
| agent_id=agent_id, | |
| response=response, | |
| step_result=step_result, | |
| messages=messages, | |
| execution_order=execution_order, | |
| remaining_agents=remaining_agents, | |
| query=query, | |
| total_tokens=total_tokens, | |
| ) | |
| combined_action = TopologyAction() | |
| for hook in self.config.topology_hooks: | |
| try: | |
| action = hook(ctx, role_graph) | |
| if action is not None: | |
| combined_action = self._merge_topology_actions(combined_action, action) | |
| except (ValueError, TypeError, KeyError, RuntimeError): | |
| pass # Ignore hook errors | |
| if self._has_topology_action(combined_action): | |
| return combined_action | |
| return None | |
| async def _apply_async_topology_hooks( | |
| self, | |
| agent_id: str, | |
| response: str | None, | |
| step_result: StepResult | None, | |
| messages: dict[str, str], | |
| execution_order: list[str], | |
| remaining_agents: list[str], | |
| query: str, | |
| total_tokens: int, | |
| role_graph: Any, | |
| ) -> TopologyAction | None: | |
| """Apply async topology hooks and collect actions.""" | |
| if not self.config.enable_dynamic_topology or not self.config.async_topology_hooks: | |
| return None | |
| ctx = StepContext( | |
| agent_id=agent_id, | |
| response=response, | |
| step_result=step_result, | |
| messages=messages, | |
| execution_order=execution_order, | |
| remaining_agents=remaining_agents, | |
| query=query, | |
| total_tokens=total_tokens, | |
| ) | |
| combined_action = TopologyAction() | |
| for hook in self.config.async_topology_hooks: | |
| try: | |
| action = await hook(ctx, role_graph) | |
| if action is not None: | |
| combined_action = self._merge_topology_actions(combined_action, action) | |
| except (ValueError, TypeError, KeyError, RuntimeError): | |
| pass # Ignore async hook errors | |
| if self._has_topology_action(combined_action): | |
| return combined_action | |
| return None | |
| def _merge_topology_actions( | |
| self, | |
| base: TopologyAction, | |
| new: TopologyAction, | |
| ) -> TopologyAction: | |
| """Merge two TopologyAction objects.""" | |
| return TopologyAction( | |
| early_stop=base.early_stop or new.early_stop, | |
| early_stop_reason=new.early_stop_reason or base.early_stop_reason, | |
| add_edges=base.add_edges + new.add_edges, | |
| remove_edges=base.remove_edges + new.remove_edges, | |
| skip_agents=list(set(base.skip_agents + new.skip_agents)), | |
| force_agents=list(set(base.force_agents + new.force_agents)), | |
| condition_skip_agents=list(set(base.condition_skip_agents + new.condition_skip_agents)), | |
| condition_unskip_agents=list(set(base.condition_unskip_agents + new.condition_unskip_agents)), | |
| insert_chains=base.insert_chains + new.insert_chains, | |
| new_end_agent=new.new_end_agent or base.new_end_agent, | |
| trigger_rebuild=base.trigger_rebuild or new.trigger_rebuild, | |
| ) | |
| def _apply_graph_modifications( | |
| self, | |
| role_graph: Any, | |
| action: TopologyAction, | |
| ) -> int: | |
| """Apply modifications to the graph and return the number of changes.""" | |
| modifications = 0 | |
| # Remove edges | |
| for src, tgt in action.remove_edges: | |
| if role_graph.remove_edge(src, tgt): | |
| modifications += 1 | |
| # Add edges | |
| for src, tgt, weight in action.add_edges: | |
| if role_graph.add_edge(src, tgt, weight): | |
| modifications += 1 | |
| return modifications | |
| def _has_topology_action(action: TopologyAction) -> bool: | |
| """Check whether the TopologyAction contains any actions.""" | |
| return bool( | |
| action.early_stop | |
| or action.add_edges | |
| or action.remove_edges | |
| or action.skip_agents | |
| or action.force_agents | |
| or action.condition_skip_agents | |
| or action.condition_unskip_agents | |
| or action.insert_chains | |
| or action.new_end_agent | |
| or action.trigger_rebuild | |
| ) | |
| def _build_conditional_edge_action( # noqa: PLR0912 | |
| self, | |
| last_agent: str, | |
| agent_ids: list[str], | |
| step_results: dict[str, StepResult], | |
| messages: dict[str, str], | |
| query: str, | |
| remaining_ids: set[str], | |
| ) -> TopologyAction | None: | |
| """ | |
| Built-in topology hook: evaluate conditional edges and return TopologyAction. | |
| Checks outgoing conditional edges of the agent and builds an action: | |
| - condition_skip_agents: agents whose conditions are not met | |
| (only if there are no other unevaluated incoming conditional edges). | |
| - condition_unskip_agents + insert_chains: agents whose conditions are met. | |
| With multiple incoming conditional edges (A→B, C→B): | |
| - If A→B is not met but C has not yet run — B is NOT skipped. | |
| - If C→B is met — B is unskipped. | |
| - B is skipped only if ALL incoming conditional edges have been evaluated and ALL are not met. | |
| Complexity: O(E_cond) — only conditional edges are checked. | |
| """ | |
| if self._scheduler is None: | |
| return None | |
| edge_conditions = self._scheduler._last_edge_conditions # noqa: SLF001 | |
| if not edge_conditions or not messages: | |
| return None | |
| evaluator = self._scheduler.condition_evaluator | |
| executed_agents = set(step_results.keys()) | |
| skip: list[str] = [] | |
| unskip: list[str] = [] | |
| chains: list[tuple[str, str]] = [] | |
| for (source, target), condition in edge_conditions.items(): | |
| if source != last_agent: | |
| continue | |
| if target not in agent_ids: | |
| continue | |
| ctx = ConditionContext( | |
| source_agent=source, | |
| target_agent=target, | |
| messages=messages, | |
| step_results=step_results, | |
| query=query, | |
| ) | |
| if evaluator.evaluate(condition, ctx): | |
| unskip.append(target) | |
| if target not in remaining_ids: | |
| chains.append((target, last_agent)) | |
| elif target in remaining_ids: | |
| # Skip only if: | |
| # 1. No other unevaluated incoming conditional edges (waiting for them). | |
| # 2. No other already-evaluated incoming conditional edge has passed. | |
| has_pending_incoming = False | |
| has_passed_incoming = False | |
| for (src, tgt), cond in edge_conditions.items(): | |
| if tgt != target or src == source: | |
| continue | |
| if src not in executed_agents: | |
| has_pending_incoming = True | |
| break | |
| # Source was already executed — check its condition | |
| other_ctx = ConditionContext( | |
| source_agent=src, | |
| target_agent=target, | |
| messages=messages, | |
| step_results=step_results, | |
| query=query, | |
| ) | |
| if evaluator.evaluate(cond, other_ctx): | |
| has_passed_incoming = True | |
| break | |
| if not has_pending_incoming and not has_passed_incoming: | |
| skip.append(target) | |
| if not skip and not unskip and not chains: | |
| return None | |
| return TopologyAction( | |
| condition_skip_agents=skip, | |
| condition_unskip_agents=unskip, | |
| insert_chains=chains, | |
| ) | |
| def _apply_topology_to_plan( | |
| self, | |
| plan: ExecutionPlan, | |
| action: TopologyAction, | |
| a_agents: Any, | |
| agent_ids: list[str], | |
| ) -> bool: | |
| """ | |
| Apply a TopologyAction to the ExecutionPlan. | |
| Single method for modifying the plan from any source: | |
| user hooks, built-in conditional edge hook, etc. | |
| Args: | |
| plan: Current execution plan. | |
| action: Action to apply. | |
| a_agents: Adjacency matrix (for BFS chains). | |
| agent_ids: List of agent IDs. | |
| Returns: | |
| True if the plan was modified. | |
| """ | |
| changed = False | |
| edge_conditions = self._scheduler._last_edge_conditions if self._scheduler else {} # noqa: SLF001 | |
| # 1. Condition skip + cascade skip of unconditional descendants | |
| for agent_id in action.condition_skip_agents: | |
| plan.condition_skipped.add(agent_id) | |
| changed = True | |
| # Cascading condition_skip of unconditional descendants (BFS) | |
| self._cascade_condition_skip( | |
| plan, | |
| agent_id, | |
| a_agents, | |
| agent_ids, | |
| edge_conditions, | |
| ) | |
| # 2. Condition unskip | |
| for agent_id in action.condition_unskip_agents: | |
| plan.condition_skipped.discard(agent_id) | |
| changed = True | |
| # 3. Skip (permanent skip) | |
| remaining_ids = {s.agent_id for s in plan.steps[plan.current_index :]} | |
| for agent_id in action.skip_agents: | |
| if agent_id in remaining_ids: | |
| plan.condition_skipped.add(agent_id) | |
| changed = True | |
| # 4. Force agents — add to the plan if not already there | |
| for agent_id in action.force_agents: | |
| if agent_id not in remaining_ids and agent_id in agent_ids: | |
| plan.condition_skipped.discard(agent_id) | |
| added = plan.insert_conditional_step(agent_id=agent_id, predecessors=[]) | |
| if added: | |
| remaining_ids.add(agent_id) | |
| changed = True | |
| # 5. Insert chains — add the agent + its unconditional chain | |
| for target, predecessor in action.insert_chains: | |
| plan.condition_skipped.discard(target) | |
| if target not in remaining_ids: | |
| added = plan.insert_conditional_step( | |
| agent_id=target, | |
| predecessors=[predecessor], | |
| ) | |
| if added: | |
| remaining_ids.add(target) | |
| changed = True | |
| # BFS — add the unconditional chain after target | |
| self._insert_unconditional_chain( | |
| plan, | |
| target, | |
| a_agents, | |
| agent_ids, | |
| edge_conditions, | |
| remaining_ids, | |
| ) | |
| else: | |
| changed = True | |
| return changed | |
| def _cascade_condition_skip( | |
| plan: ExecutionPlan, | |
| skipped_agent: str, | |
| a_agents: Any, | |
| agent_ids: list[str], | |
| edge_conditions: dict[tuple[str, str], Any], | |
| ) -> None: | |
| """ | |
| BFS: cascading condition_skip of unconditional descendants of skipped_agent. | |
| If an agent is condition_skipped, its unconditional descendants should also be | |
| skipped (if they have no other incoming data paths). | |
| """ | |
| queue = deque([skipped_agent]) | |
| visited = {skipped_agent} | |
| remaining_ids = {s.agent_id for s in plan.steps[plan.current_index :]} | |
| while queue: | |
| current = queue.popleft() | |
| if current not in agent_ids: | |
| continue | |
| current_idx = agent_ids.index(current) | |
| for j, aid in enumerate(agent_ids): | |
| if aid in visited or aid not in remaining_ids: | |
| continue | |
| weight = a_agents[current_idx, j] | |
| if hasattr(weight, "item"): | |
| weight = weight.item() | |
| if weight <= _MIN_EDGE_WEIGHT: | |
| continue | |
| # Skip conditional edges — they are handled separately | |
| if (current, aid) in edge_conditions: | |
| continue | |
| # Check: does aid have other incoming unconditional edges | |
| # from agents that are NOT condition_skipped? | |
| has_other_source = False | |
| for k, src in enumerate(agent_ids): | |
| if src == current or src in plan.condition_skipped: | |
| continue | |
| w = a_agents[k, j] | |
| if hasattr(w, "item"): | |
| w = w.item() | |
| if w > _MIN_EDGE_WEIGHT and (src, aid) not in edge_conditions: | |
| has_other_source = True | |
| break | |
| if not has_other_source: | |
| visited.add(aid) | |
| plan.condition_skipped.add(aid) | |
| queue.append(aid) | |
| def _insert_unconditional_chain( | |
| plan: ExecutionPlan, | |
| start_agent: str, | |
| a_agents: Any, | |
| agent_ids: list[str], | |
| edge_conditions: dict[tuple[str, str], Any], | |
| remaining_ids: set[str], | |
| ) -> None: | |
| """ | |
| BFS: add a chain of unconditionally linked agents to the plan after start_agent. | |
| Traverses edges without conditions and adds all subsequent agents to the plan. | |
| """ | |
| queue = deque([start_agent]) | |
| visited = {start_agent} | |
| while queue: | |
| current = queue.popleft() | |
| if current not in agent_ids: | |
| continue | |
| current_idx = agent_ids.index(current) | |
| for j, aid in enumerate(agent_ids): | |
| if aid in visited: | |
| continue | |
| weight = a_agents[current_idx, j] | |
| if hasattr(weight, "item"): | |
| weight = weight.item() | |
| if weight <= _MIN_EDGE_WEIGHT: | |
| continue | |
| # Skip conditional edges — they are handled separately | |
| if (current, aid) in edge_conditions: | |
| continue | |
| visited.add(aid) | |
| if aid not in remaining_ids: | |
| plan.condition_skipped.discard(aid) | |
| added = plan.insert_conditional_step(agent_id=aid, predecessors=[current]) | |
| if added: | |
| remaining_ids.add(aid) | |
| queue.append(aid) | |
| def _run_topology_pipeline( | |
| self, | |
| plan: ExecutionPlan, | |
| last_agent: str, | |
| a_agents: Any, | |
| agent_ids: list[str], | |
| step_results: dict[str, StepResult], | |
| messages: dict[str, str], | |
| query: str, | |
| execution_order: list[str], | |
| total_tokens: int, | |
| role_graph: Any, | |
| ) -> bool: | |
| """ | |
| Unified sync pipeline: built-in conditional hook + user hooks → plan. | |
| Called after each step in adaptive methods. | |
| Combines all TopologyAction objects and applies them to the plan. | |
| Returns: | |
| True if the plan was modified. | |
| """ | |
| remaining = [s.agent_id for s in plan.remaining_steps] | |
| # 1. Built-in hook: conditional edges | |
| remaining_ids = {s.agent_id for s in plan.steps[plan.current_index :]} | |
| cond_action = self._build_conditional_edge_action( | |
| last_agent, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| remaining_ids, | |
| ) | |
| # 2. User topology hooks | |
| user_action = self._apply_topology_hooks( | |
| last_agent, | |
| messages.get(last_agent), | |
| step_results.get(last_agent), | |
| messages, | |
| execution_order, | |
| remaining, | |
| query, | |
| total_tokens, | |
| role_graph, | |
| ) | |
| # 3. Combine actions | |
| combined = TopologyAction() | |
| if cond_action is not None: | |
| combined = self._merge_topology_actions(combined, cond_action) | |
| if user_action is not None: | |
| combined = self._merge_topology_actions(combined, user_action) | |
| if not self._has_topology_action(combined): | |
| return False | |
| # 4. Graph modification (add_edges / remove_edges) | |
| if combined.add_edges or combined.remove_edges: | |
| self._apply_graph_modifications(role_graph, combined) | |
| # 5. Plan modification | |
| return self._apply_topology_to_plan(plan, combined, a_agents, agent_ids) | |
| async def _arun_topology_pipeline( | |
| self, | |
| plan: ExecutionPlan, | |
| last_agent: str, | |
| a_agents: Any, | |
| agent_ids: list[str], | |
| step_results: dict[str, StepResult], | |
| messages: dict[str, str], | |
| query: str, | |
| execution_order: list[str], | |
| total_tokens: int, | |
| role_graph: Any, | |
| ) -> bool: | |
| """ | |
| Unified async pipeline: built-in conditional hook + async user hooks → plan. | |
| Returns: | |
| True if the plan was modified. | |
| """ | |
| remaining = [s.agent_id for s in plan.remaining_steps] | |
| # 1. Built-in hook: conditional edges (sync — fast) | |
| remaining_ids = {s.agent_id for s in plan.steps[plan.current_index :]} | |
| cond_action = self._build_conditional_edge_action( | |
| last_agent, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| remaining_ids, | |
| ) | |
| # 2. User async topology hooks | |
| user_action = await self._apply_async_topology_hooks( | |
| last_agent, | |
| messages.get(last_agent), | |
| step_results.get(last_agent), | |
| messages, | |
| execution_order, | |
| remaining, | |
| query, | |
| total_tokens, | |
| role_graph, | |
| ) | |
| # 3. Combine actions | |
| combined = TopologyAction() | |
| if cond_action is not None: | |
| combined = self._merge_topology_actions(combined, cond_action) | |
| if user_action is not None: | |
| combined = self._merge_topology_actions(combined, user_action) | |
| if not self._has_topology_action(combined): | |
| return False | |
| # 4. Graph modification | |
| if combined.add_edges or combined.remove_edges: | |
| self._apply_graph_modifications(role_graph, combined) | |
| # 5. Plan modification | |
| return self._apply_topology_to_plan(plan, combined, a_agents, agent_ids) | |
| def _finalize_run( | |
| self, | |
| run_id: uuid.UUID, | |
| *, | |
| success: bool, | |
| executed_agents: int, | |
| final_answer: str = "", | |
| error: BaseException | None = None, | |
| executed_agent_ids: list[str] | None = None, | |
| ) -> None: | |
| """Finalize metrics and notify callbacks after execution.""" | |
| if self._metrics: | |
| self._metrics.end_time = datetime.now(tz=UTC) | |
| self._metrics.executed_agents = executed_agents | |
| if self._callback_manager: | |
| self._callback_manager.on_run_end( | |
| run_id=run_id, | |
| output=final_answer, | |
| success=success, | |
| error=error, | |
| total_tokens=self._metrics.total_tokens if self._metrics else 0, | |
| total_time_ms=self._metrics.duration_seconds * 1000 if self._metrics else 0, | |
| executed_agents=executed_agent_ids or [], | |
| ) | |
| def _default_token_counter(text: str) -> int: | |
| """Simple token estimate: 4/3 of the number of words.""" | |
| return len(text.split()) * 4 // 3 | |
| def _get_caller_for_agent( | |
| self, | |
| agent_id: str, | |
| agent: Any, | |
| ) -> Callable[[str], str] | None: | |
| """ | |
| Get sync LLM caller for a specific agent. | |
| Priority: | |
| 1. llm_callers[agent_id] — explicitly specified caller for the agent | |
| 2. llm_factory.get_caller(agent.llm_config) — from the factory by configuration | |
| 3. self.llm_caller — default caller | |
| Returns: | |
| LLM caller or None if none is available. | |
| """ | |
| # 1. Check explicit per-agent caller | |
| if agent_id in self.llm_callers: | |
| return self.llm_callers[agent_id] | |
| # 2. Try factory if agent has LLM config | |
| if self.llm_factory and hasattr(agent, "get_llm_config"): | |
| llm_config = agent.get_llm_config() | |
| if llm_config and llm_config.is_configured(): | |
| caller = self.llm_factory.get_caller(llm_config, agent_id) | |
| if caller: | |
| return caller | |
| elif self.llm_factory and hasattr(agent, "llm_config") and agent.llm_config: | |
| caller = self.llm_factory.get_caller(agent.llm_config, agent_id) | |
| if caller: | |
| return caller | |
| # 3. Fallback to default caller | |
| return self.llm_caller | |
| def _get_async_caller_for_agent( | |
| self, | |
| agent_id: str, | |
| agent: Any, | |
| ) -> Callable[[str], Awaitable[str]] | None: | |
| """ | |
| Get async LLM caller for a specific agent. | |
| Priority: | |
| 1. async_llm_callers[agent_id] — explicitly specified async caller | |
| 2. llm_factory.get_async_caller(agent.llm_config) — from the factory | |
| 3. self.async_llm_caller — default async caller | |
| Note: | |
| May return ``None`` even when ``async_structured_llm_caller`` is | |
| set. Callers must check ``self.async_structured_llm_caller`` | |
| separately before treating ``None`` as a fatal error — see | |
| :meth:`_acall_llm` which dispatches to the structured path first. | |
| """ | |
| # 1. Check explicit per-agent async caller | |
| if agent_id in self.async_llm_callers: | |
| return self.async_llm_callers[agent_id] | |
| # 2. Try factory if agent has LLM config | |
| if self.llm_factory and hasattr(agent, "get_llm_config"): | |
| llm_config = agent.get_llm_config() | |
| if llm_config and llm_config.is_configured(): | |
| caller = self.llm_factory.get_async_caller(llm_config, agent_id) | |
| if caller: | |
| return caller | |
| elif self.llm_factory and hasattr(agent, "llm_config") and agent.llm_config: | |
| caller = self.llm_factory.get_async_caller(agent.llm_config, agent_id) | |
| if caller: | |
| return caller | |
| # 3. Fallback to default async caller | |
| return self.async_llm_caller | |
| # ------------------------------------------------------------------ | |
| # Structured prompt dispatch | |
| # ------------------------------------------------------------------ | |
| def _call_llm(self, caller: Callable, prompt: StructuredPrompt) -> str: | |
| """ | |
| Call the LLM using the best available interface. | |
| If a ``structured_llm_caller`` is registered, sends | |
| ``prompt.messages`` (proper system/user roles). | |
| Otherwise falls back to ``caller(prompt.text)`` (flat string). | |
| """ | |
| if self.structured_llm_caller is not None: | |
| return self.structured_llm_caller(prompt.messages) | |
| return caller(prompt.text) | |
| async def _acall_llm(self, async_caller: Callable | None, prompt: StructuredPrompt) -> str: | |
| """Async version of :meth:`_call_llm`.""" | |
| if self.async_structured_llm_caller is not None: | |
| return await self.async_structured_llm_caller(prompt.messages) | |
| if async_caller is None: | |
| msg = "No async LLM caller available" | |
| raise ValueError(msg) | |
| return await async_caller(prompt.text) | |
| def _has_any_caller(self) -> bool: | |
| """Check whether at least one LLM caller is available.""" | |
| return bool( | |
| self.llm_caller | |
| or self.llm_callers | |
| or self.structured_llm_caller | |
| or (self.llm_factory and (self.llm_factory.default_caller or self.llm_factory.caller_builder)) | |
| ) | |
| def _has_any_async_caller(self) -> bool: | |
| """Check whether at least one async LLM caller is available.""" | |
| return bool( | |
| self.async_llm_caller | |
| or self.async_llm_callers | |
| or self.async_structured_llm_caller | |
| or (self.llm_factory and (self.llm_factory.default_async_caller or self.llm_factory.async_caller_builder)) | |
| ) | |
| def run_round( | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None = None, | |
| start_agent_id: str | None = None, | |
| *, | |
| update_states: bool | None = None, | |
| filter_unreachable: bool = False, | |
| callbacks: list[Handler] | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Run a synchronous round (simple or adaptive strategy). | |
| Args: | |
| role_graph: Role graph to execute. | |
| final_agent_id: ID of final agent (overrides role_graph.end_node). | |
| start_agent_id: ID of start agent (overrides role_graph.start_node). | |
| update_states: Whether to update agent states. | |
| filter_unreachable: Exclude isolated nodes from execution. | |
| callbacks: Per-run callback handlers (merged with config.callbacks). | |
| Returns: | |
| MACPResult with execution results. | |
| """ | |
| if not self._has_any_caller(): | |
| msg = "llm_caller, llm_callers, or llm_factory is required for synchronous execution" | |
| raise ValueError(msg) | |
| # Get start/end from params or graph | |
| effective_start = start_agent_id or getattr(role_graph, "start_node", None) | |
| effective_end = final_agent_id or getattr(role_graph, "end_node", None) | |
| if self.config.adaptive: | |
| return self._run_adaptive( | |
| role_graph, | |
| effective_end, | |
| effective_start, | |
| update_states=update_states, | |
| filter_unreachable=filter_unreachable, | |
| callbacks=callbacks, | |
| ) | |
| return self._run_simple( | |
| role_graph, | |
| effective_end, | |
| effective_start, | |
| update_states=update_states, | |
| filter_unreachable=filter_unreachable, | |
| callbacks=callbacks, | |
| ) | |
| async def arun_round( | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None = None, | |
| start_agent_id: str | None = None, | |
| *, | |
| update_states: bool | None = None, | |
| filter_unreachable: bool = False, | |
| callbacks: list[Handler] | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Run an async round (simple or adaptive strategy). | |
| Args: | |
| role_graph: Role graph to execute. | |
| final_agent_id: ID of final agent (overrides role_graph.end_node). | |
| start_agent_id: ID of start agent (overrides role_graph.start_node). | |
| update_states: Whether to update agent states. | |
| filter_unreachable: Exclude isolated nodes from execution. | |
| callbacks: Per-run callback handlers (merged with config.callbacks). | |
| Returns: | |
| MACPResult with execution results. | |
| """ | |
| if not self._has_any_async_caller(): | |
| msg = "async_llm_caller, async_llm_callers, or llm_factory is required for async execution" | |
| raise ValueError(msg) | |
| # Get start/end from params or graph | |
| effective_start = start_agent_id or getattr(role_graph, "start_node", None) | |
| effective_end = final_agent_id or getattr(role_graph, "end_node", None) | |
| if self.config.adaptive: | |
| return await self._arun_adaptive( | |
| role_graph, | |
| effective_end, | |
| effective_start, | |
| update_states=update_states, | |
| filter_unreachable=filter_unreachable, | |
| callbacks=callbacks, | |
| ) | |
| return await self._arun_simple( | |
| role_graph, | |
| effective_end, | |
| effective_start, | |
| update_states=update_states, | |
| filter_unreachable=filter_unreachable, | |
| callbacks=callbacks, | |
| ) | |
| def _run_simple( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| start_agent_id: str | None, | |
| *, | |
| update_states: bool | None = None, | |
| filter_unreachable: bool = True, | |
| callbacks: list[Handler] | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Sequential execution in topological order without adaptation. | |
| Supports multi-model: each agent uses its own LLM caller. | |
| Supports filtering of isolated nodes to save tokens. | |
| """ | |
| if not self._has_any_caller(): | |
| msg = "llm_caller, llm_callers, or llm_factory is required for synchronous execution" | |
| raise ValueError(msg) | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| return MACPResult(messages={}, final_answer="", final_agent_id="", execution_order=[]) | |
| _task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| # Filter isolated nodes | |
| excluded_agents: list[str] = [] | |
| effective_agent_ids = agent_ids | |
| effective_a = a_agents | |
| if filter_unreachable and (start_agent_id is not None or final_agent_id is not None): | |
| relevant, excluded_agents = filter_reachable_agents(a_agents, agent_ids, start_agent_id, final_agent_id) | |
| if relevant and len(relevant) < len(agent_ids): | |
| indices = [agent_ids.index(aid) for aid in relevant] | |
| indices_t = torch.tensor(indices, dtype=torch.long) | |
| effective_a = a_agents[indices_t][:, indices_t] | |
| effective_agent_ids = relevant | |
| exec_order = build_execution_order(effective_a, effective_agent_ids, role_graph.role_sequence) | |
| # Initialize memory (with effective agents after filtering) | |
| self._init_memory(effective_agent_ids) | |
| # Initialize callbacks | |
| run_id = self._init_run( | |
| graph_name=getattr(role_graph, "name", None), | |
| num_agents=len(effective_agent_ids), | |
| query=query, | |
| execution_order=exec_order, | |
| callbacks=callbacks, | |
| ) | |
| task_connected = self._get_task_connected_agents(role_graph) | |
| messages: dict[str, str] = {} | |
| total_tokens = 0 | |
| actual_exec_order: list[str] = [] | |
| early_stopped = False | |
| early_stop_reason: str | None = None | |
| topology_modifications = 0 | |
| skipped_by_hooks: set[str] = set() | |
| run_error: BaseException | None = None | |
| # Get disabled nodes from graph | |
| disabled_nodes: set[str] = getattr(role_graph, "disabled_nodes", set()) | |
| try: | |
| for step_idx, agent_id in enumerate(exec_order): | |
| # Check if agent was skipped by hooks | |
| if agent_id in skipped_by_hooks: | |
| continue | |
| # Check if node is disabled | |
| if agent_id in disabled_nodes: | |
| if agent_id not in excluded_agents: | |
| excluded_agents.append(agent_id) | |
| continue | |
| agent = agent_lookup.get(agent_id) | |
| if agent is None: | |
| continue | |
| incoming_ids = get_incoming_agents(agent_id, effective_a, effective_agent_ids) | |
| incoming_messages = {aid: messages[aid] for aid in incoming_ids if aid in messages} | |
| include_query = self._should_include_query(agent_id, task_connected) | |
| memory_context = self._get_memory_context(agent_id) | |
| prompt = self._build_prompt( | |
| agent, query, incoming_messages, agent_names, memory_context, include_query=include_query | |
| ) | |
| prompt_text = prompt.text | |
| # Notify callbacks of agent start | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_start( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_names.get(agent_id, agent_id), | |
| step_index=step_idx, | |
| prompt=prompt_text[: self.config.prompt_preview_length], | |
| predecessors=incoming_ids, | |
| ) | |
| agent_start_time = time.time() | |
| try: | |
| # Get caller for this specific agent (multi-model support) | |
| caller = self._get_caller_for_agent(agent_id, agent) | |
| if caller is None: | |
| error_msg = f"No LLM caller available for agent {agent_id}" | |
| messages[agent_id] = f"[Error: {error_msg}]" | |
| actual_exec_order.append(agent_id) | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_error( | |
| run_id=run_id, | |
| error=ValueError(error_msg), | |
| agent_id=agent_id, | |
| error_type="NoCallerError", | |
| ) | |
| continue | |
| # Execute LLM caller with tools support | |
| response, agent_tokens = self._run_agent_with_tools( | |
| caller=caller, | |
| prompt=prompt, | |
| agent=agent, | |
| ) | |
| agent_duration_ms = (time.time() - agent_start_time) * 1000 | |
| messages[agent_id] = response | |
| total_tokens += agent_tokens | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| actual_exec_order.append(agent_id) | |
| # Notify callbacks of agent end | |
| if self._callback_manager: | |
| is_final = agent_id == final_agent_id or (final_agent_id is None and agent_id == exec_order[-1]) | |
| self._callback_manager.on_agent_end( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| output=response, | |
| agent_name=agent_names.get(agent_id, agent_id), | |
| step_index=step_idx, | |
| tokens_used=agent_tokens, | |
| duration_ms=agent_duration_ms, | |
| is_final=is_final, | |
| ) | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| messages[agent_id] = f"[Error: {e}]" | |
| actual_exec_order.append(agent_id) | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_error( | |
| run_id=run_id, | |
| error=e, | |
| agent_id=agent_id, | |
| error_type=type(e).__name__, | |
| ) | |
| # Check early stopping | |
| remaining = [a for a in exec_order if a not in messages and a not in skipped_by_hooks] | |
| should_stop, reason = self._check_early_stop( | |
| agent_id, | |
| messages.get(agent_id), | |
| messages, | |
| actual_exec_order, | |
| remaining, | |
| query, | |
| total_tokens, | |
| ) | |
| if should_stop: | |
| early_stopped = True | |
| early_stop_reason = reason | |
| break | |
| # Apply topology hooks | |
| if self.config.enable_dynamic_topology: | |
| action = self._apply_topology_hooks( | |
| agent_id, | |
| messages.get(agent_id), | |
| None, | |
| messages, | |
| actual_exec_order, | |
| remaining, | |
| query, | |
| total_tokens, | |
| role_graph, | |
| ) | |
| if action is not None: | |
| if action.early_stop: | |
| early_stopped = True | |
| early_stop_reason = action.early_stop_reason | |
| break | |
| if action.skip_agents: | |
| skipped_by_hooks.update(action.skip_agents) | |
| topology_modifications += self._apply_graph_modifications(role_graph, action) | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| run_error = e | |
| final_id = self._determine_final_agent(final_agent_id, actual_exec_order, messages) | |
| final_answer = messages.get(final_id, "") | |
| # Finalize callbacks | |
| self._finalize_run( | |
| run_id=run_id, | |
| success=run_error is None, | |
| executed_agents=len(actual_exec_order), | |
| final_answer=final_answer, | |
| error=run_error, | |
| executed_agent_ids=actual_exec_order, | |
| ) | |
| if run_error: | |
| raise run_error | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| return MACPResult( | |
| messages=messages, | |
| final_answer=messages.get(final_id, ""), | |
| final_agent_id=final_id, | |
| execution_order=actual_exec_order, | |
| agent_states=agent_states, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| pruned_agents=excluded_agents if excluded_agents else None, | |
| early_stopped=early_stopped, | |
| early_stop_reason=early_stop_reason, | |
| topology_modifications=topology_modifications, | |
| ) | |
| async def _arun_simple( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| start_agent_id: str | None, | |
| *, | |
| update_states: bool | None = None, | |
| filter_unreachable: bool = True, | |
| callbacks: list[Handler] | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Async sequential execution without adaptation. | |
| Supports multi-model: each agent uses its own LLM caller. | |
| Supports filtering of isolated nodes to save tokens. | |
| """ | |
| if not self._has_any_async_caller(): | |
| msg = "async_llm_caller, async_llm_callers, or llm_factory is required for async execution" | |
| raise ValueError(msg) | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| return MACPResult(messages={}, final_answer="", final_agent_id="", execution_order=[]) | |
| _task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| # Filter isolated nodes | |
| excluded_agents: list[str] = [] | |
| effective_agent_ids = agent_ids | |
| effective_a = a_agents | |
| if filter_unreachable and (start_agent_id is not None or final_agent_id is not None): | |
| relevant, excluded_agents = filter_reachable_agents(a_agents, agent_ids, start_agent_id, final_agent_id) | |
| if relevant and len(relevant) < len(agent_ids): | |
| indices = [agent_ids.index(aid) for aid in relevant] | |
| indices_t = torch.tensor(indices, dtype=torch.long) | |
| effective_a = a_agents[indices_t][:, indices_t] | |
| effective_agent_ids = relevant | |
| exec_order = build_execution_order(effective_a, effective_agent_ids, role_graph.role_sequence) | |
| # Initialize memory (with effective agents after filtering) | |
| self._init_memory(effective_agent_ids) | |
| # Initialize callbacks | |
| run_id = self._init_run( | |
| graph_name=getattr(role_graph, "name", None), | |
| num_agents=len(effective_agent_ids), | |
| query=query, | |
| execution_order=exec_order, | |
| callbacks=callbacks, | |
| ) | |
| task_connected = self._get_task_connected_agents(role_graph) | |
| messages: dict[str, str] = {} | |
| total_tokens = 0 | |
| actual_exec_order: list[str] = [] | |
| early_stopped = False | |
| early_stop_reason: str | None = None | |
| topology_modifications = 0 | |
| skipped_by_hooks: set[str] = set() | |
| run_error: BaseException | None = None | |
| # Get disabled nodes from graph | |
| disabled_nodes: set[str] = getattr(role_graph, "disabled_nodes", set()) | |
| try: | |
| for step_idx, agent_id in enumerate(exec_order): | |
| # Check if agent was skipped by hooks | |
| if agent_id in skipped_by_hooks: | |
| continue | |
| # Check if node is disabled | |
| if agent_id in disabled_nodes: | |
| if agent_id not in excluded_agents: | |
| excluded_agents.append(agent_id) | |
| continue | |
| agent = agent_lookup.get(agent_id) | |
| if agent is None: | |
| continue | |
| incoming_ids = get_incoming_agents(agent_id, effective_a, effective_agent_ids) | |
| incoming_messages = {aid: messages[aid] for aid in incoming_ids if aid in messages} | |
| include_query = self._should_include_query(agent_id, task_connected) | |
| memory_context = self._get_memory_context(agent_id) | |
| prompt = self._build_prompt( | |
| agent, query, incoming_messages, agent_names, memory_context, include_query=include_query | |
| ) | |
| # Notify callbacks of agent start | |
| # prompt is now a StructuredPrompt | |
| prompt_text = prompt.text | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_start( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_names.get(agent_id, agent_id), | |
| step_index=step_idx, | |
| prompt=prompt_text[:100], | |
| predecessors=incoming_ids, | |
| ) | |
| agent_start_time = time.time() | |
| try: | |
| # Get async caller for this specific agent (multi-model support). | |
| # When only async_structured_llm_caller is configured, | |
| # async_caller may be None — _acall_llm handles that. | |
| async_caller = self._get_async_caller_for_agent(agent_id, agent) | |
| if async_caller is None and self.async_structured_llm_caller is None: | |
| error_msg = f"No async LLM caller available for agent {agent_id}" | |
| messages[agent_id] = f"[Error: {error_msg}]" | |
| actual_exec_order.append(agent_id) | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_error( | |
| run_id=run_id, | |
| error=ValueError(error_msg), | |
| agent_id=agent_id, | |
| error_type="NoCallerError", | |
| ) | |
| continue | |
| response = await asyncio.wait_for( | |
| self._acall_llm(async_caller, prompt), | |
| timeout=self.config.timeout, | |
| ) | |
| agent_tokens = self.token_counter(prompt_text) + self.token_counter(response) | |
| agent_duration_ms = (time.time() - agent_start_time) * 1000 | |
| messages[agent_id] = response | |
| total_tokens += agent_tokens | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| actual_exec_order.append(agent_id) | |
| # Notify callbacks of agent end | |
| if self._callback_manager: | |
| is_final = agent_id == final_agent_id or (final_agent_id is None and agent_id == exec_order[-1]) | |
| self._callback_manager.on_agent_end( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| output=response, | |
| agent_name=agent_names.get(agent_id, agent_id), | |
| step_index=step_idx, | |
| tokens_used=agent_tokens, | |
| duration_ms=agent_duration_ms, | |
| is_final=is_final, | |
| ) | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| messages[agent_id] = f"[Error: {e}]" | |
| actual_exec_order.append(agent_id) | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_error( | |
| run_id=run_id, | |
| error=e, | |
| agent_id=agent_id, | |
| error_type=type(e).__name__, | |
| ) | |
| # Check early stopping | |
| remaining = [a for a in exec_order if a not in messages and a not in skipped_by_hooks] | |
| should_stop, reason = self._check_early_stop( | |
| agent_id, | |
| messages.get(agent_id), | |
| messages, | |
| actual_exec_order, | |
| remaining, | |
| query, | |
| total_tokens, | |
| ) | |
| if should_stop: | |
| early_stopped = True | |
| early_stop_reason = reason | |
| break | |
| # Apply async topology hooks | |
| if self.config.enable_dynamic_topology: | |
| action = await self._apply_async_topology_hooks( | |
| agent_id, | |
| messages.get(agent_id), | |
| None, | |
| messages, | |
| actual_exec_order, | |
| remaining, | |
| query, | |
| total_tokens, | |
| role_graph, | |
| ) | |
| if action is not None: | |
| if action.early_stop: | |
| early_stopped = True | |
| early_stop_reason = action.early_stop_reason | |
| break | |
| if action.skip_agents: | |
| skipped_by_hooks.update(action.skip_agents) | |
| topology_modifications += self._apply_graph_modifications(role_graph, action) | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| run_error = e | |
| final_id = self._determine_final_agent(final_agent_id, actual_exec_order, messages) | |
| final_answer = messages.get(final_id, "") | |
| # Finalize callbacks | |
| self._finalize_run( | |
| run_id=run_id, | |
| success=run_error is None, | |
| executed_agents=len(actual_exec_order), | |
| final_answer=final_answer, | |
| error=run_error, | |
| executed_agent_ids=actual_exec_order, | |
| ) | |
| if run_error: | |
| raise run_error | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| return MACPResult( | |
| messages=messages, | |
| final_answer=final_answer, | |
| final_agent_id=final_id, | |
| execution_order=actual_exec_order, | |
| agent_states=agent_states, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| pruned_agents=excluded_agents if excluded_agents else None, | |
| early_stopped=early_stopped, | |
| early_stop_reason=early_stop_reason, | |
| topology_modifications=topology_modifications, | |
| ) | |
| def _run_adaptive( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| start_agent_id: str | None, | |
| *, | |
| update_states: bool | None = None, | |
| filter_unreachable: bool = True, | |
| callbacks: list[Handler] | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Adaptive sync execution with conditional edges and fallback. | |
| Supports multi-model: each agent uses its own LLM caller. | |
| Supports filtering of isolated nodes to save tokens. | |
| Per-call ``callbacks`` are merged with ``RunnerConfig.callbacks`` | |
| and context-manager callbacks (similar to ``_run_simple``). | |
| """ | |
| if not self._has_any_caller(): | |
| msg = "llm_caller, llm_callers, or llm_factory is required for synchronous execution" | |
| raise ValueError(msg) | |
| if self._scheduler is None: | |
| msg = "Scheduler not initialized for adaptive mode" | |
| raise ValueError(msg) | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| return MACPResult(messages={}, final_answer="", final_agent_id="", execution_order=[]) | |
| task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| # Initialize memory | |
| self._init_memory(agent_ids) | |
| p_matrix = self._extract_p_matrix(role_graph, task_idx) | |
| # Get conditions from graph for conditional routing | |
| edge_conditions = self._get_edge_conditions(role_graph) | |
| # Initial context for conditions | |
| condition_ctx = ConditionContext( | |
| source_agent="", | |
| target_agent="", | |
| messages={}, | |
| step_results={}, | |
| query=query, | |
| ) | |
| plan = self._scheduler.build_plan( | |
| a_agents, | |
| agent_ids, | |
| p_matrix, | |
| start_agent=start_agent_id, | |
| end_agent=final_agent_id, | |
| edge_conditions=edge_conditions, | |
| condition_context=condition_ctx, | |
| filter_unreachable=filter_unreachable, | |
| ) | |
| # Initialize callbacks (per-call + config + context-manager) | |
| run_id = self._init_run( | |
| graph_name=getattr(role_graph, "name", None), | |
| num_agents=len(agent_ids), | |
| query=query, | |
| execution_order=plan.execution_order, | |
| callbacks=callbacks, | |
| ) | |
| messages: dict[str, str] = {} | |
| step_results: dict[str, StepResult] = {} | |
| execution_order: list[str] = [] | |
| fallback_attempts: dict[str, int] = {} | |
| topology_changed_count = 0 | |
| fallback_count = 0 | |
| pruned_agents: list[str] = [] | |
| errors: list[ExecutionError] = [] | |
| step_idx = 0 | |
| run_error: BaseException | None = None | |
| try: | |
| while not plan.is_complete: | |
| step = plan.get_current_step() | |
| if step is None: | |
| break | |
| # Skip agents whose conditions were not met | |
| if step.agent_id in plan.condition_skipped: | |
| plan.advance() | |
| continue | |
| should_prune, reason = self._scheduler.should_prune( | |
| step, plan, step_results.get(execution_order[-1]) if execution_order else None | |
| ) | |
| if should_prune: | |
| plan.mark_skipped(step.agent_id) | |
| pruned_agents.append(step.agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=f"Pruned: {reason}", | |
| agent_id=step.agent_id, | |
| recoverable=False, | |
| ) | |
| ) | |
| continue | |
| # Notify callbacks of agent start | |
| agent_name = agent_names.get(step.agent_id, step.agent_id) | |
| if self._callback_manager: | |
| _agent_obj = agent_lookup.get(step.agent_id) | |
| _prompt_preview_text = "" | |
| if _agent_obj is not None: | |
| _inc = {p: messages[p] for p in step.predecessors if p in messages} | |
| _mem = self._get_memory_context(step.agent_id) | |
| _sp = self._build_prompt(_agent_obj, query, _inc, agent_names, _mem) | |
| _prompt_preview_text = _sp.text | |
| self._callback_manager.on_agent_start( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| agent_name=agent_name, | |
| step_index=step_idx, | |
| prompt=_prompt_preview_text[: self.config.prompt_preview_length], | |
| predecessors=step.predecessors, | |
| ) | |
| agent_start_time = time.time() | |
| result = self._execute_step(step, messages, agent_lookup, agent_names, query) | |
| step_results[step.agent_id] = result | |
| execution_order.append(step.agent_id) | |
| if result.success: | |
| messages[step.agent_id] = result.response or "" | |
| plan.mark_completed(step.agent_id, result.tokens_used) | |
| self._save_to_memory(step.agent_id, result.response or "", step.predecessors) | |
| # Notify callbacks of agent end | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_end( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| output=result.response or "", | |
| agent_name=agent_name, | |
| step_index=step_idx, | |
| tokens_used=result.tokens_used, | |
| duration_ms=(time.time() - agent_start_time) * 1000, | |
| is_final=False, | |
| ) | |
| else: | |
| plan.mark_failed(step.agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=result.error or "Unknown error", | |
| agent_id=step.agent_id, | |
| recoverable=True, | |
| ) | |
| ) | |
| # Notify callbacks of agent error | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_error( | |
| run_id=run_id, | |
| error=Exception(result.error or "Unknown error"), | |
| agent_id=step.agent_id, | |
| error_type="ExecutionError", | |
| ) | |
| attempts = fallback_attempts.get(step.agent_id, 0) | |
| if self._scheduler.should_use_fallback(step, result, attempts): | |
| for fb_agent in step.fallback_agents: | |
| if fb_agent not in plan.completed and fb_agent not in plan.failed: | |
| plan.insert_fallback(fb_agent, plan.current_index - 1) | |
| fallback_count += 1 | |
| break | |
| fallback_attempts[step.agent_id] = attempts + 1 | |
| # Topology pipeline: conditional edges + user hooks → plan | |
| if self._run_topology_pipeline( | |
| plan, | |
| step.agent_id, | |
| a_agents, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| execution_order, | |
| plan.tokens_used, | |
| role_graph, | |
| ): | |
| topology_changed_count += 1 | |
| step_idx += 1 | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| run_error = e | |
| final_id = self._determine_final_agent(final_agent_id, execution_order, messages) | |
| final_answer = messages.get(final_id, "") | |
| # Finalize callbacks | |
| self._finalize_run( | |
| run_id=run_id, | |
| success=run_error is None, | |
| executed_agents=len(execution_order), | |
| final_answer=final_answer, | |
| error=run_error, | |
| executed_agent_ids=execution_order, | |
| ) | |
| if run_error: | |
| raise run_error | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| return MACPResult( | |
| messages=messages, | |
| final_answer=final_answer, | |
| final_agent_id=final_id, | |
| execution_order=execution_order, | |
| agent_states=agent_states, | |
| step_results=step_results, | |
| total_tokens=plan.tokens_used, | |
| total_time=time.time() - start_time, | |
| topology_changed_count=topology_changed_count, | |
| fallback_count=fallback_count, | |
| pruned_agents=pruned_agents, | |
| errors=errors if errors else None, | |
| ) | |
| async def _arun_adaptive( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| start_agent_id: str | None, | |
| *, | |
| update_states: bool | None = None, | |
| filter_unreachable: bool = True, | |
| callbacks: list[Handler] | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Adaptive async execution with parallelism and conditional edges. | |
| Supports multi-model: each agent uses its own LLM caller. | |
| Supports filtering of isolated nodes to save tokens. | |
| Per-call ``callbacks`` are merged with ``RunnerConfig.callbacks`` | |
| and context-manager callbacks (similar to ``_arun_simple``). | |
| """ | |
| if not self._has_any_async_caller(): | |
| msg = "async_llm_caller, async_llm_callers, or llm_factory is required for async execution" | |
| raise ValueError(msg) | |
| if self._scheduler is None: | |
| msg = "Scheduler not initialized for adaptive mode" | |
| raise ValueError(msg) | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| return MACPResult(messages={}, final_answer="", final_agent_id="", execution_order=[]) | |
| task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| # Initialize memory | |
| self._init_memory(agent_ids) | |
| p_matrix = self._extract_p_matrix(role_graph, task_idx) | |
| # Get conditions from graph for conditional routing | |
| edge_conditions = self._get_edge_conditions(role_graph) | |
| # Context for conditions | |
| condition_ctx = ConditionContext( | |
| source_agent="", | |
| target_agent="", | |
| messages={}, | |
| step_results={}, | |
| query=query, | |
| ) | |
| plan = self._scheduler.build_plan( | |
| a_agents, | |
| agent_ids, | |
| p_matrix, | |
| start_agent=start_agent_id, | |
| end_agent=final_agent_id, | |
| edge_conditions=edge_conditions, | |
| condition_context=condition_ctx, | |
| filter_unreachable=filter_unreachable, | |
| ) | |
| # Initialize callbacks (per-call + config + context-manager) | |
| run_id = self._init_run( | |
| graph_name=getattr(role_graph, "name", None), | |
| num_agents=len(agent_ids), | |
| query=query, | |
| execution_order=plan.execution_order, | |
| callbacks=callbacks, | |
| ) | |
| messages: dict[str, str] = {} | |
| step_results: dict[str, StepResult] = {} | |
| execution_order: list[str] = [] | |
| fallback_attempts: dict[str, int] = {} | |
| topology_changed_count = 0 | |
| fallback_count = 0 | |
| pruned_agents: list[str] = [] | |
| errors: list[ExecutionError] = [] | |
| step_idx = 0 | |
| run_error: BaseException | None = None | |
| try: | |
| while not plan.is_complete: | |
| parallel_group = self._get_parallel_group(plan, messages.keys()) | |
| if not parallel_group: | |
| break | |
| valid_steps = [] | |
| for step in parallel_group: | |
| # Skip agents whose conditions were not met | |
| if step.agent_id in plan.condition_skipped: | |
| plan.advance() | |
| continue | |
| should_prune, reason = self._scheduler.should_prune(step, plan, None) | |
| if should_prune: | |
| plan.mark_skipped(step.agent_id) | |
| pruned_agents.append(step.agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=f"Pruned: {reason}", | |
| agent_id=step.agent_id, | |
| recoverable=False, | |
| ) | |
| ) | |
| else: | |
| valid_steps.append(step) | |
| if not valid_steps: | |
| # All steps in the group were skipped or pruned — plan may stall. | |
| # If plan.is_complete is already True — we'll exit on the next iteration; | |
| # otherwise explicitly break to avoid an infinite loop. | |
| break | |
| # Notify callbacks of agent start for all steps in group | |
| group_start_times: dict[str, float] = {} | |
| for step in valid_steps: | |
| agent_name = agent_names.get(step.agent_id, step.agent_id) | |
| if self._callback_manager: | |
| _agent_obj = agent_lookup.get(step.agent_id) | |
| _prompt_preview_text = "" | |
| if _agent_obj is not None: | |
| _inc = {p: messages[p] for p in step.predecessors if p in messages} | |
| _mem = self._get_memory_context(step.agent_id) | |
| _sp = self._build_prompt(_agent_obj, query, _inc, agent_names, _mem) | |
| _prompt_preview_text = _sp.text | |
| self._callback_manager.on_agent_start( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| agent_name=agent_name, | |
| step_index=step_idx, | |
| prompt=_prompt_preview_text[: self.config.prompt_preview_length], | |
| predecessors=step.predecessors, | |
| ) | |
| group_start_times[step.agent_id] = time.time() | |
| step_idx += 1 | |
| if self.config.enable_parallel and len(valid_steps) > 1: | |
| results = await self._execute_parallel(valid_steps, messages, agent_lookup, agent_names, query) | |
| else: | |
| results = [] | |
| for step in valid_steps: | |
| r = await self._execute_step_async(step, messages, agent_lookup, agent_names, query) | |
| results.append((step, r)) | |
| for step, result in results: | |
| step_results[step.agent_id] = result | |
| execution_order.append(step.agent_id) | |
| agent_name = agent_names.get(step.agent_id, step.agent_id) | |
| agent_start_time = group_start_times.get(step.agent_id, time.time()) | |
| if result.success: | |
| messages[step.agent_id] = result.response or "" | |
| plan.mark_completed(step.agent_id, result.tokens_used) | |
| self._save_to_memory(step.agent_id, result.response or "", step.predecessors) | |
| # Notify callbacks of agent end | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_end( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| output=result.response or "", | |
| agent_name=agent_name, | |
| step_index=execution_order.index(step.agent_id), | |
| tokens_used=result.tokens_used, | |
| duration_ms=(time.time() - agent_start_time) * 1000, | |
| is_final=False, | |
| ) | |
| else: | |
| plan.mark_failed(step.agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=result.error or "Unknown error", | |
| agent_id=step.agent_id, | |
| recoverable=True, | |
| ) | |
| ) | |
| # Notify callbacks of agent error | |
| if self._callback_manager: | |
| self._callback_manager.on_agent_error( | |
| run_id=run_id, | |
| error=Exception(result.error or "Unknown error"), | |
| agent_id=step.agent_id, | |
| error_type="ExecutionError", | |
| ) | |
| attempts = fallback_attempts.get(step.agent_id, 0) | |
| if self._scheduler.should_use_fallback(step, result, attempts): | |
| for fb_agent in step.fallback_agents: | |
| if fb_agent not in plan.completed and fb_agent not in plan.failed: | |
| plan.insert_fallback(fb_agent, plan.current_index - 1) | |
| fallback_count += 1 | |
| break | |
| fallback_attempts[step.agent_id] = attempts + 1 | |
| # Topology pipeline for each executed agent in the group | |
| for step, _result in results: | |
| if await self._arun_topology_pipeline( | |
| plan, | |
| step.agent_id, | |
| a_agents, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| execution_order, | |
| plan.tokens_used, | |
| role_graph, | |
| ): | |
| topology_changed_count += 1 | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| run_error = e | |
| final_id = self._determine_final_agent(final_agent_id, execution_order, messages) | |
| final_answer = messages.get(final_id, "") | |
| # Finalize callbacks | |
| self._finalize_run( | |
| run_id=run_id, | |
| success=run_error is None, | |
| executed_agents=len(execution_order), | |
| final_answer=final_answer, | |
| error=run_error, | |
| executed_agent_ids=execution_order, | |
| ) | |
| if run_error: | |
| raise run_error | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| return MACPResult( | |
| messages=messages, | |
| final_answer=final_answer, | |
| final_agent_id=final_id, | |
| execution_order=execution_order, | |
| agent_states=agent_states, | |
| step_results=step_results, | |
| total_tokens=plan.tokens_used, | |
| total_time=time.time() - start_time, | |
| topology_changed_count=topology_changed_count, | |
| fallback_count=fallback_count, | |
| pruned_agents=pruned_agents, | |
| errors=errors if errors else None, | |
| ) | |
| def _get_task_index(self, role_graph: Any) -> int: | |
| """Get the rustworkx index of the task node or raise an error.""" | |
| if role_graph.task_node is None: | |
| msg = "RoleGraph has no task_node set" | |
| raise ValueError(msg) | |
| task_idx = role_graph.get_node_index(role_graph.task_node) | |
| if task_idx is None: | |
| msg = f"Task node '{role_graph.task_node}' not found" | |
| raise ValueError(msg) | |
| return task_idx | |
| def _get_agent_ids( | |
| self, | |
| role_graph: Any, | |
| task_idx: int, | |
| ) -> tuple[list[str], dict[str, int]]: | |
| """Return the list of agent_ids (excluding task) and the id->adjacency index map.""" | |
| agent_ids = [] | |
| id_to_idx = {} | |
| adj_idx = 0 | |
| for agent in role_graph.agents: | |
| graph_idx = role_graph.get_node_index(agent.agent_id) | |
| if graph_idx == task_idx: | |
| continue | |
| agent_ids.append(agent.agent_id) | |
| id_to_idx[agent.agent_id] = adj_idx | |
| adj_idx += 1 | |
| return agent_ids, id_to_idx | |
| def _extract_p_matrix(self, role_graph: Any, task_idx: int) -> torch.Tensor | None: | |
| """Return the probability matrix without the task row/column.""" | |
| if role_graph.p_matrix is None: | |
| return None | |
| n_nodes = role_graph.p_matrix.shape[0] | |
| mask = torch.ones(n_nodes, dtype=torch.bool) | |
| mask[task_idx] = False | |
| return role_graph.p_matrix[mask][:, mask] | |
| def _get_edge_conditions(self, role_graph: Any) -> dict[tuple[str, str], Any]: | |
| """Get all edge conditions from the graph.""" | |
| if hasattr(role_graph, "get_all_edge_conditions"): | |
| return role_graph.get_all_edge_conditions() | |
| # Fallback: check individual attributes | |
| conditions: dict[tuple[str, str], Any] = {} | |
| if hasattr(role_graph, "edge_condition_names"): | |
| conditions.update(role_graph.edge_condition_names) | |
| if hasattr(role_graph, "edge_conditions"): | |
| conditions.update(role_graph.edge_conditions) | |
| return conditions | |
| def _build_agent_names(self, role_graph: Any) -> dict[str, str]: | |
| """Map id -> display_name/role for building the prompt.""" | |
| return {a.agent_id: a.display_name or getattr(a, "role", a.agent_id) for a in role_graph.agents} | |
| def _get_task_connected_agents(self, role_graph: Any) -> set[str]: | |
| """Get the set of agents directly connected to the task node.""" | |
| if role_graph.task_node is None: | |
| return set() | |
| task_idx = role_graph.get_node_index(role_graph.task_node) | |
| if task_idx is None or role_graph.A_com is None: | |
| return set() | |
| connected = set() | |
| for agent in role_graph.agents: | |
| agent_idx = role_graph.get_node_index(agent.agent_id) | |
| if agent_idx is not None and agent_idx != task_idx and role_graph.A_com[task_idx, agent_idx] > 0: | |
| connected.add(agent.agent_id) | |
| return connected | |
| def _should_include_query(self, agent_id: str, task_connected: set[str]) -> bool: | |
| """Determine whether to include the query in the agent's prompt.""" | |
| if self.config.broadcast_task_to_all: | |
| return True | |
| return agent_id in task_connected | |
| def _run_agent_with_tools( # noqa: PLR0912, PLR0915 | |
| self, | |
| caller: Any, | |
| prompt: "str | StructuredPrompt", | |
| agent: Any, | |
| ) -> tuple[str, int]: | |
| """ | |
| Execute an agent with automatic tools support. | |
| If the agent has tools, they are ALWAYS used via native function calling. | |
| Tools are obtained from: | |
| 1. BaseTool objects directly in agent.tools | |
| 2. Tool names registered in the global registry or config.tool_registry | |
| ``prompt`` may be a plain ``str`` (legacy) or a | |
| ``StructuredPrompt`` (modern). For plain LLM calls the method | |
| dispatches via :meth:`_call_llm` which picks the structured | |
| caller when available. | |
| When a ``structured_llm_caller`` is registered the tool-calling | |
| loop also uses structured messages (system/user/tool roles) so | |
| the LLM receives proper role separation throughout the entire | |
| tool-calling conversation. | |
| Args: | |
| caller: LLM caller (must support the tools parameter) | |
| prompt: Agent prompt (str or StructuredPrompt) | |
| agent: Agent profile (AgentProfile with tools) | |
| Returns: | |
| tuple[str, int]: (response, number of tokens) | |
| """ | |
| import inspect | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Normalise prompt — always have both flat text and structured form | |
| prompt_text = prompt.text if isinstance(prompt, StructuredPrompt) else prompt | |
| # Check if the agent has tools | |
| if not TOOLS_AVAILABLE: | |
| response = self._call_llm(caller, prompt) if isinstance(prompt, StructuredPrompt) else caller(prompt) | |
| return response, self.token_counter(prompt_text) + self.token_counter(response) | |
| # Get agent tools (method from AgentProfile) | |
| agent_tools = [] | |
| if hasattr(agent, "get_tool_objects"): | |
| agent_tools = agent.get_tool_objects() | |
| elif hasattr(agent, "tools") and agent.tools: | |
| # Fallback for old format (names only) | |
| from tools import get_registry | |
| registry = self.config.tool_registry or get_registry() | |
| tool_names = agent.tools | |
| agent_tools = registry.get_tools(tool_names) | |
| if not agent_tools: | |
| # Agent has no tools — plain call (use structured dispatch) | |
| response = self._call_llm(caller, prompt) if isinstance(prompt, StructuredPrompt) else caller(prompt) | |
| return response, self.token_counter(prompt_text) + self.token_counter(response) | |
| # Check that caller supports tools | |
| sig = inspect.signature(caller) | |
| supports_tools = "tools" in sig.parameters | |
| if not supports_tools: | |
| # Caller does not support tools — plain call | |
| response = self._call_llm(caller, prompt) if isinstance(prompt, StructuredPrompt) else caller(prompt) | |
| return response, self.token_counter(prompt_text) + self.token_counter(response) | |
| # Get schemas for tools | |
| tool_schemas = [t.to_openai_schema() for t in agent_tools] | |
| if not tool_schemas: | |
| response = self._call_llm(caller, prompt) if isinstance(prompt, StructuredPrompt) else caller(prompt) | |
| return response, self.token_counter(prompt_text) + self.token_counter(response) | |
| # Get registry for executing tools | |
| from tools import ToolCall, get_registry | |
| registry = self.config.tool_registry or get_registry() | |
| # Register all agent tools in the registry (for BaseTool objects) | |
| for t in agent_tools: | |
| if not registry.has(t.name): | |
| registry.register(t) | |
| # Execute tool calling loop | |
| total_tokens = 0 | |
| # Build the conversation for the tool-calling loop. | |
| # When structured messages are available we keep a proper | |
| # multi-turn conversation; otherwise we accumulate flat text. | |
| if isinstance(prompt, StructuredPrompt): | |
| tool_messages: list[dict[str, str]] = list(prompt.messages) | |
| else: | |
| tool_messages = [{"role": "user", "content": prompt}] | |
| current_prompt = prompt_text # flat fallback | |
| # Cache for tool calls: (tool_name, args_json) -> result | |
| # Prevents repeated execution of identical calls | |
| tool_cache: dict[str, str] = {} | |
| logger.debug("Agent has tools: %s", [s["function"]["name"] for s in tool_schemas]) | |
| llm_response: Any = None | |
| use_structured_tools = isinstance(prompt, StructuredPrompt) and self.structured_llm_caller is not None | |
| for iteration in range(self.config.max_tool_iterations): | |
| # Call LLM with tools | |
| logger.debug("Tool calling iteration %d", iteration + 1) | |
| # Prefer structured messages when a structured caller is available; | |
| # this gives the LLM proper system/user/assistant/tool role separation | |
| # throughout the entire tool-calling conversation. | |
| if use_structured_tools: | |
| llm_response = caller(tool_messages, tools=tool_schemas) | |
| else: | |
| llm_response = caller(current_prompt, tools=tool_schemas) | |
| if isinstance(llm_response, str): | |
| # Caller returned a string, not an LLMResponse | |
| return llm_response, self.token_counter(current_prompt) + self.token_counter(llm_response) | |
| # Token counting: use the actual prompt content sent to the LLM | |
| if use_structured_tools: | |
| prompt_tokens = sum(self.token_counter(m.get("content", "")) for m in tool_messages) | |
| else: | |
| prompt_tokens = self.token_counter(current_prompt) | |
| total_tokens += prompt_tokens | |
| if llm_response.content: | |
| total_tokens += self.token_counter(llm_response.content) | |
| # If there are no tool_calls — return the response | |
| if not llm_response.has_tool_calls: | |
| content = llm_response.content or "" | |
| if content: | |
| logger.debug("No tool calls, returning content: %s...", content[:50]) | |
| return content, total_tokens | |
| # Execute tool_calls with caching | |
| tool_results: list[str] = [] | |
| for tc in llm_response.tool_calls: | |
| # Create a cache key from the name and arguments | |
| import json as json_module | |
| cache_key = f"{tc.name}:{json_module.dumps(tc.arguments, sort_keys=True)}" | |
| if cache_key in tool_cache: | |
| # Already called with these arguments — use cache | |
| output = tool_cache[cache_key] | |
| logger.debug("Tool cache hit: %s(%s) -> %s...", tc.name, tc.arguments, output[:50]) | |
| else: | |
| # New call — execute and cache | |
| logger.debug("Executing tool: %s(%s)", tc.name, tc.arguments) | |
| tool_call = ToolCall(name=tc.name, arguments=tc.arguments) | |
| result = registry.execute(tool_call) | |
| output = result.output if result.success else f"Error: {result.error}" | |
| tool_cache[cache_key] = output | |
| logger.debug("Tool result: %s...", output[:100]) | |
| tool_results.append(f"[{tc.name}]: {output}") | |
| # Add results to the conversation for the next iteration. | |
| # Structured path: append assistant + tool messages to the | |
| # conversation so the LLM sees the full multi-turn history. | |
| tool_results_text = "\n".join(tool_results) | |
| if isinstance(prompt, StructuredPrompt): | |
| # Append the assistant's tool-call reply | |
| if llm_response.content: | |
| tool_messages.append({"role": "assistant", "content": llm_response.content}) | |
| # Append tool results as a user message (compatible with | |
| # all providers; some accept role="tool" but "user" is universal) | |
| tool_messages.append({"role": "user", "content": f"Tool results:\n{tool_results_text}"}) | |
| # Flat text fallback (used by legacy callers) | |
| current_prompt = f"{prompt_text}\n\nTool results:\n{tool_results_text}" | |
| # Reached max_iterations, return the last content | |
| return llm_response.content if llm_response else "", total_tokens | |
| def _build_system_prompt_parts(self, agent: Any) -> list[str]: | |
| """Build system prompt parts: persona, description, tools, output_schema.""" | |
| import json as _json | |
| parts: list[str] = [] | |
| if hasattr(agent, "persona") and agent.persona: | |
| parts.append(f"You are {agent.persona}.") | |
| elif hasattr(agent, "role") and agent.role: | |
| parts.append(f"You are a {agent.role}.") | |
| if hasattr(agent, "description") and agent.description: | |
| parts.append(agent.description.strip()) | |
| # ── tools → brief mention so the agent is aware of its capabilities | |
| tool_names: list[str] = [] | |
| if hasattr(agent, "get_tool_names"): | |
| tool_names = agent.get_tool_names() | |
| elif hasattr(agent, "tools") and agent.tools: | |
| tool_names = [t if isinstance(t, str) else getattr(t, "name", str(t)) for t in agent.tools] | |
| if tool_names: | |
| parts.append(f"Available tools: {', '.join(tool_names)}.") | |
| # ── output_schema → compact format instruction ─────────────── | |
| output_schema_json = self._extract_schema_json(agent, "output_schema") | |
| if output_schema_json: | |
| schema_text = _json.dumps(output_schema_json, ensure_ascii=False, separators=(",", ":")) | |
| parts.append(f"Respond with JSON matching: {schema_text}") | |
| return parts | |
| def _build_user_prompt_parts( | |
| self, | |
| agent: Any, | |
| query: str, | |
| incoming_messages: dict[str, str], | |
| agent_names: dict[str, str], | |
| memory_context: list[dict[str, Any]] | None, | |
| *, | |
| include_query: bool, | |
| ) -> list[str]: | |
| """Build user prompt parts: query, input_schema, memory, incoming messages.""" | |
| import json as _json | |
| user_parts: list[str] = [] | |
| # Task query is added only if include_query=True | |
| if include_query and query: | |
| user_parts.append(f"Task: {query}") | |
| # ── input_schema hint (compact) ────────────────────────────── | |
| input_schema_json = self._extract_schema_json(agent, "input_schema") | |
| if input_schema_json: | |
| schema_text = _json.dumps(input_schema_json, ensure_ascii=False, separators=(",", ":")) | |
| user_parts.append(f"\nInput format: {schema_text}") | |
| # Include memory context from SharedMemoryPool | |
| if memory_context: | |
| user_parts.append("\nPrevious context:") | |
| for msg in memory_context: | |
| role = msg.get("role", "unknown") | |
| content = msg.get("content", "") | |
| user_parts.append(f"[{role}]: {content}") | |
| if incoming_messages: | |
| user_parts.append("\nMessages from other agents:") | |
| for sender_id, message in incoming_messages.items(): | |
| sender_name = agent_names.get(sender_id, sender_id) | |
| user_parts.append(f"\n[{sender_name}]:\n{message}") | |
| user_parts.append("\nProvide your response:") | |
| return user_parts | |
| def _build_state_text_parts(self, agent: Any) -> list[str]: | |
| """Build state text parts for flat string representation.""" | |
| agent_state: list[dict[str, Any]] = [] | |
| if hasattr(agent, "state"): | |
| agent_state = list(agent.state) if agent.state else [] | |
| state_text_parts: list[str] = [] | |
| if agent_state: | |
| state_text_parts.append("\nConversation history:") | |
| for entry in agent_state: | |
| entry_role = entry.get("role", "unknown") | |
| entry_content = entry.get("content", "") | |
| if entry_content: | |
| state_text_parts.append(f"[{entry_role}]: {entry_content}") | |
| return state_text_parts | |
| def _build_structured_messages( | |
| self, | |
| system_prompt: str, | |
| agent_state: list[dict[str, Any]], | |
| user_content: str, | |
| flat_user: str, | |
| use_structured_state: bool, | |
| ) -> list[dict[str, str]]: | |
| """Build structured messages list for modern chat LLMs.""" | |
| messages: list[dict[str, str]] = [ | |
| {"role": "system", "content": system_prompt}, | |
| ] | |
| if use_structured_state: | |
| # Replay agent.state as proper assistant/user turns so the | |
| # LLM sees real conversation history with correct roles. | |
| for entry in agent_state: | |
| entry_role = entry.get("role", "user") | |
| entry_content = entry.get("content", "") | |
| if entry_content: | |
| # Map roles: "agent"/"assistant" → "assistant", everything else → "user" | |
| msg_role = "assistant" if entry_role in ("agent", "assistant") else "user" | |
| messages.append({"role": msg_role, "content": entry_content}) | |
| # user_content does NOT contain state (it's in separate messages above) | |
| messages.append({"role": "user", "content": user_content}) | |
| else: | |
| # Legacy path or no state: state is inlined in user_content | |
| messages.append({"role": "user", "content": flat_user}) | |
| return messages | |
| def _build_prompt( | |
| self, | |
| agent: Any, | |
| query: str, | |
| incoming_messages: dict[str, str], | |
| agent_names: dict[str, str], | |
| memory_context: list[dict[str, Any]] | None = None, | |
| *, | |
| include_query: bool = True, | |
| ) -> StructuredPrompt: | |
| """ | |
| Build the agent prompt with persona/description, state, schemas, memory, and messages. | |
| Returns a ``StructuredPrompt`` that carries **both** representations: | |
| * ``prompt.text`` — legacy flat string (backward-compatible) | |
| * ``prompt.messages`` — ``[{"role": "system", ...}, {"role": "user", ...}]`` | |
| The system message includes: | |
| - persona / role identity | |
| - description | |
| - output_schema instructions (expected response format) | |
| The user message includes: | |
| - task query | |
| - agent conversation state (previous turns) | |
| - memory context from SharedMemoryPool | |
| - incoming messages from other agents | |
| - input_schema hint (expected input structure) | |
| When a ``structured_llm_caller`` is registered the runner sends | |
| ``prompt.messages`` directly to the LLM, giving it proper | |
| system/user role separation. Otherwise ``prompt.text`` is used | |
| with the legacy ``llm_caller(str) -> str`` interface. | |
| Args: | |
| agent: Agent object with description/persona | |
| query: User query string | |
| incoming_messages: Messages from other agents | |
| agent_names: Mapping of agent IDs to names | |
| memory_context: Optional list of memory entries | |
| include_query: Whether to include the task query in the prompt. | |
| Controlled via config.broadcast_task_to_all. | |
| """ | |
| # Build system prompt | |
| system_parts = self._build_system_prompt_parts(agent) | |
| system_prompt = "\n\n".join(system_parts) if system_parts else "You are a helpful assistant." | |
| # Build user prompt parts | |
| user_parts = self._build_user_prompt_parts( | |
| agent, query, incoming_messages, agent_names, memory_context, include_query=include_query | |
| ) | |
| user_content = "".join(user_parts) | |
| # Build state text parts | |
| agent_state: list[dict[str, Any]] = [] | |
| if hasattr(agent, "state"): | |
| agent_state = list(agent.state) if agent.state else [] | |
| use_structured_state = bool(agent_state) and self.structured_llm_caller is not None | |
| state_text_parts = self._build_state_text_parts(agent) | |
| # Build flat string (legacy) | |
| flat_user = "".join(state_text_parts) + user_content if state_text_parts else user_content | |
| flat = f"{system_prompt}\n\n{flat_user}" | |
| # Build structured messages (modern) | |
| messages = self._build_structured_messages( | |
| system_prompt, agent_state, user_content, flat_user, use_structured_state | |
| ) | |
| return StructuredPrompt(text=flat, messages=messages) | |
| def _extract_schema_json(agent: Any, attr: str) -> dict[str, Any] | None: | |
| """ | |
| Extract a JSON Schema dict from an agent's schema attribute. | |
| Supports: | |
| - ``dict`` — returned as-is (already a JSON Schema) | |
| - Pydantic ``BaseModel`` subclass — converted via ``model_json_schema()`` | |
| - ``None`` / missing — returns ``None`` | |
| """ | |
| schema = getattr(agent, attr, None) | |
| if schema is None: | |
| return None | |
| if isinstance(schema, dict): | |
| return schema | |
| # Pydantic model class | |
| try: | |
| from pydantic import BaseModel | |
| if isinstance(schema, type) and issubclass(schema, BaseModel): | |
| return schema.model_json_schema() | |
| except Exception as exc: # noqa: BLE001 | |
| # Pydantic may not be available or schema may not be a BaseModel | |
| # This is expected in some cases, so we silently continue | |
| _ = exc # Suppress unused variable warning | |
| return None | |
| def _execute_step( | |
| self, | |
| step: Any, | |
| messages: dict[str, str], | |
| agent_lookup: dict[str, Any], | |
| agent_names: dict[str, str], | |
| query: str, | |
| ) -> StepResult: | |
| """ | |
| Execute a step synchronously with retries and token counting. | |
| Supports multi-model: uses the caller for the specific agent. | |
| """ | |
| agent = agent_lookup.get(step.agent_id) | |
| if agent is None: | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=f"Agent '{step.agent_id}' not found", | |
| ) | |
| # Get caller for this specific agent (multi-model support) | |
| caller = self._get_caller_for_agent(step.agent_id, agent) | |
| if caller is None: | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=f"No LLM caller available for agent '{step.agent_id}'", | |
| ) | |
| incoming = {p: messages[p] for p in step.predecessors if p in messages} | |
| memory_context = self._get_memory_context(step.agent_id) | |
| prompt = self._build_prompt(agent, query, incoming, agent_names, memory_context) | |
| last_error = None | |
| delay = self.config.retry_delay | |
| for attempt in range(self.config.max_retries + 1): | |
| try: | |
| # Execute with tools support | |
| response, tokens = self._run_agent_with_tools( | |
| caller=caller, | |
| prompt=prompt, | |
| agent=agent, | |
| ) | |
| quality = 1.0 | |
| if self._scheduler and self._scheduler.pruning.quality_scorer: | |
| quality = self._scheduler.pruning.quality_scorer(response) | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=True, | |
| response=response, | |
| tokens_used=tokens, | |
| quality_score=quality, | |
| ) | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| last_error = str(e) | |
| if attempt < self.config.max_retries: | |
| time.sleep(delay) | |
| delay *= self.config.retry_backoff | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=last_error, | |
| ) | |
| async def _execute_step_async( | |
| self, | |
| step: Any, | |
| messages: dict[str, str], | |
| agent_lookup: dict[str, Any], | |
| agent_names: dict[str, str], | |
| query: str, | |
| ) -> StepResult: | |
| """ | |
| Execute a step asynchronously with retries and timeout. | |
| Supports multi-model: uses the async caller for the specific agent. | |
| """ | |
| agent = agent_lookup.get(step.agent_id) | |
| if agent is None: | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=f"Agent '{step.agent_id}' not found", | |
| ) | |
| # Get async caller for this specific agent (multi-model support). | |
| # _get_async_caller_for_agent may return None when only | |
| # async_structured_llm_caller is configured — _acall_llm handles | |
| # that case internally, so we only error when *neither* is available. | |
| async_caller = self._get_async_caller_for_agent(step.agent_id, agent) | |
| if async_caller is None and self.async_structured_llm_caller is None: | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=f"No async LLM caller available for agent '{step.agent_id}'", | |
| ) | |
| incoming = {p: messages[p] for p in step.predecessors if p in messages} | |
| memory_context = self._get_memory_context(step.agent_id) | |
| prompt = self._build_prompt(agent, query, incoming, agent_names, memory_context) | |
| last_error = None | |
| delay = self.config.retry_delay | |
| for attempt in range(self.config.max_retries + 1): | |
| try: | |
| response = await asyncio.wait_for( | |
| self._acall_llm(async_caller, prompt), | |
| timeout=self.config.timeout, | |
| ) | |
| tokens = self.token_counter(prompt.text) + self.token_counter(response) | |
| quality = 1.0 | |
| if self._scheduler and self._scheduler.pruning.quality_scorer: | |
| quality = self._scheduler.pruning.quality_scorer(response) | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=True, | |
| response=response, | |
| tokens_used=tokens, | |
| quality_score=quality, | |
| ) | |
| except TimeoutError: | |
| last_error = f"Timeout after {self.config.timeout}s" | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| last_error = str(e) | |
| if attempt < self.config.max_retries: | |
| await asyncio.sleep(delay) | |
| delay *= self.config.retry_backoff | |
| return StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=last_error, | |
| ) | |
| async def _execute_parallel( | |
| self, | |
| steps: list[Any], | |
| messages: dict[str, str], | |
| agent_lookup: dict[str, Any], | |
| agent_names: dict[str, str], | |
| query: str, | |
| ) -> list[tuple[Any, StepResult]]: | |
| """Execute a group of steps in parallel asynchronously.""" | |
| tasks = [self._execute_step_async(step, messages, agent_lookup, agent_names, query) for step in steps] | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| output = [] | |
| for step, result in zip(steps, results, strict=False): | |
| if isinstance(result, Exception): | |
| sr = StepResult( | |
| agent_id=step.agent_id, | |
| success=False, | |
| error=str(result), | |
| ) | |
| else: | |
| sr = result | |
| output.append((step, sr)) | |
| return output | |
| def _get_parallel_group( | |
| self, | |
| plan: ExecutionPlan, | |
| completed_agents: Any, | |
| ) -> list[Any]: | |
| """Return a group of steps ready for parallel execution.""" | |
| completed = set(completed_agents) | |
| group: list[Any] = [] | |
| for step in plan.remaining_steps: | |
| if ( | |
| step.agent_id in plan.completed | |
| or step.agent_id in plan.skipped | |
| or step.agent_id in plan.condition_skipped | |
| ): | |
| continue | |
| predecessors_done = all( | |
| p in completed or p in plan.skipped or p in plan.condition_skipped for p in step.predecessors | |
| ) | |
| if predecessors_done: | |
| group.append(step) | |
| if len(group) >= self.config.max_parallel_size: | |
| break | |
| return group | |
| def _determine_final_agent( | |
| self, | |
| requested: str | None, | |
| exec_order: list[str], | |
| messages: dict[str, str], | |
| ) -> str: | |
| """Select the final agent: the requested one or the last in execution order.""" | |
| if requested and requested in messages: | |
| return requested | |
| if exec_order: | |
| return exec_order[-1] | |
| return "" | |
| def _build_agent_states( | |
| self, | |
| messages: dict[str, str], | |
| agent_lookup: dict[str, Any], | |
| ) -> dict[str, list[dict[str, Any]]]: | |
| """Build updated agent states by appending responses to history.""" | |
| states: dict[str, list[dict[str, Any]]] = {} | |
| for agent_id, response in messages.items(): | |
| agent = agent_lookup.get(agent_id) | |
| if agent is not None: | |
| new_state = list(getattr(agent, "state", [])) | |
| new_state.append({"role": "assistant", "content": response}) | |
| states[agent_id] = new_state | |
| return states | |
| def _collect_hidden_states( | |
| self, | |
| agent_lookup: dict[str, Any], | |
| ) -> dict[str, HiddenState]: | |
| """Collect the current hidden_state/embedding of agents into a dictionary.""" | |
| hidden_states: dict[str, HiddenState] = {} | |
| for agent_id, agent in agent_lookup.items(): | |
| hs = HiddenState() | |
| if hasattr(agent, "hidden_state") and agent.hidden_state is not None: | |
| hs.tensor = agent.hidden_state | |
| if hasattr(agent, "embedding") and agent.embedding is not None: | |
| hs.embedding = agent.embedding | |
| if hs.tensor is not None or hs.embedding is not None: | |
| hidden_states[agent_id] = hs | |
| return hidden_states | |
| def _combine_hidden_states( | |
| self, | |
| states: list[HiddenState], | |
| ) -> HiddenState | None: | |
| """Combine a list of hidden states according to the hidden_combine_strategy.""" | |
| if not states: | |
| return None | |
| tensors = [s.tensor for s in states if s.tensor is not None] | |
| embeddings = [s.embedding for s in states if s.embedding is not None] | |
| combined = HiddenState() | |
| if tensors: | |
| combined.tensor = self._combine_tensors(tensors) | |
| if embeddings: | |
| combined.embedding = self._combine_tensors(embeddings) | |
| return combined if (combined.tensor is not None or combined.embedding is not None) else None | |
| def _combine_tensors(self, tensors: list[torch.Tensor]) -> torch.Tensor: | |
| """Combine a list of tensors according to the strategy (mean/sum/concat/attention).""" | |
| if len(tensors) == 1: | |
| return tensors[0] | |
| stacked = torch.stack(tensors) | |
| if self.config.hidden_combine_strategy == "mean": | |
| return stacked.mean(dim=0) | |
| if self.config.hidden_combine_strategy == "sum": | |
| return stacked.sum(dim=0) | |
| if self.config.hidden_combine_strategy == "concat": | |
| return torch.cat(tensors, dim=-1) | |
| if self.config.hidden_combine_strategy == "attention": | |
| weights = torch.softmax(torch.ones(len(tensors)), dim=0) | |
| return (stacked * weights.view(-1, *([1] * (stacked.dim() - 1)))).sum(dim=0) | |
| return stacked.mean(dim=0) | |
| def _get_incoming_hidden( | |
| self, | |
| _agent_id: str, | |
| incoming_ids: list[str], | |
| hidden_states: dict[str, HiddenState], | |
| ) -> HiddenState | None: | |
| """Get the combined hidden state of predecessors.""" | |
| if not self.config.enable_hidden_channels: | |
| return None | |
| incoming_states = [hidden_states[aid] for aid in incoming_ids if aid in hidden_states] | |
| return self._combine_hidden_states(incoming_states) | |
| def _update_agent_hidden_state( | |
| self, | |
| agent: Any, | |
| response: str, | |
| incoming_hidden: HiddenState | None, | |
| hidden_encoder: Any | None = None, | |
| ) -> HiddenState: | |
| """Update the agent's hidden_state based on the response and incoming hidden state.""" | |
| new_hidden = HiddenState() | |
| if hasattr(agent, "embedding") and agent.embedding is not None: | |
| new_hidden.embedding = agent.embedding | |
| if hidden_encoder is not None: | |
| try: | |
| encoded = hidden_encoder.encode([response]) | |
| if isinstance(encoded, torch.Tensor) and encoded.numel() > 0: | |
| new_hidden.tensor = encoded[0] | |
| except (ValueError, TypeError, RuntimeError): | |
| pass # Ignore encoding errors | |
| if new_hidden.tensor is None and incoming_hidden is not None: | |
| new_hidden.tensor = incoming_hidden.tensor | |
| new_hidden.metadata = { | |
| "last_response_length": len(response), | |
| "has_incoming": incoming_hidden is not None, | |
| } | |
| return new_hidden | |
| def run_round_with_hidden( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None = None, | |
| hidden_encoder: Any | None = None, | |
| ) -> MACPResult: | |
| """ | |
| Synchronous round with hidden state transfer between agents. | |
| Supports multi-model: each agent uses its own LLM caller. | |
| Supports conditional edge evaluation. | |
| """ | |
| if not self._has_any_caller(): | |
| msg = "llm_caller, llm_callers, or llm_factory is required" | |
| raise ValueError(msg) | |
| original_hidden_setting = self.config.enable_hidden_channels | |
| self.config.enable_hidden_channels = True | |
| start_time = time.time() | |
| try: | |
| base = self._prepare_base_context(role_graph) | |
| except Exception: | |
| self.config.enable_hidden_channels = original_hidden_setting | |
| raise | |
| if base is None: | |
| self.config.enable_hidden_channels = original_hidden_setting | |
| return MACPResult(messages={}, final_answer="", final_agent_id="", execution_order=[]) | |
| task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| # Initialize memory | |
| self._init_memory(agent_ids) | |
| hidden_states = self._collect_hidden_states(agent_lookup) | |
| messages: dict[str, str] = {} | |
| step_results: dict[str, StepResult] = {} | |
| execution_order: list[str] = [] | |
| fallback_attempts: dict[str, int] = {} | |
| topology_changed_count = 0 | |
| fallback_count = 0 | |
| pruned_agents: list[str] = [] | |
| errors: list[ExecutionError] = [] | |
| total_tokens = 0 | |
| # Adaptive mode: plan + conditional edge evaluation after each step | |
| if self.config.adaptive and self._scheduler is not None: | |
| p_matrix = self._extract_p_matrix(role_graph, task_idx) | |
| edge_conditions = self._get_edge_conditions(role_graph) | |
| condition_ctx = ConditionContext( | |
| source_agent="", | |
| target_agent="", | |
| messages={}, | |
| step_results={}, | |
| query=query, | |
| ) | |
| plan = self._scheduler.build_plan( | |
| a_agents, | |
| agent_ids, | |
| p_matrix, | |
| start_agent=None, | |
| end_agent=final_agent_id, | |
| edge_conditions=edge_conditions, | |
| condition_context=condition_ctx, | |
| filter_unreachable=True, | |
| ) | |
| while not plan.is_complete: | |
| step = plan.get_current_step() | |
| if step is None: | |
| break | |
| if step.agent_id in plan.condition_skipped: | |
| plan.advance() | |
| continue | |
| should_prune, reason = self._scheduler.should_prune( | |
| step, plan, step_results.get(execution_order[-1]) if execution_order else None | |
| ) | |
| if should_prune: | |
| plan.mark_skipped(step.agent_id) | |
| pruned_agents.append(step.agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=f"Pruned: {reason}", | |
| agent_id=step.agent_id, | |
| recoverable=False, | |
| ) | |
| ) | |
| continue | |
| agent_id = step.agent_id | |
| agent = agent_lookup.get(agent_id) | |
| if agent is None: | |
| plan.advance() | |
| continue | |
| incoming_ids = get_incoming_agents(agent_id, a_agents, agent_ids) | |
| incoming_messages = {aid: messages[aid] for aid in incoming_ids if aid in messages} | |
| incoming_hidden = self._get_incoming_hidden(agent_id, incoming_ids, hidden_states) | |
| memory_context = self._get_memory_context(agent_id) | |
| prompt = self._build_prompt(agent, query, incoming_messages, agent_names, memory_context) | |
| if incoming_hidden and incoming_hidden.metadata: | |
| context_hint = self._format_hidden_context(incoming_hidden) | |
| if context_hint: | |
| suffix = f"\n\n[Context: {context_hint}]" | |
| prompt = StructuredPrompt( | |
| text=prompt.text + suffix, | |
| messages=[ | |
| *prompt.messages[:-1], | |
| {"role": "user", "content": prompt.messages[-1]["content"] + suffix}, | |
| ], | |
| ) | |
| try: | |
| caller = self._get_caller_for_agent(agent_id, agent) | |
| if caller is None: | |
| error_msg = f"No LLM caller available for agent {agent_id}" | |
| messages[agent_id] = f"[Error: {error_msg}]" | |
| result = StepResult(agent_id=agent_id, success=False, error=error_msg) | |
| step_results[agent_id] = result | |
| execution_order.append(agent_id) | |
| plan.mark_failed(agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=error_msg, | |
| agent_id=agent_id, | |
| recoverable=True, | |
| ) | |
| ) | |
| # Fallback when caller is missing | |
| attempts = fallback_attempts.get(agent_id, 0) | |
| if self._scheduler.should_use_fallback(step, result, attempts): | |
| for fb_agent in step.fallback_agents: | |
| if fb_agent not in plan.completed and fb_agent not in plan.failed: | |
| plan.insert_fallback(fb_agent, plan.current_index - 1) | |
| fallback_count += 1 | |
| break | |
| fallback_attempts[agent_id] = attempts + 1 | |
| continue | |
| response, tokens = self._run_agent_with_tools( | |
| caller=caller, | |
| prompt=prompt, | |
| agent=agent, | |
| ) | |
| messages[agent_id] = response | |
| total_tokens += tokens | |
| execution_order.append(agent_id) | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| hidden_states[agent_id] = self._update_agent_hidden_state( | |
| agent, response, incoming_hidden, hidden_encoder | |
| ) | |
| result = StepResult( | |
| agent_id=agent_id, | |
| success=True, | |
| response=response, | |
| tokens_used=tokens, | |
| ) | |
| step_results[agent_id] = result | |
| plan.mark_completed(agent_id, tokens) | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| messages[agent_id] = f"[Error: {e}]" | |
| result = StepResult(agent_id=agent_id, success=False, error=str(e)) | |
| step_results[agent_id] = result | |
| execution_order.append(agent_id) | |
| plan.mark_failed(agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=str(e), | |
| agent_id=agent_id, | |
| recoverable=True, | |
| ) | |
| ) | |
| # Fallback on execution error | |
| attempts = fallback_attempts.get(agent_id, 0) | |
| if self._scheduler.should_use_fallback(step, result, attempts): | |
| for fb_agent in step.fallback_agents: | |
| if fb_agent not in plan.completed and fb_agent not in plan.failed: | |
| plan.insert_fallback(fb_agent, plan.current_index - 1) | |
| fallback_count += 1 | |
| break | |
| fallback_attempts[agent_id] = attempts + 1 | |
| # Topology pipeline: conditional edges + user hooks → plan | |
| if self._run_topology_pipeline( | |
| plan, | |
| agent_id, | |
| a_agents, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| execution_order, | |
| total_tokens, | |
| role_graph, | |
| ): | |
| topology_changed_count += 1 | |
| else: | |
| # Non-adaptive mode: execute the plan linearly | |
| exec_order = build_execution_order(a_agents, agent_ids, role_graph.role_sequence) | |
| for agent_id in exec_order: | |
| agent = agent_lookup.get(agent_id) | |
| if agent is None: | |
| continue | |
| incoming_ids = get_incoming_agents(agent_id, a_agents, agent_ids) | |
| incoming_messages = {aid: messages[aid] for aid in incoming_ids if aid in messages} | |
| incoming_hidden = self._get_incoming_hidden(agent_id, incoming_ids, hidden_states) | |
| memory_context = self._get_memory_context(agent_id) | |
| prompt = self._build_prompt(agent, query, incoming_messages, agent_names, memory_context) | |
| if incoming_hidden and incoming_hidden.metadata: | |
| context_hint = self._format_hidden_context(incoming_hidden) | |
| if context_hint: | |
| suffix = f"\n\n[Context: {context_hint}]" | |
| prompt = StructuredPrompt( | |
| text=prompt.text + suffix, | |
| messages=[ | |
| *prompt.messages[:-1], | |
| {"role": "user", "content": prompt.messages[-1]["content"] + suffix}, | |
| ], | |
| ) | |
| try: | |
| caller = self._get_caller_for_agent(agent_id, agent) | |
| if caller is None: | |
| error_msg = f"No LLM caller available for agent {agent_id}" | |
| messages[agent_id] = f"[Error: {error_msg}]" | |
| result = StepResult(agent_id=agent_id, success=False, error=error_msg) | |
| step_results[agent_id] = result | |
| execution_order.append(agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=error_msg, | |
| agent_id=agent_id, | |
| recoverable=True, | |
| ) | |
| ) | |
| continue | |
| response, tokens = self._run_agent_with_tools( | |
| caller=caller, | |
| prompt=prompt, | |
| agent=agent, | |
| ) | |
| messages[agent_id] = response | |
| total_tokens += tokens | |
| execution_order.append(agent_id) | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| hidden_states[agent_id] = self._update_agent_hidden_state( | |
| agent, response, incoming_hidden, hidden_encoder | |
| ) | |
| result = StepResult( | |
| agent_id=agent_id, | |
| success=True, | |
| response=response, | |
| tokens_used=tokens, | |
| ) | |
| step_results[agent_id] = result | |
| except (ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| messages[agent_id] = f"[Error: {e}]" | |
| result = StepResult(agent_id=agent_id, success=False, error=str(e)) | |
| step_results[agent_id] = result | |
| execution_order.append(agent_id) | |
| errors.append( | |
| ExecutionError( | |
| message=str(e), | |
| agent_id=agent_id, | |
| recoverable=True, | |
| ) | |
| ) | |
| final_id = self._determine_final_agent(final_agent_id, execution_order, messages) | |
| agent_states = self._build_agent_states(messages, agent_lookup) | |
| result = MACPResult( | |
| messages=messages, | |
| final_answer=messages.get(final_id, ""), | |
| final_agent_id=final_id, | |
| execution_order=execution_order, | |
| topology_changed_count=topology_changed_count, | |
| fallback_count=fallback_count, | |
| agent_states=agent_states, | |
| step_results=step_results, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| pruned_agents=pruned_agents, | |
| errors=errors if errors else None, | |
| hidden_states=hidden_states, | |
| ) | |
| self.config.enable_hidden_channels = original_hidden_setting | |
| return result | |
| def _format_hidden_context(self, hidden: HiddenState) -> str: | |
| """Format hidden state metadata for inclusion in the prompt.""" | |
| parts = [] | |
| if hidden.metadata and "last_response_length" in hidden.metadata: | |
| parts.append(f"previous response length: {hidden.metadata['last_response_length']}") | |
| return ", ".join(parts) if parts else "" | |
| # ========================================================================= | |
| # STREAMING EXECUTION METHODS | |
| # ========================================================================= | |
| def stream( | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None = None, | |
| *, | |
| update_states: bool | None = None, | |
| ) -> Iterator[AnyStreamEvent]: | |
| """ | |
| Stream execution events for real-time output. | |
| Yields events as agents are executed, allowing real-time monitoring | |
| and display of intermediate results. | |
| Args: | |
| role_graph: The RoleGraph to execute | |
| final_agent_id: Override which agent produces final answer | |
| update_states: Whether to update agent states after execution | |
| Yields: | |
| StreamEvent instances for each execution phase | |
| Example: | |
| for event in runner.stream(graph): | |
| if event.event_type == StreamEventType.AGENT_OUTPUT: | |
| print(f"{event.agent_id}: {event.content}") | |
| elif event.event_type == StreamEventType.TOKEN: | |
| print(event.token, end="", flush=True) | |
| """ | |
| if not self._has_any_caller() and self.streaming_llm_caller is None: | |
| msg = "llm_caller, llm_callers, llm_factory, or streaming_llm_caller required for streaming" | |
| raise ValueError(msg) | |
| if self.config.adaptive: | |
| yield from self._stream_adaptive(role_graph, final_agent_id, update_states=update_states) | |
| else: | |
| yield from self._stream_simple(role_graph, final_agent_id, update_states=update_states) | |
| async def astream( | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None = None, | |
| *, | |
| update_states: bool | None = None, | |
| ) -> AsyncIterator[AnyStreamEvent]: | |
| """ | |
| Async streaming execution for real-time output. | |
| Async version of stream() for use in async contexts. | |
| Args: | |
| role_graph: The RoleGraph to execute | |
| final_agent_id: Override which agent produces final answer | |
| update_states: Whether to update agent states after execution | |
| Yields: | |
| StreamEvent instances for each execution phase | |
| Example: | |
| async for event in runner.astream(graph): | |
| match event.event_type: | |
| case StreamEventType.AGENT_START: | |
| print(f"Agent {event.agent_id} started") | |
| case StreamEventType.AGENT_OUTPUT: | |
| print(f"Output: {event.content}") | |
| """ | |
| if not self._has_any_async_caller() and self.async_streaming_llm_caller is None: | |
| msg = "async_llm_caller, async_llm_callers, llm_factory, or async_streaming_llm_caller required" | |
| raise ValueError(msg) | |
| if self.config.adaptive: | |
| async for event in self._astream_adaptive(role_graph, final_agent_id, update_states=update_states): | |
| yield event | |
| else: | |
| async for event in self._astream_simple(role_graph, final_agent_id, update_states=update_states): | |
| yield event | |
| def _stream_simple( # noqa: PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| *, | |
| update_states: bool | None, | |
| ) -> Iterator[StreamEvent]: | |
| """Simple sequential streaming execution.""" | |
| run_id = str(uuid.uuid4())[:8] | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| yield RunEndEvent( | |
| run_id=run_id, success=True, final_answer="", final_agent_id="", total_time=time.time() - start_time | |
| ) | |
| return | |
| task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| del task_idx # not used directly in simple streaming | |
| exec_order = build_execution_order(a_agents, agent_ids, role_graph.role_sequence) | |
| self._init_memory(agent_ids) | |
| task_connected = self._get_task_connected_agents(role_graph) | |
| # Emit run start | |
| yield RunStartEvent( | |
| run_id=run_id, | |
| query=query, | |
| num_agents=len(exec_order), | |
| execution_order=exec_order, | |
| config_summary={ | |
| "adaptive": False, | |
| "timeout": self.config.timeout, | |
| "enable_memory": self.config.enable_memory, | |
| "broadcast_task_to_all": self.config.broadcast_task_to_all, | |
| }, | |
| ) | |
| messages: dict[str, str] = {} | |
| total_tokens = 0 | |
| errors: list[str] = [] | |
| for step_idx, agent_id in enumerate(exec_order): | |
| agent = agent_lookup.get(agent_id) | |
| if agent is None: | |
| continue | |
| agent_name = agent_names.get(agent_id, agent_id) | |
| incoming_ids = get_incoming_agents(agent_id, a_agents, agent_ids) | |
| incoming_messages = {aid: messages[aid] for aid in incoming_ids if aid in messages} | |
| include_query = self._should_include_query(agent_id, task_connected) | |
| memory_context = self._get_memory_context(agent_id) | |
| prompt = self._build_prompt( | |
| agent, query, incoming_messages, agent_names, memory_context, include_query=include_query | |
| ) | |
| # prompt is now a StructuredPrompt — use .text for preview/streaming | |
| prompt_text = prompt.text | |
| # Emit agent start | |
| yield AgentStartEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_name, | |
| step_index=step_idx, | |
| predecessors=incoming_ids, | |
| prompt_preview=prompt_text[: self.config.prompt_preview_length], | |
| ) | |
| step_start = time.time() | |
| try: | |
| # Get caller for this specific agent (multi-model support) | |
| caller = self._get_caller_for_agent(agent_id, agent) | |
| if caller is None: | |
| error_msg = f"No LLM caller available for agent {agent_id}" | |
| errors.append(f"{agent_id}: {error_msg}") | |
| messages[agent_id] = f"[Error: {error_msg}]" | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| error_type="ValueError", | |
| error_message=error_msg, | |
| will_retry=False, | |
| ) | |
| continue | |
| # Use streaming LLM if available and enabled | |
| if self.streaming_llm_caller and self.config.enable_token_streaming: | |
| response_parts: list[str] = [] | |
| token_idx = 0 | |
| for token in self.streaming_llm_caller(prompt_text): | |
| response_parts.append(token) | |
| yield TokenEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| token=token, | |
| token_index=token_idx, | |
| is_first=(token_idx == 0), | |
| is_last=False, | |
| ) | |
| token_idx += 1 | |
| # Mark last token | |
| if response_parts: | |
| yield TokenEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| token="", | |
| token_index=token_idx, | |
| is_first=False, | |
| is_last=True, | |
| ) | |
| response = "".join(response_parts) | |
| tokens = self.token_counter(prompt_text) + self.token_counter(response) | |
| else: | |
| # Use regular LLM caller for this agent (with tools support) | |
| # prompt (StructuredPrompt) is passed through — _run_agent_with_tools | |
| # dispatches via _call_llm when structured_llm_caller is available | |
| response, tokens = self._run_agent_with_tools( | |
| caller=caller, | |
| prompt=prompt, | |
| agent=agent, | |
| ) | |
| messages[agent_id] = response | |
| total_tokens += tokens | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| is_final = (step_idx == len(exec_order) - 1) or (agent_id == final_agent_id) | |
| yield AgentOutputEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_name, | |
| content=response, | |
| tokens_used=tokens, | |
| duration_ms=(time.time() - step_start) * 1000, | |
| is_final=is_final, | |
| ) | |
| except (TimeoutError, ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| error_msg = str(e) | |
| errors.append(f"{agent_id}: {error_msg}") | |
| messages[agent_id] = f"[Error: {e}]" | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| error_type=type(e).__name__, | |
| error_message=error_msg, | |
| will_retry=False, | |
| ) | |
| final_id = self._determine_final_agent(final_agent_id, exec_order, messages) | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| yield RunEndEvent( | |
| run_id=run_id, | |
| success=len(errors) == 0, | |
| final_answer=messages.get(final_id, ""), | |
| final_agent_id=final_id, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| executed_agents=list(messages.keys()), | |
| errors=errors, | |
| agent_states=agent_states, | |
| ) | |
| async def _astream_simple( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| *, | |
| update_states: bool | None, | |
| ) -> AsyncIterator[StreamEvent]: | |
| """ | |
| Async streaming execution with optional parallel support. | |
| When ``config.enable_parallel`` is ``True``, independent agents | |
| (those whose predecessors have all completed) are executed | |
| concurrently via ``asyncio.gather``. This is determined by | |
| :func:`get_parallel_groups` which partitions the topological | |
| order into dependency-based levels. | |
| When ``enable_parallel`` is ``False`` (or the graph is purely | |
| sequential), agents are executed one-by-one — identical to the | |
| synchronous ``_stream_simple`` but using async I/O. | |
| """ | |
| run_id = str(uuid.uuid4())[:8] | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| yield RunEndEvent( | |
| run_id=run_id, success=True, final_answer="", final_agent_id="", total_time=time.time() - start_time | |
| ) | |
| return | |
| _task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| exec_order = build_execution_order(a_agents, agent_ids, role_graph.role_sequence) | |
| self._init_memory(agent_ids) | |
| # When parallel execution is enabled, partition agents into | |
| # dependency-based groups so independent agents run concurrently. | |
| if self.config.enable_parallel: | |
| groups = get_parallel_groups(a_agents, agent_ids) | |
| else: | |
| # Each agent is its own group — strictly sequential. | |
| groups = [[aid] for aid in exec_order] | |
| yield RunStartEvent( | |
| run_id=run_id, | |
| query=query, | |
| num_agents=len(exec_order), | |
| execution_order=exec_order, | |
| config_summary={ | |
| "parallel": self.config.enable_parallel, | |
| }, | |
| ) | |
| messages: dict[str, str] = {} | |
| total_tokens = 0 | |
| errors: list[str] = [] | |
| step_idx = 0 | |
| for group_idx, group in enumerate(groups): | |
| # Filter out unknown agents | |
| group_agents = [aid for aid in group if agent_lookup.get(aid) is not None] | |
| if not group_agents: | |
| continue | |
| # Emit ParallelStartEvent when the group has >1 agent | |
| is_parallel_group = self.config.enable_parallel and len(group_agents) > 1 | |
| if is_parallel_group: | |
| yield ParallelStartEvent( | |
| run_id=run_id, | |
| agent_ids=group_agents, | |
| group_index=group_idx, | |
| ) | |
| # Emit AgentStartEvent for every agent in the group | |
| agent_prompts: dict[str, StructuredPrompt] = {} | |
| for agent_id in group_agents: | |
| agent = agent_lookup[agent_id] | |
| agent_name = agent_names.get(agent_id, agent_id) | |
| incoming_ids = get_incoming_agents(agent_id, a_agents, agent_ids) | |
| incoming_messages = {aid: messages[aid] for aid in incoming_ids if aid in messages} | |
| memory_context = self._get_memory_context(agent_id) | |
| prompt = self._build_prompt(agent, query, incoming_messages, agent_names, memory_context) | |
| agent_prompts[agent_id] = prompt | |
| yield AgentStartEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_name, | |
| step_index=step_idx, | |
| predecessors=incoming_ids, | |
| prompt_preview=prompt.text[: self.config.prompt_preview_length], | |
| ) | |
| step_idx += 1 | |
| # ── Execute the group ───────────────────────────────────── | |
| if is_parallel_group: | |
| # Run all agents in the group concurrently | |
| async def _call_agent(aid: str, prompt: StructuredPrompt) -> tuple[str, str | None, str | None]: | |
| """Returns (agent_id, response, error).""" | |
| try: | |
| resp = await asyncio.wait_for( | |
| self._acall_llm(self.async_llm_caller, prompt), | |
| timeout=self.config.timeout, | |
| ) | |
| except ( | |
| TimeoutError, | |
| ExecutionError, | |
| ValueError, | |
| TypeError, | |
| KeyError, | |
| RuntimeError, | |
| OSError, | |
| ) as exc: | |
| return (aid, None, str(exc)) | |
| else: | |
| return (aid, resp, None) | |
| results = await asyncio.gather(*[_call_agent(aid, agent_prompts[aid]) for aid in group_agents]) | |
| successful: list[str] = [] | |
| failed: list[str] = [] | |
| for agent_id, response, error in results: | |
| agent_name = agent_names.get(agent_id, agent_id) | |
| if error is None and response is not None: | |
| messages[agent_id] = response | |
| prompt = agent_prompts[agent_id] | |
| tokens = self.token_counter(prompt.text) + self.token_counter(response) | |
| total_tokens += tokens | |
| incoming_ids = get_incoming_agents(agent_id, a_agents, agent_ids) | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| successful.append(agent_id) | |
| yield AgentOutputEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_name, | |
| content=response, | |
| tokens_used=tokens, | |
| is_final=(agent_id == final_agent_id), | |
| ) | |
| else: | |
| errors.append(f"{agent_id}: {error}") | |
| messages[agent_id] = f"[Error: {error}]" | |
| failed.append(agent_id) | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| error_type="ExecutionError", | |
| error_message=error or "Unknown error", | |
| ) | |
| yield ParallelEndEvent( | |
| run_id=run_id, | |
| agent_ids=group_agents, | |
| group_index=group_idx, | |
| successful=successful, | |
| failed=failed, | |
| ) | |
| else: | |
| # Sequential execution — one agent at a time | |
| for agent_id in group_agents: | |
| agent_name = agent_names.get(agent_id, agent_id) | |
| prompt = agent_prompts[agent_id] | |
| step_start = time.time() | |
| try: | |
| # Use async streaming LLM if available | |
| if self.async_streaming_llm_caller and self.config.enable_token_streaming: | |
| response_parts: list[str] = [] | |
| token_idx = 0 | |
| async for token in self.async_streaming_llm_caller(prompt.text): | |
| response_parts.append(token) | |
| yield TokenEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| token=token, | |
| token_index=token_idx, | |
| is_first=(token_idx == 0), | |
| is_last=False, | |
| ) | |
| token_idx += 1 | |
| if response_parts: | |
| yield TokenEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| token="", | |
| token_index=token_idx, | |
| is_first=False, | |
| is_last=True, | |
| ) | |
| response = "".join(response_parts) | |
| else: | |
| if self.async_llm_caller is None and self.async_structured_llm_caller is None: | |
| error_msg = f"No async LLM caller available for agent {agent_id}" | |
| errors.append(f"{agent_id}: {error_msg}") | |
| messages[agent_id] = f"[Error: {error_msg}]" | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| error_type="ValueError", | |
| error_message=error_msg, | |
| ) | |
| continue | |
| response = await asyncio.wait_for( | |
| self._acall_llm(self.async_llm_caller, prompt), | |
| timeout=self.config.timeout, | |
| ) | |
| messages[agent_id] = response | |
| tokens = self.token_counter(prompt.text) + self.token_counter(response) | |
| total_tokens += tokens | |
| incoming_ids = get_incoming_agents(agent_id, a_agents, agent_ids) | |
| self._save_to_memory(agent_id, response, incoming_ids) | |
| is_final = (agent_id == final_agent_id) or ( | |
| group_idx == len(groups) - 1 and agent_id == group_agents[-1] | |
| ) | |
| yield AgentOutputEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| agent_name=agent_name, | |
| content=response, | |
| tokens_used=tokens, | |
| duration_ms=(time.time() - step_start) * 1000, | |
| is_final=is_final, | |
| ) | |
| except (TimeoutError, ExecutionError, ValueError, TypeError, KeyError, RuntimeError, OSError) as e: | |
| error_msg = str(e) | |
| errors.append(f"{agent_id}: {error_msg}") | |
| messages[agent_id] = f"[Error: {e}]" | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=agent_id, | |
| error_type=type(e).__name__, | |
| error_message=error_msg, | |
| ) | |
| final_id = self._determine_final_agent(final_agent_id, exec_order, messages) | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| yield RunEndEvent( | |
| run_id=run_id, | |
| success=len(errors) == 0, | |
| final_answer=messages.get(final_id, ""), | |
| final_agent_id=final_id, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| executed_agents=list(messages.keys()), | |
| errors=errors, | |
| agent_states=agent_states, | |
| ) | |
| def _stream_adaptive( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| *, | |
| update_states: bool | None, | |
| ) -> Iterator[StreamEvent]: | |
| """Adaptive streaming execution with conditional edges, pruning, and fallback.""" | |
| if self._scheduler is None: | |
| msg = "Scheduler not initialized for adaptive mode" | |
| raise ValueError(msg) | |
| run_id = str(uuid.uuid4())[:8] | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| yield RunEndEvent(run_id=run_id, success=True, total_time=0) | |
| return | |
| task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| self._init_memory(agent_ids) | |
| p_matrix = self._extract_p_matrix(role_graph, task_idx) | |
| edge_conditions = self._get_edge_conditions(role_graph) | |
| condition_ctx = ConditionContext( | |
| source_agent="", | |
| target_agent="", | |
| messages={}, | |
| step_results={}, | |
| query=query, | |
| ) | |
| plan = self._scheduler.build_plan( | |
| a_agents, | |
| agent_ids, | |
| p_matrix, | |
| end_agent=final_agent_id, | |
| edge_conditions=edge_conditions, | |
| condition_context=condition_ctx, | |
| ) | |
| yield RunStartEvent( | |
| run_id=run_id, | |
| query=query, | |
| num_agents=len(agent_ids), | |
| execution_order=plan.execution_order, | |
| config_summary={"adaptive": True, "policy": self.config.routing_policy.value}, | |
| ) | |
| messages: dict[str, str] = {} | |
| step_results: dict[str, StepResult] = {} | |
| execution_order: list[str] = [] | |
| fallback_attempts: dict[str, int] = {} | |
| topology_changed_count = 0 | |
| errors: list[str] = [] | |
| total_tokens = 0 | |
| step_idx = 0 | |
| while not plan.is_complete: | |
| step = plan.get_current_step() | |
| if step is None: | |
| break | |
| # Skip agents whose conditions were not met | |
| if step.agent_id in plan.condition_skipped: | |
| plan.advance() | |
| continue | |
| # Check for pruning | |
| should_prune, reason = self._scheduler.should_prune( | |
| step, plan, step_results.get(execution_order[-1]) if execution_order else None | |
| ) | |
| if should_prune: | |
| plan.mark_skipped(step.agent_id) | |
| yield PruneEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| reason=reason, | |
| ) | |
| continue | |
| agent = agent_lookup.get(step.agent_id) | |
| if agent is None: | |
| plan.advance() | |
| continue | |
| agent_name = agent_names.get(step.agent_id, step.agent_id) | |
| incoming = {p: messages[p] for p in step.predecessors if p in messages} | |
| memory_context = self._get_memory_context(step.agent_id) | |
| prompt = self._build_prompt(agent, query, incoming, agent_names, memory_context) | |
| yield AgentStartEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| agent_name=agent_name, | |
| step_index=step_idx, | |
| predecessors=step.predecessors, | |
| prompt_preview=prompt.text[: self.config.prompt_preview_length], | |
| ) | |
| step_start = time.time() | |
| result = self._execute_step(step, messages, agent_lookup, agent_names, query) | |
| step_results[step.agent_id] = result | |
| execution_order.append(step.agent_id) | |
| if result.success: | |
| messages[step.agent_id] = result.response or "" | |
| plan.mark_completed(step.agent_id, result.tokens_used) | |
| total_tokens += result.tokens_used | |
| self._save_to_memory(step.agent_id, result.response or "", step.predecessors) | |
| yield AgentOutputEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| agent_name=agent_name, | |
| content=result.response or "", | |
| tokens_used=result.tokens_used, | |
| duration_ms=(time.time() - step_start) * 1000, | |
| ) | |
| else: | |
| plan.mark_failed(step.agent_id) | |
| errors.append(f"{step.agent_id}: {result.error}") | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| error_type="ExecutionError", | |
| error_message=result.error or "Unknown error", | |
| ) | |
| # Handle fallback | |
| attempts = fallback_attempts.get(step.agent_id, 0) | |
| if self._scheduler.should_use_fallback(step, result, attempts): | |
| for fb_agent in step.fallback_agents: | |
| if fb_agent not in plan.completed and fb_agent not in plan.failed: | |
| plan.insert_fallback(fb_agent, plan.current_index - 1) | |
| yield FallbackEvent( | |
| run_id=run_id, | |
| failed_agent_id=step.agent_id, | |
| fallback_agent_id=fb_agent, | |
| attempt=attempts + 1, | |
| ) | |
| break | |
| fallback_attempts[step.agent_id] = attempts + 1 | |
| # Topology pipeline: conditional edges + user hooks → plan | |
| old_remaining = [s.agent_id for s in plan.remaining_steps] | |
| if self._run_topology_pipeline( | |
| plan, | |
| step.agent_id, | |
| a_agents, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| execution_order, | |
| total_tokens, | |
| role_graph, | |
| ): | |
| topology_changed_count += 1 | |
| new_remaining = [s.agent_id for s in plan.remaining_steps] | |
| if old_remaining != new_remaining: | |
| yield TopologyChangedEvent( | |
| run_id=run_id, | |
| reason="Topology pipeline: conditional edges", | |
| old_remaining=old_remaining, | |
| new_remaining=new_remaining, | |
| change_count=topology_changed_count, | |
| ) | |
| step_idx += 1 | |
| final_id = self._determine_final_agent(final_agent_id, execution_order, messages) | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| yield RunEndEvent( | |
| run_id=run_id, | |
| success=len(errors) == 0, | |
| final_answer=messages.get(final_id, ""), | |
| final_agent_id=final_id, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| executed_agents=execution_order, | |
| errors=errors, | |
| agent_states=agent_states, | |
| ) | |
| async def _astream_adaptive( # noqa: PLR0912, PLR0915 | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None, | |
| *, | |
| update_states: bool | None, | |
| ) -> AsyncIterator[StreamEvent]: | |
| """Async adaptive streaming with parallel execution support.""" | |
| if self._scheduler is None: | |
| msg = "Scheduler not initialized for adaptive mode" | |
| raise ValueError(msg) | |
| run_id = str(uuid.uuid4())[:8] | |
| start_time = time.time() | |
| base = self._prepare_base_context(role_graph) | |
| if base is None: | |
| yield RunEndEvent(run_id=run_id, success=True, total_time=0) | |
| return | |
| task_idx, a_agents, agent_ids, query, agent_lookup, agent_names = base | |
| self._init_memory(agent_ids) | |
| p_matrix = self._extract_p_matrix(role_graph, task_idx) | |
| edge_conditions = self._get_edge_conditions(role_graph) | |
| condition_ctx = ConditionContext( | |
| source_agent="", | |
| target_agent="", | |
| messages={}, | |
| step_results={}, | |
| query=query, | |
| ) | |
| plan = self._scheduler.build_plan( | |
| a_agents, | |
| agent_ids, | |
| p_matrix, | |
| end_agent=final_agent_id, | |
| edge_conditions=edge_conditions, | |
| condition_context=condition_ctx, | |
| ) | |
| yield RunStartEvent( | |
| run_id=run_id, | |
| query=query, | |
| num_agents=len(agent_ids), | |
| execution_order=plan.execution_order, | |
| config_summary={ | |
| "adaptive": True, | |
| "parallel": self.config.enable_parallel, | |
| "policy": self.config.routing_policy.value, | |
| }, | |
| ) | |
| messages: dict[str, str] = {} | |
| step_results: dict[str, StepResult] = {} | |
| execution_order: list[str] = [] | |
| fallback_attempts: dict[str, int] = {} | |
| topology_changed_count = 0 | |
| errors: list[str] = [] | |
| total_tokens = 0 | |
| group_idx = 0 | |
| while not plan.is_complete: | |
| parallel_group = self._get_parallel_group(plan, messages.keys()) | |
| if not parallel_group: | |
| break | |
| # Filter condition-skipped and pruned steps | |
| valid_steps = [] | |
| for step in parallel_group: | |
| # Skip agents whose conditions were not met | |
| if step.agent_id in plan.condition_skipped: | |
| plan.advance() | |
| continue | |
| should_prune, reason = self._scheduler.should_prune(step, plan, None) | |
| if should_prune: | |
| plan.mark_skipped(step.agent_id) | |
| yield PruneEvent(run_id=run_id, agent_id=step.agent_id, reason=reason) | |
| else: | |
| valid_steps.append(step) | |
| if not valid_steps: | |
| # All steps in the group are skipped or pruned — break to avoid stalling. | |
| break | |
| # Emit parallel start event | |
| if self.config.enable_parallel and len(valid_steps) > 1: | |
| yield ParallelStartEvent( | |
| run_id=run_id, | |
| agent_ids=[s.agent_id for s in valid_steps], | |
| group_index=group_idx, | |
| ) | |
| # Emit agent start events | |
| for step in valid_steps: | |
| agent_name = agent_names.get(step.agent_id, step.agent_id) | |
| incoming = {p: messages[p] for p in step.predecessors if p in messages} | |
| agent = agent_lookup.get(step.agent_id) | |
| if agent: | |
| memory_context = self._get_memory_context(step.agent_id) | |
| prompt = self._build_prompt(agent, query, incoming, agent_names, memory_context) | |
| yield AgentStartEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| agent_name=agent_name, | |
| predecessors=step.predecessors, | |
| prompt_preview=prompt.text[: self.config.prompt_preview_length], | |
| ) | |
| # Execute steps (parallel or sequential) | |
| if self.config.enable_parallel and len(valid_steps) > 1: | |
| results = await self._execute_parallel(valid_steps, messages, agent_lookup, agent_names, query) | |
| else: | |
| results = [] | |
| for step in valid_steps: | |
| r = await self._execute_step_async(step, messages, agent_lookup, agent_names, query) | |
| results.append((step, r)) | |
| # Process results and emit events | |
| successful: list[str] = [] | |
| failed: list[str] = [] | |
| for step, result in results: | |
| step_results[step.agent_id] = result | |
| execution_order.append(step.agent_id) | |
| agent_name = agent_names.get(step.agent_id, step.agent_id) | |
| if result.success: | |
| messages[step.agent_id] = result.response or "" | |
| plan.mark_completed(step.agent_id, result.tokens_used) | |
| total_tokens += result.tokens_used | |
| self._save_to_memory(step.agent_id, result.response or "", step.predecessors) | |
| successful.append(step.agent_id) | |
| yield AgentOutputEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| agent_name=agent_name, | |
| content=result.response or "", | |
| tokens_used=result.tokens_used, | |
| ) | |
| else: | |
| plan.mark_failed(step.agent_id) | |
| errors.append(f"{step.agent_id}: {result.error}") | |
| failed.append(step.agent_id) | |
| yield AgentErrorEvent( | |
| run_id=run_id, | |
| agent_id=step.agent_id, | |
| error_message=result.error or "Unknown error", | |
| ) | |
| # Handle fallback | |
| attempts = fallback_attempts.get(step.agent_id, 0) | |
| if self._scheduler.should_use_fallback(step, result, attempts): | |
| for fb_agent in step.fallback_agents: | |
| if fb_agent not in plan.completed and fb_agent not in plan.failed: | |
| plan.insert_fallback(fb_agent, plan.current_index - 1) | |
| yield FallbackEvent( | |
| run_id=run_id, | |
| failed_agent_id=step.agent_id, | |
| fallback_agent_id=fb_agent, | |
| attempt=attempts + 1, | |
| ) | |
| break | |
| fallback_attempts[step.agent_id] = attempts + 1 | |
| # Emit parallel end event | |
| if self.config.enable_parallel and len(valid_steps) > 1: | |
| yield ParallelEndEvent( | |
| run_id=run_id, | |
| agent_ids=[s.agent_id for s in valid_steps], | |
| group_index=group_idx, | |
| successful=successful, | |
| failed=failed, | |
| ) | |
| # Topology pipeline for each executed agent in the group | |
| old_remaining = [s.agent_id for s in plan.remaining_steps] | |
| for step, _result in results: | |
| if await self._arun_topology_pipeline( | |
| plan, | |
| step.agent_id, | |
| a_agents, | |
| agent_ids, | |
| step_results, | |
| messages, | |
| query, | |
| execution_order, | |
| total_tokens, | |
| role_graph, | |
| ): | |
| topology_changed_count += 1 | |
| new_remaining = [s.agent_id for s in plan.remaining_steps] | |
| if old_remaining != new_remaining: | |
| yield TopologyChangedEvent( | |
| run_id=run_id, | |
| reason="Topology pipeline: conditional edges", | |
| old_remaining=old_remaining, | |
| new_remaining=new_remaining, | |
| change_count=topology_changed_count, | |
| ) | |
| group_idx += 1 | |
| final_id = self._determine_final_agent(final_agent_id, execution_order, messages) | |
| do_update = update_states if update_states is not None else self.config.update_states | |
| agent_states = self._build_agent_states(messages, agent_lookup) if do_update else None | |
| yield RunEndEvent( | |
| run_id=run_id, | |
| success=len(errors) == 0, | |
| final_answer=messages.get(final_id, ""), | |
| final_agent_id=final_id, | |
| total_tokens=total_tokens, | |
| total_time=time.time() - start_time, | |
| executed_agents=execution_order, | |
| errors=errors, | |
| agent_states=agent_states, | |
| ) | |
| def stream_to_result( | |
| self, | |
| role_graph: Any, | |
| final_agent_id: str | None = None, | |
| *, | |
| update_states: bool | None = None, | |
| ) -> tuple[Iterator[StreamEvent], MACPResult]: | |
| """ | |
| Stream execution and also return final MACPResult. | |
| Useful when you want both streaming display and complete result. | |
| Returns: | |
| Tuple of (event iterator, MACPResult) | |
| Example: | |
| stream, result_future = runner.stream_to_result(graph) | |
| for event in stream: | |
| print(event) | |
| result = result_future # Available after stream exhausted | |
| """ | |
| events: list[StreamEvent] = [] | |
| messages: dict[str, str] = {} | |
| final_answer = "" | |
| final_agent = "" | |
| execution_order: list[str] = [] | |
| total_tokens = 0 | |
| total_time = 0.0 | |
| errors_list: list[str] = [] | |
| def collecting_stream() -> Iterator[StreamEvent]: | |
| nonlocal final_answer, final_agent, total_tokens, total_time, errors_list | |
| for event in self.stream(role_graph, final_agent_id, update_states=update_states): | |
| events.append(event) | |
| if isinstance(event, AgentOutputEvent): | |
| messages[event.agent_id] = event.content | |
| execution_order.append(event.agent_id) | |
| elif isinstance(event, RunEndEvent): | |
| final_answer = event.final_answer | |
| final_agent = event.final_agent_id | |
| total_tokens = event.total_tokens | |
| total_time = event.total_time | |
| errors_list = event.errors | |
| yield event | |
| stream = collecting_stream() | |
| # Create a lazy result that becomes valid after stream is exhausted | |
| class LazyResult: | |
| def __init__(self, runner: MACPRunner): | |
| self._runner = runner | |
| self._result: MACPResult | None = None | |
| def __getattr__(self, name: str) -> Any: | |
| if self._result is None: | |
| self._result = MACPResult( | |
| messages=messages, | |
| final_answer=final_answer, | |
| final_agent_id=final_agent, | |
| execution_order=execution_order, | |
| total_tokens=total_tokens, | |
| total_time=total_time, | |
| errors=[ExecutionError(message=e, agent_id="", recoverable=False) for e in errors_list] | |
| if errors_list | |
| else None, | |
| ) | |
| return getattr(self._result, name) | |
| # LazyResult provides a proxy to access result attributes lazily | |
| lazy_result: MACPResult = LazyResult(self) # type: ignore[assignment] | |
| return stream, lazy_result | |