ThejasRao's picture
Deploy ModPilot Investigation Engine
7302343
Raw
History Blame Contribute Delete
8.4 kB
"""Orchestrator loop — runs tools per tier plan, enforces budgets, early-stops.
Spec: docs/04-InvestigationEngine.md §3, docs/Specs.md §7.2.
For each tier the Strategy Selector picks a plan (FAST=2 tools, STANDARD=4,
DEEP=5+ with potential re-loop). The Orchestrator iterates the plan; each
iteration checks budgets *before* running, invokes the tool with isolated
exception handling, appends the result to the Evidence Accumulator, then
asks the convergence policy whether enough signal has accumulated to stop.
The Orchestrator owns:
- tier → tool plan mapping (overridable per call for testing)
- budget enforcement (tool count + wall-clock time)
- per-tool exception isolation: a single tool throwing must not blow
up the investigation; we record a `failure` ToolResult and keep going
- convergence policy: 1 strong-signal success on FAST, 2 on STANDARD/DEEP
- structured logging at start, per-tool, and stop
It does NOT decide the verdict — that's the Reasoner. It does NOT persist —
that's the API handler wrapping `run()` with a session.
"""
from __future__ import annotations
import time
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
import structlog
from orchestrator.tools import (
EvidenceAccumulator,
ToolContext,
ToolName,
ToolRegistry,
ToolResult,
)
if TYPE_CHECKING:
from collections.abc import Callable
from orchestrator.strategy import StrategyDecision
logger = structlog.get_logger(__name__)
# === Result =============================================================
StopReason = str # "plan_complete" | "converged" | "budget_tool" | "budget_time"
@dataclass(frozen=True)
class OrchestratorResult:
correlation_id: str
subreddit_id: str
tier: str
accumulator: EvidenceAccumulator
started_at: datetime
completed_at: datetime
total_latency_ms: int
tools_run: int
early_stopped: bool
stop_reason: StopReason
plan: list[ToolName] = field(default_factory=list)
# === Default tier plans =================================================
# Specs §7.1 budget table; STANDARD picks the 4-tool variant pending the
# open question about 4-vs-5.
_DEFAULT_PLANS: dict[str, list[ToolName]] = {
"FAST": ["policy_match", "report_velocity"],
"STANDARD": ["policy_match", "report_velocity", "user_history", "prior_actions"],
"DEEP": [
"policy_match",
"report_velocity",
"user_history",
"prior_actions",
"thread_context",
],
}
# === Orchestrator =======================================================
class Orchestrator:
"""Stateless coordinator. Safe to share across requests."""
def __init__(
self,
registry: ToolRegistry,
*,
# `perf_counter`-style clock injected so tests can drive elapsed
# time deterministically. Returns seconds.
clock: Callable[[], float] = time.perf_counter,
) -> None:
self._registry = registry
self._clock = clock
def default_plan(self, tier: str) -> list[ToolName]:
try:
return list(_DEFAULT_PLANS[tier])
except KeyError as e:
raise ValueError(f"no default plan for tier {tier!r}") from e
async def run(
self,
*,
decision: StrategyDecision,
context: ToolContext,
plan: list[ToolName] | None = None,
) -> OrchestratorResult:
chosen_plan = plan if plan is not None else self.default_plan(decision.tier)
accumulator = EvidenceAccumulator()
log = logger.bind(
correlation_id=context.correlation_id,
subreddit_id=context.subreddit_id,
tier=decision.tier,
plan_length=len(chosen_plan),
)
log.info("orchestrator.started", plan=chosen_plan)
started_at = datetime.now(UTC)
started_perf = self._clock()
tools_run = 0
stop_reason: StopReason = "plan_complete"
early_stopped = False
for tool_name in chosen_plan:
# 1. Budget pre-check.
elapsed_ms = int((self._clock() - started_perf) * 1000)
if elapsed_ms >= decision.time_budget_ms:
stop_reason = "budget_time"
early_stopped = True
log.info("orchestrator.stop", reason=stop_reason, elapsed_ms=elapsed_ms)
break
if tools_run >= decision.tool_budget:
stop_reason = "budget_tool"
early_stopped = True
log.info("orchestrator.stop", reason=stop_reason, tools_run=tools_run)
break
# 2. Resolve tool. Unregistered → record skip, keep going.
if not self._registry.has(tool_name):
accumulator.append(
ToolResult(
tool=tool_name,
status="skipped",
summary=f"tool {tool_name!r} not registered",
latency_ms=0,
)
)
tools_run += 1
log.warning("orchestrator.tool.unregistered", tool=tool_name)
continue
# 3. Run the tool with isolated exception handling.
tool = self._registry.get(tool_name)
tool_started = self._clock()
try:
result = await tool.run(context)
except Exception as exc:
latency = int((self._clock() - tool_started) * 1000)
result = ToolResult(
tool=tool_name,
status="failure",
summary=f"tool raised: {type(exc).__name__}",
latency_ms=latency,
error=str(exc),
)
log.warning(
"orchestrator.tool.raised",
tool=tool_name,
exc_type=type(exc).__name__,
)
accumulator.append(result)
tools_run += 1
log.info(
"orchestrator.tool.completed",
tool=tool_name,
status=result.status,
latency_ms=result.latency_ms,
)
# 4. Convergence check — stop early if we have enough signal.
if _converged(accumulator, decision.tier):
stop_reason = "converged"
early_stopped = True
log.info(
"orchestrator.stop",
reason=stop_reason,
tools_run=tools_run,
successful_evidence=len(accumulator.successful_entries()),
)
break
completed_at = datetime.now(UTC)
total_latency_ms = int((self._clock() - started_perf) * 1000)
log.info(
"orchestrator.completed",
tools_run=tools_run,
stop_reason=stop_reason,
total_latency_ms=total_latency_ms,
)
return OrchestratorResult(
correlation_id=context.correlation_id,
subreddit_id=context.subreddit_id,
tier=decision.tier,
accumulator=accumulator,
started_at=started_at,
completed_at=completed_at,
total_latency_ms=total_latency_ms,
tools_run=tools_run,
early_stopped=early_stopped,
stop_reason=stop_reason,
plan=chosen_plan,
)
# === Convergence policy =================================================
def _converged(acc: EvidenceAccumulator, tier: str) -> bool:
"""A simple, explainable convergence rule.
For FAST tier: 1 strong-signal success is enough.
For STANDARD / DEEP: require 2 strong signals — we want the Reasoner
to have at least two corroborating evidence rows before short-circuiting.
A "strong signal" is a successful tool result whose `detail.signal` is
`"high"`. Tools self-report this. If no tool has set it, convergence
never triggers and the plan runs to its budget.
"""
threshold = 1 if tier == "FAST" else 2
strong = sum(
1
for entry in acc.successful_entries()
if entry.detail.get("signal") == "high"
)
return strong >= threshold
__all__ = [
"Orchestrator",
"OrchestratorResult",
"StopReason",
]