"""Parallel sub-agent fan-out for the Python engine (mirror of the Rust CLI's `task_batch`). Where the Router runs ONE task through a tier ladder, fan-out runs MANY independent tasks at once: each gets its own SmallCodeAgent + fresh Workspace and they run concurrently via asyncio.gather, bounded so local inference isn't oversubscribed. Use it for independent work — exploring/solving several things in parallel — when each subtask doesn't depend on the others' output. Cheap when each agent is a small local (e.g. the fine-tuned 1.5B) model: wall-clock is ~the slowest job, not the sum. """ from __future__ import annotations import asyncio from collections.abc import AsyncIterator from dataclasses import dataclass, field from .agent import SmallCodeAgent, Step from .config import Preset, load_preset from .live_run import LiveFrame from .router import _verify from .trace_collector import TraceEvent from .ui_trace import merge_step_metadata MAX_CONCURRENCY = 4 @dataclass class FanoutResult: index: int task: str final: str steps: list[Step] model: str verified: bool files: dict[str, str] = field(default_factory=dict) error: str | None = None trace_events: list[TraceEvent] = field(default_factory=list) agent: SmallCodeAgent | None = None async def fan_out(tasks: list[str], preset: Preset | None = None, model: str | None = None, max_steps: int = 12, concurrency: int = MAX_CONCURRENCY) -> list[FanoutResult]: """Run `tasks` concurrently, each in its own agent/workspace. `model` defaults to the preset's entry tier (the cheap small model — the natural choice for fanning out). Results are returned in input order. """ results: list[FanoutResult] = [] async for frame in fan_out_live(tasks, preset=preset, model=model, max_steps=max_steps, concurrency=concurrency): if frame.done and isinstance(frame.result, list): results = frame.result return results async def fan_out_live( tasks: list[str], preset: Preset | None = None, model: str | None = None, max_steps: int = 12, concurrency: int = MAX_CONCURRENCY, poll_interval: float = 0.35, ) -> AsyncIterator[LiveFrame]: """Yield aggregate live frames while fan-out jobs run.""" if not tasks: yield LiveFrame(done=True, result=[]) return preset = preset or load_preset() model = model or preset.default_model sem = asyncio.Semaphore(max(1, concurrency)) agents: list[SmallCodeAgent] = [] for i, t in enumerate(tasks): agents.append(SmallCodeAgent(preset=preset, model=model, max_steps=max_steps)) async def _job(index: int, task: str, agent: SmallCodeAgent) -> FanoutResult: async with sem: try: final, steps = await agent.run(task) ok = False if (agent.hit_max_steps or agent.errored) else _verify(agent) events = merge_step_metadata(agent.trace_collector.snapshot(), agent.raw_history()) return FanoutResult( index=index, task=task, final=final, steps=steps, model=model, verified=bool(ok), files=agent.files(), trace_events=events, agent=agent, ) except Exception as e: return FanoutResult(index=index, task=task, final="", steps=[], model=model, verified=False, error=str(e)) finally: agent.cleanup() job_tasks = [ asyncio.create_task(_job(i, t, agents[i])) for i, t in enumerate(tasks) ] try: while not all(j.done() for j in job_tasks): # Mid-run we must NOT call current_steps()/history() on a live agent # (the Rust agent isn't reentrant and would deadlock). Read only the # trace collectors (plain lists) and workspace files (disk). events: list[TraceEvent] = [] all_files: dict[str, str] = {} for i, agent in enumerate(agents): events.extend(agent.trace_collector.snapshot()) for path, content in agent.files().items(): all_files[f"[{i + 1}] {path}"] = content yield LiveFrame(steps=[], events=events, files=all_files) await asyncio.sleep(poll_interval) results = [await j for j in job_tasks] results.sort(key=lambda r: r.index) yield LiveFrame(done=True, result=results) finally: for j in job_tasks: if not j.done(): j.cancel() def summarize(results: list[FanoutResult]) -> str: """Aggregate fan-out results into one labeled summary (mirrors the Rust output).""" out = [f"Ran {len(results)} subagents in parallel. Results:\n"] for r in results: head = f"=== [{r.index + 1}] {r.model} {'OK' if r.verified else 'unverified'} ===" body = r.error and f"error: {r.error}" or r.final.strip() out.append(f"{head}\n{body}\n") return "\n".join(out).rstrip()