| """Hierarchical orchestrator using middleware and sub-teams.""" |
|
|
| import asyncio |
| from collections.abc import AsyncGenerator |
|
|
| import structlog |
|
|
| from src.agents.judge_agent_llm import LLMSubIterationJudge |
| from src.agents.magentic_agents import create_search_agent |
| from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam |
| from src.services.embeddings import get_embedding_service |
| from src.state import init_magentic_state |
| from src.utils.models import AgentEvent |
|
|
| logger = structlog.get_logger() |
|
|
|
|
| class ResearchTeam(SubIterationTeam): |
| """Adapts Magentic ChatAgent to SubIterationTeam protocol.""" |
|
|
| def __init__(self) -> None: |
| self.agent = create_search_agent() |
|
|
| async def execute(self, task: str) -> str: |
| response = await self.agent.run(task) |
| if response.messages: |
| for msg in reversed(response.messages): |
| if msg.role == "assistant" and msg.text: |
| return str(msg.text) |
| return "No response from agent." |
|
|
|
|
| class HierarchicalOrchestrator: |
| """Orchestrator that uses hierarchical teams and sub-iterations.""" |
|
|
| def __init__(self) -> None: |
| self.team = ResearchTeam() |
| self.judge = LLMSubIterationJudge() |
| self.middleware = SubIterationMiddleware(self.team, self.judge, max_iterations=5) |
|
|
| async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: |
| logger.info("Starting hierarchical orchestrator", query=query) |
|
|
| try: |
| service = get_embedding_service() |
| init_magentic_state(service) |
| except Exception as e: |
| logger.warning( |
| "Embedding service initialization failed, using default state", |
| error=str(e), |
| ) |
| init_magentic_state() |
|
|
| yield AgentEvent(type="started", message=f"Starting research: {query}") |
|
|
| queue: asyncio.Queue[AgentEvent | None] = asyncio.Queue() |
|
|
| async def event_callback(event: AgentEvent) -> None: |
| await queue.put(event) |
|
|
| task_future = asyncio.create_task(self.middleware.run(query, event_callback)) |
|
|
| while not task_future.done(): |
| get_event = asyncio.create_task(queue.get()) |
| done, _ = await asyncio.wait( |
| {task_future, get_event}, return_when=asyncio.FIRST_COMPLETED |
| ) |
|
|
| if get_event in done: |
| event = get_event.result() |
| if event: |
| yield event |
| else: |
| get_event.cancel() |
|
|
| |
| while not queue.empty(): |
| ev = queue.get_nowait() |
| if ev: |
| yield ev |
|
|
| try: |
| result, assessment = await task_future |
|
|
| assessment_text = assessment.reasoning if assessment else "None" |
| yield AgentEvent( |
| type="complete", |
| message=( |
| f"Research complete.\n\nResult:\n{result}\n\nAssessment:\n{assessment_text}" |
| ), |
| data={"assessment": assessment.model_dump() if assessment else None}, |
| ) |
| except Exception as e: |
| logger.error("Orchestrator failed", error=str(e)) |
| yield AgentEvent(type="error", message=f"Orchestrator failed: {e}") |
|
|