Spaces:
Running
Running
| """Execution orchestration: bridges the web API to MACPRunner.""" | |
| import asyncio | |
| import os | |
| import uuid | |
| from datetime import UTC, datetime | |
| from typing import Any | |
| from backend.models.execution import ( | |
| EarlyStopConditionSchema, | |
| EarlyStopType, | |
| LLMProviderConfig, | |
| RunnerConfigSchema, | |
| TopologyHookSchema, | |
| TopologyHookType, | |
| ) | |
| from backend.services.graph_service import build_gmas_graph | |
| from backend.services.storage_service import storage | |
| class RunState: | |
| """Tracks a single execution run.""" | |
| def __init__(self, run_id: str, graph_data: dict[str, Any], task_query: str): | |
| self.run_id = run_id | |
| self.graph_data = graph_data | |
| self.task_query = task_query | |
| self.status: str = "pending" | |
| self.queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() | |
| self.task: asyncio.Task | None = None | |
| self.result: dict[str, Any] | None = None | |
| self.events: list[dict[str, Any]] = [] | |
| self.started_at: str = datetime.now(UTC).isoformat() | |
| self.completed_at: str | None = None | |
| self.cancelled: bool = False | |
| # In-memory store for active runs | |
| _active_runs: dict[str, RunState] = {} | |
| def get_active_run(run_id: str) -> RunState | None: | |
| return _active_runs.get(run_id) | |
| async def start_execution( | |
| graph_data: dict[str, Any], | |
| task_query: str, | |
| config: RunnerConfigSchema | None = None, | |
| llm_provider: LLMProviderConfig | None = None, | |
| ) -> str: | |
| """Start an async execution and return the run_id.""" | |
| run_id = str(uuid.uuid4())[:12] | |
| run_state = RunState(run_id=run_id, graph_data=graph_data, task_query=task_query) | |
| _active_runs[run_id] = run_state | |
| run_state.task = asyncio.create_task(_run_execution(run_state, config, llm_provider)) | |
| return run_id | |
| async def _run_execution( | |
| run_state: RunState, | |
| config_schema: RunnerConfigSchema | None, | |
| llm_provider: LLMProviderConfig | None, | |
| ) -> None: | |
| """Execute the workflow via arun_round() with callback-based event emission. | |
| Uses ``runner.arun_round()`` instead of ``runner.astream()`` because only | |
| the ``arun_round()`` code path supports early stopping and topology hooks. | |
| A ``BaseCallbackHandler`` subclass bridges each callback into the | |
| WebSocket event queue so the frontend still receives real-time updates. | |
| """ | |
| run_state.status = "running" | |
| try: | |
| from callbacks.base import BaseCallbackHandler | |
| from execution import MACPRunner | |
| from execution.runner import RunnerConfig | |
| # ----- callback handler that pushes events to the WS queue ----- | |
| class _EventBridge(BaseCallbackHandler): | |
| """Converts MACPRunner callbacks into event dicts for the frontend.""" | |
| def _emit(self, event: dict[str, Any]) -> None: | |
| event.setdefault("run_id", run_state.run_id) | |
| event.setdefault("timestamp", datetime.now(UTC).isoformat()) | |
| run_state.events.append(event) | |
| run_state.queue.put_nowait(event) | |
| # Run lifecycle | |
| def on_run_start(self, *, run_id, query, num_agents=0, | |
| execution_order=None, **kw): | |
| self._emit({ | |
| "event_type": "run_start", | |
| "num_agents": num_agents, | |
| "execution_order": execution_order or [], | |
| }) | |
| def on_run_end(self, *, run_id, output, success=True, error=None, | |
| total_tokens=0, total_time_ms=0.0, | |
| executed_agents=None, **kw): | |
| self._emit({ | |
| "event_type": "run_end", | |
| "final_answer": output, | |
| "success": success, | |
| "total_tokens": total_tokens, | |
| "total_time": total_time_ms / 1000.0, | |
| "executed_agents": executed_agents or [], | |
| "error": str(error) if error else None, | |
| }) | |
| # Agent lifecycle | |
| def on_agent_start(self, *, run_id, agent_id, agent_name="", | |
| step_index=0, prompt="", predecessors=None, **kw): | |
| self._emit({ | |
| "event_type": "agent_start", | |
| "agent_id": agent_id, | |
| "agent_name": agent_name, | |
| }) | |
| def on_agent_end(self, *, run_id, agent_id, output, agent_name="", | |
| step_index=0, tokens_used=0, duration_ms=0.0, | |
| is_final=False, **kw): | |
| self._emit({ | |
| "event_type": "agent_output", | |
| "agent_id": agent_id, | |
| "agent_name": agent_name, | |
| "content": output, | |
| "tokens_used": tokens_used, | |
| "duration_ms": duration_ms, | |
| }) | |
| def on_agent_error(self, error, *, run_id, agent_id, error_type="", | |
| will_retry=False, attempt=0, max_attempts=0, **kw): | |
| self._emit({ | |
| "event_type": "agent_error", | |
| "agent_id": agent_id, | |
| "error_type": error_type, | |
| "error_message": str(error), | |
| "will_retry": will_retry, | |
| }) | |
| # Topology / dynamic graph | |
| def on_topology_changed(self, *, run_id, reason, old_remaining, | |
| new_remaining, change_count=0, **kw): | |
| self._emit({ | |
| "event_type": "topology_changed", | |
| "content": reason, | |
| }) | |
| # Prune / fallback | |
| def on_prune(self, *, run_id, agent_id, reason, **kw): | |
| self._emit({ | |
| "event_type": "prune", | |
| "agent_id": agent_id, | |
| "content": reason, | |
| }) | |
| def on_fallback(self, *, run_id, failed_agent_id, fallback_agent_id, | |
| reason="", **kw): | |
| self._emit({ | |
| "event_type": "fallback", | |
| "agent_id": failed_agent_id, | |
| "content": f"Fallback to {fallback_agent_id}: {reason}", | |
| }) | |
| # Parallel execution | |
| def on_parallel_start(self, *, run_id, agent_ids, group_index=0, **kw): | |
| self._emit({ | |
| "event_type": "parallel_start", | |
| "agent_ids": agent_ids, | |
| }) | |
| def on_parallel_end(self, *, run_id, agent_ids, group_index=0, | |
| successful=None, failed=None, **kw): | |
| self._emit({ | |
| "event_type": "parallel_end", | |
| "agent_ids": agent_ids, | |
| }) | |
| # Memory | |
| def on_memory_read(self, *, run_id, agent_id, entries_count=0, | |
| keys=None, **kw): | |
| self._emit({ | |
| "event_type": "memory_read", | |
| "agent_id": agent_id, | |
| }) | |
| def on_memory_write(self, *, run_id, agent_id, key, value_size=0, **kw): | |
| self._emit({ | |
| "event_type": "memory_write", | |
| "agent_id": agent_id, | |
| }) | |
| # Budget | |
| def on_budget_warning(self, *, run_id, budget_type, current, limit, | |
| ratio=0.0, **kw): | |
| self._emit({ | |
| "event_type": "budget_warning", | |
| "content": f"{budget_type}: {current}/{limit}", | |
| }) | |
| def on_budget_exceeded(self, *, run_id, budget_type, current, limit, | |
| action_taken="", **kw): | |
| self._emit({ | |
| "event_type": "budget_exceeded", | |
| "content": f"{budget_type}: {current}/{limit} — {action_taken}", | |
| }) | |
| handler = _EventBridge() | |
| # Build graph | |
| graph_data = run_state.graph_data | |
| if run_state.task_query: | |
| graph_data["task_query"] = run_state.task_query | |
| graph = build_gmas_graph(graph_data) | |
| # Build runner config (always include the callback handler) | |
| runner_config = RunnerConfig(callbacks=[handler]) | |
| if config_schema: | |
| early_stops = _build_early_stop_conditions(config_schema.early_stop_conditions) | |
| topo_hooks = _build_topology_hooks(config_schema.topology_hooks) | |
| enable_dyn = config_schema.enable_dynamic_topology or bool(early_stops) or bool(topo_hooks) | |
| runner_config = RunnerConfig( | |
| timeout=config_schema.timeout, | |
| adaptive=config_schema.adaptive, | |
| enable_parallel=config_schema.enable_parallel, | |
| max_parallel_size=config_schema.max_parallel_size, | |
| max_retries=config_schema.max_retries, | |
| enable_memory=config_schema.enable_memory, | |
| memory_context_limit=config_schema.memory_context_limit, | |
| broadcast_task_to_all=config_schema.broadcast_task_to_all, | |
| enable_dynamic_topology=enable_dyn, | |
| max_tool_iterations=config_schema.max_tool_iterations, | |
| early_stop_conditions=early_stops, | |
| async_topology_hooks=topo_hooks, | |
| callbacks=[handler], | |
| ) | |
| # Resolve LLM caller | |
| sync_caller = _build_llm_caller(llm_provider) | |
| async def async_caller(prompt: str) -> str: | |
| return await asyncio.to_thread(sync_caller, prompt) | |
| runner = MACPRunner(async_llm_caller=async_caller, config=runner_config) | |
| # Use arun_round() — the only code path that supports early stopping & topology hooks | |
| result = await runner.arun_round(graph) | |
| # Emit early_stop event if the runner stopped early | |
| if result.early_stopped: | |
| handler._emit({ | |
| "event_type": "early_stop", | |
| "content": result.early_stop_reason or "Early stop triggered", | |
| }) | |
| run_state.status = "completed" | |
| run_state.completed_at = datetime.now(UTC).isoformat() | |
| # Extract result from the last run_end event (emitted by the handler) | |
| for ev in reversed(run_state.events): | |
| if ev.get("event_type") == "run_end": | |
| run_state.result = ev | |
| break | |
| except asyncio.CancelledError: | |
| run_state.status = "cancelled" | |
| await run_state.queue.put({"event_type": "cancelled", "run_id": run_state.run_id}) | |
| except Exception as exc: | |
| run_state.status = "error" | |
| error_event = { | |
| "event_type": "error", | |
| "run_id": run_state.run_id, | |
| "error": str(exc), | |
| "timestamp": datetime.now(UTC).isoformat(), | |
| } | |
| run_state.events.append(error_event) | |
| await run_state.queue.put(error_event) | |
| finally: | |
| run_state.completed_at = run_state.completed_at or datetime.now(UTC).isoformat() | |
| await run_state.queue.put(None) # Sentinel | |
| # Persist run | |
| _persist_run(run_state) | |
| def _build_llm_caller(provider: LLMProviderConfig | None): | |
| """Build an LLM caller from provider config.""" | |
| if provider is None: | |
| # Return a mock caller for testing | |
| def mock_caller(prompt: str) -> str: | |
| return f"[Mock LLM Response] Received prompt of {len(prompt)} characters." | |
| return mock_caller | |
| # Resolve API key | |
| api_key = provider.api_key | |
| if api_key.startswith("$"): | |
| api_key = os.environ.get(api_key[1:], "") | |
| base_url = provider.base_url | |
| model = provider.default_model or "gpt-4" | |
| try: | |
| from openai import OpenAI | |
| client = OpenAI(api_key=api_key, base_url=base_url) | |
| def openai_caller(prompt: str) -> str: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| ) | |
| return response.choices[0].message.content or "" | |
| return openai_caller | |
| except ImportError: | |
| def fallback_caller(prompt: str) -> str: | |
| return f"[No LLM client available] Prompt length: {len(prompt)}" | |
| return fallback_caller | |
| def _build_early_stop_conditions(schemas: list[EarlyStopConditionSchema]) -> list: | |
| """Convert UI early-stop schemas into framework EarlyStopCondition objects.""" | |
| if not schemas: | |
| return [] | |
| from execution.runner import EarlyStopCondition | |
| conditions = [] | |
| for s in schemas: | |
| if s.type == EarlyStopType.KEYWORD and s.keyword: | |
| conditions.append(EarlyStopCondition.on_keyword(s.keyword)) | |
| elif s.type == EarlyStopType.TOKEN_LIMIT and s.max_tokens: | |
| conditions.append(EarlyStopCondition.on_token_limit(s.max_tokens)) | |
| elif s.type == EarlyStopType.AGENT_COUNT and s.max_agents: | |
| conditions.append(EarlyStopCondition.on_agent_count(s.max_agents)) | |
| return conditions | |
| def _build_topology_hooks(schemas: list[TopologyHookSchema]) -> list: | |
| """Convert UI topology-hook schemas into async hook callables. | |
| The ``arun()`` execution path reads from ``async_topology_hooks``, | |
| so every hook must be an async callable ``(StepContext, RoleGraph) -> TopologyAction | None``. | |
| """ | |
| if not schemas: | |
| return [] | |
| from execution.runner import TopologyAction | |
| hooks = [] | |
| for s in schemas: | |
| if s.type == TopologyHookType.STOP_ON_KEYWORD and s.keyword: | |
| kw = s.keyword | |
| async def _stop_hook(ctx, _graph, _kw=kw): | |
| if _kw.lower() in (ctx.response or "").lower(): | |
| return TopologyAction(early_stop=True, early_stop_reason=f"Keyword '{_kw}' found") | |
| return None | |
| hooks.append(_stop_hook) | |
| elif s.type == TopologyHookType.SKIP_ON_TOKEN_BUDGET and s.token_threshold: | |
| threshold = s.token_threshold | |
| async def _budget_hook(ctx, _graph, _th=threshold): | |
| if ctx.total_tokens > _th: | |
| return TopologyAction(skip_agents=list(ctx.remaining_agents)) | |
| return None | |
| hooks.append(_budget_hook) | |
| elif s.type == TopologyHookType.FORCE_REVIEWER_ON_ERROR and s.reviewer_agent_id: | |
| reviewer = s.reviewer_agent_id | |
| async def _reviewer_hook(ctx, _graph, _rev=reviewer): | |
| if ctx.step_result and not getattr(ctx.step_result, "success", True): | |
| return TopologyAction(force_agents=[_rev]) | |
| return None | |
| hooks.append(_reviewer_hook) | |
| elif s.type == TopologyHookType.INSERT_CHAIN_ON_KEYWORD and s.keyword and s.source_agent and s.target_agent: | |
| kw, src, tgt = s.keyword, s.source_agent, s.target_agent | |
| async def _insert_hook(ctx, _graph, _kw=kw, _src=src, _tgt=tgt): | |
| if _kw.lower() in (ctx.response or "").lower(): | |
| return TopologyAction(insert_chains=[(_src, _tgt)]) | |
| return None | |
| hooks.append(_insert_hook) | |
| elif s.type == TopologyHookType.ADD_EDGE_ON_KEYWORD and s.keyword and s.source_agent and s.target_agent: | |
| kw, src, tgt, w = s.keyword, s.source_agent, s.target_agent, s.weight | |
| async def _add_edge_hook(ctx, _graph, _kw=kw, _src=src, _tgt=tgt, _w=w): | |
| if _kw.lower() in (ctx.response or "").lower(): | |
| return TopologyAction(add_edges=[(_src, _tgt, _w)]) | |
| return None | |
| hooks.append(_add_edge_hook) | |
| elif s.type == TopologyHookType.REDIRECT_END_ON_KEYWORD and s.keyword and s.target_agent: | |
| kw, tgt = s.keyword, s.target_agent | |
| async def _redirect_hook(ctx, _graph, _kw=kw, _tgt=tgt): | |
| if _kw.lower() in (ctx.response or "").lower(): | |
| return TopologyAction(new_end_agent=_tgt) | |
| return None | |
| hooks.append(_redirect_hook) | |
| elif s.type == TopologyHookType.SKIP_AGENT_ON_KEYWORD and s.keyword and s.target_agent: | |
| kw, tgt = s.keyword, s.target_agent | |
| async def _skip_hook(ctx, _graph, _kw=kw, _tgt=tgt): | |
| if _kw.lower() in (ctx.response or "").lower(): | |
| return TopologyAction(skip_agents=[_tgt]) | |
| return None | |
| hooks.append(_skip_hook) | |
| return hooks | |
| def cancel_execution(run_id: str) -> bool: | |
| """Cancel a running execution.""" | |
| run_state = _active_runs.get(run_id) | |
| if run_state and run_state.task and not run_state.task.done(): | |
| run_state.cancelled = True | |
| run_state.task.cancel() | |
| return True | |
| return False | |
| def get_run_history() -> list[dict[str, Any]]: | |
| """Get all persisted runs.""" | |
| return storage.list_runs() | |
| def get_run_detail(run_id: str) -> dict[str, Any] | None: | |
| """Get a specific run's details.""" | |
| # Check active first | |
| active = _active_runs.get(run_id) | |
| if active: | |
| return { | |
| "run_id": active.run_id, | |
| "status": active.status, | |
| "events": active.events, | |
| "result": active.result, | |
| "started_at": active.started_at, | |
| "completed_at": active.completed_at, | |
| } | |
| return storage.get_run(run_id) | |
| def _persist_run(run_state: RunState) -> None: | |
| """Save completed run to disk.""" | |
| storage.save_run( | |
| run_state.run_id, | |
| { | |
| "run_id": run_state.run_id, | |
| "status": run_state.status, | |
| "task_query": run_state.task_query, | |
| "events": run_state.events, | |
| "result": run_state.result, | |
| "started_at": run_state.started_at, | |
| "completed_at": run_state.completed_at, | |
| }, | |
| ) | |