why-agent / agent /graph.py
MapoTofu9's picture
deploy: HF Spaces
5d30bdc
Raw
History Blame Contribute Delete
21 kB
"""LangGraph state machine for why-agent root-cause investigation.
The graph orchestrates the six-phase loop:
plan β†’ decompose β†’ drill β†’ cross_check β†’ critique β†’ report
↑ |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ (if evidence weak)
"""
from __future__ import annotations
import logging
import os
import re
from datetime import UTC, datetime
from typing import Literal
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import StructuredTool
from langgraph.graph import END, StateGraph
from agent.client import get_llm
from agent.prompts import _render_critique, _render_system
from agent.state import (
EvidenceEntry,
Hypothesis,
InvestigationState,
Phase,
ToolResult,
)
from agent.tools.schemas import (
ComparePeriodsInput,
DecomposeMetricInput,
InspectSchemaInput,
RunSqlInput,
)
logger = logging.getLogger(__name__)
MAX_RETRIES = 3
MAX_TOOL_CALLS = 50 # hard cap across all phases β€” prevents infinite tool-call loops
def _iso_now() -> str:
return datetime.now(UTC).isoformat()
def _format_hypotheses(hypotheses: list[Hypothesis]) -> str:
if not hypotheses:
return "No hypotheses yet."
lines = []
for h in hypotheses:
ev = ", ".join(h.supporting_evidence) or "none"
lines.append(f" [{h.id}] {h.description} (status={h.status}, supporting_evidence={ev})")
return "\n".join(lines)
def _format_evidence(evidence: list[EvidenceEntry], full_output: bool = False) -> str:
if not evidence:
return "No evidence collected yet."
lines = []
for e in evidence:
err_tag = " [ERROR]" if "error" in e.output else ""
raw = str(e.output)
out_snippet = raw if full_output else raw[:400]
out_snippet = out_snippet.replace("{", "{{").replace("}", "}}")
lines.append(f" [{e.phase.value}] {e.tool_name}{err_tag}: {out_snippet}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# Tool wrappers β€” each manages its own DuckDB connection so StructuredTool
# schema generation (which inspects function signatures) never sees conn
# ---------------------------------------------------------------------------
def _make_tool_wrapper(name: str):
def wrapper(
args: InspectSchemaInput | RunSqlInput | ComparePeriodsInput | DecomposeMetricInput,
):
from agent.tools.run_sql import build_connection
conn = build_connection(os.getenv("PARQUET_DIR", "data/parquet"))
try:
if name == "inspect_schema":
from agent.tools.inspect_schema import inspect_schema as _fn
return _fn(args).model_dump() # type: ignore[arg-type]
elif name == "run_sql":
from agent.tools.run_sql import run_sql as _fn
return _fn(args, conn).model_dump() # type: ignore[arg-type]
elif name == "compare_periods":
from agent.tools.compare_periods import compare_periods as _fn
return _fn(args, conn).model_dump() # type: ignore[arg-type]
elif name == "decompose_metric":
from agent.tools.decompose_metric import decompose_metric as _fn
return _fn(args, conn).model_dump() # type: ignore[arg-type]
finally:
conn.close()
return wrapper
# ---------------------------------------------------------------------------
# Cached tool definitions β€” built once so the LLM sees stable schemas
# ---------------------------------------------------------------------------
_CACHED_TOOLS: list[StructuredTool] | None = None
def _get_tools(): # type: ignore[reportReturnType]
global _CACHED_TOOLS
if _CACHED_TOOLS is None:
_CACHED_TOOLS = [
StructuredTool.from_function(
name="inspect_schema",
func=_make_tool_wrapper("inspect_schema"),
args_schema=InspectSchemaInput,
description="List tables (no arg) or describe one table (cols, types, business meaning).",
),
StructuredTool.from_function(
name="run_sql",
func=_make_tool_wrapper("run_sql"),
args_schema=RunSqlInput,
description="Execute a read-only SELECT against DuckDB. Returns {rows, truncated, row_count, execution_ms}.",
),
StructuredTool.from_function(
name="compare_periods",
func=_make_tool_wrapper("compare_periods"),
args_schema=ComparePeriodsInput,
description="Headline diff: by how much did metric change between two windows? Returns {before_value, after_value, abs_delta, pct_delta}.",
),
StructuredTool.from_function(
name="decompose_metric",
func=_make_tool_wrapper("decompose_metric"),
args_schema=DecomposeMetricInput,
description="Drill-down: WHICH slice of metric drove the movement? Returns ranked slices by anomaly score.",
),
]
return _CACHED_TOOLS
# ---------------------------------------------------------------------------
# Tool executor node
# ---------------------------------------------------------------------------
def execute_tools(state: InvestigationState) -> InvestigationState:
"""Run every pending tool call and append an EvidenceEntry for each."""
if not state.pending_tool_calls:
return state
from agent.tools.run_sql import build_connection
conn = build_connection(os.getenv("PARQUET_DIR", "data/parquet"))
try:
batch_reasoning = state.pending_reasoning
state.pending_reasoning = None
for tc in state.pending_tool_calls:
args = tc.args
tool_name = tc.tool_name
output: dict = {}
_t0 = datetime.now(UTC)
try:
if tool_name == "inspect_schema":
from agent.tools.inspect_schema import inspect_schema as _fn
inp = InspectSchemaInput(**args)
output = _fn(inp).model_dump()
elif tool_name == "run_sql":
from agent.tools.run_sql import run_sql as _fn
inp = RunSqlInput(**args)
output = _fn(inp, conn).model_dump()
elif tool_name == "compare_periods":
from agent.tools.compare_periods import compare_periods as _fn
inp = ComparePeriodsInput(**args)
output = _fn(inp, conn).model_dump()
elif tool_name == "decompose_metric":
from agent.tools.decompose_metric import decompose_metric as _fn
inp = DecomposeMetricInput(**args)
output = _fn(inp, conn).model_dump()
else:
output = {
"error": f"Unknown tool {tool_name!r}",
"hint": "Use one of: inspect_schema, run_sql, compare_periods, decompose_metric.",
}
except Exception as exc:
logger.warning("Tool %s raised (converted to dict): %s", tool_name, exc)
output = {"error": str(exc), "hint": "Retry with corrected arguments."}
# Add ToolMessage so the LLM sees the result in the next turn.
from langchain_core.messages import ToolMessage
tc.output = output
state.messages.append(
ToolMessage(
content=str(output),
tool_call_id=tc.args.get("_tool_call_id", ""),
)
)
entry = EvidenceEntry(
phase=state.phase,
tool_name=tool_name,
args=args,
output=output,
timestamp=_iso_now(),
reasoning=batch_reasoning,
duration_ms=(datetime.now(UTC) - _t0).total_seconds() * 1000,
)
batch_reasoning = None # only attach to the first call in the batch
state.add_evidence(entry)
state.pending_tool_calls = []
return state
finally:
conn.close()
# ---------------------------------------------------------------------------
# LLM call node
# ---------------------------------------------------------------------------
def llm_call(state: InvestigationState) -> InvestigationState:
"""Send messages to the LLM; collect tool calls into pending_tool_calls."""
llm = get_llm()
system_content = _render_system(
phase=state.phase.value,
hypotheses=_format_hypotheses(state.hypotheses),
evidence_summary=_format_evidence(state.evidence),
critique_feedback=state.critique_feedback,
)
all_messages = [SystemMessage(content=system_content)] + list(state.messages)
if not any(isinstance(m, HumanMessage) for m in all_messages):
all_messages.append(HumanMessage(content=state.user_question))
response = llm.bind_tools(_get_tools()).invoke(all_messages)
state.messages.append(response)
# Capture the LLM's text reasoning for display alongside the next tool calls.
# response.content may be a string or a list of content blocks (OpenAI-compatible APIs).
if isinstance(response.content, str):
raw_content = response.content
elif isinstance(response.content, list):
raw_content = " ".join(
block.get("text", "")
for block in response.content
if isinstance(block, dict) and block.get("type") == "text"
)
else:
raw_content = ""
state.pending_reasoning = (
re.sub(r"<think>.*?</think>", "", raw_content, flags=re.DOTALL).strip() or None
)
# Capture question classification stated by the agent in the plan phase.
# The system prompt instructs the agent to state its classification in the
# first plan turn; we persist it so critique can apply the right checks.
if state.phase == Phase.PLAN and state.question_type is None and state.pending_reasoning:
r = state.pending_reasoning.upper()
if "CROSS-SECTIONAL" in r or "CROSS_SECTIONAL" in r:
state.question_type = "CROSS_SECTIONAL"
elif "TIME-SERIES" in r or "TIME_SERIES" in r:
state.question_type = "TIME_SERIES"
elif "EXPLORATORY" in r:
state.question_type = "EXPLORATORY"
pending: list[ToolResult] = []
for tc in response.tool_calls or []:
pending.append(
ToolResult(
tool_name=tc["name"],
args={**tc["args"], "_tool_call_id": tc.get("id", "")},
output={},
)
)
state.pending_tool_calls = pending
return state
# ---------------------------------------------------------------------------
# Phase-stepping nodes
# ---------------------------------------------------------------------------
def _make_phase_node(phase: Phase):
def node(state: InvestigationState) -> InvestigationState:
state.phase = phase
return llm_call(state)
return node
# ---------------------------------------------------------------------------
# Critique node
# ---------------------------------------------------------------------------
def critique(state: InvestigationState) -> InvestigationState:
"""Ask the LLM to evaluate evidence strength; decide loop or report."""
state.phase = Phase.CRITIQUE
critique_prompt = _render_critique(
user_question=state.user_question,
hypotheses=_format_hypotheses(state.hypotheses),
evidence_summary=_format_evidence(state.evidence, full_output=True),
evidence_count=len(state.evidence),
retry_count=state.retry_count,
max_retries=MAX_RETRIES,
question_type=state.question_type,
)
llm = get_llm()
response = llm.invoke([HumanMessage(content=critique_prompt)])
text = response.content if isinstance(response.content, str) else str(response.content)
# Remove <think> / </think> markers but keep their content β€” Qwen3/MiniMax
# thinking mode sometimes embeds the VERDICT line inside a think block.
# Stripping the whole block would discard the verdict; removing only the
# tags makes the full response visible to the parser below.
text = re.sub(r"</?think>", "", text).strip()
# Scan ALL lines for the VERDICT β€” the model may emit preamble or thinking
# prose before the verdict line (especially after <think> tag removal).
# Strip Markdown bold markers and code-fence backticks from each line so
# "**VERDICT: strong**" and "`VERDICT: strong`" both parse correctly.
stripped_lines = [ln.strip() for ln in text.split("\n")]
# Find VERDICT on any line β€” model may emit preamble or inline prose before
# or alongside the keyword. Use a regex search so "After review, VERDICT: weak"
# is captured even though it doesn't start with "verdict:".
verdict_idx: int | None = None
verdict_word: str | None = None
for i, ln in enumerate(stripped_lines):
m = re.search(r"\bverdict\s*:\s*(\w+)", ln.lower().strip("* `"))
if m:
verdict_idx = i
verdict_word = m.group(1)
break
if verdict_idx is not None:
if verdict_word == "strong":
state.critique_passed = True
state.critique_feedback = None
else:
# Justification = lines after the VERDICT line β€” directed at the retry.
justification_lines = [ln for ln in stripped_lines[verdict_idx + 1 :] if ln]
state.critique_feedback = " ".join(justification_lines) or None
state.critique_passed = False
state.retry_count += 1
if state.retry_count >= MAX_RETRIES:
logger.warning("Max critique retries (%d) reached; forcing report.", MAX_RETRIES)
state.critique_passed = True
state.error = "Max critique retries reached. Evidence may be incomplete."
elif any(
# Require the keyword to open the line (optionally preceded by "the") β€”
# avoids false-positive on prose like "while the evidence is strong for X,
# the after-period is missing." but still matches "The evidence is strong."
re.match(r"(the\s+)?evidence is strong|(the\s+)?proceed to report", ln.lower().strip())
for ln in stripped_lines
):
state.critique_passed = True
state.critique_feedback = None
else:
state.critique_passed = False
state.critique_feedback = None
state.retry_count += 1
if state.retry_count >= MAX_RETRIES:
logger.warning("Max critique retries (%d) reached; forcing report.", MAX_RETRIES)
state.critique_passed = True
state.error = "Max critique retries reached. Evidence may be incomplete."
return state
# ---------------------------------------------------------------------------
# Report node
# ---------------------------------------------------------------------------
def report(state: InvestigationState) -> InvestigationState:
"""Assemble and store the final report dict."""
state.phase = Phase.REPORT
report_prompt = (
f"You are writing the final report for an investigation.\n\n"
f"**User question:** {state.user_question}\n\n"
f"**Hypotheses considered:**\n{_format_hypotheses(state.hypotheses)}\n\n"
f"**Evidence (full tool outputs):**\n"
f"{_format_evidence(state.evidence, full_output=True)}\n\n"
f"---\n\n"
f"Write a concise structured report with the following sections. Do NOT "
f"recap every tool call β€” distill what mattered. Reference specific "
f"numbers from the evidence; do not invent any.\n\n"
f"**1. Headline answer.** 1–3 sentences in plain prose. Lead with the "
f"dominant driver as a quantified claim (e.g., '~85% of the headline "
f"gap is audience selection, not campaign quality'). Then give the "
f"supporting numbers β€” the controlled comparison and the residual β€” "
f"and what they mean. If the investigation reframed the user's "
f"question, say so here. If you use bold, reserve it for the leading "
f"driver claim only β€” never bold supporting or secondary "
f"conclusions.\n\n"
f"**2. Evidence chain.** 3–6 numbered steps showing how you reached the "
f"answer. Each step references specific numbers from the evidence above. "
f"Show the progression: from the headline observation, through the moves "
f"that ruled in or out alternatives, to the conclusion.\n\n"
f"**3. Quantified attribution.** When the question compares entities or "
f"periods, decompose the headline gap arithmetically:\n"
f" - Aggregate gap: <number>\n"
f" - On overlap / controlled comparison: <number>\n"
f" - Selection or composition effect: <number> (~X% of total)\n"
f" - Genuine effect: <number> (~Y% of total)\n"
f"For exploratory questions where attribution doesn't apply, replace "
f"this section with the direct answer and supporting numbers.\n\n"
f"**4. Residual unexplained.** What part of the observation remains "
f"unaccounted for, and why. Be specific about whether the residual is "
f"because the data doesn't contain the relevant information (e.g., "
f"actual subject text, audience targeting criteria, real-world events) "
f"or because the investigation didn't reach it. Do not invent causes "
f"for the residual.\n\n"
f"**5. Confidence.** high / medium / low β€” and one sentence on what "
f"would raise your confidence (data you don't have, queries you didn't "
f"run, etc.).\n\n"
f"**6. Next steps.** 3–5 concrete actions an analyst should take next. "
f"For each, name the action type (e.g., A/B test, instrumentation, "
f"data request, follow-up query, qualitative review), state what it "
f"would prove or rule out, and note any data needed beyond this "
f"dataset.\n\n"
f"Make sure to mention any hypotheses that were investigated and ruled "
f"out β€” that's part of showing rigor."
)
llm = get_llm()
# MiniMax rejects single HumanMessage; prepend a dummy HumanMessage to keep it happy.
response = llm.invoke(
[HumanMessage(content="Please answer."), HumanMessage(content=report_prompt)]
)
state.final_report = {
"user_question": state.user_question,
"text": response.content,
"hypotheses": [h.model_dump() for h in state.hypotheses],
"evidence_count": len(state.evidence),
"critique_passed": state.critique_passed,
"error": state.error,
}
return state
# ---------------------------------------------------------------------------
# Build the graph
# ---------------------------------------------------------------------------
def build_graph():
builder = StateGraph(InvestigationState)
builder.add_node("llm_call", llm_call)
builder.add_node("execute_tools", execute_tools)
builder.add_node("critique", critique)
builder.add_node("report", report)
for phase in [Phase.PLAN, Phase.DECOMPOSE, Phase.DRILL, Phase.CROSS_CHECK]:
builder.add_node(phase.value, _make_phase_node(phase))
# Linear pipeline: each phase advances to the next when the LLM stops calling tools.
# On critique retry, decompose re-enters β†’ drill β†’ cross_check β†’ critique.
_phase_next: dict[Phase, str] = {
Phase.PLAN: "decompose",
Phase.DECOMPOSE: "drill",
Phase.DRILL: "cross_check",
Phase.CROSS_CHECK: "critique",
}
def route_after_llm(state: InvestigationState) -> str:
if len(state.evidence) >= MAX_TOOL_CALLS:
logger.warning("Tool call cap (%d) reached; forcing critique.", MAX_TOOL_CALLS)
return "critique"
if state.pending_tool_calls:
return "execute_tools"
return _phase_next.get(state.phase, "critique")
for phase in [Phase.PLAN, Phase.DECOMPOSE, Phase.DRILL, Phase.CROSS_CHECK]:
builder.add_conditional_edges(phase.value, route_after_llm)
builder.add_edge("execute_tools", "llm_call")
builder.add_conditional_edges("llm_call", route_after_llm)
def route_after_critique(state: InvestigationState) -> Literal["report", "decompose"]:
return "report" if state.critique_passed else "decompose"
builder.add_conditional_edges("critique", route_after_critique)
builder.add_edge("report", END)
builder.set_entry_point("plan")
return builder.compile()