| """LangGraph state machine for why-agent root-cause investigation. |
| |
| The graph orchestrates the six-phase loop: |
| plan β decompose β drill β cross_check β critique β report |
| β | |
| ββββββββββββ (if evidence weak) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import os |
| import re |
| from datetime import UTC, datetime |
| from typing import Literal |
|
|
| from langchain_core.messages import HumanMessage, SystemMessage |
| from langchain_core.tools import StructuredTool |
| from langgraph.graph import END, StateGraph |
|
|
| from agent.client import get_llm |
| from agent.prompts import _render_critique, _render_system |
| from agent.state import ( |
| EvidenceEntry, |
| Hypothesis, |
| InvestigationState, |
| Phase, |
| ToolResult, |
| ) |
| from agent.tools.schemas import ( |
| ComparePeriodsInput, |
| DecomposeMetricInput, |
| InspectSchemaInput, |
| RunSqlInput, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| MAX_RETRIES = 3 |
| MAX_TOOL_CALLS = 50 |
|
|
|
|
| def _iso_now() -> str: |
| return datetime.now(UTC).isoformat() |
|
|
|
|
| def _format_hypotheses(hypotheses: list[Hypothesis]) -> str: |
| if not hypotheses: |
| return "No hypotheses yet." |
| lines = [] |
| for h in hypotheses: |
| ev = ", ".join(h.supporting_evidence) or "none" |
| lines.append(f" [{h.id}] {h.description} (status={h.status}, supporting_evidence={ev})") |
| return "\n".join(lines) |
|
|
|
|
| def _format_evidence(evidence: list[EvidenceEntry], full_output: bool = False) -> str: |
| if not evidence: |
| return "No evidence collected yet." |
| lines = [] |
| for e in evidence: |
| err_tag = " [ERROR]" if "error" in e.output else "" |
| raw = str(e.output) |
| out_snippet = raw if full_output else raw[:400] |
| out_snippet = out_snippet.replace("{", "{{").replace("}", "}}") |
| lines.append(f" [{e.phase.value}] {e.tool_name}{err_tag}: {out_snippet}") |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| def _make_tool_wrapper(name: str): |
| def wrapper( |
| args: InspectSchemaInput | RunSqlInput | ComparePeriodsInput | DecomposeMetricInput, |
| ): |
| from agent.tools.run_sql import build_connection |
|
|
| conn = build_connection(os.getenv("PARQUET_DIR", "data/parquet")) |
| try: |
| if name == "inspect_schema": |
| from agent.tools.inspect_schema import inspect_schema as _fn |
|
|
| return _fn(args).model_dump() |
| elif name == "run_sql": |
| from agent.tools.run_sql import run_sql as _fn |
|
|
| return _fn(args, conn).model_dump() |
| elif name == "compare_periods": |
| from agent.tools.compare_periods import compare_periods as _fn |
|
|
| return _fn(args, conn).model_dump() |
| elif name == "decompose_metric": |
| from agent.tools.decompose_metric import decompose_metric as _fn |
|
|
| return _fn(args, conn).model_dump() |
| finally: |
| conn.close() |
|
|
| return wrapper |
|
|
|
|
| |
| |
| |
|
|
| _CACHED_TOOLS: list[StructuredTool] | None = None |
|
|
|
|
| def _get_tools(): |
| global _CACHED_TOOLS |
| if _CACHED_TOOLS is None: |
| _CACHED_TOOLS = [ |
| StructuredTool.from_function( |
| name="inspect_schema", |
| func=_make_tool_wrapper("inspect_schema"), |
| args_schema=InspectSchemaInput, |
| description="List tables (no arg) or describe one table (cols, types, business meaning).", |
| ), |
| StructuredTool.from_function( |
| name="run_sql", |
| func=_make_tool_wrapper("run_sql"), |
| args_schema=RunSqlInput, |
| description="Execute a read-only SELECT against DuckDB. Returns {rows, truncated, row_count, execution_ms}.", |
| ), |
| StructuredTool.from_function( |
| name="compare_periods", |
| func=_make_tool_wrapper("compare_periods"), |
| args_schema=ComparePeriodsInput, |
| description="Headline diff: by how much did metric change between two windows? Returns {before_value, after_value, abs_delta, pct_delta}.", |
| ), |
| StructuredTool.from_function( |
| name="decompose_metric", |
| func=_make_tool_wrapper("decompose_metric"), |
| args_schema=DecomposeMetricInput, |
| description="Drill-down: WHICH slice of metric drove the movement? Returns ranked slices by anomaly score.", |
| ), |
| ] |
| return _CACHED_TOOLS |
|
|
|
|
| |
| |
| |
|
|
|
|
| def execute_tools(state: InvestigationState) -> InvestigationState: |
| """Run every pending tool call and append an EvidenceEntry for each.""" |
| if not state.pending_tool_calls: |
| return state |
|
|
| from agent.tools.run_sql import build_connection |
|
|
| conn = build_connection(os.getenv("PARQUET_DIR", "data/parquet")) |
| try: |
| batch_reasoning = state.pending_reasoning |
| state.pending_reasoning = None |
| for tc in state.pending_tool_calls: |
| args = tc.args |
| tool_name = tc.tool_name |
| output: dict = {} |
| _t0 = datetime.now(UTC) |
| try: |
| if tool_name == "inspect_schema": |
| from agent.tools.inspect_schema import inspect_schema as _fn |
|
|
| inp = InspectSchemaInput(**args) |
| output = _fn(inp).model_dump() |
| elif tool_name == "run_sql": |
| from agent.tools.run_sql import run_sql as _fn |
|
|
| inp = RunSqlInput(**args) |
| output = _fn(inp, conn).model_dump() |
| elif tool_name == "compare_periods": |
| from agent.tools.compare_periods import compare_periods as _fn |
|
|
| inp = ComparePeriodsInput(**args) |
| output = _fn(inp, conn).model_dump() |
| elif tool_name == "decompose_metric": |
| from agent.tools.decompose_metric import decompose_metric as _fn |
|
|
| inp = DecomposeMetricInput(**args) |
| output = _fn(inp, conn).model_dump() |
| else: |
| output = { |
| "error": f"Unknown tool {tool_name!r}", |
| "hint": "Use one of: inspect_schema, run_sql, compare_periods, decompose_metric.", |
| } |
| except Exception as exc: |
| logger.warning("Tool %s raised (converted to dict): %s", tool_name, exc) |
| output = {"error": str(exc), "hint": "Retry with corrected arguments."} |
|
|
| |
| from langchain_core.messages import ToolMessage |
|
|
| tc.output = output |
| state.messages.append( |
| ToolMessage( |
| content=str(output), |
| tool_call_id=tc.args.get("_tool_call_id", ""), |
| ) |
| ) |
|
|
| entry = EvidenceEntry( |
| phase=state.phase, |
| tool_name=tool_name, |
| args=args, |
| output=output, |
| timestamp=_iso_now(), |
| reasoning=batch_reasoning, |
| duration_ms=(datetime.now(UTC) - _t0).total_seconds() * 1000, |
| ) |
| batch_reasoning = None |
| state.add_evidence(entry) |
|
|
| state.pending_tool_calls = [] |
| return state |
| finally: |
| conn.close() |
|
|
|
|
| |
| |
| |
|
|
|
|
| def llm_call(state: InvestigationState) -> InvestigationState: |
| """Send messages to the LLM; collect tool calls into pending_tool_calls.""" |
| llm = get_llm() |
|
|
| system_content = _render_system( |
| phase=state.phase.value, |
| hypotheses=_format_hypotheses(state.hypotheses), |
| evidence_summary=_format_evidence(state.evidence), |
| critique_feedback=state.critique_feedback, |
| ) |
|
|
| all_messages = [SystemMessage(content=system_content)] + list(state.messages) |
| if not any(isinstance(m, HumanMessage) for m in all_messages): |
| all_messages.append(HumanMessage(content=state.user_question)) |
| response = llm.bind_tools(_get_tools()).invoke(all_messages) |
|
|
| state.messages.append(response) |
|
|
| |
| |
| if isinstance(response.content, str): |
| raw_content = response.content |
| elif isinstance(response.content, list): |
| raw_content = " ".join( |
| block.get("text", "") |
| for block in response.content |
| if isinstance(block, dict) and block.get("type") == "text" |
| ) |
| else: |
| raw_content = "" |
| state.pending_reasoning = ( |
| re.sub(r"<think>.*?</think>", "", raw_content, flags=re.DOTALL).strip() or None |
| ) |
|
|
| |
| |
| |
| if state.phase == Phase.PLAN and state.question_type is None and state.pending_reasoning: |
| r = state.pending_reasoning.upper() |
| if "CROSS-SECTIONAL" in r or "CROSS_SECTIONAL" in r: |
| state.question_type = "CROSS_SECTIONAL" |
| elif "TIME-SERIES" in r or "TIME_SERIES" in r: |
| state.question_type = "TIME_SERIES" |
| elif "EXPLORATORY" in r: |
| state.question_type = "EXPLORATORY" |
|
|
| pending: list[ToolResult] = [] |
| for tc in response.tool_calls or []: |
| pending.append( |
| ToolResult( |
| tool_name=tc["name"], |
| args={**tc["args"], "_tool_call_id": tc.get("id", "")}, |
| output={}, |
| ) |
| ) |
| state.pending_tool_calls = pending |
|
|
| return state |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _make_phase_node(phase: Phase): |
| def node(state: InvestigationState) -> InvestigationState: |
| state.phase = phase |
| return llm_call(state) |
|
|
| return node |
|
|
|
|
| |
| |
| |
|
|
|
|
| def critique(state: InvestigationState) -> InvestigationState: |
| """Ask the LLM to evaluate evidence strength; decide loop or report.""" |
| state.phase = Phase.CRITIQUE |
| critique_prompt = _render_critique( |
| user_question=state.user_question, |
| hypotheses=_format_hypotheses(state.hypotheses), |
| evidence_summary=_format_evidence(state.evidence, full_output=True), |
| evidence_count=len(state.evidence), |
| retry_count=state.retry_count, |
| max_retries=MAX_RETRIES, |
| question_type=state.question_type, |
| ) |
|
|
| llm = get_llm() |
| response = llm.invoke([HumanMessage(content=critique_prompt)]) |
|
|
| text = response.content if isinstance(response.content, str) else str(response.content) |
| |
| |
| |
| |
| text = re.sub(r"</?think>", "", text).strip() |
| |
| |
| |
| |
| stripped_lines = [ln.strip() for ln in text.split("\n")] |
|
|
| |
| |
| |
| verdict_idx: int | None = None |
| verdict_word: str | None = None |
| for i, ln in enumerate(stripped_lines): |
| m = re.search(r"\bverdict\s*:\s*(\w+)", ln.lower().strip("* `")) |
| if m: |
| verdict_idx = i |
| verdict_word = m.group(1) |
| break |
|
|
| if verdict_idx is not None: |
| if verdict_word == "strong": |
| state.critique_passed = True |
| state.critique_feedback = None |
| else: |
| |
| justification_lines = [ln for ln in stripped_lines[verdict_idx + 1 :] if ln] |
| state.critique_feedback = " ".join(justification_lines) or None |
| state.critique_passed = False |
| state.retry_count += 1 |
| if state.retry_count >= MAX_RETRIES: |
| logger.warning("Max critique retries (%d) reached; forcing report.", MAX_RETRIES) |
| state.critique_passed = True |
| state.error = "Max critique retries reached. Evidence may be incomplete." |
| elif any( |
| |
| |
| |
| re.match(r"(the\s+)?evidence is strong|(the\s+)?proceed to report", ln.lower().strip()) |
| for ln in stripped_lines |
| ): |
| state.critique_passed = True |
| state.critique_feedback = None |
| else: |
| state.critique_passed = False |
| state.critique_feedback = None |
| state.retry_count += 1 |
| if state.retry_count >= MAX_RETRIES: |
| logger.warning("Max critique retries (%d) reached; forcing report.", MAX_RETRIES) |
| state.critique_passed = True |
| state.error = "Max critique retries reached. Evidence may be incomplete." |
|
|
| return state |
|
|
|
|
| |
| |
| |
|
|
|
|
| def report(state: InvestigationState) -> InvestigationState: |
| """Assemble and store the final report dict.""" |
| state.phase = Phase.REPORT |
| report_prompt = ( |
| f"You are writing the final report for an investigation.\n\n" |
| f"**User question:** {state.user_question}\n\n" |
| f"**Hypotheses considered:**\n{_format_hypotheses(state.hypotheses)}\n\n" |
| f"**Evidence (full tool outputs):**\n" |
| f"{_format_evidence(state.evidence, full_output=True)}\n\n" |
| f"---\n\n" |
| f"Write a concise structured report with the following sections. Do NOT " |
| f"recap every tool call β distill what mattered. Reference specific " |
| f"numbers from the evidence; do not invent any.\n\n" |
| f"**1. Headline answer.** 1β3 sentences in plain prose. Lead with the " |
| f"dominant driver as a quantified claim (e.g., '~85% of the headline " |
| f"gap is audience selection, not campaign quality'). Then give the " |
| f"supporting numbers β the controlled comparison and the residual β " |
| f"and what they mean. If the investigation reframed the user's " |
| f"question, say so here. If you use bold, reserve it for the leading " |
| f"driver claim only β never bold supporting or secondary " |
| f"conclusions.\n\n" |
| f"**2. Evidence chain.** 3β6 numbered steps showing how you reached the " |
| f"answer. Each step references specific numbers from the evidence above. " |
| f"Show the progression: from the headline observation, through the moves " |
| f"that ruled in or out alternatives, to the conclusion.\n\n" |
| f"**3. Quantified attribution.** When the question compares entities or " |
| f"periods, decompose the headline gap arithmetically:\n" |
| f" - Aggregate gap: <number>\n" |
| f" - On overlap / controlled comparison: <number>\n" |
| f" - Selection or composition effect: <number> (~X% of total)\n" |
| f" - Genuine effect: <number> (~Y% of total)\n" |
| f"For exploratory questions where attribution doesn't apply, replace " |
| f"this section with the direct answer and supporting numbers.\n\n" |
| f"**4. Residual unexplained.** What part of the observation remains " |
| f"unaccounted for, and why. Be specific about whether the residual is " |
| f"because the data doesn't contain the relevant information (e.g., " |
| f"actual subject text, audience targeting criteria, real-world events) " |
| f"or because the investigation didn't reach it. Do not invent causes " |
| f"for the residual.\n\n" |
| f"**5. Confidence.** high / medium / low β and one sentence on what " |
| f"would raise your confidence (data you don't have, queries you didn't " |
| f"run, etc.).\n\n" |
| f"**6. Next steps.** 3β5 concrete actions an analyst should take next. " |
| f"For each, name the action type (e.g., A/B test, instrumentation, " |
| f"data request, follow-up query, qualitative review), state what it " |
| f"would prove or rule out, and note any data needed beyond this " |
| f"dataset.\n\n" |
| f"Make sure to mention any hypotheses that were investigated and ruled " |
| f"out β that's part of showing rigor." |
| ) |
|
|
| llm = get_llm() |
| |
| response = llm.invoke( |
| [HumanMessage(content="Please answer."), HumanMessage(content=report_prompt)] |
| ) |
|
|
| state.final_report = { |
| "user_question": state.user_question, |
| "text": response.content, |
| "hypotheses": [h.model_dump() for h in state.hypotheses], |
| "evidence_count": len(state.evidence), |
| "critique_passed": state.critique_passed, |
| "error": state.error, |
| } |
| return state |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_graph(): |
| builder = StateGraph(InvestigationState) |
|
|
| builder.add_node("llm_call", llm_call) |
| builder.add_node("execute_tools", execute_tools) |
| builder.add_node("critique", critique) |
| builder.add_node("report", report) |
|
|
| for phase in [Phase.PLAN, Phase.DECOMPOSE, Phase.DRILL, Phase.CROSS_CHECK]: |
| builder.add_node(phase.value, _make_phase_node(phase)) |
|
|
| |
| |
| _phase_next: dict[Phase, str] = { |
| Phase.PLAN: "decompose", |
| Phase.DECOMPOSE: "drill", |
| Phase.DRILL: "cross_check", |
| Phase.CROSS_CHECK: "critique", |
| } |
|
|
| def route_after_llm(state: InvestigationState) -> str: |
| if len(state.evidence) >= MAX_TOOL_CALLS: |
| logger.warning("Tool call cap (%d) reached; forcing critique.", MAX_TOOL_CALLS) |
| return "critique" |
| if state.pending_tool_calls: |
| return "execute_tools" |
| return _phase_next.get(state.phase, "critique") |
|
|
| for phase in [Phase.PLAN, Phase.DECOMPOSE, Phase.DRILL, Phase.CROSS_CHECK]: |
| builder.add_conditional_edges(phase.value, route_after_llm) |
|
|
| builder.add_edge("execute_tools", "llm_call") |
| builder.add_conditional_edges("llm_call", route_after_llm) |
|
|
| def route_after_critique(state: InvestigationState) -> Literal["report", "decompose"]: |
| return "report" if state.critique_passed else "decompose" |
|
|
| builder.add_conditional_edges("critique", route_after_critique) |
| builder.add_edge("report", END) |
|
|
| builder.set_entry_point("plan") |
|
|
| return builder.compile() |
|
|