VibecoderMcSwaggins's picture
fix(tests): Address CodeRabbit review feedback
c9c58c4
raw
history blame
23.6 kB
"""
Advanced Orchestrator using Microsoft Agent Framework.
This orchestrator uses the ChatAgent pattern from Microsoft's agent-framework-core
package for multi-agent coordination. It provides richer orchestration capabilities
including specialized agents (Search, Hypothesis, Judge, Report) coordinated by
a manager agent.
Note: Previously named 'orchestrator_magentic.py' - renamed to eliminate confusion
with the 'magentic' PyPI package (which is a different library).
Design Patterns:
- Mediator: Manager agent coordinates between specialized agents
- Strategy: Different agents implement different strategies for their tasks
- Observer: Event stream allows UI to observe progress
"""
import asyncio
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import structlog
from agent_framework import (
MAGENTIC_EVENT_TYPE_ORCHESTRATOR,
ORCH_MSG_KIND_INSTRUCTION,
ORCH_MSG_KIND_TASK_LEDGER,
AgentRunUpdateEvent,
ChatAgent,
ExecutorCompletedEvent,
MagenticBuilder,
WorkflowOutputEvent,
)
from src.agents.magentic_agents import (
create_hypothesis_agent,
create_judge_agent,
create_report_agent,
create_search_agent,
)
from src.agents.state import get_magentic_state, init_magentic_state
from src.clients.base import BaseChatClient
from src.clients.factory import get_chat_client
from src.config.domain import ResearchDomain, get_domain_config
from src.orchestrators.base import OrchestratorProtocol
from src.utils.config import settings
from src.utils.models import AgentEvent
from src.utils.service_loader import get_embedding_service_if_available
if TYPE_CHECKING:
from src.services.embedding_protocol import EmbeddingServiceProtocol
logger = structlog.get_logger()
# Agent ID constants - prevents silent breakage if agents are renamed
REPORTER_AGENT_ID = "reporter"
SEARCHER_AGENT_ID = "searcher"
JUDGE_AGENT_ID = "judge"
HYPOTHESIZER_AGENT_ID = "hypothesizer"
@dataclass
class WorkflowState:
"""Tracks mutable state during workflow execution."""
iteration: int = 0
reporter_ran: bool = False
current_message_buffer: str = ""
current_agent_id: str | None = None
last_streamed_length: int = 0
final_event_received: bool = False
class AdvancedOrchestrator(OrchestratorProtocol):
"""
Advanced orchestrator using Microsoft Agent Framework ChatAgent pattern.
Each agent has an internal LLM that understands natural language
instructions from the manager and can call tools appropriately.
This orchestrator provides:
- Multi-agent coordination (Search, Hypothesis, Judge, Report)
- Manager agent for workflow orchestration
- Streaming events for real-time UI updates
- Configurable timeouts and round limits
"""
def __init__(
self,
max_rounds: int = 5,
chat_client: BaseChatClient | None = None,
provider: str | None = None,
api_key: str | None = None,
domain: ResearchDomain | str | None = None,
timeout_seconds: float | None = None,
) -> None:
"""Initialize the advanced orchestrator.
Args:
max_rounds: Maximum number of coordination rounds.
chat_client: Optional pre-configured chat client.
provider: Optional provider override ("openai", "huggingface").
api_key: Optional API key override.
domain: Research domain for customization.
timeout_seconds: Optional timeout override (defaults to settings).
"""
self._max_rounds = max_rounds
self.domain = domain or ResearchDomain.SEXUAL_HEALTH
self.domain_config = get_domain_config(self.domain)
self._timeout_seconds = timeout_seconds or settings.advanced_timeout
self.logger = logger.bind(orchestrator="advanced")
# Use provided client or create one via factory
self._chat_client = chat_client or get_chat_client(
provider=provider,
api_key=api_key,
)
# Store API key for service initialization
self._api_key = api_key
# Event stream for UI updates
self._events: list[AgentEvent] = []
# Initialize services lazily
self._embedding_service: EmbeddingServiceProtocol | None = None
# Track execution statistics
self.stats = {
"rounds": 0,
"searches": 0,
"hypotheses": 0,
"reports": 0,
"errors": 0,
}
def _init_embedding_service(self) -> "EmbeddingServiceProtocol | None":
"""Initialize embedding service if available."""
return get_embedding_service_if_available(api_key=self._api_key)
def _build_workflow(self) -> Any:
"""Build the workflow with ChatAgent participants."""
# Create agents with internal LLMs
search_agent = create_search_agent(self._chat_client, domain=self.domain)
judge_agent = create_judge_agent(self._chat_client, domain=self.domain)
hypothesis_agent = create_hypothesis_agent(self._chat_client, domain=self.domain)
report_agent = create_report_agent(self._chat_client, domain=self.domain)
# Manager chat client (orchestrates the agents)
manager_client = self._chat_client
manager_agent = ChatAgent(chat_client=manager_client)
return (
MagenticBuilder()
.participants(
searcher=search_agent,
hypothesizer=hypothesis_agent,
judge=judge_agent,
reporter=report_agent,
)
.with_standard_manager(
agent=manager_agent,
max_round_count=self._max_rounds,
max_stall_count=3,
max_reset_count=2,
)
.build()
)
def _create_task_prompt(self, query: str) -> str:
"""Create the initial task prompt for the manager agent."""
return f"""Research {self.domain_config.report_focus} for: {query}
## CRITICAL RULE
When JudgeAgent says "SUFFICIENT EVIDENCE" or "STOP SEARCHING":
β†’ IMMEDIATELY delegate to ReportAgent for synthesis
β†’ Do NOT continue searching or gathering more evidence
β†’ The Judge has determined evidence quality is adequate
## Standard Workflow
1. SearchAgent: Find evidence from PubMed, ClinicalTrials.gov, and Europe PMC
2. HypothesisAgent: Generate mechanistic hypotheses (Drug -> Target -> Pathway -> Effect)
3. JudgeAgent: Evaluate if evidence is sufficient
4. If insufficient -> SearchAgent refines search based on gaps
5. If sufficient -> ReportAgent synthesizes final report
Focus on:
- Identifying specific molecular targets
- Understanding mechanism of action
- Finding clinical evidence supporting hypotheses
The final output should be a structured research report."""
def _get_agent_semantic_name(self, agent_id: str) -> str:
"""Map internal agent ID to user-facing semantic name."""
name = agent_id.lower()
if SEARCHER_AGENT_ID in name:
return "SearchAgent"
if JUDGE_AGENT_ID in name:
return "JudgeAgent"
if HYPOTHESIZER_AGENT_ID in name:
return "HypothesisAgent"
if REPORTER_AGENT_ID in name:
return "ReportAgent"
return "ManagerAgent"
async def _init_workflow_events(self, query: str) -> AsyncGenerator[AgentEvent, None]:
"""Yield initialization events."""
yield AgentEvent(
type="started",
message=f"Starting research (Advanced mode): {query}",
iteration=0,
)
yield AgentEvent(
type="progress",
message="Loading embedding service (LlamaIndex/ChromaDB)...",
iteration=0,
)
async def _synthesize_fallback(
self,
iteration: int,
reason: str,
) -> AsyncGenerator[AgentEvent, None]:
"""
Unified fallback synthesis for all termination scenarios.
This method handles synthesis when the workflow terminates without
a proper report from ReportAgent. It's a safety net for:
- Timeout scenarios
- Manager model failing to delegate to ReportAgent (7B model limitation)
- Max rounds reached without synthesis
Args:
iteration: Current workflow iteration count
reason: Why synthesis is being forced ("timeout", "no_reporter", "max_rounds")
"""
status_messages = {
"timeout": "Workflow timed out. Synthesizing available evidence...",
"no_reporter": "Synthesizing research findings...",
"max_rounds": "Max rounds reached. Synthesizing findings...",
}
try:
state = get_magentic_state()
evidence_summary = await state.memory.get_context_summary()
report_agent = create_report_agent(self._chat_client, domain=self.domain)
yield AgentEvent(
type="synthesizing",
message=status_messages.get(reason, "Synthesizing..."),
iteration=iteration,
)
synthesis_result = await report_agent.run(
"Synthesize research report from this evidence. "
f"If evidence is sparse, say so.\n\n{evidence_summary}"
)
yield AgentEvent(
type="complete",
message=synthesis_result.text,
data={"reason": f"{reason}_synthesis", "iterations": iteration},
iteration=iteration,
)
except Exception as synth_error:
logger.error("Fallback synthesis failed", reason=reason, error=str(synth_error))
yield AgentEvent(
type="complete",
message=f"Research completed. Synthesis failed: {synth_error}",
data={"reason": f"{reason}_synthesis_failed", "iterations": iteration},
iteration=iteration,
)
async def run( # noqa: PLR0915 - Complex but necessary for event stream handling
self,
query: str,
) -> AsyncGenerator[AgentEvent, None]:
"""
Run the workflow.
Args:
query: User's research question
Yields:
AgentEvent objects for real-time UI updates
"""
logger.info("Starting Advanced orchestrator", query=query)
async for event in self._init_workflow_events(query):
yield event
# Initialize context state
embedding_service = self._init_embedding_service()
yield AgentEvent(
type="progress",
message="Initializing research memory...",
iteration=0,
)
init_magentic_state(query, embedding_service)
yield AgentEvent(
type="progress",
message="Building agent team (Search, Judge, Hypothesis, Report)...",
iteration=0,
)
workflow = self._build_workflow()
task = self._create_task_prompt(query)
# UX FIX: Yield thinking state before blocking workflow call
# The workflow.run_stream() blocks for 2+ minutes on first LLM call
yield AgentEvent(
type="thinking",
message=(
f"Multi-agent reasoning in progress (Limit: {self._max_rounds} Manager rounds)... "
"Allocating time for deep research..."
),
iteration=0,
)
state = WorkflowState()
try:
async with asyncio.timeout(self._timeout_seconds):
async for event in workflow.run_stream(task):
# 1. Handle Streaming (Source of Truth for Content)
if isinstance(event, AgentRunUpdateEvent) and event.data:
# Check metadata to filter internal orchestrator messages
props = getattr(event.data, "additional_properties", None) or {}
event_type = props.get("magentic_event_type")
msg_kind = props.get("orchestrator_message_kind")
# Filter out internal orchestrator messages (task_ledger, instruction)
if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR:
if msg_kind in (ORCH_MSG_KIND_TASK_LEDGER, ORCH_MSG_KIND_INSTRUCTION):
continue # Skip internal coordination messages
author = getattr(event.data, "author_name", None)
# Detect agent switch to clear buffer
if author != state.current_agent_id:
state.current_message_buffer = ""
state.current_agent_id = author
text = getattr(event.data, "text", None)
if text:
state.current_message_buffer += text
yield AgentEvent(
type="streaming",
message=text,
data={"agent_id": author},
iteration=state.iteration,
)
continue
# 2. Handle Completion Signal
if isinstance(event, ExecutorCompletedEvent):
# Internal state tracking only - NO UI events
# P1 FIX: Track if ReportAgent produced output
agent_name = getattr(event, "executor_id", "") or "unknown"
if REPORTER_AGENT_ID in agent_name.lower():
state.reporter_ran = True
# P2 BUG FIX: Save length before clearing
state.last_streamed_length = len(state.current_message_buffer)
# Clear buffer after consuming
state.current_message_buffer = ""
continue
# 3. Handle Final Events Inline (P2 Duplicate Report Fix + P1 Forced Synthesis)
if isinstance(event, WorkflowOutputEvent):
if state.final_event_received:
continue # Skip duplicate final events
state.final_event_received = True
# P1 FIX: Force synthesis if ReportAgent never ran
if not state.reporter_ran:
logger.warning(
"ReportAgent never ran - forcing synthesis",
iterations=state.iteration,
)
async for synth_event in self._synthesize_fallback(
state.iteration, "no_reporter"
):
yield synth_event
else:
yield self._handle_final_event(
event, state.iteration, state.last_streamed_length
)
continue
# 4. Handle other events normally
agent_event = self._process_event(event, state.iteration)
if agent_event:
yield agent_event
# GUARANTEE: Always emit termination event if stream ends without one
# (e.g., max rounds reached)
if not state.final_event_received:
logger.warning(
"Workflow ended without final event",
iterations=state.iteration,
)
# P1 FIX: Force synthesis if ReportAgent never ran
if not state.reporter_ran:
async for synth_event in self._synthesize_fallback(
state.iteration, "max_rounds"
):
yield synth_event
else:
yield AgentEvent(
type="complete",
message=(
f"Research completed after {state.iteration} agent rounds. "
"Max iterations reached - results may be partial. "
"Try a more specific query for better results."
),
data={
"iterations": state.iteration,
"reason": "max_rounds_reached",
},
iteration=state.iteration,
)
except TimeoutError:
async for event in self._synthesize_fallback(state.iteration, "timeout"):
yield event
except Exception as e:
logger.error("Workflow failed", error=str(e))
yield AgentEvent(
type="error",
message=f"Workflow error: {e!s}",
iteration=state.iteration,
)
def _handle_final_event(
self,
event: WorkflowOutputEvent,
iteration: int,
last_streamed_length: int,
) -> AgentEvent:
"""Handle final workflow events with duplicate content suppression (P2 Bug Fix)."""
# DECISION: Did we stream substantial content?
if last_streamed_length > 100:
# YES: Final event is a SIGNAL, not a payload
return AgentEvent(
type="complete",
message="Research complete.",
data={
"iterations": iteration,
"streamed_chars": last_streamed_length,
},
iteration=iteration,
)
# NO: Final event must carry the payload (tool-only turn, cache hit)
text = self._extract_text(event.data) if event.data else "Research complete"
return AgentEvent(
type="complete",
message=text,
data={"iterations": iteration},
iteration=iteration,
)
def _extract_text(self, message: Any) -> str:
"""
Defensively extract text from a message object.
Handles ChatMessage objects from both OpenAI and HuggingFace clients.
ChatMessage has: .text (str), .contents (list of content objects)
Also handles plain string messages (e.g., WorkflowOutputEvent.data).
"""
if not message:
return ""
# Priority 0: Handle plain string messages (e.g., WorkflowOutputEvent.data)
if isinstance(message, str):
# Filter out obvious repr-style noise
if not (message.startswith("<") and "object at" in message):
return message
return ""
# Priority 1: .text (standard ChatMessage text content)
if hasattr(message, "text") and message.text:
text = message.text
# Verify it's actually a string, not the object itself
if isinstance(text, str) and not (text.startswith("<") and "object at" in text):
return text
# Priority 2: .contents (list of FunctionCallContent, TextContent, etc.)
# This handles tool call responses from HuggingFace
if hasattr(message, "contents") and message.contents:
parts = []
for content in message.contents:
# TextContent has .text
if hasattr(content, "text") and content.text:
parts.append(str(content.text))
# FunctionCallContent has .name and .arguments
elif hasattr(content, "name"):
parts.append(f"[Tool: {content.name}]")
if parts:
return " ".join(parts)
# Priority 3: .content (legacy - some frameworks use singular)
if hasattr(message, "content") and message.content:
content = message.content
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join([str(c.text) for c in content if hasattr(c, "text")])
# Fallback: Return empty string instead of repr
# The repr is useless for display purposes
return ""
def _smart_truncate(self, text: str, max_len: int = 200) -> str:
"""Truncate at sentence boundary to avoid cutting words."""
if len(text) <= max_len:
return text
# Find last sentence boundary before limit
truncated = text[:max_len]
last_period = truncated.rfind(". ")
if last_period > max_len // 2:
return truncated[: last_period + 1]
# Fallback to word boundary
return truncated.rsplit(" ", 1)[0] + "..."
def _process_event(self, event: Any, iteration: int) -> AgentEvent | None:
"""Process workflow event into AgentEvent."""
# Handle orchestrator messages (formerly MagenticOrchestratorMessageEvent)
# We check the event type string directly
if getattr(event, "type", "") == MAGENTIC_EVENT_TYPE_ORCHESTRATOR:
kind = getattr(event, "kind", "")
message = getattr(event, "message", "")
# FILTERING: Skip internal framework bookkeeping
if kind in (ORCH_MSG_KIND_TASK_LEDGER, ORCH_MSG_KIND_INSTRUCTION):
return None
# TRANSFORMATION: Handle user_task BEFORE text extraction
# (user_task uses static message, doesn't need text content)
if kind == "user_task":
return AgentEvent(
type="progress",
message="Manager assigning research task to agents...",
iteration=iteration,
)
# For other manager events, extract and validate text
text = self._extract_text(message)
if not text:
return None
# Default fallback for other manager events
return AgentEvent(
type="judging",
message=f"Manager ({kind}): {self._smart_truncate(text)}",
iteration=iteration,
)
# NOTE: The following event types are handled inline in run() loop and never reach
# this method due to `continue` statements:
# - ExecutorCompletedEvent: Accumulator Pattern
# - AgentRunUpdateEvent: Accumulator Pattern
# - WorkflowOutputEvent: P2 Duplicate Fix via _handle_final_event()
return None
def _create_deprecated_alias() -> type["AdvancedOrchestrator"]:
"""Create a deprecated alias that warns on use."""
import warnings
class MagenticOrchestrator(AdvancedOrchestrator):
"""Deprecated alias for AdvancedOrchestrator.
.. deprecated:: 0.1.0
Use :class:`AdvancedOrchestrator` instead.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize deprecated MagenticOrchestrator (use AdvancedOrchestrator)."""
warnings.warn(
"MagenticOrchestrator is deprecated, use AdvancedOrchestrator instead. "
"The name 'magentic' was confusing with the 'magentic' PyPI package.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
return MagenticOrchestrator
# Backwards compatibility alias with deprecation warning
MagenticOrchestrator = _create_deprecated_alias()