Spaces:
Sleeping
Sleeping
bump: bringing in evaluation code from @ptoroisaza
Browse files- analysis_src/__pycache__/consistency.cpython-310.pyc +0 -0
- analysis_src/__pycache__/extract_consistency_data.cpython-310.pyc +0 -0
- analysis_src/__pycache__/extract_discovery_trajectory.cpython-310.pyc +0 -0
- analysis_src/__pycache__/extract_inference_data.cpython-310.pyc +0 -0
- analysis_src/__pycache__/extract_majority_vote_data.cpython-310.pyc +0 -0
- analysis_src/__pycache__/extract_tool_failures.cpython-310.pyc +0 -0
- analysis_src/__pycache__/model_styles.cpython-310.pyc +0 -0
- analysis_src/__pycache__/utils.cpython-310.pyc +0 -0
- analysis_src/consistency.py +513 -0
- analysis_src/extract_consistency_data.py +251 -0
- analysis_src/extract_discovery_trajectory.py +928 -0
- analysis_src/extract_exploration.py +623 -0
- analysis_src/extract_inference_data.py +595 -0
- analysis_src/extract_majority_vote_data.py +507 -0
- analysis_src/extract_tool_failures.py +560 -0
- analysis_src/model_styles.py +241 -0
- analysis_src/utils.py +155 -0
- evaluation.ipynb +0 -0
analysis_src/__pycache__/consistency.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
analysis_src/__pycache__/extract_consistency_data.cpython-310.pyc
ADDED
|
Binary file (6.73 kB). View file
|
|
|
analysis_src/__pycache__/extract_discovery_trajectory.cpython-310.pyc
ADDED
|
Binary file (21.9 kB). View file
|
|
|
analysis_src/__pycache__/extract_inference_data.cpython-310.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
analysis_src/__pycache__/extract_majority_vote_data.cpython-310.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
analysis_src/__pycache__/extract_tool_failures.cpython-310.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
analysis_src/__pycache__/model_styles.cpython-310.pyc
ADDED
|
Binary file (4.52 kB). View file
|
|
|
analysis_src/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (4.85 kB). View file
|
|
|
analysis_src/consistency.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
# TODO: Needs to be integrated into the itbench_leaderboard module
|
| 4 |
+
# This script calculates ICC (Intraclass Correlation Coefficient) and other
|
| 5 |
+
# consistency metrics for agent evaluation results.
|
| 6 |
+
|
| 7 |
+
Consistency Analysis for Agent Leaderboard Results.
|
| 8 |
+
|
| 9 |
+
Computes ICC (Intraclass Correlation Coefficient) to measure the reliability
|
| 10 |
+
and consistency of agent responses across multiple trials per scenario.
|
| 11 |
+
|
| 12 |
+
ICC answers: "Of all the variance observed, how much is due to actual scenario
|
| 13 |
+
difficulty (signal) vs. random model variability (noise/flakiness)?"
|
| 14 |
+
|
| 15 |
+
Interpretation:
|
| 16 |
+
ICC > 0.9: Excellent consistency
|
| 17 |
+
ICC 0.75-0.9: Good consistency
|
| 18 |
+
ICC 0.5-0.75: Moderate consistency
|
| 19 |
+
ICC < 0.5: Poor consistency (high flakiness)
|
| 20 |
+
|
| 21 |
+
Usage:
|
| 22 |
+
python -m itbench_leaderboard.consistency --results-dir leaderboard_results/results
|
| 23 |
+
python -m itbench_leaderboard.consistency --results-file path/to/results.json
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import sys
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Optional
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class ConsistencyMetrics:
|
| 38 |
+
"""Container for all consistency metrics."""
|
| 39 |
+
|
| 40 |
+
# Core ICC metrics
|
| 41 |
+
icc: float
|
| 42 |
+
flakiness_ratio: float # 1 - ICC
|
| 43 |
+
|
| 44 |
+
# ANOVA components
|
| 45 |
+
msb: float # Mean Square Between (between-scenario variance)
|
| 46 |
+
msw: float # Mean Square Within (within-scenario variance)
|
| 47 |
+
|
| 48 |
+
# Within-scenario consistency
|
| 49 |
+
mean_within_std: float
|
| 50 |
+
mean_agreement_rate: float
|
| 51 |
+
repeatability_coefficient: float
|
| 52 |
+
|
| 53 |
+
# Summary stats
|
| 54 |
+
n_scenarios: int
|
| 55 |
+
n_trials: int
|
| 56 |
+
n_flaky_scenarios: int
|
| 57 |
+
flaky_scenarios: list = field(default_factory=list)
|
| 58 |
+
|
| 59 |
+
# Per-scenario breakdown
|
| 60 |
+
scenario_details: dict = field(default_factory=dict)
|
| 61 |
+
|
| 62 |
+
def __str__(self) -> str:
|
| 63 |
+
return (
|
| 64 |
+
f"ICC: {self.icc:.4f} (flakiness: {self.flakiness_ratio:.4f})\n"
|
| 65 |
+
f"MSB (between): {self.msb:.4f}, MSW (within): {self.msw:.4f}\n"
|
| 66 |
+
f"Mean within-std: {self.mean_within_std:.4f}\n"
|
| 67 |
+
f"Agreement rate: {self.mean_agreement_rate:.4f}\n"
|
| 68 |
+
f"Repeatability coef: {self.repeatability_coefficient:.4f}\n"
|
| 69 |
+
f"Flaky scenarios: {self.n_flaky_scenarios}/{self.n_scenarios}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_results(filepath: Path) -> dict:
|
| 74 |
+
"""Load results JSON file."""
|
| 75 |
+
with open(filepath, "r") as f:
|
| 76 |
+
return json.load(f)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def extract_trial_scores(
|
| 80 |
+
results: dict,
|
| 81 |
+
metric: str = "root_cause_entity_f1"
|
| 82 |
+
) -> dict[str, list[float]]:
|
| 83 |
+
"""
|
| 84 |
+
Extract per-trial scores for a given metric from results.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
results: Loaded JSON results
|
| 88 |
+
metric: The metric name to extract (default: root_cause_entity_f1)
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Dict mapping scenario_id -> list of trial scores
|
| 92 |
+
"""
|
| 93 |
+
scenario_trials = {}
|
| 94 |
+
|
| 95 |
+
scenarios = results.get("scenarios", {})
|
| 96 |
+
for scenario_id, scenario_data in scenarios.items():
|
| 97 |
+
runs = scenario_data.get("runs", [])
|
| 98 |
+
trial_scores = []
|
| 99 |
+
|
| 100 |
+
for run in runs:
|
| 101 |
+
scores = run.get("scores", {})
|
| 102 |
+
score = scores.get(metric)
|
| 103 |
+
|
| 104 |
+
# Handle None/null values
|
| 105 |
+
if score is None:
|
| 106 |
+
score = 0.0
|
| 107 |
+
trial_scores.append(float(score))
|
| 108 |
+
|
| 109 |
+
if trial_scores:
|
| 110 |
+
scenario_trials[scenario_id] = trial_scores
|
| 111 |
+
|
| 112 |
+
return scenario_trials
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def calculate_agreement_rate(trials: list[float], tolerance: float = 0.1) -> float:
|
| 116 |
+
"""
|
| 117 |
+
Calculate agreement rate between trial pairs.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
trials: List of trial scores
|
| 121 |
+
tolerance: Maximum difference to consider as "agreement"
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Fraction of trial pairs that agree (0-1)
|
| 125 |
+
"""
|
| 126 |
+
from itertools import combinations
|
| 127 |
+
|
| 128 |
+
if len(trials) < 2:
|
| 129 |
+
return 1.0
|
| 130 |
+
|
| 131 |
+
pairs = list(combinations(trials, 2))
|
| 132 |
+
agreements = sum(1 for a, b in pairs if abs(a - b) <= tolerance)
|
| 133 |
+
return agreements / len(pairs)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def compute_icc(scenario_trials: dict[str, list[float]]) -> ConsistencyMetrics:
|
| 137 |
+
"""
|
| 138 |
+
Compute ICC(1,1) - one-way random effects model.
|
| 139 |
+
|
| 140 |
+
The ICC formula:
|
| 141 |
+
ICC = (MSB - MSW) / (MSB + (k-1) * MSW)
|
| 142 |
+
|
| 143 |
+
Where:
|
| 144 |
+
MSB = k * Var(scenario_means) [between-scenario variance]
|
| 145 |
+
MSW = Mean(Var(trials per scenario)) [within-scenario variance]
|
| 146 |
+
k = number of trials per scenario
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
scenario_trials: Dict mapping scenario_id -> list of trial scores
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
ConsistencyMetrics with ICC and related metrics
|
| 153 |
+
"""
|
| 154 |
+
# Convert to numpy array
|
| 155 |
+
scenarios = list(scenario_trials.keys())
|
| 156 |
+
|
| 157 |
+
# Ensure all scenarios have same number of trials
|
| 158 |
+
n_trials_list = [len(trials) for trials in scenario_trials.values()]
|
| 159 |
+
if len(set(n_trials_list)) > 1:
|
| 160 |
+
# Pad or truncate to minimum
|
| 161 |
+
k = min(n_trials_list)
|
| 162 |
+
scores = np.array([scenario_trials[s][:k] for s in scenarios])
|
| 163 |
+
else:
|
| 164 |
+
k = n_trials_list[0] if n_trials_list else 0
|
| 165 |
+
scores = np.array([scenario_trials[s] for s in scenarios])
|
| 166 |
+
|
| 167 |
+
n_scenarios = len(scenarios)
|
| 168 |
+
|
| 169 |
+
if n_scenarios == 0 or k == 0:
|
| 170 |
+
return ConsistencyMetrics(
|
| 171 |
+
icc=float('nan'),
|
| 172 |
+
flakiness_ratio=float('nan'),
|
| 173 |
+
msb=0.0,
|
| 174 |
+
msw=0.0,
|
| 175 |
+
mean_within_std=0.0,
|
| 176 |
+
mean_agreement_rate=1.0,
|
| 177 |
+
repeatability_coefficient=0.0,
|
| 178 |
+
n_scenarios=0,
|
| 179 |
+
n_trials=0,
|
| 180 |
+
n_flaky_scenarios=0,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Calculate scenario means
|
| 184 |
+
scenario_means = np.mean(scores, axis=1)
|
| 185 |
+
|
| 186 |
+
# Between-scenario variance (MSB)
|
| 187 |
+
# MSB = k * Var(scenario means)
|
| 188 |
+
msb = k * np.var(scenario_means, ddof=1) if n_scenarios > 1 else 0.0
|
| 189 |
+
|
| 190 |
+
# Within-scenario variance (MSW)
|
| 191 |
+
# MSW = average of within-scenario variances
|
| 192 |
+
within_vars = np.var(scores, axis=1, ddof=1) if k > 1 else np.zeros(n_scenarios)
|
| 193 |
+
msw = np.mean(within_vars)
|
| 194 |
+
|
| 195 |
+
# ICC(1,1) formula
|
| 196 |
+
denominator = msb + (k - 1) * msw
|
| 197 |
+
if denominator > 0:
|
| 198 |
+
icc = (msb - msw) / denominator
|
| 199 |
+
icc = max(0.0, icc) # ICC can be negative, clip to 0
|
| 200 |
+
else:
|
| 201 |
+
icc = float('nan') if msw == 0 and msb == 0 else 0.0
|
| 202 |
+
|
| 203 |
+
# Within-scenario standard deviations
|
| 204 |
+
within_stds = np.std(scores, axis=1, ddof=1) if k > 1 else np.zeros(n_scenarios)
|
| 205 |
+
mean_within_std = np.mean(within_stds)
|
| 206 |
+
|
| 207 |
+
# Agreement rates
|
| 208 |
+
agreement_rates = [
|
| 209 |
+
calculate_agreement_rate(scenario_trials[s])
|
| 210 |
+
for s in scenarios
|
| 211 |
+
]
|
| 212 |
+
mean_agreement_rate = np.mean(agreement_rates)
|
| 213 |
+
|
| 214 |
+
# Repeatability coefficient (95% of repeat differences < RC)
|
| 215 |
+
rc = 1.96 * np.sqrt(2 * msw) if msw > 0 else 0.0
|
| 216 |
+
|
| 217 |
+
# Identify flaky scenarios (high within-variance)
|
| 218 |
+
flaky_threshold = 0.3
|
| 219 |
+
flaky_scenarios = [
|
| 220 |
+
(s, float(std))
|
| 221 |
+
for s, std in zip(scenarios, within_stds)
|
| 222 |
+
if std > flaky_threshold
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
# Per-scenario details
|
| 226 |
+
scenario_details = {}
|
| 227 |
+
for i, s in enumerate(scenarios):
|
| 228 |
+
scenario_details[s] = {
|
| 229 |
+
"trials": scenario_trials[s],
|
| 230 |
+
"mean": float(scenario_means[i]),
|
| 231 |
+
"std": float(within_stds[i]),
|
| 232 |
+
"agreement_rate": agreement_rates[i],
|
| 233 |
+
"is_flaky": within_stds[i] > flaky_threshold,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
return ConsistencyMetrics(
|
| 237 |
+
icc=float(icc),
|
| 238 |
+
flakiness_ratio=float(1 - icc) if not np.isnan(icc) else float('nan'),
|
| 239 |
+
msb=float(msb),
|
| 240 |
+
msw=float(msw),
|
| 241 |
+
mean_within_std=float(mean_within_std),
|
| 242 |
+
mean_agreement_rate=float(mean_agreement_rate),
|
| 243 |
+
repeatability_coefficient=float(rc),
|
| 244 |
+
n_scenarios=n_scenarios,
|
| 245 |
+
n_trials=k,
|
| 246 |
+
n_flaky_scenarios=len(flaky_scenarios),
|
| 247 |
+
flaky_scenarios=flaky_scenarios,
|
| 248 |
+
scenario_details=scenario_details,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def analyze_results_file(
|
| 253 |
+
filepath: Path,
|
| 254 |
+
metrics: list[str] | None = None,
|
| 255 |
+
) -> dict[str, ConsistencyMetrics]:
|
| 256 |
+
"""
|
| 257 |
+
Analyze a single results file for multiple metrics.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
filepath: Path to the results JSON file
|
| 261 |
+
metrics: List of metrics to analyze. Defaults to common metrics.
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
Dict mapping metric_name -> ConsistencyMetrics
|
| 265 |
+
"""
|
| 266 |
+
if metrics is None:
|
| 267 |
+
metrics = [
|
| 268 |
+
"root_cause_entity_f1",
|
| 269 |
+
"root_cause_proximity_with_fp_f1",
|
| 270 |
+
"propagation_chain",
|
| 271 |
+
]
|
| 272 |
+
|
| 273 |
+
results = load_results(filepath)
|
| 274 |
+
|
| 275 |
+
analysis = {}
|
| 276 |
+
for metric in metrics:
|
| 277 |
+
scenario_trials = extract_trial_scores(results, metric)
|
| 278 |
+
if scenario_trials:
|
| 279 |
+
analysis[metric] = compute_icc(scenario_trials)
|
| 280 |
+
|
| 281 |
+
return analysis
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def compare_models(
|
| 285 |
+
results_dir: Path,
|
| 286 |
+
model_patterns: list[str],
|
| 287 |
+
metric: str = "root_cause_entity_f1",
|
| 288 |
+
) -> dict[str, ConsistencyMetrics]:
|
| 289 |
+
"""
|
| 290 |
+
Compare ICC across multiple models.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
results_dir: Directory containing results JSON files
|
| 294 |
+
model_patterns: List of model name patterns to match
|
| 295 |
+
metric: The metric to analyze
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
Dict mapping model_name -> ConsistencyMetrics
|
| 299 |
+
"""
|
| 300 |
+
comparison = {}
|
| 301 |
+
|
| 302 |
+
for pattern in model_patterns:
|
| 303 |
+
# Find matching file
|
| 304 |
+
matches = list(results_dir.glob(f"*{pattern}*.json"))
|
| 305 |
+
|
| 306 |
+
if not matches:
|
| 307 |
+
print(f"Warning: No file found for pattern '{pattern}'", file=sys.stderr)
|
| 308 |
+
continue
|
| 309 |
+
|
| 310 |
+
filepath = matches[0]
|
| 311 |
+
print(f"Analyzing: {filepath.name}")
|
| 312 |
+
|
| 313 |
+
results = load_results(filepath)
|
| 314 |
+
scenario_trials = extract_trial_scores(results, metric)
|
| 315 |
+
|
| 316 |
+
if scenario_trials:
|
| 317 |
+
model_name = results.get("agent_name", filepath.stem)
|
| 318 |
+
comparison[model_name] = compute_icc(scenario_trials)
|
| 319 |
+
|
| 320 |
+
return comparison
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def print_comparison_table(
|
| 324 |
+
comparison: dict[str, ConsistencyMetrics],
|
| 325 |
+
metric: str,
|
| 326 |
+
) -> None:
|
| 327 |
+
"""Print a formatted comparison table."""
|
| 328 |
+
print(f"\n{'='*80}")
|
| 329 |
+
print(f"ICC Comparison for metric: {metric}")
|
| 330 |
+
print(f"{'='*80}\n")
|
| 331 |
+
|
| 332 |
+
# Header
|
| 333 |
+
print(f"{'Model':<55} {'ICC':>8} {'Flaky%':>8} {'Std':>8} {'Agree%':>8}")
|
| 334 |
+
print("-" * 91)
|
| 335 |
+
|
| 336 |
+
# Sort by ICC descending
|
| 337 |
+
sorted_models = sorted(
|
| 338 |
+
comparison.items(),
|
| 339 |
+
key=lambda x: x[1].icc if not np.isnan(x[1].icc) else -1,
|
| 340 |
+
reverse=True
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
for model, metrics in sorted_models:
|
| 344 |
+
# Truncate model name if too long
|
| 345 |
+
display_name = model[:52] + "..." if len(model) > 55 else model
|
| 346 |
+
|
| 347 |
+
icc_str = f"{metrics.icc:.4f}" if not np.isnan(metrics.icc) else "N/A"
|
| 348 |
+
flaky_pct = f"{metrics.flakiness_ratio*100:.1f}%" if not np.isnan(metrics.flakiness_ratio) else "N/A"
|
| 349 |
+
|
| 350 |
+
print(
|
| 351 |
+
f"{display_name:<55} "
|
| 352 |
+
f"{icc_str:>8} "
|
| 353 |
+
f"{flaky_pct:>8} "
|
| 354 |
+
f"{metrics.mean_within_std:>8.4f} "
|
| 355 |
+
f"{metrics.mean_agreement_rate*100:>7.1f}%"
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
print("\nInterpretation:")
|
| 359 |
+
print(" ICC > 0.9: Excellent consistency")
|
| 360 |
+
print(" ICC 0.75-0.9: Good consistency")
|
| 361 |
+
print(" ICC 0.5-0.75: Moderate consistency")
|
| 362 |
+
print(" ICC < 0.5: Poor consistency (high flakiness)")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def print_detailed_report(
|
| 366 |
+
model_name: str,
|
| 367 |
+
metrics_analysis: dict[str, ConsistencyMetrics],
|
| 368 |
+
) -> None:
|
| 369 |
+
"""Print detailed report for a single model."""
|
| 370 |
+
print(f"\n{'='*80}")
|
| 371 |
+
print(f"Detailed Consistency Report: {model_name}")
|
| 372 |
+
print(f"{'='*80}\n")
|
| 373 |
+
|
| 374 |
+
for metric_name, cm in metrics_analysis.items():
|
| 375 |
+
print(f"\n--- {metric_name} ---")
|
| 376 |
+
print(cm)
|
| 377 |
+
|
| 378 |
+
if cm.flaky_scenarios:
|
| 379 |
+
print(f"\nFlaky scenarios (std > 0.3):")
|
| 380 |
+
for scenario, std in sorted(cm.flaky_scenarios, key=lambda x: -x[1])[:10]:
|
| 381 |
+
details = cm.scenario_details.get(scenario, {})
|
| 382 |
+
trials = details.get("trials", [])
|
| 383 |
+
print(f" {scenario}: std={std:.3f}, trials={trials}")
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def main():
|
| 387 |
+
parser = argparse.ArgumentParser(
|
| 388 |
+
description="Calculate ICC and consistency metrics for leaderboard results",
|
| 389 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 390 |
+
epilog=__doc__,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
parser.add_argument(
|
| 394 |
+
"--results-dir",
|
| 395 |
+
type=Path,
|
| 396 |
+
default=Path("leaderboard_results/results"),
|
| 397 |
+
help="Directory containing results JSON files",
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--results-file",
|
| 402 |
+
type=Path,
|
| 403 |
+
help="Analyze a single results file",
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
parser.add_argument(
|
| 407 |
+
"--models",
|
| 408 |
+
nargs="+",
|
| 409 |
+
default=[
|
| 410 |
+
"react with code_Azure_o4-mini",
|
| 411 |
+
"react with code_Azure_gpt-5.1-2025-11-13",
|
| 412 |
+
"react with code_gcp_gemini-3-pro-preview",
|
| 413 |
+
"react with code_GCP_gemini-2.5-pro",
|
| 414 |
+
],
|
| 415 |
+
help="Model name patterns to compare",
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
parser.add_argument(
|
| 419 |
+
"--metric",
|
| 420 |
+
type=str,
|
| 421 |
+
default="root_cause_entity_f1",
|
| 422 |
+
help="Metric to analyze (default: root_cause_entity_f1)",
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
parser.add_argument(
|
| 426 |
+
"--all-metrics",
|
| 427 |
+
action="store_true",
|
| 428 |
+
help="Analyze all common metrics",
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
parser.add_argument(
|
| 432 |
+
"--detailed",
|
| 433 |
+
action="store_true",
|
| 434 |
+
help="Show detailed per-scenario breakdown",
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
parser.add_argument(
|
| 438 |
+
"--output-json",
|
| 439 |
+
type=Path,
|
| 440 |
+
help="Save results to JSON file",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
args = parser.parse_args()
|
| 444 |
+
|
| 445 |
+
# Determine metrics to analyze
|
| 446 |
+
if args.all_metrics:
|
| 447 |
+
metrics = [
|
| 448 |
+
"root_cause_entity_f1",
|
| 449 |
+
"root_cause_entity_precision",
|
| 450 |
+
"root_cause_entity_recall",
|
| 451 |
+
"root_cause_proximity_with_fp_f1",
|
| 452 |
+
"propagation_chain",
|
| 453 |
+
"fault_localization_component_identification",
|
| 454 |
+
]
|
| 455 |
+
else:
|
| 456 |
+
metrics = [args.metric]
|
| 457 |
+
|
| 458 |
+
results_to_save = {}
|
| 459 |
+
|
| 460 |
+
if args.results_file:
|
| 461 |
+
# Single file analysis
|
| 462 |
+
print(f"Analyzing: {args.results_file}")
|
| 463 |
+
analysis = analyze_results_file(args.results_file, metrics)
|
| 464 |
+
|
| 465 |
+
results = load_results(args.results_file)
|
| 466 |
+
model_name = results.get("agent_name", args.results_file.stem)
|
| 467 |
+
|
| 468 |
+
print_detailed_report(model_name, analysis)
|
| 469 |
+
|
| 470 |
+
results_to_save[model_name] = {
|
| 471 |
+
m: {
|
| 472 |
+
"icc": cm.icc,
|
| 473 |
+
"flakiness_ratio": cm.flakiness_ratio,
|
| 474 |
+
"mean_within_std": cm.mean_within_std,
|
| 475 |
+
"mean_agreement_rate": cm.mean_agreement_rate,
|
| 476 |
+
"n_flaky_scenarios": cm.n_flaky_scenarios,
|
| 477 |
+
"n_scenarios": cm.n_scenarios,
|
| 478 |
+
}
|
| 479 |
+
for m, cm in analysis.items()
|
| 480 |
+
}
|
| 481 |
+
else:
|
| 482 |
+
# Multi-model comparison
|
| 483 |
+
for metric in metrics:
|
| 484 |
+
comparison = compare_models(args.results_dir, args.models, metric)
|
| 485 |
+
print_comparison_table(comparison, metric)
|
| 486 |
+
|
| 487 |
+
# Store results
|
| 488 |
+
for model, cm in comparison.items():
|
| 489 |
+
if model not in results_to_save:
|
| 490 |
+
results_to_save[model] = {}
|
| 491 |
+
results_to_save[model][metric] = {
|
| 492 |
+
"icc": cm.icc if not np.isnan(cm.icc) else None,
|
| 493 |
+
"flakiness_ratio": cm.flakiness_ratio if not np.isnan(cm.flakiness_ratio) else None,
|
| 494 |
+
"mean_within_std": cm.mean_within_std,
|
| 495 |
+
"mean_agreement_rate": cm.mean_agreement_rate,
|
| 496 |
+
"n_flaky_scenarios": cm.n_flaky_scenarios,
|
| 497 |
+
"n_scenarios": cm.n_scenarios,
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
if args.detailed:
|
| 501 |
+
for model, cm in comparison.items():
|
| 502 |
+
print_detailed_report(model, {metric: cm})
|
| 503 |
+
|
| 504 |
+
# Save to JSON if requested
|
| 505 |
+
if args.output_json:
|
| 506 |
+
with open(args.output_json, "w") as f:
|
| 507 |
+
json.dump(results_to_save, f, indent=2)
|
| 508 |
+
print(f"\nResults saved to: {args.output_json}")
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
if __name__ == "__main__":
|
| 512 |
+
main()
|
| 513 |
+
|
analysis_src/extract_consistency_data.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract consistency (ICC) and performance data for all 'react with code' agents.
|
| 4 |
+
|
| 5 |
+
This script reads directly from the run directories (not JSON result files)
|
| 6 |
+
to ensure all trials are captured.
|
| 7 |
+
|
| 8 |
+
Output is saved to paper_analysis/react with code/resources/figures/consistency/ as CSV files for plotting.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import sys
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from dataclasses import dataclass, field
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
# Add project root to path
|
| 20 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 21 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 22 |
+
|
| 23 |
+
from src.consistency import (
|
| 24 |
+
compute_icc,
|
| 25 |
+
ConsistencyMetrics,
|
| 26 |
+
)
|
| 27 |
+
from src.utils import (
|
| 28 |
+
get_model_name,
|
| 29 |
+
find_react_with_code_dirs,
|
| 30 |
+
read_judge_outputs_from_dir,
|
| 31 |
+
extract_trial_scores_from_judge_outputs,
|
| 32 |
+
get_runs_stats,
|
| 33 |
+
filter_scenarios_with_min_runs,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Paths
|
| 37 |
+
LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
|
| 38 |
+
RESULTS_JSON_DIR = LEADERBOARD_DIR / "results"
|
| 39 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "consistency"
|
| 40 |
+
|
| 41 |
+
# Minimum runs per scenario required for inclusion
|
| 42 |
+
MIN_RUNS_PER_SCENARIO = 3
|
| 43 |
+
|
| 44 |
+
# Minimum scenarios needed after filtering (must have at least this many with 3+ runs)
|
| 45 |
+
MIN_QUALIFYING_SCENARIOS = 20
|
| 46 |
+
|
| 47 |
+
# Metrics to analyze
|
| 48 |
+
METRICS = [
|
| 49 |
+
"root_cause_entity_f1",
|
| 50 |
+
"root_cause_entity_precision",
|
| 51 |
+
"root_cause_entity_recall",
|
| 52 |
+
"root_cause_proximity_with_fp_f1",
|
| 53 |
+
"propagation_chain",
|
| 54 |
+
"fault_localization_component_identification",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# Short names for display
|
| 58 |
+
METRIC_SHORT_NAMES = {
|
| 59 |
+
"root_cause_entity_f1": "RC Entity F1",
|
| 60 |
+
"root_cause_entity_precision": "RC Entity Prec",
|
| 61 |
+
"root_cause_entity_recall": "RC Entity Rec",
|
| 62 |
+
"root_cause_proximity_with_fp_f1": "RC Proximity F1",
|
| 63 |
+
"propagation_chain": "Prop. Chain",
|
| 64 |
+
"fault_localization_component_identification": "Fault Loc.",
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
def extract_all_data() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 68 |
+
"""
|
| 69 |
+
Extract ICC and performance data for all agents by reading from directories.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
- icc_df: ICC scores per model per metric
|
| 73 |
+
- perf_df: Performance averages per model per metric
|
| 74 |
+
- scenario_df: Per-scenario breakdown
|
| 75 |
+
"""
|
| 76 |
+
agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
|
| 77 |
+
print(f"Found {len(agent_dirs)} 'react with code' agent directories:")
|
| 78 |
+
for d in agent_dirs:
|
| 79 |
+
print(f" - {d.name}")
|
| 80 |
+
|
| 81 |
+
icc_records = []
|
| 82 |
+
perf_records = []
|
| 83 |
+
scenario_records = []
|
| 84 |
+
|
| 85 |
+
valid_models = []
|
| 86 |
+
skipped_models = []
|
| 87 |
+
|
| 88 |
+
for agent_dir in tqdm(agent_dirs, desc="Reading agent data"):
|
| 89 |
+
model_name = get_model_name(agent_dir.name)
|
| 90 |
+
|
| 91 |
+
print(f"\nReading: {agent_dir.name}")
|
| 92 |
+
scenario_data = read_judge_outputs_from_dir(agent_dir)
|
| 93 |
+
|
| 94 |
+
n_scenarios, min_runs, max_runs, n_qualifying = get_runs_stats(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 95 |
+
|
| 96 |
+
if n_scenarios == 0:
|
| 97 |
+
print(f" SKIPPING {model_name}: No judge outputs found")
|
| 98 |
+
skipped_models.append((model_name, "No data", 0))
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
if n_qualifying < MIN_QUALIFYING_SCENARIOS:
|
| 102 |
+
print(f" SKIPPING {model_name}: Only {n_qualifying}/{n_scenarios} scenarios have {MIN_RUNS_PER_SCENARIO}+ runs")
|
| 103 |
+
skipped_models.append((model_name, f"{n_qualifying}/{n_scenarios} qualifying", n_qualifying))
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
# Filter to only include scenarios with enough runs
|
| 107 |
+
scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 108 |
+
n_scenarios_filtered = len(scenario_data)
|
| 109 |
+
|
| 110 |
+
print(f" Processing: {model_name} ({n_scenarios_filtered} scenarios with {MIN_RUNS_PER_SCENARIO}+ runs)")
|
| 111 |
+
valid_models.append(model_name)
|
| 112 |
+
|
| 113 |
+
for metric in tqdm(METRICS, desc=f" {model_name} metrics", leave=False):
|
| 114 |
+
# Extract trial scores
|
| 115 |
+
scenario_trials = extract_trial_scores_from_judge_outputs(scenario_data, metric)
|
| 116 |
+
|
| 117 |
+
if not scenario_trials:
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
# Calculate performance average
|
| 121 |
+
all_scores = [s for trials in scenario_trials.values() for s in trials]
|
| 122 |
+
perf_avg = np.mean(all_scores) if all_scores else 0.0
|
| 123 |
+
|
| 124 |
+
perf_records.append({
|
| 125 |
+
"model": model_name,
|
| 126 |
+
"metric": METRIC_SHORT_NAMES.get(metric, metric),
|
| 127 |
+
"metric_raw": metric,
|
| 128 |
+
"performance": perf_avg,
|
| 129 |
+
})
|
| 130 |
+
|
| 131 |
+
# ICC calculation
|
| 132 |
+
try:
|
| 133 |
+
icc_metrics = compute_icc(scenario_trials)
|
| 134 |
+
|
| 135 |
+
icc_records.append({
|
| 136 |
+
"model": model_name,
|
| 137 |
+
"metric": METRIC_SHORT_NAMES.get(metric, metric),
|
| 138 |
+
"metric_raw": metric,
|
| 139 |
+
"icc": icc_metrics.icc if not np.isnan(icc_metrics.icc) else 0.0,
|
| 140 |
+
"flakiness": icc_metrics.flakiness_ratio if not np.isnan(icc_metrics.flakiness_ratio) else 1.0,
|
| 141 |
+
"within_std": icc_metrics.mean_within_std,
|
| 142 |
+
"agreement_rate": icc_metrics.mean_agreement_rate,
|
| 143 |
+
"n_flaky_scenarios": icc_metrics.n_flaky_scenarios,
|
| 144 |
+
"n_scenarios": icc_metrics.n_scenarios,
|
| 145 |
+
})
|
| 146 |
+
|
| 147 |
+
# Per-scenario data (only for root_cause_entity_f1)
|
| 148 |
+
if metric == "root_cause_entity_f1":
|
| 149 |
+
for scenario_id, details in icc_metrics.scenario_details.items():
|
| 150 |
+
scenario_records.append({
|
| 151 |
+
"model": model_name,
|
| 152 |
+
"scenario": scenario_id,
|
| 153 |
+
"mean": details["mean"],
|
| 154 |
+
"std": details["std"],
|
| 155 |
+
"trials": details["trials"],
|
| 156 |
+
"is_flaky": details["is_flaky"],
|
| 157 |
+
})
|
| 158 |
+
except Exception as e:
|
| 159 |
+
print(f" Error computing ICC for {metric}: {e}")
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
if skipped_models:
|
| 163 |
+
print(f"\n⚠️ Skipped {len(skipped_models)} models:")
|
| 164 |
+
for name, reason, _ in skipped_models:
|
| 165 |
+
print(f" - {name}: {reason}")
|
| 166 |
+
|
| 167 |
+
print(f"\n✓ Included {len(valid_models)} models: {valid_models}")
|
| 168 |
+
|
| 169 |
+
icc_df = pd.DataFrame(icc_records)
|
| 170 |
+
perf_df = pd.DataFrame(perf_records)
|
| 171 |
+
scenario_df = pd.DataFrame(scenario_records)
|
| 172 |
+
|
| 173 |
+
return icc_df, perf_df, scenario_df
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def save_data(icc_df: pd.DataFrame, perf_df: pd.DataFrame, scenario_df: pd.DataFrame):
|
| 177 |
+
"""Save extracted data to CSV files."""
|
| 178 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 179 |
+
|
| 180 |
+
icc_path = OUTPUT_DIR / "icc_data.csv"
|
| 181 |
+
perf_path = OUTPUT_DIR / "performance_data.csv"
|
| 182 |
+
scenario_path = OUTPUT_DIR / "scenario_data.csv"
|
| 183 |
+
|
| 184 |
+
icc_df.to_csv(icc_path, index=False)
|
| 185 |
+
perf_df.to_csv(perf_path, index=False)
|
| 186 |
+
scenario_df.to_csv(scenario_path, index=False)
|
| 187 |
+
|
| 188 |
+
print(f"\nData saved to:")
|
| 189 |
+
print(f" - {icc_path}")
|
| 190 |
+
print(f" - {perf_path}")
|
| 191 |
+
print(f" - {scenario_path}")
|
| 192 |
+
|
| 193 |
+
# Also save a summary JSON
|
| 194 |
+
summary = {
|
| 195 |
+
"models": icc_df["model"].unique().tolist(),
|
| 196 |
+
"metrics": icc_df["metric"].unique().tolist(),
|
| 197 |
+
"n_scenarios": int(icc_df["n_scenarios"].max()) if len(icc_df) > 0 else 0,
|
| 198 |
+
"min_runs_required": MIN_RUNS_PER_SCENARIO,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
summary_path = OUTPUT_DIR / "analysis_summary.json"
|
| 202 |
+
with open(summary_path, "w") as f:
|
| 203 |
+
json.dump(summary, f, indent=2)
|
| 204 |
+
print(f" - {summary_path}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def print_summary(icc_df: pd.DataFrame, perf_df: pd.DataFrame):
|
| 208 |
+
"""Print summary tables."""
|
| 209 |
+
print("\n" + "="*80)
|
| 210 |
+
print("ICC Summary (root_cause_entity_f1)")
|
| 211 |
+
print("="*80)
|
| 212 |
+
|
| 213 |
+
rc_icc = icc_df[icc_df["metric_raw"] == "root_cause_entity_f1"].copy()
|
| 214 |
+
rc_icc = rc_icc.sort_values("icc", ascending=False)
|
| 215 |
+
|
| 216 |
+
print(f"\n{'Model':<20} {'ICC':>8} {'Flaky%':>8} {'Std':>8} {'Agree%':>8}")
|
| 217 |
+
print("-" * 56)
|
| 218 |
+
for _, row in rc_icc.iterrows():
|
| 219 |
+
print(f"{row['model']:<20} {row['icc']:>8.4f} {row['flakiness']*100:>7.1f}% {row['within_std']:>8.4f} {row['agreement_rate']*100:>7.1f}%")
|
| 220 |
+
|
| 221 |
+
print("\n" + "="*80)
|
| 222 |
+
print("Performance Summary (root_cause_entity_f1)")
|
| 223 |
+
print("="*80)
|
| 224 |
+
|
| 225 |
+
rc_perf = perf_df[perf_df["metric_raw"] == "root_cause_entity_f1"].copy()
|
| 226 |
+
rc_perf = rc_perf.sort_values("performance", ascending=False)
|
| 227 |
+
|
| 228 |
+
print(f"\n{'Model':<20} {'Avg Score':>12}")
|
| 229 |
+
print("-" * 34)
|
| 230 |
+
for _, row in rc_perf.iterrows():
|
| 231 |
+
print(f"{row['model']:<20} {row['performance']:>12.4f}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def main():
|
| 235 |
+
print("Extracting consistency data for 'react with code' agents...")
|
| 236 |
+
print(f"Reading from directories: {LEADERBOARD_DIR}")
|
| 237 |
+
print(f"Output directory: {OUTPUT_DIR}")
|
| 238 |
+
print(f"Minimum runs per scenario: {MIN_RUNS_PER_SCENARIO}")
|
| 239 |
+
|
| 240 |
+
icc_df, perf_df, scenario_df = extract_all_data()
|
| 241 |
+
|
| 242 |
+
if len(icc_df) == 0:
|
| 243 |
+
print("No data extracted!")
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
save_data(icc_df, perf_df, scenario_df)
|
| 247 |
+
print_summary(icc_df, perf_df)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
if __name__ == "__main__":
|
| 251 |
+
main()
|
analysis_src/extract_discovery_trajectory.py
ADDED
|
@@ -0,0 +1,928 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Root Cause Discovery Trajectory Analysis
|
| 4 |
+
|
| 5 |
+
Analyzes how agents discover root cause entities:
|
| 6 |
+
- T_encounter: When GT entity first appears in tool output
|
| 7 |
+
- T_investigate: When agent actively queries GT entity
|
| 8 |
+
- T_assert: When agent asserts GT entity as root cause
|
| 9 |
+
- T_exonerate: When agent dismisses GT entity (if ever)
|
| 10 |
+
- T_recover: When agent corrects after exoneration
|
| 11 |
+
|
| 12 |
+
Metrics computed:
|
| 13 |
+
- Discovery efficiency (how early GT appears)
|
| 14 |
+
- Investigation delay (turns between seeing and investigating)
|
| 15 |
+
- Assertion delay (turns to confirm after investigating)
|
| 16 |
+
- Recovery rate (% of trials with successful recovery)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import sys
|
| 21 |
+
import re
|
| 22 |
+
import yaml
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from dataclasses import dataclass, field, asdict
|
| 25 |
+
from typing import Optional, List, Dict, Any
|
| 26 |
+
import numpy as np
|
| 27 |
+
import pandas as pd
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import seaborn as sns
|
| 30 |
+
import plotly.graph_objects as go
|
| 31 |
+
from tqdm import tqdm
|
| 32 |
+
|
| 33 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 34 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 35 |
+
|
| 36 |
+
from src.utils import find_latest_rollout_file
|
| 37 |
+
|
| 38 |
+
from src.model_styles import (
|
| 39 |
+
get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Improved regex to capture standard K8s resource patterns: namespace/Kind/name
|
| 43 |
+
# Captures: (namespace, Kind, name)
|
| 44 |
+
K8S_ENTITY_PATTERN = re.compile(r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|Node)/([\w-]+)', re.IGNORECASE)
|
| 45 |
+
|
| 46 |
+
def extract_k8s_entities(text: str) -> List[str]:
|
| 47 |
+
"""Extract all K8s entities matching the standard pattern."""
|
| 48 |
+
matches = K8S_ENTITY_PATTERN.findall(text)
|
| 49 |
+
entities = []
|
| 50 |
+
for m in matches:
|
| 51 |
+
# Normalize to namespace/Kind/name
|
| 52 |
+
entity = f"{m[0]}/{m[1]}/{m[2]}"
|
| 53 |
+
entities.append(entity)
|
| 54 |
+
return entities
|
| 55 |
+
|
| 56 |
+
# Paths
|
| 57 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 58 |
+
LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
|
| 59 |
+
GT_DIR = PROJECT_ROOT / "data" / "itbench-snapshots"
|
| 60 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "discovery"
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class GroundTruth:
|
| 64 |
+
"""Ground truth root cause entity info."""
|
| 65 |
+
scenario: str
|
| 66 |
+
entity_name: str
|
| 67 |
+
entity_kind: str
|
| 68 |
+
group_id: str
|
| 69 |
+
filters: List[str] # regex patterns to match entity
|
| 70 |
+
aliases: List[str] # related entity group IDs
|
| 71 |
+
propagation_entities: set = field(default_factory=set) # All entities involved in propagation
|
| 72 |
+
all_entities: list = field(default_factory=list) # All entities defined in the scenario
|
| 73 |
+
entity_filters: Dict[str, List[str]] = field(default_factory=dict) # group_id -> filters mapping for all entities
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class EntityMention:
|
| 78 |
+
"""A mention of an entity in the agent's trajectory."""
|
| 79 |
+
turn: int
|
| 80 |
+
mention_type: str # 'encounter', 'investigate', 'assert', 'exonerate'
|
| 81 |
+
context: str # 'tool_output', 'tool_args', 'reasoning', 'final_output'
|
| 82 |
+
text_snippet: str
|
| 83 |
+
sentiment: str # 'positive', 'negative', 'neutral'
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@dataclass
|
| 87 |
+
class TrajectoryAnalysis:
|
| 88 |
+
"""Analysis results for a single trial."""
|
| 89 |
+
model: str
|
| 90 |
+
scenario: str
|
| 91 |
+
trial: int
|
| 92 |
+
total_turns: int
|
| 93 |
+
gt_entity: str
|
| 94 |
+
|
| 95 |
+
# Key timestamps (turn numbers, None if not found)
|
| 96 |
+
t_encounter: Optional[int] = None
|
| 97 |
+
t_investigate: Optional[int] = None
|
| 98 |
+
t_assert: Optional[int] = None
|
| 99 |
+
t_exonerate: Optional[int] = None
|
| 100 |
+
t_recover: Optional[int] = None
|
| 101 |
+
|
| 102 |
+
# Final outcome (from judge scores if available)
|
| 103 |
+
final_success: bool = False # Did the final answer include GT?
|
| 104 |
+
root_cause_f1: Optional[float] = None
|
| 105 |
+
|
| 106 |
+
# Pipeline stage reached (for funnel analysis)
|
| 107 |
+
# 0=none, 1=encounter, 2=investigate, 3=assert, 4=success
|
| 108 |
+
max_stage_reached: int = 0
|
| 109 |
+
|
| 110 |
+
# All mentions for detailed analysis
|
| 111 |
+
mentions: List[EntityMention] = field(default_factory=list)
|
| 112 |
+
|
| 113 |
+
# Exploration metrics
|
| 114 |
+
total_entities_available: int = 0
|
| 115 |
+
unique_entities_encountered: int = 0
|
| 116 |
+
unique_entities_investigated: int = 0
|
| 117 |
+
exploration_ratio: float = 0.0 # investigated / available
|
| 118 |
+
|
| 119 |
+
# Coverage metrics
|
| 120 |
+
on_chain_investigated: int = 0
|
| 121 |
+
off_chain_investigated: int = 0 # Detoured
|
| 122 |
+
propagation_coverage: float = 0.0 # % of chain entities investigated
|
| 123 |
+
detour_rate: float = 0.0 # off_chain / total_investigated
|
| 124 |
+
|
| 125 |
+
# Computed metrics
|
| 126 |
+
discovery_efficiency: Optional[float] = None # t_encounter / total_turns
|
| 127 |
+
investigation_delay: Optional[int] = None # t_investigate - t_encounter
|
| 128 |
+
assertion_delay: Optional[int] = None # t_assert - t_investigate
|
| 129 |
+
had_recovery: bool = False
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def check_entity_match(text: str, entity_info: Dict) -> bool:
|
| 133 |
+
"""Check if text matches an arbitrary entity definition."""
|
| 134 |
+
text_lower = text.lower()
|
| 135 |
+
entity_name = entity_info.get('id', '').lower()
|
| 136 |
+
|
| 137 |
+
# Check direct name match
|
| 138 |
+
if entity_name and entity_name in text_lower:
|
| 139 |
+
return True
|
| 140 |
+
|
| 141 |
+
# Check filters
|
| 142 |
+
filters = entity_info.get('filter', [])
|
| 143 |
+
for pattern in filters:
|
| 144 |
+
search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '')
|
| 145 |
+
if search_term and search_term.lower() in text_lower:
|
| 146 |
+
return True
|
| 147 |
+
|
| 148 |
+
return False
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_ground_truth(scenario: str) -> Optional[GroundTruth]:
|
| 152 |
+
"""Load and parse ground truth YAML for a scenario."""
|
| 153 |
+
gt_path = GT_DIR / scenario / "ground_truth.yaml"
|
| 154 |
+
if not gt_path.exists():
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
with open(gt_path) as f:
|
| 158 |
+
gt_data = yaml.safe_load(f)
|
| 159 |
+
|
| 160 |
+
# Find the root cause group
|
| 161 |
+
root_cause_group = None
|
| 162 |
+
all_groups = gt_data.get('groups', [])
|
| 163 |
+
|
| 164 |
+
for group in all_groups:
|
| 165 |
+
if group.get('root_cause', False):
|
| 166 |
+
root_cause_group = group
|
| 167 |
+
break
|
| 168 |
+
|
| 169 |
+
if not root_cause_group:
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
# Get fault entity info
|
| 173 |
+
fault_list = gt_data.get('fault', [])
|
| 174 |
+
fault_info = fault_list[0] if fault_list else {}
|
| 175 |
+
entity_info = fault_info.get('entity', {})
|
| 176 |
+
|
| 177 |
+
# Collect all aliases
|
| 178 |
+
aliases = []
|
| 179 |
+
for alias_group in gt_data.get('aliases', []):
|
| 180 |
+
if root_cause_group['id'] in alias_group:
|
| 181 |
+
aliases.extend(alias_group)
|
| 182 |
+
|
| 183 |
+
# Collect all entities in propagation chain
|
| 184 |
+
propagation_entities = set()
|
| 185 |
+
for prop in gt_data.get('propagations', []):
|
| 186 |
+
if 'source' in prop:
|
| 187 |
+
propagation_entities.add(prop['source'])
|
| 188 |
+
if 'target' in prop:
|
| 189 |
+
propagation_entities.add(prop['target'])
|
| 190 |
+
|
| 191 |
+
# Add root cause itself if not already there (it should be as source)
|
| 192 |
+
propagation_entities.add(root_cause_group['id'])
|
| 193 |
+
|
| 194 |
+
# Build entity_filters mapping: group_id -> list of filter patterns
|
| 195 |
+
entity_filters = {}
|
| 196 |
+
for group in all_groups:
|
| 197 |
+
group_id = group.get('id', '')
|
| 198 |
+
filters = group.get('filter', [])
|
| 199 |
+
# Also use the group id itself and 'name' field as filters
|
| 200 |
+
name = group.get('name', '')
|
| 201 |
+
all_filters = list(filters) if filters else []
|
| 202 |
+
if group_id:
|
| 203 |
+
all_filters.append(group_id)
|
| 204 |
+
if name and name != group_id:
|
| 205 |
+
all_filters.append(name)
|
| 206 |
+
entity_filters[group_id] = all_filters
|
| 207 |
+
|
| 208 |
+
gt_obj = GroundTruth(
|
| 209 |
+
scenario=scenario,
|
| 210 |
+
entity_name=entity_info.get('name', root_cause_group['id']),
|
| 211 |
+
entity_kind=root_cause_group.get('kind', 'Unknown'),
|
| 212 |
+
group_id=root_cause_group['id'],
|
| 213 |
+
filters=root_cause_group.get('filter', []),
|
| 214 |
+
aliases=aliases,
|
| 215 |
+
propagation_entities=propagation_entities,
|
| 216 |
+
entity_filters=entity_filters
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Attach all entities for exploration analysis
|
| 220 |
+
gt_obj.all_entities = all_groups
|
| 221 |
+
return gt_obj
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def entity_matches(text: str, gt: GroundTruth) -> bool:
|
| 225 |
+
"""Check if text mentions the ground truth entity."""
|
| 226 |
+
text_lower = text.lower()
|
| 227 |
+
|
| 228 |
+
# Check direct name match
|
| 229 |
+
if gt.entity_name.lower() in text_lower:
|
| 230 |
+
return True
|
| 231 |
+
|
| 232 |
+
# Check group_id match
|
| 233 |
+
if gt.group_id.lower().replace('-', ' ') in text_lower.replace('-', ' '):
|
| 234 |
+
return True
|
| 235 |
+
if gt.group_id.lower().replace('-', '') in text_lower.replace('-', ''):
|
| 236 |
+
return True
|
| 237 |
+
|
| 238 |
+
# Check filter patterns
|
| 239 |
+
for pattern in gt.filters:
|
| 240 |
+
# Convert filter pattern to regex-friendly form
|
| 241 |
+
search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '')
|
| 242 |
+
if search_term.lower() in text_lower:
|
| 243 |
+
return True
|
| 244 |
+
|
| 245 |
+
# Check aliases
|
| 246 |
+
for alias in gt.aliases:
|
| 247 |
+
alias_term = alias.replace('-', ' ').lower()
|
| 248 |
+
if alias_term in text_lower.replace('-', ' '):
|
| 249 |
+
return True
|
| 250 |
+
|
| 251 |
+
return False
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def is_entity_on_chain(entity_str: str, gt: GroundTruth) -> Optional[str]:
|
| 255 |
+
"""
|
| 256 |
+
Check if an entity string matches any entity in the fault propagation chain.
|
| 257 |
+
Returns the matched group_id if on-chain, None if off-chain.
|
| 258 |
+
|
| 259 |
+
entity_str: e.g., "otel-demo/Pod/frontend-abc123" or just "frontend"
|
| 260 |
+
"""
|
| 261 |
+
entity_lower = entity_str.lower()
|
| 262 |
+
|
| 263 |
+
# For each propagation entity, check if entity_str matches its filters
|
| 264 |
+
for group_id in gt.propagation_entities:
|
| 265 |
+
filters = gt.entity_filters.get(group_id, [])
|
| 266 |
+
|
| 267 |
+
# Check group_id itself
|
| 268 |
+
if group_id.lower() in entity_lower or entity_lower in group_id.lower():
|
| 269 |
+
return group_id
|
| 270 |
+
|
| 271 |
+
# Check filter patterns
|
| 272 |
+
for pattern in filters:
|
| 273 |
+
# Clean up the regex pattern for simple matching
|
| 274 |
+
search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '').replace('\\', '')
|
| 275 |
+
if search_term and len(search_term) > 2: # Avoid too short matches
|
| 276 |
+
if search_term.lower() in entity_lower:
|
| 277 |
+
return group_id
|
| 278 |
+
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_entity_group_match(entity_str: str, gt: GroundTruth) -> Optional[str]:
|
| 283 |
+
"""
|
| 284 |
+
Check if an entity string matches any entity group in the scenario.
|
| 285 |
+
Returns the matched group_id if found, None otherwise.
|
| 286 |
+
"""
|
| 287 |
+
entity_lower = entity_str.lower()
|
| 288 |
+
|
| 289 |
+
for group in gt.all_entities:
|
| 290 |
+
group_id = group.get('id', '')
|
| 291 |
+
filters = group.get('filter', [])
|
| 292 |
+
name = group.get('name', '')
|
| 293 |
+
|
| 294 |
+
# Check group_id
|
| 295 |
+
if group_id and (group_id.lower() in entity_lower or entity_lower in group_id.lower()):
|
| 296 |
+
return group_id
|
| 297 |
+
|
| 298 |
+
# Check name
|
| 299 |
+
if name and (name.lower() in entity_lower or entity_lower in name.lower()):
|
| 300 |
+
return group_id
|
| 301 |
+
|
| 302 |
+
# Check filter patterns
|
| 303 |
+
for pattern in filters:
|
| 304 |
+
search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '').replace('\\', '')
|
| 305 |
+
if search_term and len(search_term) > 2:
|
| 306 |
+
if search_term.lower() in entity_lower:
|
| 307 |
+
return group_id
|
| 308 |
+
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def classify_sentiment(text: str, gt: GroundTruth) -> str:
|
| 313 |
+
"""Classify if mention is positive (asserting), negative (exonerating), or neutral."""
|
| 314 |
+
text_lower = text.lower()
|
| 315 |
+
|
| 316 |
+
# Find the sentence/context containing the entity
|
| 317 |
+
entity_term = gt.entity_name.lower()
|
| 318 |
+
|
| 319 |
+
# Positive indicators (asserting as root cause)
|
| 320 |
+
positive_patterns = [
|
| 321 |
+
r'root\s*cause',
|
| 322 |
+
r'is\s+the\s+cause',
|
| 323 |
+
r'caused\s+by',
|
| 324 |
+
r'source\s+of\s+(the\s+)?problem',
|
| 325 |
+
r'culprit',
|
| 326 |
+
r'responsible\s+for',
|
| 327 |
+
r'likely\s+cause',
|
| 328 |
+
r'appears\s+to\s+be\s+the\s+issue',
|
| 329 |
+
r'primary\s+issue',
|
| 330 |
+
r'main\s+issue',
|
| 331 |
+
]
|
| 332 |
+
|
| 333 |
+
# Negative indicators (exonerating)
|
| 334 |
+
negative_patterns = [
|
| 335 |
+
r'not\s+the\s+(root\s*)?cause',
|
| 336 |
+
r'ruled\s+out',
|
| 337 |
+
r'is\s+not\s+responsible',
|
| 338 |
+
r'working\s+(correctly|normally|fine)',
|
| 339 |
+
r'healthy',
|
| 340 |
+
r'no\s+issues?\s+(found|detected)',
|
| 341 |
+
r'can\s+be\s+excluded',
|
| 342 |
+
r'unlikely\s+to\s+be',
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
for pattern in positive_patterns:
|
| 346 |
+
if re.search(pattern, text_lower):
|
| 347 |
+
return 'positive'
|
| 348 |
+
|
| 349 |
+
for pattern in negative_patterns:
|
| 350 |
+
if re.search(pattern, text_lower):
|
| 351 |
+
return 'negative'
|
| 352 |
+
|
| 353 |
+
return 'neutral'
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def get_latest_rollout(trial_dir: Path) -> Optional[Path]:
|
| 357 |
+
"""Get the latest rollout file from a trial directory."""
|
| 358 |
+
sessions_dir = trial_dir / "sessions"
|
| 359 |
+
if not sessions_dir.exists():
|
| 360 |
+
return None
|
| 361 |
+
|
| 362 |
+
rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl"))
|
| 363 |
+
if not rollout_files:
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
# Sort by modification time, get latest
|
| 367 |
+
return max(rollout_files, key=lambda p: p.stat().st_mtime)
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def get_judge_score(trial_dir: Path) -> Optional[float]:
|
| 371 |
+
"""Get root_cause_entity_f1 from judge output."""
|
| 372 |
+
judge_path = trial_dir / "judge_output.json"
|
| 373 |
+
if not judge_path.exists():
|
| 374 |
+
return None
|
| 375 |
+
|
| 376 |
+
try:
|
| 377 |
+
with open(judge_path) as f:
|
| 378 |
+
judge_data = json.load(f)
|
| 379 |
+
return judge_data.get('flat_scores', {}).get('root_cause_entity_f1')
|
| 380 |
+
except:
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def parse_rollout(rollout_path: Path, gt: GroundTruth) -> TrajectoryAnalysis:
|
| 385 |
+
"""Parse a rollout file and extract entity mentions."""
|
| 386 |
+
mentions = []
|
| 387 |
+
turn_num = 0
|
| 388 |
+
total_turns = 0
|
| 389 |
+
|
| 390 |
+
t_encounter = None
|
| 391 |
+
t_investigate = None
|
| 392 |
+
t_assert = None
|
| 393 |
+
t_exonerate = None
|
| 394 |
+
t_recover = None
|
| 395 |
+
|
| 396 |
+
# Exploration tracking
|
| 397 |
+
# We want to track unique entities from the SCENARIO that were touched
|
| 398 |
+
# gt.filters contains patterns for the root cause.
|
| 399 |
+
# But we want patterns for ALL entities in the scenario.
|
| 400 |
+
# The GroundTruth class only has root cause info currently.
|
| 401 |
+
# We need to pass the full list of scenario entities.
|
| 402 |
+
|
| 403 |
+
# Wait, GroundTruth class needs updating first to hold all scenario entities.
|
| 404 |
+
# Currently it only holds root cause info.
|
| 405 |
+
# Let's assume the caller will update GroundTruth definition or pass a list of entities.
|
| 406 |
+
|
| 407 |
+
# Actually, let's update this function to work with the updated GroundTruth class
|
| 408 |
+
# which will be updated in the next step.
|
| 409 |
+
|
| 410 |
+
encountered_entities = set()
|
| 411 |
+
investigated_entities = set()
|
| 412 |
+
|
| 413 |
+
# Track which entity groups were investigated (on-chain vs off-chain)
|
| 414 |
+
on_chain_groups_investigated = set()
|
| 415 |
+
off_chain_groups_investigated = set()
|
| 416 |
+
all_groups_investigated = set()
|
| 417 |
+
|
| 418 |
+
with open(rollout_path) as f:
|
| 419 |
+
for line in f:
|
| 420 |
+
try:
|
| 421 |
+
obj = json.loads(line)
|
| 422 |
+
except json.JSONDecodeError:
|
| 423 |
+
continue
|
| 424 |
+
|
| 425 |
+
if obj.get('type') == 'turn_context':
|
| 426 |
+
turn_num += 1
|
| 427 |
+
total_turns = turn_num
|
| 428 |
+
|
| 429 |
+
if obj.get('type') != 'response_item':
|
| 430 |
+
continue
|
| 431 |
+
|
| 432 |
+
payload = obj.get('payload', {})
|
| 433 |
+
|
| 434 |
+
# Check tool outputs (encounter)
|
| 435 |
+
if payload.get('type') == 'function_call_output':
|
| 436 |
+
output = str(payload.get('output', ''))
|
| 437 |
+
|
| 438 |
+
# Check for root cause match
|
| 439 |
+
if entity_matches(output, gt):
|
| 440 |
+
sentiment = classify_sentiment(output, gt)
|
| 441 |
+
mentions.append(EntityMention(
|
| 442 |
+
turn=turn_num,
|
| 443 |
+
mention_type='encounter',
|
| 444 |
+
context='tool_output',
|
| 445 |
+
text_snippet=output[:200],
|
| 446 |
+
sentiment=sentiment
|
| 447 |
+
))
|
| 448 |
+
if t_encounter is None:
|
| 449 |
+
t_encounter = turn_num
|
| 450 |
+
|
| 451 |
+
# Broad exploration check using Regex
|
| 452 |
+
found_entities = extract_k8s_entities(output)
|
| 453 |
+
for entity in found_entities:
|
| 454 |
+
encountered_entities.add(entity)
|
| 455 |
+
|
| 456 |
+
# Check tool arguments (investigate)
|
| 457 |
+
if payload.get('type') == 'function_call':
|
| 458 |
+
args = payload.get('arguments', {})
|
| 459 |
+
if isinstance(args, str):
|
| 460 |
+
try:
|
| 461 |
+
args = json.loads(args)
|
| 462 |
+
except:
|
| 463 |
+
args = {'raw': args}
|
| 464 |
+
args_str = json.dumps(args)
|
| 465 |
+
|
| 466 |
+
# Root cause check
|
| 467 |
+
if entity_matches(args_str, gt):
|
| 468 |
+
mentions.append(EntityMention(
|
| 469 |
+
turn=turn_num,
|
| 470 |
+
mention_type='investigate',
|
| 471 |
+
context='tool_args',
|
| 472 |
+
text_snippet=args_str[:200],
|
| 473 |
+
sentiment='neutral'
|
| 474 |
+
))
|
| 475 |
+
if t_investigate is None:
|
| 476 |
+
t_investigate = turn_num
|
| 477 |
+
|
| 478 |
+
# Broad exploration check using Regex
|
| 479 |
+
found_entities = extract_k8s_entities(args_str)
|
| 480 |
+
for entity in found_entities:
|
| 481 |
+
investigated_entities.add(entity)
|
| 482 |
+
|
| 483 |
+
# Classify as on-chain or off-chain
|
| 484 |
+
on_chain_group = is_entity_on_chain(entity, gt)
|
| 485 |
+
if on_chain_group:
|
| 486 |
+
on_chain_groups_investigated.add(on_chain_group)
|
| 487 |
+
all_groups_investigated.add(on_chain_group)
|
| 488 |
+
else:
|
| 489 |
+
# Check if it matches any entity in scenario at all
|
| 490 |
+
any_group = get_entity_group_match(entity, gt)
|
| 491 |
+
if any_group:
|
| 492 |
+
off_chain_groups_investigated.add(any_group)
|
| 493 |
+
all_groups_investigated.add(any_group)
|
| 494 |
+
|
| 495 |
+
# Check update_plan for assertions/reasoning
|
| 496 |
+
if payload.get('name') == 'update_plan':
|
| 497 |
+
explanation = args.get('explanation', '')
|
| 498 |
+
if entity_matches(explanation, gt):
|
| 499 |
+
sentiment = classify_sentiment(explanation, gt)
|
| 500 |
+
mention_type = 'assert' if sentiment == 'positive' else ('exonerate' if sentiment == 'negative' else 'investigate')
|
| 501 |
+
mentions.append(EntityMention(
|
| 502 |
+
turn=turn_num,
|
| 503 |
+
mention_type=mention_type,
|
| 504 |
+
context='reasoning',
|
| 505 |
+
text_snippet=explanation[:200],
|
| 506 |
+
sentiment=sentiment
|
| 507 |
+
))
|
| 508 |
+
|
| 509 |
+
if mention_type == 'assert' and t_assert is None:
|
| 510 |
+
t_assert = turn_num
|
| 511 |
+
elif mention_type == 'exonerate' and t_exonerate is None:
|
| 512 |
+
t_exonerate = turn_num
|
| 513 |
+
|
| 514 |
+
# Check shell commands for final output
|
| 515 |
+
if payload.get('name') == 'shell':
|
| 516 |
+
cmd = args.get('command', [])
|
| 517 |
+
cmd_str = ' '.join(cmd) if isinstance(cmd, list) else str(cmd)
|
| 518 |
+
|
| 519 |
+
# Look for output generation with root cause assertions
|
| 520 |
+
if ('output.json' in cmd_str or 'root_cause' in cmd_str.lower()) and entity_matches(cmd_str, gt):
|
| 521 |
+
sentiment = classify_sentiment(cmd_str, gt)
|
| 522 |
+
if sentiment == 'positive' or 'root_cause' in cmd_str.lower():
|
| 523 |
+
mentions.append(EntityMention(
|
| 524 |
+
turn=turn_num,
|
| 525 |
+
mention_type='assert',
|
| 526 |
+
context='final_output',
|
| 527 |
+
text_snippet=cmd_str[:300],
|
| 528 |
+
sentiment='positive'
|
| 529 |
+
))
|
| 530 |
+
if t_assert is None:
|
| 531 |
+
t_assert = turn_num
|
| 532 |
+
|
| 533 |
+
# Check for recovery (exoneration followed by assertion)
|
| 534 |
+
had_recovery = False
|
| 535 |
+
if t_exonerate is not None and t_assert is not None and t_exonerate < t_assert:
|
| 536 |
+
had_recovery = True
|
| 537 |
+
t_recover = t_assert
|
| 538 |
+
|
| 539 |
+
# Compute metrics
|
| 540 |
+
discovery_efficiency = t_encounter / total_turns if t_encounter and total_turns > 0 else None
|
| 541 |
+
investigation_delay = t_investigate - t_encounter if t_investigate and t_encounter else None
|
| 542 |
+
assertion_delay = t_assert - t_investigate if t_assert and t_investigate else None
|
| 543 |
+
|
| 544 |
+
# Compute max stage reached (without final success - that comes from judge)
|
| 545 |
+
# 0=none, 1=encounter, 2=investigate, 3=assert
|
| 546 |
+
max_stage = 0
|
| 547 |
+
if t_encounter is not None:
|
| 548 |
+
max_stage = 1
|
| 549 |
+
if t_investigate is not None:
|
| 550 |
+
max_stage = 2
|
| 551 |
+
if t_assert is not None:
|
| 552 |
+
max_stage = 3
|
| 553 |
+
|
| 554 |
+
# Exploration metrics
|
| 555 |
+
# Note: total_entities_available is hard to define with regex approach as we don't know the universe.
|
| 556 |
+
# We will use the number of encountered entities as the denominator for "investigation ratio"
|
| 557 |
+
# or just report the raw counts.
|
| 558 |
+
|
| 559 |
+
num_encountered = len(encountered_entities)
|
| 560 |
+
num_investigated = len(investigated_entities)
|
| 561 |
+
|
| 562 |
+
# Ratio: What % of things seen were actually investigated?
|
| 563 |
+
expl_ratio = num_investigated / num_encountered if num_encountered > 0 else 0.0
|
| 564 |
+
|
| 565 |
+
# Coverage metrics: on-chain (fault propagation) vs off-chain (detoured)
|
| 566 |
+
n_on_chain = len(on_chain_groups_investigated)
|
| 567 |
+
n_off_chain = len(off_chain_groups_investigated)
|
| 568 |
+
total_investigated_groups = len(all_groups_investigated)
|
| 569 |
+
|
| 570 |
+
# Propagation coverage: what % of the fault propagation chain was investigated?
|
| 571 |
+
n_propagation_entities = len(gt.propagation_entities)
|
| 572 |
+
prop_coverage = n_on_chain / n_propagation_entities if n_propagation_entities > 0 else 0.0
|
| 573 |
+
|
| 574 |
+
# Detour rate: what % of investigated entities were off-chain (not in fault propagation)?
|
| 575 |
+
det_rate = n_off_chain / total_investigated_groups if total_investigated_groups > 0 else 0.0
|
| 576 |
+
|
| 577 |
+
return TrajectoryAnalysis(
|
| 578 |
+
model="", # Set by caller
|
| 579 |
+
scenario="", # Set by caller
|
| 580 |
+
trial=0, # Set by caller
|
| 581 |
+
total_turns=total_turns,
|
| 582 |
+
gt_entity=gt.entity_name,
|
| 583 |
+
t_encounter=t_encounter,
|
| 584 |
+
t_investigate=t_investigate,
|
| 585 |
+
t_assert=t_assert,
|
| 586 |
+
t_exonerate=t_exonerate,
|
| 587 |
+
t_recover=t_recover,
|
| 588 |
+
max_stage_reached=max_stage,
|
| 589 |
+
mentions=mentions,
|
| 590 |
+
total_entities_available=num_encountered, # Using encountered as the "available" set
|
| 591 |
+
unique_entities_encountered=num_encountered,
|
| 592 |
+
unique_entities_investigated=num_investigated,
|
| 593 |
+
exploration_ratio=expl_ratio,
|
| 594 |
+
# Coverage metrics (on-chain vs off-chain)
|
| 595 |
+
on_chain_investigated=n_on_chain,
|
| 596 |
+
off_chain_investigated=n_off_chain,
|
| 597 |
+
propagation_coverage=prop_coverage,
|
| 598 |
+
detour_rate=det_rate,
|
| 599 |
+
# Computed metrics
|
| 600 |
+
discovery_efficiency=discovery_efficiency,
|
| 601 |
+
investigation_delay=investigation_delay,
|
| 602 |
+
assertion_delay=assertion_delay,
|
| 603 |
+
had_recovery=had_recovery
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def analyze_model(model_dir: Path, gt_cache: Dict[str, GroundTruth]) -> List[TrajectoryAnalysis]:
|
| 608 |
+
"""Analyze all trials for a model."""
|
| 609 |
+
results = []
|
| 610 |
+
model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
|
| 611 |
+
|
| 612 |
+
scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")]
|
| 613 |
+
for scenario_dir in tqdm(scenario_dirs, desc=f" {model_name} scenarios"):
|
| 614 |
+
scenario = scenario_dir.name
|
| 615 |
+
gt = gt_cache.get(scenario)
|
| 616 |
+
if gt is None:
|
| 617 |
+
continue
|
| 618 |
+
|
| 619 |
+
trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()]
|
| 620 |
+
for trial_dir in tqdm(trial_dirs, desc=f" {scenario} trials"):
|
| 621 |
+
trial_num = int(trial_dir.name)
|
| 622 |
+
rollout_path = find_latest_rollout_file(trial_dir)
|
| 623 |
+
|
| 624 |
+
if rollout_path is None:
|
| 625 |
+
continue
|
| 626 |
+
|
| 627 |
+
try:
|
| 628 |
+
analysis = parse_rollout(rollout_path, gt)
|
| 629 |
+
analysis.model = model_name
|
| 630 |
+
analysis.scenario = scenario
|
| 631 |
+
analysis.trial = trial_num
|
| 632 |
+
|
| 633 |
+
# Get judge score to determine final success
|
| 634 |
+
f1_score = get_judge_score(trial_dir)
|
| 635 |
+
analysis.root_cause_f1 = f1_score
|
| 636 |
+
if f1_score is not None and f1_score > 0:
|
| 637 |
+
analysis.final_success = True
|
| 638 |
+
analysis.max_stage_reached = 4 # Success!
|
| 639 |
+
|
| 640 |
+
results.append(analysis)
|
| 641 |
+
except Exception as e:
|
| 642 |
+
print(f"Error processing {model_name}/{scenario}/{trial_num}: {e}")
|
| 643 |
+
|
| 644 |
+
return results
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
def plot_pipeline_funnel(summary_df: pd.DataFrame):
|
| 648 |
+
"""
|
| 649 |
+
Figure 1: Stacked bar showing where trials drop off in the pipeline.
|
| 650 |
+
|
| 651 |
+
Pipeline stages:
|
| 652 |
+
- Encounter: GT entity appears in tool OUTPUT (passive - agent didn't ask for it)
|
| 653 |
+
- Investigate: GT entity appears in tool ARGUMENTS (active - agent explicitly queried it)
|
| 654 |
+
- Assert: Agent declares GT as root cause
|
| 655 |
+
- Success: Judge confirms correct answer
|
| 656 |
+
"""
|
| 657 |
+
# Filter out mistral (no data) and prepare data
|
| 658 |
+
data = summary_df[summary_df['encounter_rate'] > 0].copy()
|
| 659 |
+
data['model_clean'] = data['model'].apply(get_display_name)
|
| 660 |
+
data = data.sort_values('success_rate', ascending=True)
|
| 661 |
+
|
| 662 |
+
# Stack: none, encounter_only, investigate_only, assert_only, success
|
| 663 |
+
# Normalize to percentages
|
| 664 |
+
n_trials = data['n_trials']
|
| 665 |
+
|
| 666 |
+
none_pct = data['n_stage_0_none'] / n_trials * 100
|
| 667 |
+
enc_pct = data['n_stage_1_encounter_only'] / n_trials * 100
|
| 668 |
+
inv_pct = data['n_stage_2_investigate_only'] / n_trials * 100
|
| 669 |
+
ass_pct = data['n_stage_3_assert_only'] / n_trials * 100
|
| 670 |
+
suc_pct = data['n_stage_4_success'] / n_trials * 100
|
| 671 |
+
|
| 672 |
+
n_models = len(data)
|
| 673 |
+
y = np.arange(n_models)
|
| 674 |
+
bar_height = 0.7
|
| 675 |
+
|
| 676 |
+
plt.rcParams.update(PLOT_PARAMETERS)
|
| 677 |
+
|
| 678 |
+
STAGE_COLORS = {
|
| 679 |
+
'none': '#d73027', # Red - never encountered GT
|
| 680 |
+
'encounter': '#fc8d59', # Orange - saw but didn't investigate
|
| 681 |
+
'investigate': '#fee08b', # Yellow - investigated but didn't assert
|
| 682 |
+
'assert': '#d9ef8b', # Light green - asserted but wrong final answer
|
| 683 |
+
'success': '#1a9850', # Green - success
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
# Create figure sized to fill half column with legend
|
| 687 |
+
fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 2.5))
|
| 688 |
+
|
| 689 |
+
# Plot stacked bars with GT prefix labels
|
| 690 |
+
ax.barh(y, none_pct, height=bar_height, label='RC never seen', color=STAGE_COLORS['none'],
|
| 691 |
+
edgecolor='white', linewidth=0.3)
|
| 692 |
+
ax.barh(y, enc_pct, height=bar_height, left=none_pct, label='RC seen, not queried',
|
| 693 |
+
color=STAGE_COLORS['encounter'], edgecolor='white', linewidth=0.3)
|
| 694 |
+
ax.barh(y, inv_pct, height=bar_height, left=none_pct + enc_pct, label='RC queried, not asserted',
|
| 695 |
+
color=STAGE_COLORS['investigate'], edgecolor='white', linewidth=0.3)
|
| 696 |
+
ax.barh(y, ass_pct, height=bar_height, left=none_pct + enc_pct + inv_pct, label='RC asserted, not in output',
|
| 697 |
+
color=STAGE_COLORS['assert'], edgecolor='white', linewidth=0.3)
|
| 698 |
+
ax.barh(y, suc_pct, height=bar_height, left=none_pct + enc_pct + inv_pct + ass_pct, label='RC asserted, in output',
|
| 699 |
+
color=STAGE_COLORS['success'], edgecolor='white', linewidth=0.3)
|
| 700 |
+
|
| 701 |
+
# Add percentage labels to each stack
|
| 702 |
+
min_pct_threshold = 4 # Only show labels for segments >= 2%
|
| 703 |
+
label_fontsize = MIN_FONT_SIZE - 3
|
| 704 |
+
|
| 705 |
+
for i, model_idx in enumerate(y):
|
| 706 |
+
segments = [
|
| 707 |
+
(none_pct.iloc[i], none_pct.iloc[i] / 2),
|
| 708 |
+
(enc_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] / 2),
|
| 709 |
+
(inv_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] / 2),
|
| 710 |
+
(ass_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] + ass_pct.iloc[i] / 2),
|
| 711 |
+
(suc_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] + ass_pct.iloc[i] + suc_pct.iloc[i] / 2)
|
| 712 |
+
]
|
| 713 |
+
|
| 714 |
+
for pct, x_pos in segments:
|
| 715 |
+
if pct >= min_pct_threshold:
|
| 716 |
+
ax.text(x_pos, model_idx, f'{pct:.0f}%',
|
| 717 |
+
ha='center', va='center', fontsize=label_fontsize,
|
| 718 |
+
color='black', weight='bold')
|
| 719 |
+
|
| 720 |
+
ax.set_yticks(y)
|
| 721 |
+
ax.set_yticklabels(data['model_clean'], fontsize=MIN_FONT_SIZE)
|
| 722 |
+
ax.set_xlabel('Trials (%)', fontsize=MIN_FONT_SIZE)
|
| 723 |
+
ax.set_xlim(0, 100)
|
| 724 |
+
ax.set_ylim(-0.5, n_models - 0.5)
|
| 725 |
+
ax.tick_params(axis='x', labelsize=MIN_FONT_SIZE)
|
| 726 |
+
|
| 727 |
+
# Legend below the plot - 2 columns, positioned below x-axis label
|
| 728 |
+
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.18), ncol=2,
|
| 729 |
+
frameon=False, fontsize=MIN_FONT_SIZE, columnspacing=0.8,
|
| 730 |
+
handletextpad=0.3, handlelength=1.0)
|
| 731 |
+
|
| 732 |
+
# Tight margins - more bottom space for legend
|
| 733 |
+
fig.subplots_adjust(left=0.28, right=0.99, top=0.99, bottom=0.38)
|
| 734 |
+
|
| 735 |
+
plt.title("Root Cause Entity Discovery Funnel")
|
| 736 |
+
plt.show()
|
| 737 |
+
fig.savefig(OUTPUT_DIR / "fig_conversion_funnel.png")
|
| 738 |
+
plt.close(fig)
|
| 739 |
+
print("Saved: fig_conversion_funnel.png")
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def extract_all_data():
|
| 743 |
+
# Create output directory
|
| 744 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 745 |
+
|
| 746 |
+
# Load all ground truths
|
| 747 |
+
print("\nLoading ground truth data...")
|
| 748 |
+
gt_cache = {}
|
| 749 |
+
scenario_dirs = [d for d in GT_DIR.iterdir() if d.is_dir() and d.name.startswith("Scenario-")]
|
| 750 |
+
for scenario_dir in tqdm(scenario_dirs, desc="Loading ground truths"):
|
| 751 |
+
gt = load_ground_truth(scenario_dir.name)
|
| 752 |
+
if gt:
|
| 753 |
+
gt_cache[scenario_dir.name] = gt
|
| 754 |
+
print(f"Loaded {len(gt_cache)} ground truth files")
|
| 755 |
+
|
| 756 |
+
# Find react with code agents
|
| 757 |
+
model_dirs = [d for d in LEADERBOARD_DIR.iterdir()
|
| 758 |
+
if d.is_dir() and d.name.startswith("react with code_")]
|
| 759 |
+
print(f"Found {len(model_dirs)} agent models")
|
| 760 |
+
|
| 761 |
+
# Analyze each model
|
| 762 |
+
all_results = []
|
| 763 |
+
for model_dir in tqdm(model_dirs, desc="Analyzing models"):
|
| 764 |
+
model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
|
| 765 |
+
print(f"\nAnalyzing {model_name}...")
|
| 766 |
+
|
| 767 |
+
results = analyze_model(model_dir, gt_cache)
|
| 768 |
+
all_results.extend(results)
|
| 769 |
+
|
| 770 |
+
# Summary stats
|
| 771 |
+
if results:
|
| 772 |
+
encounters = [r for r in results if r.t_encounter is not None]
|
| 773 |
+
asserts = [r for r in results if r.t_assert is not None]
|
| 774 |
+
recoveries = [r for r in results if r.had_recovery]
|
| 775 |
+
|
| 776 |
+
print(f" Trials: {len(results)}")
|
| 777 |
+
print(f" Encounters: {len(encounters)} ({100*len(encounters)/len(results):.1f}%)")
|
| 778 |
+
print(f" Assertions: {len(asserts)} ({100*len(asserts)/len(results):.1f}%)")
|
| 779 |
+
print(f" Recoveries: {len(recoveries)} ({100*len(recoveries)/len(results):.1f}%)")
|
| 780 |
+
|
| 781 |
+
# Convert to DataFrame
|
| 782 |
+
print("\n" + "=" * 60)
|
| 783 |
+
print("Generating output files...")
|
| 784 |
+
|
| 785 |
+
# Summary per trial
|
| 786 |
+
trial_data = []
|
| 787 |
+
for r in all_results:
|
| 788 |
+
trial_data.append({
|
| 789 |
+
'model': r.model,
|
| 790 |
+
'scenario': r.scenario,
|
| 791 |
+
'trial': r.trial,
|
| 792 |
+
'total_turns': r.total_turns,
|
| 793 |
+
'gt_entity': r.gt_entity,
|
| 794 |
+
't_encounter': r.t_encounter,
|
| 795 |
+
't_investigate': r.t_investigate,
|
| 796 |
+
't_assert': r.t_assert,
|
| 797 |
+
't_exonerate': r.t_exonerate,
|
| 798 |
+
't_recover': r.t_recover,
|
| 799 |
+
'max_stage_reached': r.max_stage_reached,
|
| 800 |
+
'final_success': r.final_success,
|
| 801 |
+
'root_cause_f1': r.root_cause_f1,
|
| 802 |
+
'discovery_efficiency': r.discovery_efficiency,
|
| 803 |
+
'investigation_delay': r.investigation_delay,
|
| 804 |
+
'assertion_delay': r.assertion_delay,
|
| 805 |
+
'had_recovery': r.had_recovery,
|
| 806 |
+
'n_mentions': len(r.mentions),
|
| 807 |
+
'total_entities_available': r.total_entities_available,
|
| 808 |
+
'unique_entities_encountered': r.unique_entities_encountered,
|
| 809 |
+
'unique_entities_investigated': r.unique_entities_investigated,
|
| 810 |
+
'exploration_ratio': r.exploration_ratio,
|
| 811 |
+
# Coverage metrics (on-chain vs off-chain)
|
| 812 |
+
'on_chain_investigated': r.on_chain_investigated,
|
| 813 |
+
'off_chain_investigated': r.off_chain_investigated,
|
| 814 |
+
'propagation_coverage': r.propagation_coverage,
|
| 815 |
+
'detour_rate': r.detour_rate
|
| 816 |
+
})
|
| 817 |
+
|
| 818 |
+
trial_df = pd.DataFrame(trial_data)
|
| 819 |
+
trial_df.to_csv(OUTPUT_DIR / "discovery_trials.csv", index=False)
|
| 820 |
+
print(f"Saved: {OUTPUT_DIR / 'discovery_trials.csv'}")
|
| 821 |
+
|
| 822 |
+
# Summary per model
|
| 823 |
+
model_summary = []
|
| 824 |
+
for model in trial_df['model'].unique():
|
| 825 |
+
model_data = trial_df[trial_df['model'] == model]
|
| 826 |
+
n_total = len(model_data)
|
| 827 |
+
|
| 828 |
+
# Funnel stages: count trials reaching each stage
|
| 829 |
+
# Stage 0: none, 1: encounter, 2: investigate, 3: assert, 4: success
|
| 830 |
+
stage_counts = model_data['max_stage_reached'].value_counts().to_dict()
|
| 831 |
+
|
| 832 |
+
# Cumulative: how many reached AT LEAST this stage
|
| 833 |
+
n_encounter = len(model_data[model_data['max_stage_reached'] >= 1])
|
| 834 |
+
n_investigate = len(model_data[model_data['max_stage_reached'] >= 2])
|
| 835 |
+
n_assert = len(model_data[model_data['max_stage_reached'] >= 3])
|
| 836 |
+
n_success = len(model_data[model_data['max_stage_reached'] >= 4])
|
| 837 |
+
|
| 838 |
+
# Filter to trials where we found something
|
| 839 |
+
with_encounter = model_data[model_data['t_encounter'].notna()]
|
| 840 |
+
with_assert = model_data[model_data['t_assert'].notna()]
|
| 841 |
+
with_recovery = model_data[model_data['had_recovery'] == True]
|
| 842 |
+
with_success = model_data[model_data['final_success'] == True]
|
| 843 |
+
|
| 844 |
+
model_summary.append({
|
| 845 |
+
'model': model,
|
| 846 |
+
'n_trials': n_total,
|
| 847 |
+
'n_scenarios': model_data['scenario'].nunique(),
|
| 848 |
+
# Funnel rates (cumulative, relative to total trials)
|
| 849 |
+
'encounter_rate': n_encounter / n_total if n_total > 0 else 0,
|
| 850 |
+
'investigate_rate': n_investigate / n_total if n_total > 0 else 0,
|
| 851 |
+
'assertion_rate': n_assert / n_total if n_total > 0 else 0,
|
| 852 |
+
'success_rate': n_success / n_total if n_total > 0 else 0,
|
| 853 |
+
# Conversion rate: given encounter, did model declare it as root cause?
|
| 854 |
+
# This handles multi-root-cause scenarios better
|
| 855 |
+
'conversion_rate': n_success / n_encounter if n_encounter > 0 else 0,
|
| 856 |
+
# Drop-off at each stage (exclusive counts)
|
| 857 |
+
'n_stage_0_none': stage_counts.get(0, 0),
|
| 858 |
+
'n_stage_1_encounter_only': stage_counts.get(1, 0),
|
| 859 |
+
'n_stage_2_investigate_only': stage_counts.get(2, 0),
|
| 860 |
+
'n_stage_3_assert_only': stage_counts.get(3, 0),
|
| 861 |
+
'n_stage_4_success': stage_counts.get(4, 0),
|
| 862 |
+
# Legacy metrics
|
| 863 |
+
'recovery_rate': len(with_recovery) / n_total if n_total > 0 else 0,
|
| 864 |
+
'avg_t_encounter': with_encounter['t_encounter'].mean() if len(with_encounter) > 0 else None,
|
| 865 |
+
'avg_t_assert': with_assert['t_assert'].mean() if len(with_assert) > 0 else None,
|
| 866 |
+
'avg_total_turns': model_data['total_turns'].mean(),
|
| 867 |
+
'avg_discovery_efficiency': with_encounter['discovery_efficiency'].mean() if len(with_encounter) > 0 else None,
|
| 868 |
+
'avg_investigation_delay': with_encounter['investigation_delay'].mean() if len(with_encounter) > 0 else None,
|
| 869 |
+
'avg_assertion_delay': with_assert['assertion_delay'].mean() if len(with_assert) > 0 else None,
|
| 870 |
+
'avg_f1': with_success['root_cause_f1'].mean() if len(with_success) > 0 else None,
|
| 871 |
+
'avg_exploration_ratio': model_data['exploration_ratio'].mean(),
|
| 872 |
+
'avg_entities_investigated': model_data['unique_entities_investigated'].mean(),
|
| 873 |
+
# Coverage metrics (fault propagation coverage)
|
| 874 |
+
'avg_on_chain_investigated': model_data['on_chain_investigated'].mean(),
|
| 875 |
+
'avg_off_chain_investigated': model_data['off_chain_investigated'].mean(),
|
| 876 |
+
'avg_propagation_coverage': model_data['propagation_coverage'].mean(),
|
| 877 |
+
'avg_detour_rate': model_data['detour_rate'].mean()
|
| 878 |
+
})
|
| 879 |
+
|
| 880 |
+
summary_df = pd.DataFrame(model_summary)
|
| 881 |
+
summary_df.to_csv(OUTPUT_DIR / "discovery_summary.csv", index=False)
|
| 882 |
+
print(f"Saved: {OUTPUT_DIR / 'discovery_summary.csv'}")
|
| 883 |
+
|
| 884 |
+
trials_n = len(all_results)
|
| 885 |
+
|
| 886 |
+
return summary_df, trial_df, trials_n
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def main():
|
| 890 |
+
print("=" * 60)
|
| 891 |
+
print("Root Cause Discovery Trajectory Analysis")
|
| 892 |
+
print("=" * 60)
|
| 893 |
+
|
| 894 |
+
summary_df, trial_df, trials_n = extract_all_data()
|
| 895 |
+
|
| 896 |
+
# Print summary table with funnel
|
| 897 |
+
print("\n" + "=" * 80)
|
| 898 |
+
print("Discovery Pipeline Funnel:")
|
| 899 |
+
print("-" * 80)
|
| 900 |
+
print(f"{'Model':<25} {'Trials':>7} {'Encntr':>8} {'Invest':>8} {'Assert':>8} {'Success':>8}")
|
| 901 |
+
print("-" * 80)
|
| 902 |
+
for _, row in summary_df.iterrows():
|
| 903 |
+
print(f"{row['model']:<25} {row['n_trials']:>7} "
|
| 904 |
+
f"{row['encounter_rate']*100:>7.0f}% "
|
| 905 |
+
f"{row['investigate_rate']*100:>7.0f}% "
|
| 906 |
+
f"{row['assertion_rate']*100:>7.0f}% "
|
| 907 |
+
f"{row['success_rate']*100:>7.0f}%")
|
| 908 |
+
|
| 909 |
+
print("\n" + "=" * 80)
|
| 910 |
+
print("Drop-off Analysis (where trials stopped):")
|
| 911 |
+
print("-" * 80)
|
| 912 |
+
print(f"{'Model':<25} {'None':>7} {'Enc→X':>7} {'Inv→X':>7} {'Ass→X':>7} {'✓':>7}")
|
| 913 |
+
print("-" * 80)
|
| 914 |
+
for _, row in summary_df.iterrows():
|
| 915 |
+
print(f"{row['model']:<25} "
|
| 916 |
+
f"{row['n_stage_0_none']:>7} "
|
| 917 |
+
f"{row['n_stage_1_encounter_only']:>7} "
|
| 918 |
+
f"{row['n_stage_2_investigate_only']:>7} "
|
| 919 |
+
f"{row['n_stage_3_assert_only']:>7} "
|
| 920 |
+
f"{row['n_stage_4_success']:>7}")
|
| 921 |
+
|
| 922 |
+
print(f"\nTotal trials analyzed: {trials_n}")
|
| 923 |
+
print(f"\nOutput saved to: {OUTPUT_DIR}")
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
if __name__ == "__main__":
|
| 927 |
+
main()
|
| 928 |
+
|
analysis_src/extract_exploration.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Exploration Breadth Analysis by Diagnosis Correctness
|
| 4 |
+
|
| 5 |
+
Creates a plot comparing exploration breadth between:
|
| 6 |
+
- Correct diagnoses (recall > 0, i.e., root_cause_f1 > 0)
|
| 7 |
+
- Incorrect diagnoses (recall = 0, i.e., root_cause_f1 == 0)
|
| 8 |
+
|
| 9 |
+
Uses semantic entity grouping to avoid counting "frontend deployment" and
|
| 10 |
+
"frontend service" as separate entities.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import sys
|
| 15 |
+
import re
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, List, Dict, Set, Tuple
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import numpy as np
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import seaborn as sns
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
# Publication settings - ICML half column
|
| 26 |
+
HALF_COLUMN_WIDTH = 3.25 # inches
|
| 27 |
+
MIN_FONT_SIZE = 8
|
| 28 |
+
|
| 29 |
+
plt.rcParams.update({
|
| 30 |
+
'font.size': MIN_FONT_SIZE,
|
| 31 |
+
'font.family': 'serif',
|
| 32 |
+
'axes.labelsize': MIN_FONT_SIZE,
|
| 33 |
+
'axes.titlesize': MIN_FONT_SIZE + 1,
|
| 34 |
+
'xtick.labelsize': MIN_FONT_SIZE,
|
| 35 |
+
'ytick.labelsize': MIN_FONT_SIZE,
|
| 36 |
+
'legend.fontsize': MIN_FONT_SIZE,
|
| 37 |
+
'figure.dpi': 150,
|
| 38 |
+
'savefig.dpi': 300,
|
| 39 |
+
'savefig.bbox': 'tight',
|
| 40 |
+
'axes.spines.top': False,
|
| 41 |
+
'axes.spines.right': False,
|
| 42 |
+
})
|
| 43 |
+
|
| 44 |
+
# Paths
|
| 45 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 46 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 47 |
+
|
| 48 |
+
from src.utils import find_latest_rollout_file
|
| 49 |
+
|
| 50 |
+
from src.model_styles import (
|
| 51 |
+
get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Paths
|
| 55 |
+
LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
|
| 56 |
+
GT_DIR = PROJECT_ROOT / "data" / "itbench-snapshots"
|
| 57 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "discovery"
|
| 58 |
+
|
| 59 |
+
# Regex for K8s entities
|
| 60 |
+
K8S_ENTITY_PATTERN = re.compile(
|
| 61 |
+
r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|'
|
| 62 |
+
r'DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|'
|
| 63 |
+
r'PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|'
|
| 64 |
+
r'ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|'
|
| 65 |
+
r'Node|Schedule|NetworkChaos|StressChaos|PodChaos)/([\w-]+)',
|
| 66 |
+
re.IGNORECASE
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Service name normalization patterns
|
| 70 |
+
SERVICE_NORMALIZATIONS = {
|
| 71 |
+
# Map specific variations to canonical names
|
| 72 |
+
'frontend-proxy': 'frontend-proxy',
|
| 73 |
+
'frontendproxy': 'frontend-proxy',
|
| 74 |
+
'frontend': 'frontend',
|
| 75 |
+
'checkout': 'checkout',
|
| 76 |
+
'checkoutservice': 'checkout',
|
| 77 |
+
'cart': 'cart',
|
| 78 |
+
'cartservice': 'cart',
|
| 79 |
+
'shipping': 'shipping',
|
| 80 |
+
'shippingservice': 'shipping',
|
| 81 |
+
'product-catalog': 'product-catalog',
|
| 82 |
+
'productcatalog': 'product-catalog',
|
| 83 |
+
'productcatalogservice': 'product-catalog',
|
| 84 |
+
'recommendation': 'recommendation',
|
| 85 |
+
'recommendationservice': 'recommendation',
|
| 86 |
+
'email': 'email',
|
| 87 |
+
'emailservice': 'email',
|
| 88 |
+
'payment': 'payment',
|
| 89 |
+
'paymentservice': 'payment',
|
| 90 |
+
'currency': 'currency',
|
| 91 |
+
'currencyservice': 'currency',
|
| 92 |
+
'ad': 'ad',
|
| 93 |
+
'adservice': 'ad',
|
| 94 |
+
'fraud-detection': 'fraud-detection',
|
| 95 |
+
'frauddetection': 'fraud-detection',
|
| 96 |
+
'frauddetectionservice': 'fraud-detection',
|
| 97 |
+
'load-generator': 'load-generator',
|
| 98 |
+
'loadgenerator': 'load-generator',
|
| 99 |
+
'flagd': 'flagd',
|
| 100 |
+
'otel-collector': 'otel-collector',
|
| 101 |
+
'valkey': 'valkey',
|
| 102 |
+
'valkey-cart': 'valkey', # valkey instance for cart
|
| 103 |
+
'redis': 'valkey', # alias
|
| 104 |
+
'kafka': 'kafka',
|
| 105 |
+
'quote': 'quote',
|
| 106 |
+
'quoteservice': 'quote',
|
| 107 |
+
'accounting': 'accounting',
|
| 108 |
+
'accountingservice': 'accounting',
|
| 109 |
+
'otel-demo': 'otel-demo', # namespace
|
| 110 |
+
'imageprovider': 'imageprovider',
|
| 111 |
+
'flagdui': 'flagdui',
|
| 112 |
+
'opensearch': 'opensearch',
|
| 113 |
+
'grafana': 'grafana',
|
| 114 |
+
'jaeger': 'jaeger',
|
| 115 |
+
'prometheus': 'prometheus',
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# Model name mapping for cleaner labels
|
| 119 |
+
MODEL_NAMES = {
|
| 120 |
+
'Azure_gpt-5.1-2025-11-13': 'GPT-5.1',
|
| 121 |
+
'Azure_o4-mini': 'o4-mini',
|
| 122 |
+
'GCP_gemini-2.5-pro': 'Gemini 2.5 Pro',
|
| 123 |
+
'gcp_gemini-3-pro-preview': 'Gemini 3 Pro',
|
| 124 |
+
'gemini-3-pro-preview': 'Gemini 3 Pro',
|
| 125 |
+
'gemini-3-flash-preview': 'Gemini 3 Flash',
|
| 126 |
+
'moonshotai_kimi-k2-thinking': 'Kimi K2',
|
| 127 |
+
'aws_claude-opus-4-5': 'Claude Opus 4.5',
|
| 128 |
+
'openai_gpt-oss-120b': 'GPT-OSS-120B',
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def normalize_entity_to_logical(entity: str) -> str:
|
| 133 |
+
"""
|
| 134 |
+
Normalize an entity to its logical/canonical service name.
|
| 135 |
+
|
| 136 |
+
e.g., "otel-demo/Deployment/frontend-abc123" -> "frontend"
|
| 137 |
+
"otel-demo/Service/checkout" -> "checkout"
|
| 138 |
+
"chaos-mesh/NetworkChaos/xyz" -> "chaos:NetworkChaos"
|
| 139 |
+
"""
|
| 140 |
+
parts = entity.lower().split('/')
|
| 141 |
+
|
| 142 |
+
# Handle chaos-mesh specially
|
| 143 |
+
if 'chaos-mesh' in parts[0] if parts else '':
|
| 144 |
+
if len(parts) >= 2:
|
| 145 |
+
return f"chaos:{parts[1]}"
|
| 146 |
+
return "chaos"
|
| 147 |
+
|
| 148 |
+
# Get the name part (last component)
|
| 149 |
+
if len(parts) >= 3:
|
| 150 |
+
name = parts[2]
|
| 151 |
+
elif len(parts) >= 1:
|
| 152 |
+
name = parts[-1]
|
| 153 |
+
else:
|
| 154 |
+
return entity.lower()
|
| 155 |
+
|
| 156 |
+
# Strip pod suffixes (e.g., frontend-5d4f6b7c8d-xyz9a -> frontend)
|
| 157 |
+
# Pattern: name followed by hash-like suffixes from ReplicaSets/Pods
|
| 158 |
+
# ReplicaSet adds -<hash10> and Pod adds -<hash5>
|
| 159 |
+
# e.g., frontend-5d4f6b7c8d-xyz9a -> strip -5d4f6b7c8d-xyz9a
|
| 160 |
+
name = re.sub(r'-[a-f0-9]{8,10}-[a-z0-9]{5}$', '', name) # Pod suffix (RS hash + Pod hash)
|
| 161 |
+
name = re.sub(r'-[a-f0-9]{8,10}$', '', name) # ReplicaSet suffix only (10-char hex hash)
|
| 162 |
+
|
| 163 |
+
# Also strip numeric suffixes like -1, -2 from entity names
|
| 164 |
+
name = re.sub(r'-\d+$', '', name)
|
| 165 |
+
|
| 166 |
+
# First check for exact match (most reliable)
|
| 167 |
+
if name in SERVICE_NORMALIZATIONS:
|
| 168 |
+
return SERVICE_NORMALIZATIONS[name]
|
| 169 |
+
|
| 170 |
+
# Try matching with service name variations
|
| 171 |
+
# Sort by length descending so longer patterns match first (frontend-proxy before frontend)
|
| 172 |
+
for pattern in sorted(SERVICE_NORMALIZATIONS.keys(), key=len, reverse=True):
|
| 173 |
+
canonical = SERVICE_NORMALIZATIONS[pattern]
|
| 174 |
+
# Exact match or name starts with pattern followed by typical suffixes
|
| 175 |
+
if name == pattern:
|
| 176 |
+
return canonical
|
| 177 |
+
# e.g., "checkoutservice" starts with "checkout"
|
| 178 |
+
if name.startswith(pattern) and (
|
| 179 |
+
len(name) == len(pattern) or
|
| 180 |
+
name[len(pattern):].startswith('service') or
|
| 181 |
+
name[len(pattern):].startswith('-')
|
| 182 |
+
):
|
| 183 |
+
return canonical
|
| 184 |
+
|
| 185 |
+
# Fallback: return cleaned name
|
| 186 |
+
return name
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def extract_k8s_entities(text: str) -> List[str]:
|
| 190 |
+
"""Extract all K8s entities from text."""
|
| 191 |
+
matches = K8S_ENTITY_PATTERN.findall(text)
|
| 192 |
+
entities = []
|
| 193 |
+
for m in matches:
|
| 194 |
+
entity = f"{m[0]}/{m[1]}/{m[2]}"
|
| 195 |
+
entities.append(entity)
|
| 196 |
+
return entities
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def extract_logical_entities(text: str) -> Set[str]:
|
| 200 |
+
"""Extract and normalize entities to logical names."""
|
| 201 |
+
raw_entities = extract_k8s_entities(text)
|
| 202 |
+
return {normalize_entity_to_logical(e) for e in raw_entities}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def get_latest_rollout(trial_dir: Path) -> Optional[Path]:
|
| 206 |
+
"""Get the latest rollout file from a trial directory."""
|
| 207 |
+
sessions_dir = trial_dir / "sessions"
|
| 208 |
+
if not sessions_dir.exists():
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
+
rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl"))
|
| 212 |
+
if not rollout_files:
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
return max(rollout_files, key=lambda p: p.stat().st_mtime)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_judge_f1(trial_dir: Path) -> float:
|
| 219 |
+
"""Get root_cause_entity_f1 from judge output."""
|
| 220 |
+
judge_path = trial_dir / "judge_output.json"
|
| 221 |
+
if not judge_path.exists():
|
| 222 |
+
return 0.0
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
with open(judge_path) as f:
|
| 226 |
+
judge_data = json.load(f)
|
| 227 |
+
return judge_data.get('flat_scores', {}).get('root_cause_entity_f1', 0.0) or 0.0
|
| 228 |
+
except:
|
| 229 |
+
return 0.0
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def count_semantic_entities_investigated(rollout_path: Path) -> int:
|
| 233 |
+
"""
|
| 234 |
+
Count unique semantic entity groups investigated in a rollout.
|
| 235 |
+
|
| 236 |
+
Uses normalization to group similar entities:
|
| 237 |
+
- otel-demo/Deployment/frontend and otel-demo/Service/frontend -> 1 entity ("frontend")
|
| 238 |
+
- otel-demo/Pod/frontend-abc123 and otel-demo/Pod/frontend-xyz456 -> 1 entity ("frontend")
|
| 239 |
+
"""
|
| 240 |
+
investigated_logical = set()
|
| 241 |
+
|
| 242 |
+
with open(rollout_path) as f:
|
| 243 |
+
for line in f:
|
| 244 |
+
try:
|
| 245 |
+
obj = json.loads(line)
|
| 246 |
+
except json.JSONDecodeError:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
if obj.get('type') != 'response_item':
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
payload = obj.get('payload', {})
|
| 253 |
+
|
| 254 |
+
# Check tool arguments (investigation = active querying)
|
| 255 |
+
if payload.get('type') == 'function_call':
|
| 256 |
+
args = payload.get('arguments', {})
|
| 257 |
+
if isinstance(args, str):
|
| 258 |
+
try:
|
| 259 |
+
args = json.loads(args)
|
| 260 |
+
except:
|
| 261 |
+
args = {'raw': args}
|
| 262 |
+
args_str = json.dumps(args)
|
| 263 |
+
|
| 264 |
+
# Extract and normalize entities
|
| 265 |
+
logical_entities = extract_logical_entities(args_str)
|
| 266 |
+
investigated_logical.update(logical_entities)
|
| 267 |
+
|
| 268 |
+
return len(investigated_logical)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def analyze_all_trials() -> pd.DataFrame:
|
| 272 |
+
"""
|
| 273 |
+
Analyze all trials from react with code agents.
|
| 274 |
+
Returns DataFrame with model, scenario, trial, f1_score, semantic_entities_investigated.
|
| 275 |
+
"""
|
| 276 |
+
results = []
|
| 277 |
+
|
| 278 |
+
# Find react with code agents
|
| 279 |
+
model_dirs = [d for d in LEADERBOARD_DIR.iterdir()
|
| 280 |
+
if d.is_dir() and d.name.startswith("react with code_")]
|
| 281 |
+
|
| 282 |
+
print(f"Found {len(model_dirs)} agent models")
|
| 283 |
+
|
| 284 |
+
for model_dir in tqdm(model_dirs, desc="Processing models"):
|
| 285 |
+
model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
|
| 286 |
+
print(f"Processing {model_name}...")
|
| 287 |
+
|
| 288 |
+
scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")]
|
| 289 |
+
for scenario_dir in tqdm(scenario_dirs, desc=f" {model_name} scenarios", leave=False):
|
| 290 |
+
scenario = scenario_dir.name
|
| 291 |
+
|
| 292 |
+
trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()]
|
| 293 |
+
for trial_dir in tqdm(trial_dirs, desc=f" {scenario} trials", leave=False):
|
| 294 |
+
trial_num = int(trial_dir.name)
|
| 295 |
+
rollout_path = get_latest_rollout(trial_dir)
|
| 296 |
+
|
| 297 |
+
if rollout_path is None:
|
| 298 |
+
continue
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
f1_score = get_judge_f1(trial_dir)
|
| 302 |
+
semantic_count = count_semantic_entities_investigated(rollout_path)
|
| 303 |
+
|
| 304 |
+
results.append({
|
| 305 |
+
'model': model_name,
|
| 306 |
+
'scenario': scenario,
|
| 307 |
+
'trial': trial_num,
|
| 308 |
+
'root_cause_f1': f1_score,
|
| 309 |
+
'is_correct': f1_score > 0,
|
| 310 |
+
'semantic_entities_investigated': semantic_count
|
| 311 |
+
})
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f" Error processing {model_name}/{scenario}/{trial_num}: {e}")
|
| 314 |
+
|
| 315 |
+
return pd.DataFrame(results)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def clean_model_name(name: str) -> str:
|
| 319 |
+
return MODEL_NAMES.get(name, name)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def plot_exploration_by_correctness(df: pd.DataFrame):
|
| 323 |
+
"""
|
| 324 |
+
Plot comparing exploration breadth between correct and incorrect diagnoses.
|
| 325 |
+
Creates a grouped bar chart or box plot.
|
| 326 |
+
"""
|
| 327 |
+
# Aggregate by model and correctness
|
| 328 |
+
agg = df.groupby(['model', 'is_correct']).agg({
|
| 329 |
+
'semantic_entities_investigated': ['mean', 'std', 'count']
|
| 330 |
+
}).reset_index()
|
| 331 |
+
agg.columns = ['model', 'is_correct', 'mean_entities', 'std_entities', 'n_trials']
|
| 332 |
+
|
| 333 |
+
# Pivot for easier plotting
|
| 334 |
+
correct_df = agg[agg['is_correct'] == True].set_index('model')
|
| 335 |
+
incorrect_df = agg[agg['is_correct'] == False].set_index('model')
|
| 336 |
+
|
| 337 |
+
# Get all models that have both correct and incorrect trials
|
| 338 |
+
models_both = set(correct_df.index) & set(incorrect_df.index)
|
| 339 |
+
|
| 340 |
+
# Create comparison data
|
| 341 |
+
comparison_data = []
|
| 342 |
+
for model in models_both:
|
| 343 |
+
comparison_data.append({
|
| 344 |
+
'model': model,
|
| 345 |
+
'model_clean': clean_model_name(model),
|
| 346 |
+
'correct_mean': correct_df.loc[model, 'mean_entities'],
|
| 347 |
+
'correct_std': correct_df.loc[model, 'std_entities'],
|
| 348 |
+
'correct_n': correct_df.loc[model, 'n_trials'],
|
| 349 |
+
'incorrect_mean': incorrect_df.loc[model, 'mean_entities'],
|
| 350 |
+
'incorrect_std': incorrect_df.loc[model, 'std_entities'],
|
| 351 |
+
'incorrect_n': incorrect_df.loc[model, 'n_trials'],
|
| 352 |
+
})
|
| 353 |
+
|
| 354 |
+
comp_df = pd.DataFrame(comparison_data)
|
| 355 |
+
comp_df = comp_df.sort_values('correct_mean', ascending=True)
|
| 356 |
+
|
| 357 |
+
# === Figure 1: Grouped bar chart ===
|
| 358 |
+
fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 3.0))
|
| 359 |
+
|
| 360 |
+
y = np.arange(len(comp_df))
|
| 361 |
+
bar_height = 0.35
|
| 362 |
+
|
| 363 |
+
# Incorrect (red) and Correct (green) bars
|
| 364 |
+
bars_incorrect = ax.barh(y - bar_height/2, comp_df['incorrect_mean'],
|
| 365 |
+
height=bar_height, label='Incorrect (recall=0)',
|
| 366 |
+
color='#d62728', edgecolor='black', linewidth=0.3, alpha=0.8)
|
| 367 |
+
bars_correct = ax.barh(y + bar_height/2, comp_df['correct_mean'],
|
| 368 |
+
height=bar_height, label='Correct (recall>0)',
|
| 369 |
+
color='#2ca02c', edgecolor='black', linewidth=0.3, alpha=0.8)
|
| 370 |
+
|
| 371 |
+
ax.set_yticks(y)
|
| 372 |
+
ax.set_yticklabels(comp_df['model_clean'])
|
| 373 |
+
ax.set_xlabel('Avg. Semantic Entity Groups Investigated')
|
| 374 |
+
|
| 375 |
+
# Add value labels
|
| 376 |
+
for i, (bar_i, bar_c) in enumerate(zip(bars_incorrect, bars_correct)):
|
| 377 |
+
# Incorrect
|
| 378 |
+
ax.text(bar_i.get_width() + 0.1, bar_i.get_y() + bar_i.get_height()/2,
|
| 379 |
+
f'{bar_i.get_width():.1f}', va='center', ha='left',
|
| 380 |
+
fontsize=MIN_FONT_SIZE - 1, color='#d62728')
|
| 381 |
+
# Correct
|
| 382 |
+
ax.text(bar_c.get_width() + 0.1, bar_c.get_y() + bar_c.get_height()/2,
|
| 383 |
+
f'{bar_c.get_width():.1f}', va='center', ha='left',
|
| 384 |
+
fontsize=MIN_FONT_SIZE - 1, color='#2ca02c')
|
| 385 |
+
|
| 386 |
+
ax.legend(loc='lower right', frameon=False, fontsize=MIN_FONT_SIZE)
|
| 387 |
+
|
| 388 |
+
plt.tight_layout()
|
| 389 |
+
fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.pdf")
|
| 390 |
+
fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.png")
|
| 391 |
+
plt.close(fig)
|
| 392 |
+
print(f"Saved: fig_exploration_by_correctness.pdf/png")
|
| 393 |
+
|
| 394 |
+
# === Figure 2: Box plot distribution ===
|
| 395 |
+
fig2, ax2 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 1.5, 3.5))
|
| 396 |
+
|
| 397 |
+
# Prepare data for box plot
|
| 398 |
+
df['correctness'] = df['is_correct'].map({True: 'Correct\n(recall>0)', False: 'Incorrect\n(recall=0)'})
|
| 399 |
+
df['model_clean'] = df['model'].apply(clean_model_name)
|
| 400 |
+
|
| 401 |
+
# Order models by overall median exploration
|
| 402 |
+
model_order = df.groupby('model_clean')['semantic_entities_investigated'].median().sort_values().index.tolist()
|
| 403 |
+
|
| 404 |
+
# Create box plot with hue
|
| 405 |
+
sns.boxplot(data=df, x='model_clean', y='semantic_entities_investigated',
|
| 406 |
+
hue='correctness', order=model_order, ax=ax2,
|
| 407 |
+
palette={'Correct\n(recall>0)': '#2ca02c', 'Incorrect\n(recall=0)': '#d62728'},
|
| 408 |
+
linewidth=0.5, fliersize=2)
|
| 409 |
+
|
| 410 |
+
ax2.set_xlabel('')
|
| 411 |
+
ax2.set_ylabel('Semantic Entity Groups Investigated')
|
| 412 |
+
ax2.tick_params(axis='x', rotation=45)
|
| 413 |
+
ax2.legend(title='', loc='upper left', frameon=False, fontsize=MIN_FONT_SIZE)
|
| 414 |
+
|
| 415 |
+
plt.tight_layout()
|
| 416 |
+
fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.pdf")
|
| 417 |
+
fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.png")
|
| 418 |
+
plt.close(fig2)
|
| 419 |
+
print(f"Saved: fig_exploration_by_correctness_boxplot.pdf/png")
|
| 420 |
+
|
| 421 |
+
# === Figure 3: Aggregated across all models ===
|
| 422 |
+
fig3, ax3 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 0.8, 2.5))
|
| 423 |
+
|
| 424 |
+
correct_all = df[df['is_correct'] == True]['semantic_entities_investigated']
|
| 425 |
+
incorrect_all = df[df['is_correct'] == False]['semantic_entities_investigated']
|
| 426 |
+
|
| 427 |
+
# Violin plot for overall distribution
|
| 428 |
+
parts = ax3.violinplot([incorrect_all, correct_all], positions=[0, 1],
|
| 429 |
+
showmeans=True, showmedians=True)
|
| 430 |
+
|
| 431 |
+
# Color the violins
|
| 432 |
+
colors = ['#d62728', '#2ca02c']
|
| 433 |
+
for i, pc in enumerate(parts['bodies']):
|
| 434 |
+
pc.set_facecolor(colors[i])
|
| 435 |
+
pc.set_alpha(0.7)
|
| 436 |
+
|
| 437 |
+
# Style the other elements
|
| 438 |
+
for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']:
|
| 439 |
+
if partname in parts:
|
| 440 |
+
parts[partname].set_edgecolor('black')
|
| 441 |
+
parts[partname].set_linewidth(0.5)
|
| 442 |
+
|
| 443 |
+
ax3.set_xticks([0, 1])
|
| 444 |
+
ax3.set_xticklabels(['Incorrect\n(recall=0)', 'Correct\n(recall>0)'])
|
| 445 |
+
ax3.set_ylabel('Semantic Entities Investigated')
|
| 446 |
+
|
| 447 |
+
# Add mean values as text
|
| 448 |
+
ax3.text(0, incorrect_all.mean() + 0.5, f'μ={incorrect_all.mean():.1f}',
|
| 449 |
+
ha='center', fontsize=MIN_FONT_SIZE, color='#d62728')
|
| 450 |
+
ax3.text(1, correct_all.mean() + 0.5, f'μ={correct_all.mean():.1f}',
|
| 451 |
+
ha='center', fontsize=MIN_FONT_SIZE, color='#2ca02c')
|
| 452 |
+
|
| 453 |
+
# Add n counts
|
| 454 |
+
ax3.text(0, ax3.get_ylim()[0] + 0.5, f'n={len(incorrect_all)}',
|
| 455 |
+
ha='center', fontsize=MIN_FONT_SIZE - 1)
|
| 456 |
+
ax3.text(1, ax3.get_ylim()[0] + 0.5, f'n={len(correct_all)}',
|
| 457 |
+
ha='center', fontsize=MIN_FONT_SIZE - 1)
|
| 458 |
+
|
| 459 |
+
plt.tight_layout()
|
| 460 |
+
fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.pdf")
|
| 461 |
+
fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.png")
|
| 462 |
+
plt.close(fig3)
|
| 463 |
+
print(f"Saved: fig_exploration_overall_correctness.pdf/png")
|
| 464 |
+
|
| 465 |
+
# Print statistics
|
| 466 |
+
print("\n" + "=" * 60)
|
| 467 |
+
print("Exploration Breadth by Diagnosis Correctness")
|
| 468 |
+
print("=" * 60)
|
| 469 |
+
print(f"\nOverall Statistics:")
|
| 470 |
+
print(f" Correct diagnoses (n={len(correct_all)}): mean={correct_all.mean():.2f}, median={correct_all.median():.1f}")
|
| 471 |
+
print(f" Incorrect diagnoses (n={len(incorrect_all)}): mean={incorrect_all.mean():.2f}, median={incorrect_all.median():.1f}")
|
| 472 |
+
|
| 473 |
+
# Statistical test
|
| 474 |
+
from scipy import stats
|
| 475 |
+
stat, pvalue = stats.mannwhitneyu(correct_all, incorrect_all, alternative='two-sided')
|
| 476 |
+
print(f"\n Mann-Whitney U test: U={stat:.0f}, p={pvalue:.4f}")
|
| 477 |
+
|
| 478 |
+
print(f"\nPer-Model Comparison:")
|
| 479 |
+
print(f"{'Model':<20} {'Correct':>12} {'Incorrect':>12} {'Diff':>8}")
|
| 480 |
+
print("-" * 55)
|
| 481 |
+
for _, row in comp_df.sort_values('correct_mean', ascending=False).iterrows():
|
| 482 |
+
diff = row['correct_mean'] - row['incorrect_mean']
|
| 483 |
+
print(f"{row['model_clean']:<20} {row['correct_mean']:>10.1f} (n={int(row['correct_n'])}) "
|
| 484 |
+
f"{row['incorrect_mean']:>10.1f} (n={int(row['incorrect_n'])}) {diff:>+7.1f}")
|
| 485 |
+
|
| 486 |
+
return comp_df
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def plot_success_by_exploration_bins(df: pd.DataFrame):
|
| 490 |
+
"""
|
| 491 |
+
Plot showing success rate as a function of exploration breadth.
|
| 492 |
+
This shows a clear dose-response relationship.
|
| 493 |
+
"""
|
| 494 |
+
# Create exploration bins
|
| 495 |
+
bins = [0, 2, 4, 6, 8, 10, 100]
|
| 496 |
+
labels = ['0-2', '3-4', '5-6', '7-8', '9-10', '11+']
|
| 497 |
+
df['exploration_bin'] = pd.cut(df['semantic_entities_investigated'],
|
| 498 |
+
bins=bins, labels=labels)
|
| 499 |
+
|
| 500 |
+
# Calculate success rate per bin
|
| 501 |
+
bin_stats = []
|
| 502 |
+
for label in labels:
|
| 503 |
+
subset = df[df['exploration_bin'] == label]
|
| 504 |
+
if len(subset) > 0:
|
| 505 |
+
success_rate = (subset['root_cause_f1'] > 0).mean() * 100
|
| 506 |
+
bin_stats.append({
|
| 507 |
+
'bin': label,
|
| 508 |
+
'success_rate': success_rate,
|
| 509 |
+
'n': len(subset)
|
| 510 |
+
})
|
| 511 |
+
|
| 512 |
+
stats_df = pd.DataFrame(bin_stats)
|
| 513 |
+
|
| 514 |
+
# Create figure
|
| 515 |
+
fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 2.5))
|
| 516 |
+
|
| 517 |
+
x = np.arange(len(stats_df))
|
| 518 |
+
bars = ax.bar(x, stats_df['success_rate'],
|
| 519 |
+
color='#4a90d9', edgecolor='black', linewidth=0.5)
|
| 520 |
+
|
| 521 |
+
ax.set_xticks(x)
|
| 522 |
+
ax.set_xticklabels(stats_df['bin'])
|
| 523 |
+
ax.set_xlabel('Semantic Entities Investigated')
|
| 524 |
+
ax.set_ylabel('Correct Diagnosis Rate (%)')
|
| 525 |
+
|
| 526 |
+
# Add value labels on bars
|
| 527 |
+
for i, (bar, row) in enumerate(zip(bars, stats_df.itertuples())):
|
| 528 |
+
height = bar.get_height()
|
| 529 |
+
ax.text(bar.get_x() + bar.get_width()/2, height + 1,
|
| 530 |
+
f'{height:.0f}%', ha='center', va='bottom',
|
| 531 |
+
fontsize=MIN_FONT_SIZE)
|
| 532 |
+
ax.text(bar.get_x() + bar.get_width()/2, 2,
|
| 533 |
+
f'n={row.n}', ha='center', va='bottom',
|
| 534 |
+
fontsize=MIN_FONT_SIZE - 1, color='white')
|
| 535 |
+
|
| 536 |
+
ax.set_ylim(0, 60)
|
| 537 |
+
|
| 538 |
+
plt.tight_layout()
|
| 539 |
+
fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.pdf")
|
| 540 |
+
fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.png")
|
| 541 |
+
plt.close(fig)
|
| 542 |
+
print(f"Saved: fig_exploration_success_rate.pdf/png")
|
| 543 |
+
|
| 544 |
+
# Also create a combined figure with both views
|
| 545 |
+
fig2, axes = plt.subplots(1, 2, figsize=(HALF_COLUMN_WIDTH * 2 + 0.3, 2.5))
|
| 546 |
+
|
| 547 |
+
# Left: Success rate by exploration bins
|
| 548 |
+
ax1 = axes[0]
|
| 549 |
+
bars1 = ax1.bar(x, stats_df['success_rate'],
|
| 550 |
+
color='#4a90d9', edgecolor='black', linewidth=0.5)
|
| 551 |
+
ax1.set_xticks(x)
|
| 552 |
+
ax1.set_xticklabels(stats_df['bin'])
|
| 553 |
+
ax1.set_xlabel('Entities Investigated')
|
| 554 |
+
ax1.set_ylabel('Correct Diagnosis Rate (%)')
|
| 555 |
+
ax1.set_title('(a) Success vs Exploration', fontsize=MIN_FONT_SIZE + 1)
|
| 556 |
+
for bar, row in zip(bars1, stats_df.itertuples()):
|
| 557 |
+
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
|
| 558 |
+
f'{bar.get_height():.0f}%', ha='center', va='bottom',
|
| 559 |
+
fontsize=MIN_FONT_SIZE - 1)
|
| 560 |
+
ax1.set_ylim(0, 60)
|
| 561 |
+
|
| 562 |
+
# Right: Exploration distribution by correctness (violin)
|
| 563 |
+
ax2 = axes[1]
|
| 564 |
+
correct = df[df['is_correct'] == True]['semantic_entities_investigated']
|
| 565 |
+
incorrect = df[df['is_correct'] == False]['semantic_entities_investigated']
|
| 566 |
+
|
| 567 |
+
parts = ax2.violinplot([incorrect, correct], positions=[0, 1],
|
| 568 |
+
showmeans=True, showmedians=True)
|
| 569 |
+
colors = ['#d62728', '#2ca02c']
|
| 570 |
+
for i, pc in enumerate(parts['bodies']):
|
| 571 |
+
pc.set_facecolor(colors[i])
|
| 572 |
+
pc.set_alpha(0.7)
|
| 573 |
+
for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']:
|
| 574 |
+
if partname in parts:
|
| 575 |
+
parts[partname].set_edgecolor('black')
|
| 576 |
+
parts[partname].set_linewidth(0.5)
|
| 577 |
+
|
| 578 |
+
ax2.set_xticks([0, 1])
|
| 579 |
+
ax2.set_xticklabels(['Incorrect', 'Correct'])
|
| 580 |
+
ax2.set_ylabel('Entities Investigated')
|
| 581 |
+
ax2.set_title('(b) Exploration by Outcome', fontsize=MIN_FONT_SIZE + 1)
|
| 582 |
+
ax2.text(0, incorrect.mean() + 1, f'μ={incorrect.mean():.1f}',
|
| 583 |
+
ha='center', fontsize=MIN_FONT_SIZE - 1, color='#d62728')
|
| 584 |
+
ax2.text(1, correct.mean() + 1, f'μ={correct.mean():.1f}',
|
| 585 |
+
ha='center', fontsize=MIN_FONT_SIZE - 1, color='#2ca02c')
|
| 586 |
+
|
| 587 |
+
plt.tight_layout()
|
| 588 |
+
fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.pdf")
|
| 589 |
+
fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.png")
|
| 590 |
+
plt.close(fig2)
|
| 591 |
+
print(f"Saved: fig_exploration_combined.pdf/png")
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def main():
|
| 595 |
+
print("=" * 60)
|
| 596 |
+
print("Exploration Breadth by Diagnosis Correctness Analysis")
|
| 597 |
+
print("=" * 60)
|
| 598 |
+
|
| 599 |
+
# Check if we can use cached data or need to re-extract
|
| 600 |
+
cache_path = OUTPUT_DIR / "exploration_by_correctness.csv"
|
| 601 |
+
|
| 602 |
+
if cache_path.exists():
|
| 603 |
+
print(f"\nLoading cached data from {cache_path}")
|
| 604 |
+
df = pd.read_csv(cache_path)
|
| 605 |
+
else:
|
| 606 |
+
print("\nExtracting data from rollout files (this may take a while)...")
|
| 607 |
+
df = analyze_all_trials()
|
| 608 |
+
df.to_csv(cache_path, index=False)
|
| 609 |
+
print(f"Saved cache to: {cache_path}")
|
| 610 |
+
|
| 611 |
+
print(f"\nLoaded {len(df)} trials from {df['model'].nunique()} models")
|
| 612 |
+
|
| 613 |
+
# Generate plots
|
| 614 |
+
print("\nGenerating figures...")
|
| 615 |
+
plot_exploration_by_correctness(df)
|
| 616 |
+
plot_success_by_exploration_bins(df) # NEW: dose-response plot
|
| 617 |
+
|
| 618 |
+
print(f"\nDone! Figures saved to: {OUTPUT_DIR}")
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
if __name__ == "__main__":
|
| 622 |
+
main()
|
| 623 |
+
|
analysis_src/extract_inference_data.py
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract inference request and token usage data for all 'react with code' agents.
|
| 4 |
+
|
| 5 |
+
This script reads session.jsonl files to count inference requests and estimate token usage.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
import ast
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 21 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 22 |
+
|
| 23 |
+
from src.utils import (
|
| 24 |
+
get_model_name,
|
| 25 |
+
find_react_with_code_dirs,
|
| 26 |
+
get_runs_stats,
|
| 27 |
+
filter_scenarios_with_min_runs,
|
| 28 |
+
find_latest_rollout_file
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from src.model_styles import (
|
| 32 |
+
get_model_style, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, get_color_palette, PLOT_PARAMETERS
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# Paths
|
| 36 |
+
LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
|
| 37 |
+
RESULTS_JSON_DIR = LEADERBOARD_DIR / "results"
|
| 38 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "inferences"
|
| 39 |
+
|
| 40 |
+
# Minimum runs per scenario required
|
| 41 |
+
MIN_RUNS_PER_SCENARIO = 3
|
| 42 |
+
MIN_QUALIFYING_SCENARIOS = 20
|
| 43 |
+
|
| 44 |
+
# Token estimation factor (chars per token)
|
| 45 |
+
CHARS_PER_TOKEN = 4
|
| 46 |
+
|
| 47 |
+
def extract_tokens_from_rollout(rollout_file: Path) -> dict:
|
| 48 |
+
"""
|
| 49 |
+
Extract token counts and tool usage from a rollout file.
|
| 50 |
+
|
| 51 |
+
Counts:
|
| 52 |
+
- INPUT: system prompt + user messages + tool outputs
|
| 53 |
+
- OUTPUT: assistant messages + tool call arguments
|
| 54 |
+
- TOOLS: counts by tool name, including code execution
|
| 55 |
+
"""
|
| 56 |
+
system_prompt_chars = 0
|
| 57 |
+
user_input_chars = 0
|
| 58 |
+
assistant_output_chars = 0
|
| 59 |
+
tool_call_chars = 0
|
| 60 |
+
tool_output_chars = 0
|
| 61 |
+
|
| 62 |
+
assistant_msg_count = 0
|
| 63 |
+
tool_call_count = 0
|
| 64 |
+
tool_counts = {} # tool_name -> count
|
| 65 |
+
code_execution_count = 0 # Specifically track code/python execution
|
| 66 |
+
|
| 67 |
+
# Tool names that indicate code execution
|
| 68 |
+
CODE_TOOLS = ['execute_python', 'run_python', 'python', 'execute_code',
|
| 69 |
+
'run_code', 'shell', 'bash', 'terminal', 'exec']
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
with open(rollout_file) as f:
|
| 73 |
+
for line in f:
|
| 74 |
+
try:
|
| 75 |
+
d = json.loads(line)
|
| 76 |
+
msg_type = d.get('type', '')
|
| 77 |
+
payload = d.get('payload', {})
|
| 78 |
+
|
| 79 |
+
if msg_type == 'session_meta':
|
| 80 |
+
# System prompt
|
| 81 |
+
instructions = payload.get('instructions', '')
|
| 82 |
+
system_prompt_chars += len(str(instructions))
|
| 83 |
+
|
| 84 |
+
elif msg_type == 'response_item':
|
| 85 |
+
item_type = payload.get('type', '')
|
| 86 |
+
role = payload.get('role', '')
|
| 87 |
+
|
| 88 |
+
if item_type == 'message':
|
| 89 |
+
content = payload.get('content', [])
|
| 90 |
+
if isinstance(content, list):
|
| 91 |
+
text = ' '.join([
|
| 92 |
+
c.get('text', '') if isinstance(c, dict) else str(c)
|
| 93 |
+
for c in content
|
| 94 |
+
])
|
| 95 |
+
else:
|
| 96 |
+
text = str(content)
|
| 97 |
+
|
| 98 |
+
if role == 'user':
|
| 99 |
+
user_input_chars += len(text)
|
| 100 |
+
elif role == 'assistant':
|
| 101 |
+
assistant_output_chars += len(text)
|
| 102 |
+
assistant_msg_count += 1
|
| 103 |
+
|
| 104 |
+
elif item_type == 'function_call':
|
| 105 |
+
# Tool call (output)
|
| 106 |
+
name = payload.get('name', '')
|
| 107 |
+
arguments = payload.get('arguments', '')
|
| 108 |
+
tool_call_chars += len(str(name)) + len(str(arguments))
|
| 109 |
+
tool_call_count += 1
|
| 110 |
+
|
| 111 |
+
# Track tool usage
|
| 112 |
+
tool_counts[name] = tool_counts.get(name, 0) + 1
|
| 113 |
+
|
| 114 |
+
# Check if it's code execution
|
| 115 |
+
name_lower = name.lower()
|
| 116 |
+
if any(code_tool in name_lower for code_tool in CODE_TOOLS):
|
| 117 |
+
code_execution_count += 1
|
| 118 |
+
# Also check if arguments contain python code patterns
|
| 119 |
+
args_str = str(arguments).lower()
|
| 120 |
+
if 'python' in name_lower or ('def ' in args_str or 'import ' in args_str):
|
| 121 |
+
code_execution_count += 1
|
| 122 |
+
|
| 123 |
+
elif item_type == 'function_call_output':
|
| 124 |
+
# Tool output (input to model)
|
| 125 |
+
output = payload.get('output', '')
|
| 126 |
+
tool_output_chars += len(str(output))
|
| 127 |
+
|
| 128 |
+
except json.JSONDecodeError:
|
| 129 |
+
continue
|
| 130 |
+
except Exception as e:
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
# INPUT = system + user + tool outputs (fed back to model)
|
| 134 |
+
input_chars = system_prompt_chars + user_input_chars + tool_output_chars
|
| 135 |
+
# OUTPUT = assistant responses + tool call arguments
|
| 136 |
+
output_chars = assistant_output_chars + tool_call_chars
|
| 137 |
+
|
| 138 |
+
return {
|
| 139 |
+
'system_prompt_chars': system_prompt_chars,
|
| 140 |
+
'user_input_chars': user_input_chars,
|
| 141 |
+
'assistant_output_chars': assistant_output_chars,
|
| 142 |
+
'tool_call_chars': tool_call_chars,
|
| 143 |
+
'tool_output_chars': tool_output_chars,
|
| 144 |
+
'input_chars': input_chars,
|
| 145 |
+
'output_chars': output_chars,
|
| 146 |
+
'input_tokens': input_chars // CHARS_PER_TOKEN,
|
| 147 |
+
'output_tokens': output_chars // CHARS_PER_TOKEN,
|
| 148 |
+
'assistant_msg_count': assistant_msg_count,
|
| 149 |
+
'tool_call_count': tool_call_count,
|
| 150 |
+
'tool_counts': tool_counts,
|
| 151 |
+
'code_execution_count': code_execution_count,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def extract_session_stats(session_file: Path) -> dict:
|
| 156 |
+
"""
|
| 157 |
+
Extract inference stats from session.jsonl and rollout files.
|
| 158 |
+
|
| 159 |
+
Uses the latest rollout file for accurate token counting.
|
| 160 |
+
"""
|
| 161 |
+
if not session_file.exists():
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
trial_dir = session_file.parent
|
| 165 |
+
|
| 166 |
+
# Count inference requests from session.jsonl
|
| 167 |
+
inference_count = 0
|
| 168 |
+
try:
|
| 169 |
+
with open(session_file) as f:
|
| 170 |
+
for line in f:
|
| 171 |
+
try:
|
| 172 |
+
d = json.loads(line)
|
| 173 |
+
if d.get('type') == 'response_item':
|
| 174 |
+
inference_count += 1
|
| 175 |
+
except json.JSONDecodeError:
|
| 176 |
+
continue
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f" Warning: Error reading {session_file}: {e}")
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
# First check stdout.log for real token counts (OpenAI models)
|
| 182 |
+
stdout_log = trial_dir / "traces" / "stdout.log"
|
| 183 |
+
has_real_tokens = False
|
| 184 |
+
input_tokens = 0
|
| 185 |
+
output_tokens = 0
|
| 186 |
+
cached_input_tokens = 0
|
| 187 |
+
|
| 188 |
+
if stdout_log.exists():
|
| 189 |
+
try:
|
| 190 |
+
with open(stdout_log) as f:
|
| 191 |
+
for line in f:
|
| 192 |
+
try:
|
| 193 |
+
d = json.loads(line)
|
| 194 |
+
if d.get('type') == 'turn.completed':
|
| 195 |
+
usage = d.get('usage', {})
|
| 196 |
+
input_tokens = usage.get('input_tokens', 0)
|
| 197 |
+
output_tokens = usage.get('output_tokens', 0)
|
| 198 |
+
cached_input_tokens = usage.get('cached_input_tokens', 0)
|
| 199 |
+
if input_tokens > 0 or output_tokens > 0:
|
| 200 |
+
has_real_tokens = True
|
| 201 |
+
break
|
| 202 |
+
except json.JSONDecodeError:
|
| 203 |
+
continue
|
| 204 |
+
except Exception:
|
| 205 |
+
pass
|
| 206 |
+
|
| 207 |
+
# Extract from latest rollout file for tokens (if needed) and tool counts
|
| 208 |
+
tool_call_count = 0
|
| 209 |
+
tool_counts = {}
|
| 210 |
+
code_execution_count = 0
|
| 211 |
+
|
| 212 |
+
latest_rollout = find_latest_rollout_file(trial_dir)
|
| 213 |
+
if latest_rollout:
|
| 214 |
+
rollout_stats = extract_tokens_from_rollout(latest_rollout)
|
| 215 |
+
if rollout_stats:
|
| 216 |
+
# Use rollout tokens if no real API token data
|
| 217 |
+
if not has_real_tokens:
|
| 218 |
+
input_tokens = rollout_stats['input_tokens']
|
| 219 |
+
output_tokens = rollout_stats['output_tokens']
|
| 220 |
+
|
| 221 |
+
# Always use rollout for tool counts
|
| 222 |
+
tool_call_count = rollout_stats['tool_call_count']
|
| 223 |
+
tool_counts = rollout_stats['tool_counts']
|
| 224 |
+
code_execution_count = rollout_stats['code_execution_count']
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
'inference_count': inference_count,
|
| 228 |
+
'input_tokens': input_tokens,
|
| 229 |
+
'cached_input_tokens': cached_input_tokens,
|
| 230 |
+
'output_tokens': output_tokens,
|
| 231 |
+
'total_tokens': input_tokens + output_tokens,
|
| 232 |
+
'has_real_tokens': has_real_tokens,
|
| 233 |
+
'tool_call_count': tool_call_count,
|
| 234 |
+
'tool_counts': tool_counts,
|
| 235 |
+
'code_execution_count': code_execution_count,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def read_agent_stats(agent_dir: Path) -> dict[str, list[dict]]:
|
| 240 |
+
"""
|
| 241 |
+
Read session stats from all scenarios/trials for an agent.
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
Dict mapping scenario_id -> list of stats (one per trial)
|
| 245 |
+
"""
|
| 246 |
+
scenario_data = {}
|
| 247 |
+
|
| 248 |
+
for scenario_dir in agent_dir.iterdir():
|
| 249 |
+
if not scenario_dir.is_dir() or not scenario_dir.name.startswith("Scenario"):
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
scenario_id = scenario_dir.name
|
| 253 |
+
trials = []
|
| 254 |
+
|
| 255 |
+
for trial_dir in sorted(scenario_dir.iterdir()):
|
| 256 |
+
if not trial_dir.is_dir():
|
| 257 |
+
continue
|
| 258 |
+
|
| 259 |
+
session_file = trial_dir / "session.jsonl"
|
| 260 |
+
stats = extract_session_stats(session_file)
|
| 261 |
+
if stats:
|
| 262 |
+
trials.append(stats)
|
| 263 |
+
|
| 264 |
+
if trials:
|
| 265 |
+
scenario_data[scenario_id] = trials
|
| 266 |
+
|
| 267 |
+
return scenario_data
|
| 268 |
+
|
| 269 |
+
def load_performance_data() -> pd.DataFrame:
|
| 270 |
+
"""Load performance data from the consistency analysis."""
|
| 271 |
+
perf_file = PROJECT_ROOT / "data" / "output" / "consistency" / "performance_data.csv"
|
| 272 |
+
if perf_file.exists():
|
| 273 |
+
df = pd.read_csv(perf_file)
|
| 274 |
+
return df[df["metric_raw"] == "root_cause_entity_f1"][["model", "performance"]]
|
| 275 |
+
return pd.DataFrame()
|
| 276 |
+
|
| 277 |
+
def extract_all_data() -> tuple[pd.DataFrame, pd.DataFrame]:
|
| 278 |
+
"""
|
| 279 |
+
Extract inference data for all agents.
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
- summary_df: Aggregated stats per model
|
| 283 |
+
- detail_df: Per-scenario stats
|
| 284 |
+
"""
|
| 285 |
+
agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
|
| 286 |
+
print(f"Found {len(agent_dirs)} 'react with code' agent directories")
|
| 287 |
+
|
| 288 |
+
summary_records = []
|
| 289 |
+
detail_records = []
|
| 290 |
+
|
| 291 |
+
for agent_dir in tqdm(agent_dirs, desc="Processing agents"):
|
| 292 |
+
model_name = get_model_name(agent_dir.name)
|
| 293 |
+
|
| 294 |
+
print(f"\nReading: {agent_dir.name}")
|
| 295 |
+
scenario_data = read_agent_stats(agent_dir)
|
| 296 |
+
|
| 297 |
+
n_scenarios, min_runs, max_runs, n_qualifying = get_runs_stats(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 298 |
+
|
| 299 |
+
if n_scenarios == 0:
|
| 300 |
+
print(f" SKIPPING {model_name}: No session data found")
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
if n_qualifying < MIN_QUALIFYING_SCENARIOS:
|
| 304 |
+
print(f" SKIPPING {model_name}: Only {n_qualifying}/{n_scenarios} scenarios have {MIN_RUNS_PER_SCENARIO}+ runs")
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
# Filter scenarios
|
| 308 |
+
scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 309 |
+
n_scenarios_filtered = len(scenario_data)
|
| 310 |
+
|
| 311 |
+
print(f" Processing: {model_name} ({n_scenarios_filtered} scenarios)")
|
| 312 |
+
|
| 313 |
+
# Aggregate across all scenarios and trials
|
| 314 |
+
all_inference_counts = []
|
| 315 |
+
all_input_tokens = []
|
| 316 |
+
all_output_tokens = []
|
| 317 |
+
all_total_tokens = []
|
| 318 |
+
all_cached_tokens = []
|
| 319 |
+
all_tool_call_counts = []
|
| 320 |
+
all_code_execution_counts = []
|
| 321 |
+
aggregated_tool_counts = {}
|
| 322 |
+
|
| 323 |
+
for scenario_id, trials in tqdm(scenario_data.items(), desc=f" {model_name} scenarios", leave=False):
|
| 324 |
+
for trial in trials:
|
| 325 |
+
all_inference_counts.append(trial['inference_count'])
|
| 326 |
+
all_input_tokens.append(trial['input_tokens'])
|
| 327 |
+
all_output_tokens.append(trial['output_tokens'])
|
| 328 |
+
all_total_tokens.append(trial['total_tokens'])
|
| 329 |
+
all_cached_tokens.append(trial.get('cached_input_tokens', 0))
|
| 330 |
+
all_tool_call_counts.append(trial.get('tool_call_count', 0))
|
| 331 |
+
all_code_execution_counts.append(trial.get('code_execution_count', 0))
|
| 332 |
+
|
| 333 |
+
# Aggregate tool counts
|
| 334 |
+
for tool_name, count in trial.get('tool_counts', {}).items():
|
| 335 |
+
aggregated_tool_counts[tool_name] = aggregated_tool_counts.get(tool_name, 0) + count
|
| 336 |
+
|
| 337 |
+
detail_records.append({
|
| 338 |
+
'model': model_name,
|
| 339 |
+
'scenario': scenario_id,
|
| 340 |
+
'inference_count': trial['inference_count'],
|
| 341 |
+
'input_tokens': trial['input_tokens'],
|
| 342 |
+
'cached_input_tokens': trial.get('cached_input_tokens', 0),
|
| 343 |
+
'output_tokens': trial['output_tokens'],
|
| 344 |
+
'total_tokens': trial['total_tokens'],
|
| 345 |
+
'tool_call_count': trial.get('tool_call_count', 0),
|
| 346 |
+
'code_execution_count': trial.get('code_execution_count', 0),
|
| 347 |
+
})
|
| 348 |
+
|
| 349 |
+
# Summary stats
|
| 350 |
+
summary_records.append({
|
| 351 |
+
'model': model_name,
|
| 352 |
+
'n_scenarios': n_scenarios_filtered,
|
| 353 |
+
'n_trials': len(all_inference_counts),
|
| 354 |
+
'avg_inference_count': np.mean(all_inference_counts),
|
| 355 |
+
'std_inference_count': np.std(all_inference_counts),
|
| 356 |
+
'avg_input_tokens': np.mean(all_input_tokens),
|
| 357 |
+
'avg_cached_tokens': np.mean(all_cached_tokens),
|
| 358 |
+
'avg_output_tokens': np.mean(all_output_tokens),
|
| 359 |
+
'avg_total_tokens': np.mean(all_total_tokens),
|
| 360 |
+
'total_inference_count': sum(all_inference_counts),
|
| 361 |
+
'total_tokens': sum(all_total_tokens),
|
| 362 |
+
'avg_tool_call_count': np.mean(all_tool_call_counts) if all_tool_call_counts else 0,
|
| 363 |
+
'total_tool_calls': sum(all_tool_call_counts),
|
| 364 |
+
'avg_code_execution_count': np.mean(all_code_execution_counts) if all_code_execution_counts else 0,
|
| 365 |
+
'total_code_executions': sum(all_code_execution_counts),
|
| 366 |
+
'top_tools': dict(sorted(aggregated_tool_counts.items(), key=lambda x: -x[1])[:10]),
|
| 367 |
+
})
|
| 368 |
+
|
| 369 |
+
summary_df = pd.DataFrame(summary_records)
|
| 370 |
+
detail_df = pd.DataFrame(detail_records)
|
| 371 |
+
|
| 372 |
+
# Merge with performance data
|
| 373 |
+
perf_df = load_performance_data()
|
| 374 |
+
if len(perf_df) > 0:
|
| 375 |
+
summary_df = pd.merge(summary_df, perf_df, on='model', how='left')
|
| 376 |
+
|
| 377 |
+
return summary_df, detail_df
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def save_data(summary_df: pd.DataFrame, detail_df: pd.DataFrame):
|
| 381 |
+
"""Save extracted data to CSV files."""
|
| 382 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 383 |
+
|
| 384 |
+
summary_path = OUTPUT_DIR / "inference_summary.csv"
|
| 385 |
+
detail_path = OUTPUT_DIR / "inference_detail.csv"
|
| 386 |
+
|
| 387 |
+
summary_df.to_csv(summary_path, index=False)
|
| 388 |
+
detail_df.to_csv(detail_path, index=False)
|
| 389 |
+
|
| 390 |
+
print(f"\nData saved to:")
|
| 391 |
+
print(f" - {summary_path}")
|
| 392 |
+
print(f" - {detail_path}")
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def print_summary(summary_df: pd.DataFrame):
|
| 396 |
+
"""Print summary table."""
|
| 397 |
+
print("\n" + "="*80)
|
| 398 |
+
print("Inference Summary")
|
| 399 |
+
print("="*80)
|
| 400 |
+
|
| 401 |
+
summary_df = summary_df.sort_values("avg_inference_count", ascending=False)
|
| 402 |
+
|
| 403 |
+
print(f"\n{'Model':<25} {'Avg Infer':>10} {'Avg Tokens':>12} {'Avg In':>10} {'Avg Out':>10}")
|
| 404 |
+
print("-" * 70)
|
| 405 |
+
for _, row in summary_df.iterrows():
|
| 406 |
+
print(f"{row['model']:<25} {row['avg_inference_count']:>10.1f} {row['avg_total_tokens']:>12.0f} {row['avg_input_tokens']:>10.0f} {row['avg_output_tokens']:>10.0f}")
|
| 407 |
+
|
| 408 |
+
def plot_tool_usage(summary_df: pd.DataFrame):
|
| 409 |
+
"""
|
| 410 |
+
Figure: Tool usage per model - total tool calls and code execution.
|
| 411 |
+
"""
|
| 412 |
+
plt.rcParams.update(PLOT_PARAMETERS)
|
| 413 |
+
|
| 414 |
+
if 'avg_tool_call_count' not in summary_df.columns:
|
| 415 |
+
print("Skipping tool usage: no tool data")
|
| 416 |
+
return
|
| 417 |
+
|
| 418 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(SINGLE_COLUMN_WIDTH * 2, 2.5))
|
| 419 |
+
|
| 420 |
+
data = summary_df.sort_values("avg_tool_call_count", ascending=True)
|
| 421 |
+
|
| 422 |
+
color_palette = get_color_palette(len(data))
|
| 423 |
+
colors = [color_palette[i % len(color_palette)] for i in range(len(data))]
|
| 424 |
+
|
| 425 |
+
# Left: Total tool calls
|
| 426 |
+
bars1 = ax1.barh(data["model"], data["avg_tool_call_count"], color=colors,
|
| 427 |
+
edgecolor='black', linewidth=0.5)
|
| 428 |
+
ax1.set_xlabel("Avg. Tool Calls per Scenario")
|
| 429 |
+
|
| 430 |
+
for bar, val in zip(bars1, data["avg_tool_call_count"]):
|
| 431 |
+
ax1.text(val + 1, bar.get_y() + bar.get_height()/2,
|
| 432 |
+
f'{val:.0f}', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1)
|
| 433 |
+
|
| 434 |
+
ax1.set_xlim(0, data["avg_tool_call_count"].max() * 1.15)
|
| 435 |
+
|
| 436 |
+
# Right: Code executions
|
| 437 |
+
bars2 = ax2.barh(data["model"], data["avg_code_execution_count"], color=colors,
|
| 438 |
+
edgecolor='black', linewidth=0.5)
|
| 439 |
+
ax2.set_xlabel("Avg. Code Executions per Scenario")
|
| 440 |
+
|
| 441 |
+
for bar, val in zip(bars2, data["avg_code_execution_count"]):
|
| 442 |
+
if val > 0:
|
| 443 |
+
ax2.text(val + 0.5, bar.get_y() + bar.get_height()/2,
|
| 444 |
+
f'{val:.0f}', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1)
|
| 445 |
+
|
| 446 |
+
ax2.set_xlim(0, max(data["avg_code_execution_count"].max() * 1.3, 1))
|
| 447 |
+
ax2.set_yticklabels([])
|
| 448 |
+
|
| 449 |
+
plt.title("Tool Call Distribution")
|
| 450 |
+
|
| 451 |
+
plt.tight_layout()
|
| 452 |
+
plt.show()
|
| 453 |
+
fig.savefig(OUTPUT_DIR / "fig_tool_usage.png")
|
| 454 |
+
plt.close(fig)
|
| 455 |
+
print("Saved: fig_tool_usage.png")
|
| 456 |
+
|
| 457 |
+
def plot_inference_vs_performance(summary_df: pd.DataFrame):
|
| 458 |
+
"""
|
| 459 |
+
Figure 3: Inference count vs Performance scatter.
|
| 460 |
+
"""
|
| 461 |
+
if 'performance' not in summary_df.columns:
|
| 462 |
+
print("Skipping inference vs performance: no performance data")
|
| 463 |
+
return
|
| 464 |
+
|
| 465 |
+
plt.rcParams.update(PLOT_PARAMETERS)
|
| 466 |
+
|
| 467 |
+
fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH))
|
| 468 |
+
|
| 469 |
+
data = summary_df.dropna(subset=['performance'])
|
| 470 |
+
|
| 471 |
+
# Manual label offsets to avoid overlap
|
| 472 |
+
label_offsets = {
|
| 473 |
+
"GPT-5.1": (-5, -8, "right", "top"),
|
| 474 |
+
"o4-mini": (5, -8, "left", "top"),
|
| 475 |
+
"GPT-OSS-120B": (5, 3, "left", "bottom"),
|
| 476 |
+
"Gemini-2.5-Pro": (-5, 3, "right", "bottom"),
|
| 477 |
+
"Gemini-3-Flash": (5, 3, "left", "bottom"),
|
| 478 |
+
"gemini-3-pro-preview": (5, 3, "left", "bottom"),
|
| 479 |
+
"Kimi-K2": (5, 3, "left", "bottom"),
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
# Get color palette
|
| 483 |
+
color_palette = get_color_palette(len(data))
|
| 484 |
+
|
| 485 |
+
# Scatter plot
|
| 486 |
+
for i, (_, row) in enumerate(data.iterrows()):
|
| 487 |
+
ax.scatter(row["avg_inference_count"], row["performance"],
|
| 488 |
+
c=[color_palette[i % len(color_palette)]], s=60, edgecolors='black',
|
| 489 |
+
linewidth=0.5, zorder=10)
|
| 490 |
+
|
| 491 |
+
# Label with custom offset
|
| 492 |
+
offset = label_offsets.get(row["model"], (5, 3, "left", "bottom"))
|
| 493 |
+
ax.annotate(row["model"],
|
| 494 |
+
(row["avg_inference_count"], row["performance"]),
|
| 495 |
+
xytext=(offset[0], offset[1]), textcoords='offset points',
|
| 496 |
+
fontsize=MIN_FONT_SIZE - 1, ha=offset[2], va=offset[3])
|
| 497 |
+
|
| 498 |
+
ax.set_xlabel("Avg. Inference Requests")
|
| 499 |
+
ax.set_ylabel("Performance (RC Entity F1)")
|
| 500 |
+
ax.set_xlim(0, data["avg_inference_count"].max() * 1.2)
|
| 501 |
+
ax.set_ylim(0, 0.7)
|
| 502 |
+
|
| 503 |
+
plt.title("Inference Requests vs. Performance")
|
| 504 |
+
|
| 505 |
+
plt.tight_layout()
|
| 506 |
+
plt.show()
|
| 507 |
+
fig.savefig(OUTPUT_DIR / "fig_inference_vs_performance.png")
|
| 508 |
+
plt.close(fig)
|
| 509 |
+
print("Saved: fig_inference_vs_performance.png")
|
| 510 |
+
|
| 511 |
+
def plot_tool_breakdown_heatmap(summary_df: pd.DataFrame):
|
| 512 |
+
"""
|
| 513 |
+
Generate a heatmap showing which tools each agent uses most.
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
# Parse the stringified dict of top_tools
|
| 517 |
+
tool_usage = []
|
| 518 |
+
|
| 519 |
+
for _, row in summary_df.iterrows():
|
| 520 |
+
if pd.isna(row.get('top_tools')):
|
| 521 |
+
print("pd.isna")
|
| 522 |
+
continue
|
| 523 |
+
|
| 524 |
+
tools = row['top_tools']
|
| 525 |
+
total_calls = row['total_tool_calls']
|
| 526 |
+
if total_calls == 0:
|
| 527 |
+
print("No tool calls")
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
for tool, count in tools.items():
|
| 531 |
+
tool_usage.append({
|
| 532 |
+
'model': row['model'],
|
| 533 |
+
'tool': tool,
|
| 534 |
+
'count': count,
|
| 535 |
+
'avg_per_scenario': count / row['n_scenarios']
|
| 536 |
+
})
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
df = pd.DataFrame(tool_usage)
|
| 540 |
+
if len(df) == 0:
|
| 541 |
+
print("No tool usage data found")
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
# Pivot for heatmap
|
| 545 |
+
pivot_df = df.pivot(index='model', columns='tool', values='avg_per_scenario').fillna(0)
|
| 546 |
+
|
| 547 |
+
# Filter to top 10 most used tools across all models
|
| 548 |
+
# top_tools = pivot_df.sum().sort_values(ascending=False).head(10).index
|
| 549 |
+
top_tools = pivot_df.sum().sort_values(ascending=False).index
|
| 550 |
+
pivot_df = pivot_df[top_tools]
|
| 551 |
+
|
| 552 |
+
# Sort models by total tool usage
|
| 553 |
+
pivot_df['total'] = pivot_df.sum(axis=1)
|
| 554 |
+
pivot_df = pivot_df.sort_values('total', ascending=False).drop('total', axis=1)
|
| 555 |
+
|
| 556 |
+
# Plot
|
| 557 |
+
PLOT_PARAMETERS['font.size'] = 8
|
| 558 |
+
plt.rcParams.update(PLOT_PARAMETERS)
|
| 559 |
+
|
| 560 |
+
fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH * 2, 4))
|
| 561 |
+
|
| 562 |
+
sns.heatmap(pivot_df, annot=True, fmt='.1f', cmap='YlOrRd', ax=ax,
|
| 563 |
+
cbar_kws={'label': 'Avg. Calls per Scenario'})
|
| 564 |
+
|
| 565 |
+
ax.set_xlabel("")
|
| 566 |
+
ax.set_ylabel("")
|
| 567 |
+
plt.xticks(rotation=45, ha='right')
|
| 568 |
+
plt.yticks(rotation=0)
|
| 569 |
+
|
| 570 |
+
plt.title("Tool Call Distribution")
|
| 571 |
+
|
| 572 |
+
plt.tight_layout()
|
| 573 |
+
plt.show()
|
| 574 |
+
fig.savefig(OUTPUT_DIR / "fig_tool_usage_heatmap.png")
|
| 575 |
+
plt.close(fig)
|
| 576 |
+
print("Saved: fig_tool_usage_heatmap.png")
|
| 577 |
+
|
| 578 |
+
def main():
|
| 579 |
+
print("Extracting inference data for 'react with code' agents...")
|
| 580 |
+
print(f"Reading from directories: {LEADERBOARD_DIR}")
|
| 581 |
+
print(f"Output directory: {OUTPUT_DIR}")
|
| 582 |
+
|
| 583 |
+
summary_df, detail_df = extract_all_data()
|
| 584 |
+
|
| 585 |
+
if len(summary_df) == 0:
|
| 586 |
+
print("No data extracted!")
|
| 587 |
+
return
|
| 588 |
+
|
| 589 |
+
save_data(summary_df, detail_df)
|
| 590 |
+
print_summary(summary_df)
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
if __name__ == "__main__":
|
| 594 |
+
main()
|
| 595 |
+
|
analysis_src/extract_majority_vote_data.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract Majority Vote and consistency breakdown data for all 'react with code' agents.
|
| 4 |
+
|
| 5 |
+
This script computes:
|
| 6 |
+
- Pass@k: At least 1 trial succeeds
|
| 7 |
+
- Majority@k: Majority of trials succeed
|
| 8 |
+
- All@k: All trials succeed
|
| 9 |
+
- Consistency breakdown: Consistent Correct, Consistent Wrong, Inconsistent
|
| 10 |
+
|
| 11 |
+
Output is saved to paper_analysis/react with code/resources/figures/consistency/ as CSV files.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import json
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from itertools import combinations
|
| 18 |
+
import numpy as np
|
| 19 |
+
import pandas as pd
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import seaborn as sns
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
# Add project root to path
|
| 25 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 26 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 27 |
+
|
| 28 |
+
from src.utils import (
|
| 29 |
+
get_model_name,
|
| 30 |
+
find_react_with_code_dirs,
|
| 31 |
+
read_judge_outputs_from_dir,
|
| 32 |
+
extract_trial_scores_from_judge_outputs,
|
| 33 |
+
filter_scenarios_with_min_runs,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from src.model_styles import (
|
| 37 |
+
get_model_style, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Paths
|
| 41 |
+
LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
|
| 42 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "consistency"
|
| 43 |
+
|
| 44 |
+
# Minimum runs per scenario required for inclusion
|
| 45 |
+
MIN_RUNS_PER_SCENARIO = 2
|
| 46 |
+
|
| 47 |
+
# Minimum scenarios needed after filtering
|
| 48 |
+
MIN_QUALIFYING_SCENARIOS = 20
|
| 49 |
+
|
| 50 |
+
# Success threshold for binary classification
|
| 51 |
+
SUCCESS_THRESHOLD = 0.5
|
| 52 |
+
|
| 53 |
+
def compute_majority_vote_metrics(
|
| 54 |
+
scenario_trials: dict[str, list[float]],
|
| 55 |
+
success_threshold: float = SUCCESS_THRESHOLD
|
| 56 |
+
) -> dict:
|
| 57 |
+
"""
|
| 58 |
+
Compute majority vote and consistency metrics.
|
| 59 |
+
|
| 60 |
+
Returns dict with:
|
| 61 |
+
- pass_at_k: At least 1 trial succeeds
|
| 62 |
+
- majority_at_k: Majority of trials succeed
|
| 63 |
+
- all_at_k: All trials succeed
|
| 64 |
+
- consistent_correct: All trials succeed
|
| 65 |
+
- consistent_wrong: All trials fail
|
| 66 |
+
- inconsistent: Mixed results
|
| 67 |
+
"""
|
| 68 |
+
scenarios = list(scenario_trials.keys())
|
| 69 |
+
n_trials_list = [len(trials) for trials in scenario_trials.values()]
|
| 70 |
+
|
| 71 |
+
if not n_trials_list:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
k = min(n_trials_list)
|
| 75 |
+
n_scenarios = len(scenarios)
|
| 76 |
+
|
| 77 |
+
if n_scenarios == 0 or k < 1:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
pass_at_k = 0
|
| 81 |
+
majority_at_k = 0
|
| 82 |
+
all_at_k = 0
|
| 83 |
+
consistent_correct = 0
|
| 84 |
+
consistent_wrong = 0
|
| 85 |
+
inconsistent = 0
|
| 86 |
+
|
| 87 |
+
scenario_details = []
|
| 88 |
+
all_scores = []
|
| 89 |
+
|
| 90 |
+
for s in scenarios:
|
| 91 |
+
trials = scenario_trials[s][:k]
|
| 92 |
+
all_scores.extend(trials)
|
| 93 |
+
successes = [1 if t >= success_threshold else 0 for t in trials]
|
| 94 |
+
n_success = sum(successes)
|
| 95 |
+
|
| 96 |
+
if n_success >= 1:
|
| 97 |
+
pass_at_k += 1
|
| 98 |
+
|
| 99 |
+
if n_success > k / 2:
|
| 100 |
+
majority_at_k += 1
|
| 101 |
+
|
| 102 |
+
if n_success == k:
|
| 103 |
+
all_at_k += 1
|
| 104 |
+
consistent_correct += 1
|
| 105 |
+
consistency_type = "correct"
|
| 106 |
+
elif n_success == 0:
|
| 107 |
+
consistent_wrong += 1
|
| 108 |
+
consistency_type = "wrong"
|
| 109 |
+
else:
|
| 110 |
+
inconsistent += 1
|
| 111 |
+
consistency_type = "inconsistent"
|
| 112 |
+
|
| 113 |
+
scenario_details.append({
|
| 114 |
+
"scenario": s,
|
| 115 |
+
"n_success": n_success,
|
| 116 |
+
"n_trials": k,
|
| 117 |
+
"majority_correct": n_success > k / 2,
|
| 118 |
+
"consistency_type": consistency_type,
|
| 119 |
+
"mean_score": np.mean(trials),
|
| 120 |
+
"std_score": np.std(trials) if len(trials) > 1 else 0,
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
return {
|
| 124 |
+
"n_scenarios": n_scenarios,
|
| 125 |
+
"n_trials": k,
|
| 126 |
+
"threshold": success_threshold,
|
| 127 |
+
"pass_at_k": pass_at_k / n_scenarios,
|
| 128 |
+
"majority_at_k": majority_at_k / n_scenarios,
|
| 129 |
+
"all_at_k": all_at_k / n_scenarios,
|
| 130 |
+
"consistent_correct": consistent_correct / n_scenarios,
|
| 131 |
+
"consistent_wrong": consistent_wrong / n_scenarios,
|
| 132 |
+
"inconsistent": inconsistent / n_scenarios,
|
| 133 |
+
"n_pass": pass_at_k,
|
| 134 |
+
"n_majority": majority_at_k,
|
| 135 |
+
"n_all": all_at_k,
|
| 136 |
+
"n_consistent_correct": consistent_correct,
|
| 137 |
+
"n_consistent_wrong": consistent_wrong,
|
| 138 |
+
"n_inconsistent": inconsistent,
|
| 139 |
+
"overall_mean": np.mean(all_scores),
|
| 140 |
+
"overall_std": np.std(all_scores),
|
| 141 |
+
"scenario_details": scenario_details,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# Metrics to extract
|
| 146 |
+
METRICS = [
|
| 147 |
+
("root_cause_entity_f1", "F1"),
|
| 148 |
+
("root_cause_entity_precision", "Precision"),
|
| 149 |
+
("root_cause_entity_recall", "Recall"),
|
| 150 |
+
]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def extract_all_data() -> dict[str, tuple[pd.DataFrame, pd.DataFrame]]:
|
| 154 |
+
"""
|
| 155 |
+
Extract majority vote data for all agents, for multiple metrics.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
- dict mapping metric_name -> (summary_df, scenario_df)
|
| 159 |
+
"""
|
| 160 |
+
agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
|
| 161 |
+
print(f"Found {len(agent_dirs)} 'react with code' agent directories:")
|
| 162 |
+
for d in agent_dirs:
|
| 163 |
+
print(f" - {d.name}")
|
| 164 |
+
|
| 165 |
+
# Read all judge outputs once
|
| 166 |
+
agent_data = {}
|
| 167 |
+
valid_models = []
|
| 168 |
+
skipped_models = []
|
| 169 |
+
|
| 170 |
+
for agent_dir in tqdm(agent_dirs, desc="Reading agent data"):
|
| 171 |
+
model_name = get_model_name(agent_dir.name)
|
| 172 |
+
|
| 173 |
+
print(f"\nReading: {agent_dir.name}")
|
| 174 |
+
scenario_data = read_judge_outputs_from_dir(agent_dir)
|
| 175 |
+
|
| 176 |
+
if not scenario_data:
|
| 177 |
+
print(f" SKIPPING {model_name}: No judge outputs found")
|
| 178 |
+
skipped_models.append((model_name, "No data"))
|
| 179 |
+
continue
|
| 180 |
+
|
| 181 |
+
# Filter scenarios with minimum runs
|
| 182 |
+
scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 183 |
+
n_qualifying = len(scenario_data)
|
| 184 |
+
|
| 185 |
+
if n_qualifying < MIN_QUALIFYING_SCENARIOS:
|
| 186 |
+
print(f" SKIPPING {model_name}: Only {n_qualifying} scenarios with {MIN_RUNS_PER_SCENARIO}+ runs")
|
| 187 |
+
skipped_models.append((model_name, f"{n_qualifying} qualifying"))
|
| 188 |
+
continue
|
| 189 |
+
|
| 190 |
+
print(f" Processing: {model_name} ({n_qualifying} scenarios)")
|
| 191 |
+
valid_models.append(model_name)
|
| 192 |
+
agent_data[model_name] = scenario_data
|
| 193 |
+
|
| 194 |
+
if skipped_models:
|
| 195 |
+
print(f"\n⚠️ Skipped {len(skipped_models)} models:")
|
| 196 |
+
for name, reason in skipped_models:
|
| 197 |
+
print(f" - {name}: {reason}")
|
| 198 |
+
|
| 199 |
+
print(f"\n✓ Included {len(valid_models)} models: {valid_models}")
|
| 200 |
+
|
| 201 |
+
# Extract for each metric
|
| 202 |
+
results = {}
|
| 203 |
+
|
| 204 |
+
for metric_key, metric_label in tqdm(METRICS, desc="Processing metrics"):
|
| 205 |
+
print(f"\n--- Extracting for metric: {metric_label} ({metric_key}) ---")
|
| 206 |
+
|
| 207 |
+
summary_records = []
|
| 208 |
+
scenario_records = []
|
| 209 |
+
|
| 210 |
+
for model_name, scenario_data in tqdm(agent_data.items(), desc=f" {metric_label}", leave=False):
|
| 211 |
+
# Extract scores for this metric
|
| 212 |
+
scenario_trials = extract_trial_scores_from_judge_outputs(scenario_data, metric_key)
|
| 213 |
+
|
| 214 |
+
# Compute majority vote metrics
|
| 215 |
+
metrics = compute_majority_vote_metrics(scenario_trials)
|
| 216 |
+
|
| 217 |
+
if metrics is None:
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
# Add to summary
|
| 221 |
+
summary_records.append({
|
| 222 |
+
"model": model_name,
|
| 223 |
+
"metric": metric_label,
|
| 224 |
+
"n_scenarios": metrics["n_scenarios"],
|
| 225 |
+
"n_trials": metrics["n_trials"],
|
| 226 |
+
"pass_at_k": metrics["pass_at_k"],
|
| 227 |
+
"majority_at_k": metrics["majority_at_k"],
|
| 228 |
+
"all_at_k": metrics["all_at_k"],
|
| 229 |
+
"consistent_correct": metrics["consistent_correct"],
|
| 230 |
+
"consistent_wrong": metrics["consistent_wrong"],
|
| 231 |
+
"inconsistent": metrics["inconsistent"],
|
| 232 |
+
"overall_mean": metrics["overall_mean"],
|
| 233 |
+
"overall_std": metrics["overall_std"],
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
# Add per-scenario data
|
| 237 |
+
for detail in metrics["scenario_details"]:
|
| 238 |
+
scenario_records.append({
|
| 239 |
+
"model": model_name,
|
| 240 |
+
"metric": metric_label,
|
| 241 |
+
"scenario": detail["scenario"],
|
| 242 |
+
"n_success": detail["n_success"],
|
| 243 |
+
"n_trials": detail["n_trials"],
|
| 244 |
+
"majority_correct": detail["majority_correct"],
|
| 245 |
+
"consistency_type": detail["consistency_type"],
|
| 246 |
+
"mean_score": detail["mean_score"],
|
| 247 |
+
"std_score": detail["std_score"],
|
| 248 |
+
})
|
| 249 |
+
|
| 250 |
+
summary_df = pd.DataFrame(summary_records)
|
| 251 |
+
scenario_df = pd.DataFrame(scenario_records)
|
| 252 |
+
results[metric_label] = (summary_df, scenario_df)
|
| 253 |
+
|
| 254 |
+
return results
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def save_data(results: dict[str, tuple[pd.DataFrame, pd.DataFrame]]):
|
| 258 |
+
"""Save extracted data to CSV files for each metric."""
|
| 259 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
# Also save combined data for backward compatibility
|
| 262 |
+
all_summaries = []
|
| 263 |
+
all_scenarios = []
|
| 264 |
+
|
| 265 |
+
for metric_label, (summary_df, scenario_df) in results.items():
|
| 266 |
+
metric_suffix = metric_label.lower()
|
| 267 |
+
|
| 268 |
+
summary_path = OUTPUT_DIR / f"majority_vote_data_{metric_suffix}.csv"
|
| 269 |
+
scenario_path = OUTPUT_DIR / f"majority_vote_scenarios_{metric_suffix}.csv"
|
| 270 |
+
|
| 271 |
+
summary_df.to_csv(summary_path, index=False)
|
| 272 |
+
scenario_df.to_csv(scenario_path, index=False)
|
| 273 |
+
|
| 274 |
+
print(f"\nData saved for {metric_label}:")
|
| 275 |
+
print(f" - {summary_path}")
|
| 276 |
+
print(f" - {scenario_path}")
|
| 277 |
+
|
| 278 |
+
all_summaries.append(summary_df)
|
| 279 |
+
all_scenarios.append(scenario_df)
|
| 280 |
+
|
| 281 |
+
# Save combined (default to F1 for backward compatibility)
|
| 282 |
+
if "F1" in results:
|
| 283 |
+
f1_summary, f1_scenario = results["F1"]
|
| 284 |
+
# Save without metric column for backward compat
|
| 285 |
+
f1_summary_compat = f1_summary.drop(columns=["metric"], errors="ignore")
|
| 286 |
+
f1_scenario_compat = f1_scenario.drop(columns=["metric"], errors="ignore")
|
| 287 |
+
f1_summary_compat.to_csv(OUTPUT_DIR / "majority_vote_data.csv", index=False)
|
| 288 |
+
f1_scenario_compat.to_csv(OUTPUT_DIR / "majority_vote_scenarios.csv", index=False)
|
| 289 |
+
print(f"\nBackward-compatible files (F1) saved to:")
|
| 290 |
+
print(f" - {OUTPUT_DIR / 'majority_vote_data.csv'}")
|
| 291 |
+
print(f" - {OUTPUT_DIR / 'majority_vote_scenarios.csv'}")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def print_summary(results: dict[str, tuple[pd.DataFrame, pd.DataFrame]]):
|
| 295 |
+
"""Print summary table for each metric."""
|
| 296 |
+
for metric_label, (summary_df, _) in results.items():
|
| 297 |
+
print("\n" + "="*80)
|
| 298 |
+
print(f"Majority Vote Summary ({metric_label}, threshold={SUCCESS_THRESHOLD})")
|
| 299 |
+
print("="*80)
|
| 300 |
+
|
| 301 |
+
df = summary_df.sort_values("majority_at_k", ascending=False)
|
| 302 |
+
|
| 303 |
+
print(f"\n{'Model':<20} {'Maj@k':>8} {'Pass@k':>8} {'All@k':>8} {'Cons✓':>8} {'Cons✗':>8} {'Incons':>8}")
|
| 304 |
+
print("-" * 80)
|
| 305 |
+
for _, row in df.iterrows():
|
| 306 |
+
print(f"{row['model']:<20} "
|
| 307 |
+
f"{row['majority_at_k']*100:>7.1f}% "
|
| 308 |
+
f"{row['pass_at_k']*100:>7.1f}% "
|
| 309 |
+
f"{row['all_at_k']*100:>7.1f}% "
|
| 310 |
+
f"{row['consistent_correct']*100:>7.1f}% "
|
| 311 |
+
f"{row['consistent_wrong']*100:>7.1f}% "
|
| 312 |
+
f"{row['inconsistent']*100:>7.1f}%")
|
| 313 |
+
|
| 314 |
+
def plot_majority_vs_performance(df: pd.DataFrame):
|
| 315 |
+
"""
|
| 316 |
+
Figure: Majority@k vs Performance scatter plot.
|
| 317 |
+
"""
|
| 318 |
+
plt.rcParams.update({PLOT_PARAMETERS})
|
| 319 |
+
|
| 320 |
+
fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH))
|
| 321 |
+
|
| 322 |
+
# Axis limits
|
| 323 |
+
x_min, x_max = 0, 1.0
|
| 324 |
+
y_min, y_max = 0, 100
|
| 325 |
+
|
| 326 |
+
# Gradient shading toward top-right (ideal)
|
| 327 |
+
for i in range(5):
|
| 328 |
+
alpha = 0.02 + i * 0.02
|
| 329 |
+
x_start = 0.1 + i * 0.15
|
| 330 |
+
y_start = 10 + i * 15
|
| 331 |
+
rect = plt.Rectangle((x_start, y_start), x_max - x_start, y_max - y_start,
|
| 332 |
+
color='#2ecc71', alpha=alpha, zorder=0)
|
| 333 |
+
ax.add_patch(rect)
|
| 334 |
+
|
| 335 |
+
# Arrow pointing to ideal
|
| 336 |
+
ax.annotate('', xy=(0.85, 85), xytext=(0.55, 55),
|
| 337 |
+
arrowprops=dict(arrowstyle='->', color='#27ae60', alpha=0.7, lw=1.5),
|
| 338 |
+
zorder=2)
|
| 339 |
+
ax.text(0.58, 58, 'better', fontsize=MIN_FONT_SIZE, style='italic',
|
| 340 |
+
color='#27ae60', alpha=0.8, rotation=45, zorder=2)
|
| 341 |
+
|
| 342 |
+
# Mark ideal corner
|
| 343 |
+
ax.scatter([1.0], [100], marker='*', s=100, c='#27ae60', alpha=0.5, zorder=2)
|
| 344 |
+
ax.text(0.92, 95, 'ideal', fontsize=MIN_FONT_SIZE - 1, color='#27ae60',
|
| 345 |
+
alpha=0.7, ha='right')
|
| 346 |
+
|
| 347 |
+
# Scatter points with model-specific colors and markers
|
| 348 |
+
for _, row in df.iterrows():
|
| 349 |
+
style = get_model_style(row["model"])
|
| 350 |
+
ax.scatter(row["overall_mean"], row["majority_at_k"] * 100,
|
| 351 |
+
c=style['color'], marker=style['marker'],
|
| 352 |
+
s=80, edgecolors='black', linewidth=0.5, zorder=10)
|
| 353 |
+
|
| 354 |
+
# Labels with smart positioning
|
| 355 |
+
for _, row in df.iterrows():
|
| 356 |
+
model = row["model"]
|
| 357 |
+
x_pos = row["overall_mean"]
|
| 358 |
+
y_pos = row["majority_at_k"] * 100
|
| 359 |
+
|
| 360 |
+
dx, dy = 0.03, 2
|
| 361 |
+
ha, va = "left", "center"
|
| 362 |
+
|
| 363 |
+
if x_pos > 0.7:
|
| 364 |
+
dx = -0.03
|
| 365 |
+
ha = "right"
|
| 366 |
+
if y_pos > 80:
|
| 367 |
+
dy = -3
|
| 368 |
+
va = "top"
|
| 369 |
+
|
| 370 |
+
ax.text(x_pos + dx, y_pos + dy, model, fontsize=MIN_FONT_SIZE - 1,
|
| 371 |
+
ha=ha, va=va, zorder=11)
|
| 372 |
+
|
| 373 |
+
ax.set_xlabel("Performance (RC Entity F1)")
|
| 374 |
+
ax.set_ylabel("Majority@k (%)")
|
| 375 |
+
ax.set_xlim(x_min, x_max)
|
| 376 |
+
ax.set_ylim(y_min, y_max)
|
| 377 |
+
|
| 378 |
+
plt.tight_layout()
|
| 379 |
+
plt.show()
|
| 380 |
+
fig.savefig(OUTPUT_DIR / "fig_majority_vs_performance.pdf")
|
| 381 |
+
fig.savefig(OUTPUT_DIR / "fig_majority_vs_performance.png")
|
| 382 |
+
plt.close(fig)
|
| 383 |
+
print("Saved: fig_majority_vs_performance.pdf/png")
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def plot_pass_vs_majority(df: pd.DataFrame, metric: str = "F1", suffix: str = ""):
|
| 387 |
+
"""
|
| 388 |
+
Figure: Scatter plot of Pass@k (x-axis) vs Majority@k (y-axis).
|
| 389 |
+
|
| 390 |
+
Args:
|
| 391 |
+
df: DataFrame with pass_at_k and majority_at_k columns
|
| 392 |
+
metric: Name of metric for labeling (F1, Precision, Recall)
|
| 393 |
+
suffix: Suffix for output filename (e.g., "_precision")
|
| 394 |
+
"""
|
| 395 |
+
fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH))
|
| 396 |
+
|
| 397 |
+
ax_min, ax_max = 0, 100
|
| 398 |
+
|
| 399 |
+
# Diagonal line
|
| 400 |
+
ax.plot([ax_min, ax_max], [ax_min, ax_max], color='#444444', linestyle='--',
|
| 401 |
+
linewidth=1.5, alpha=0.6, zorder=1)
|
| 402 |
+
|
| 403 |
+
# Consistency region labels
|
| 404 |
+
ax.text(8, 92, 'more\nconsistent', fontsize=MIN_FONT_SIZE + 1, color='#333333',
|
| 405 |
+
ha='left', va='top', style='italic')
|
| 406 |
+
ax.text(92, 8, 'less\nconsistent', fontsize=MIN_FONT_SIZE + 1, color='#333333',
|
| 407 |
+
ha='right', va='bottom', style='italic')
|
| 408 |
+
|
| 409 |
+
# Collect and plot points
|
| 410 |
+
points = {}
|
| 411 |
+
for _, row in df.iterrows():
|
| 412 |
+
style = get_model_style(row["model"])
|
| 413 |
+
x = row["pass_at_k"] * 100
|
| 414 |
+
y = row["majority_at_k"] * 100
|
| 415 |
+
ax.scatter(x, y, c=style['color'], marker=style['marker'],
|
| 416 |
+
s=50, edgecolors='black', linewidth=0.5, zorder=10)
|
| 417 |
+
points[row["model"]] = {'x': x, 'y': y}
|
| 418 |
+
|
| 419 |
+
line_color = '#444444'
|
| 420 |
+
line_width = 1.2
|
| 421 |
+
|
| 422 |
+
# Place labels with manual positioning
|
| 423 |
+
for model, p in points.items():
|
| 424 |
+
x, y = p['x'], p['y']
|
| 425 |
+
|
| 426 |
+
if 'GPT-OSS-120B' in model:
|
| 427 |
+
# Label to the right, slightly below
|
| 428 |
+
ax.text(x + 3, y - 2, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
|
| 429 |
+
|
| 430 |
+
elif 'Gemini 2.5 Pro' in model:
|
| 431 |
+
# TEAL CIRCLE: label slightly below and to the right
|
| 432 |
+
ax.text(x + 3, y + 2, model, fontsize=MIN_FONT_SIZE, ha='left', va='bottom', zorder=11)
|
| 433 |
+
|
| 434 |
+
elif 'o4-mini' in model:
|
| 435 |
+
# YELLOW SQUARE: shorter line goes right then to label
|
| 436 |
+
label_x = x + 12
|
| 437 |
+
label_y = y
|
| 438 |
+
# Horizontal line right (shorter)
|
| 439 |
+
ax.plot([x, label_x], [y, y], color=line_color, linewidth=line_width, alpha=0.8, zorder=5)
|
| 440 |
+
ax.text(label_x + 1, label_y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
|
| 441 |
+
|
| 442 |
+
elif 'GPT-5.1' in model:
|
| 443 |
+
# GREEN SQUARE: line from left edge, goes left then up
|
| 444 |
+
label_x = 5
|
| 445 |
+
label_y = 25
|
| 446 |
+
start_x = x - 2 # Left edge of the square marker
|
| 447 |
+
# Horizontal line left from left edge midpoint
|
| 448 |
+
ax.plot([start_x, label_x], [y, y], color=line_color, linewidth=line_width, alpha=0.8, zorder=5)
|
| 449 |
+
# Vertical line up to label height
|
| 450 |
+
ax.plot([label_x, label_x], [y, label_y], color=line_color, linewidth=line_width, alpha=0.8, zorder=5)
|
| 451 |
+
ax.text(label_x, label_y + 1, model, fontsize=MIN_FONT_SIZE, ha='left', va='bottom', zorder=11)
|
| 452 |
+
|
| 453 |
+
elif 'Claude Opus' in model:
|
| 454 |
+
# Label to the right
|
| 455 |
+
ax.text(x + 5, y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
|
| 456 |
+
|
| 457 |
+
elif 'Gemini 3 Pro' in model:
|
| 458 |
+
# Label BELOW the circle, offset left
|
| 459 |
+
ax.text(x - 18, y - 6, model, fontsize=MIN_FONT_SIZE, ha='left', va='top', zorder=11)
|
| 460 |
+
|
| 461 |
+
elif 'Gemini 3 Flash' in model:
|
| 462 |
+
# Label at x=95 to avoid diagonal line
|
| 463 |
+
ax.text(105, y + 4, model, fontsize=MIN_FONT_SIZE, ha='right', va='bottom', zorder=11)
|
| 464 |
+
|
| 465 |
+
elif 'Kimi K2' in model:
|
| 466 |
+
# Label to the right
|
| 467 |
+
ax.text(x + 3, y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
|
| 468 |
+
|
| 469 |
+
else:
|
| 470 |
+
# Default: label to the right
|
| 471 |
+
ax.text(x + 3, y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
|
| 472 |
+
|
| 473 |
+
ax.set_xlabel(f"Pass@k (%) [{metric}]")
|
| 474 |
+
ax.set_ylabel(f"Majority@k (%) [{metric}]")
|
| 475 |
+
ax.set_xlim(ax_min, ax_max)
|
| 476 |
+
ax.set_ylim(ax_min, ax_max)
|
| 477 |
+
ax.set_aspect('equal')
|
| 478 |
+
|
| 479 |
+
plt.title("Consistency: Pass@k vs. Majority@k")
|
| 480 |
+
plt.tight_layout()
|
| 481 |
+
plt.show()
|
| 482 |
+
filename = f"fig_pass_vs_majority{suffix}"
|
| 483 |
+
fig.savefig(OUTPUT_DIR / f"{filename}.png")
|
| 484 |
+
plt.close(fig)
|
| 485 |
+
print(f"Saved: {filename}.png")
|
| 486 |
+
|
| 487 |
+
def main():
|
| 488 |
+
print("Extracting majority vote data for 'react with code' agents...")
|
| 489 |
+
print(f"Reading from directories: {LEADERBOARD_DIR}")
|
| 490 |
+
print(f"Output directory: {OUTPUT_DIR}")
|
| 491 |
+
print(f"Success threshold: {SUCCESS_THRESHOLD}")
|
| 492 |
+
print(f"Minimum runs per scenario: {MIN_RUNS_PER_SCENARIO}")
|
| 493 |
+
print(f"Metrics: {[m[1] for m in METRICS]}")
|
| 494 |
+
|
| 495 |
+
results = extract_all_data()
|
| 496 |
+
|
| 497 |
+
if not results:
|
| 498 |
+
print("No data extracted!")
|
| 499 |
+
return
|
| 500 |
+
|
| 501 |
+
save_data(results)
|
| 502 |
+
print_summary(results)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
main()
|
| 507 |
+
|
analysis_src/extract_tool_failures.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Extract tool call failure data for all 'react with code' agents.
|
| 4 |
+
|
| 5 |
+
This script reads rollout JSONL files to identify and categorize tool call failures.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import re
|
| 10 |
+
import sys
|
| 11 |
+
import ast
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from collections import defaultdict
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import seaborn as sns
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
|
| 21 |
+
# Add project root to path
|
| 22 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 23 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 24 |
+
|
| 25 |
+
from src.utils import (
|
| 26 |
+
get_model_name,
|
| 27 |
+
find_react_with_code_dirs,
|
| 28 |
+
get_runs_stats,
|
| 29 |
+
filter_scenarios_with_min_runs,
|
| 30 |
+
find_latest_rollout_file
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from src.model_styles import (
|
| 34 |
+
get_model_style, get_color_palette, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, _COLORS, PLOT_PARAMETERS
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Paths
|
| 38 |
+
LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
|
| 39 |
+
RESULTS_JSON_DIR = LEADERBOARD_DIR / "results"
|
| 40 |
+
OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "tool_failures"
|
| 41 |
+
|
| 42 |
+
# Minimum runs per scenario required
|
| 43 |
+
MIN_RUNS_PER_SCENARIO = 3
|
| 44 |
+
MIN_QUALIFYING_SCENARIOS = 20
|
| 45 |
+
|
| 46 |
+
# Failure type patterns
|
| 47 |
+
FAILURE_PATTERNS = {
|
| 48 |
+
'python_syntax': [
|
| 49 |
+
r'SyntaxError',
|
| 50 |
+
r'IndentationError',
|
| 51 |
+
r'TabError',
|
| 52 |
+
],
|
| 53 |
+
'python_type': [
|
| 54 |
+
r'TypeError',
|
| 55 |
+
r'AttributeError',
|
| 56 |
+
r'ValueError',
|
| 57 |
+
r'KeyError',
|
| 58 |
+
r'IndexError',
|
| 59 |
+
],
|
| 60 |
+
'python_name': [
|
| 61 |
+
r'NameError',
|
| 62 |
+
r'UnboundLocalError',
|
| 63 |
+
r'ModuleNotFoundError',
|
| 64 |
+
r'ImportError',
|
| 65 |
+
],
|
| 66 |
+
'file_not_found': [
|
| 67 |
+
r'FileNotFoundError',
|
| 68 |
+
r'No such file or directory',
|
| 69 |
+
r'ENOENT',
|
| 70 |
+
r'path does not exist',
|
| 71 |
+
],
|
| 72 |
+
'permission_denied': [
|
| 73 |
+
r'PermissionError',
|
| 74 |
+
r'Permission denied',
|
| 75 |
+
r'EACCES',
|
| 76 |
+
],
|
| 77 |
+
'json_parse': [
|
| 78 |
+
r'JSONDecodeError',
|
| 79 |
+
r'json\.decoder\.JSONDecodeError',
|
| 80 |
+
r'Expecting value',
|
| 81 |
+
r'Invalid JSON',
|
| 82 |
+
],
|
| 83 |
+
'timeout': [
|
| 84 |
+
r'TimeoutError',
|
| 85 |
+
r'timeout',
|
| 86 |
+
r'Timed out',
|
| 87 |
+
r'deadline exceeded',
|
| 88 |
+
],
|
| 89 |
+
'memory': [
|
| 90 |
+
r'MemoryError',
|
| 91 |
+
r'out of memory',
|
| 92 |
+
r'OOM',
|
| 93 |
+
r'Cannot allocate memory',
|
| 94 |
+
],
|
| 95 |
+
'connection': [
|
| 96 |
+
r'ConnectionError',
|
| 97 |
+
r'ConnectionRefusedError',
|
| 98 |
+
r'Connection refused',
|
| 99 |
+
r'ECONNREFUSED',
|
| 100 |
+
],
|
| 101 |
+
'shell_command': [
|
| 102 |
+
r'command not found',
|
| 103 |
+
r'No such command',
|
| 104 |
+
r'not recognized as',
|
| 105 |
+
],
|
| 106 |
+
'assertion': [
|
| 107 |
+
r'AssertionError',
|
| 108 |
+
],
|
| 109 |
+
'runtime': [
|
| 110 |
+
r'RuntimeError',
|
| 111 |
+
r'Exception',
|
| 112 |
+
r'Error:',
|
| 113 |
+
],
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def classify_failure(output: str) -> tuple[str, str]:
|
| 118 |
+
"""
|
| 119 |
+
Classify a failure based on the output string.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
(category, specific_error)
|
| 123 |
+
"""
|
| 124 |
+
for category, patterns in FAILURE_PATTERNS.items():
|
| 125 |
+
for pattern in patterns:
|
| 126 |
+
if re.search(pattern, output, re.IGNORECASE):
|
| 127 |
+
# Extract the specific error type
|
| 128 |
+
match = re.search(pattern, output, re.IGNORECASE)
|
| 129 |
+
return (category, match.group(0) if match else pattern)
|
| 130 |
+
|
| 131 |
+
# Check for generic traceback
|
| 132 |
+
if 'Traceback' in output:
|
| 133 |
+
return ('other_python', 'Unknown Python Error')
|
| 134 |
+
|
| 135 |
+
return ('other', 'Unknown Error')
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def extract_tool_calls_from_rollout(rollout_file: Path) -> dict:
|
| 139 |
+
"""
|
| 140 |
+
Extract all tool calls and their outcomes from a rollout file.
|
| 141 |
+
|
| 142 |
+
Returns dict with:
|
| 143 |
+
- total_tool_calls: int
|
| 144 |
+
- failed_tool_calls: int
|
| 145 |
+
- failures: list of failure details
|
| 146 |
+
- tool_call_counts: dict of tool_name -> count
|
| 147 |
+
- tool_failure_counts: dict of tool_name -> failure_count
|
| 148 |
+
"""
|
| 149 |
+
tool_calls = {} # call_id -> {name, arguments}
|
| 150 |
+
total_calls = 0
|
| 151 |
+
failed_calls = 0
|
| 152 |
+
failures = []
|
| 153 |
+
tool_call_counts = defaultdict(int)
|
| 154 |
+
tool_failure_counts = defaultdict(int)
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
with open(rollout_file) as f:
|
| 158 |
+
for line in f:
|
| 159 |
+
try:
|
| 160 |
+
d = json.loads(line)
|
| 161 |
+
if d.get('type') != 'response_item':
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
payload = d.get('payload', {})
|
| 165 |
+
payload_type = payload.get('type', '')
|
| 166 |
+
|
| 167 |
+
if payload_type == 'function_call':
|
| 168 |
+
call_id = payload.get('call_id', '')
|
| 169 |
+
name = payload.get('name', '')
|
| 170 |
+
arguments = payload.get('arguments', '')
|
| 171 |
+
tool_calls[call_id] = {
|
| 172 |
+
'name': name,
|
| 173 |
+
'arguments': arguments,
|
| 174 |
+
'timestamp': d.get('timestamp', ''),
|
| 175 |
+
}
|
| 176 |
+
total_calls += 1
|
| 177 |
+
tool_call_counts[name] += 1
|
| 178 |
+
|
| 179 |
+
elif payload_type == 'function_call_output':
|
| 180 |
+
call_id = payload.get('call_id', '')
|
| 181 |
+
output = payload.get('output', '')
|
| 182 |
+
|
| 183 |
+
# Check if this is a failure
|
| 184 |
+
is_failure = False
|
| 185 |
+
failure_info = None
|
| 186 |
+
|
| 187 |
+
# Parse the output if it's JSON
|
| 188 |
+
try:
|
| 189 |
+
output_data = json.loads(output)
|
| 190 |
+
if isinstance(output_data, dict):
|
| 191 |
+
exit_code = output_data.get('metadata', {}).get('exit_code', 0)
|
| 192 |
+
output_text = output_data.get('output', '')
|
| 193 |
+
if exit_code != 0:
|
| 194 |
+
is_failure = True
|
| 195 |
+
category, error = classify_failure(output_text)
|
| 196 |
+
failure_info = {
|
| 197 |
+
'exit_code': exit_code,
|
| 198 |
+
'category': category,
|
| 199 |
+
'error': error,
|
| 200 |
+
'output_snippet': output_text[:300] if output_text else '',
|
| 201 |
+
}
|
| 202 |
+
except json.JSONDecodeError:
|
| 203 |
+
# Not JSON, check for error patterns in raw output
|
| 204 |
+
if 'Error' in output or 'error' in output or 'Traceback' in output:
|
| 205 |
+
is_failure = True
|
| 206 |
+
category, error = classify_failure(output)
|
| 207 |
+
failure_info = {
|
| 208 |
+
'exit_code': None,
|
| 209 |
+
'category': category,
|
| 210 |
+
'error': error,
|
| 211 |
+
'output_snippet': output[:300],
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
if is_failure and call_id in tool_calls:
|
| 215 |
+
failed_calls += 1
|
| 216 |
+
tool_name = tool_calls[call_id]['name']
|
| 217 |
+
tool_failure_counts[tool_name] += 1
|
| 218 |
+
failures.append({
|
| 219 |
+
'tool_name': tool_name,
|
| 220 |
+
'arguments': tool_calls[call_id]['arguments'][:200],
|
| 221 |
+
'timestamp': tool_calls[call_id]['timestamp'],
|
| 222 |
+
**failure_info,
|
| 223 |
+
})
|
| 224 |
+
|
| 225 |
+
except json.JSONDecodeError:
|
| 226 |
+
continue
|
| 227 |
+
except Exception as e:
|
| 228 |
+
print(f" Warning: Error reading {rollout_file}: {e}")
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
'total_tool_calls': total_calls,
|
| 233 |
+
'failed_tool_calls': failed_calls,
|
| 234 |
+
'failures': failures,
|
| 235 |
+
'tool_call_counts': dict(tool_call_counts),
|
| 236 |
+
'tool_failure_counts': dict(tool_failure_counts),
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def read_agent_stats(agent_dir: Path) -> dict[str, list[dict]]:
|
| 241 |
+
"""
|
| 242 |
+
Read tool call stats from all scenarios/trials for an agent.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
Dict mapping scenario_id -> list of stats (one per trial)
|
| 246 |
+
"""
|
| 247 |
+
scenario_data = {}
|
| 248 |
+
|
| 249 |
+
for scenario_dir in agent_dir.iterdir():
|
| 250 |
+
if not scenario_dir.is_dir() or not scenario_dir.name.startswith("Scenario"):
|
| 251 |
+
continue
|
| 252 |
+
|
| 253 |
+
scenario_id = scenario_dir.name
|
| 254 |
+
trials = []
|
| 255 |
+
|
| 256 |
+
for trial_dir in sorted(scenario_dir.iterdir()):
|
| 257 |
+
if not trial_dir.is_dir():
|
| 258 |
+
continue
|
| 259 |
+
|
| 260 |
+
rollout_file = find_latest_rollout_file(trial_dir)
|
| 261 |
+
if rollout_file:
|
| 262 |
+
stats = extract_tool_calls_from_rollout(rollout_file)
|
| 263 |
+
if stats:
|
| 264 |
+
trials.append(stats)
|
| 265 |
+
|
| 266 |
+
if trials:
|
| 267 |
+
scenario_data[scenario_id] = trials
|
| 268 |
+
|
| 269 |
+
return scenario_data
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def extract_all_data() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
| 275 |
+
"""
|
| 276 |
+
Extract tool failure data for all agents.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
- summary_df: Aggregated stats per model
|
| 280 |
+
- detail_df: Per-trial failure stats
|
| 281 |
+
- failures_df: Individual failure details
|
| 282 |
+
"""
|
| 283 |
+
agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
|
| 284 |
+
print(f"Found {len(agent_dirs)} 'react with code' agent directories")
|
| 285 |
+
|
| 286 |
+
summary_records = []
|
| 287 |
+
detail_records = []
|
| 288 |
+
failure_records = []
|
| 289 |
+
|
| 290 |
+
for agent_dir in tqdm(agent_dirs, desc="Processing agents"):
|
| 291 |
+
model_name = get_model_name(agent_dir.name)
|
| 292 |
+
|
| 293 |
+
print(f"\nProcessing: {agent_dir.name}")
|
| 294 |
+
scenario_data = read_agent_stats(agent_dir)
|
| 295 |
+
|
| 296 |
+
n_scenarios, min_runs, max_runs, n_qualifying = get_runs_stats(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 297 |
+
|
| 298 |
+
if n_scenarios == 0:
|
| 299 |
+
print(f" SKIPPING {model_name}: No rollout data found")
|
| 300 |
+
continue
|
| 301 |
+
|
| 302 |
+
if n_qualifying < MIN_QUALIFYING_SCENARIOS:
|
| 303 |
+
print(f" SKIPPING {model_name}: Only {n_qualifying}/{n_scenarios} scenarios have {MIN_RUNS_PER_SCENARIO}+ runs")
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
# Filter scenarios
|
| 307 |
+
scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
|
| 308 |
+
n_scenarios_filtered = len(scenario_data)
|
| 309 |
+
|
| 310 |
+
print(f" Processing: {model_name} ({n_scenarios_filtered} scenarios)")
|
| 311 |
+
|
| 312 |
+
# Aggregate across all scenarios and trials
|
| 313 |
+
all_total_calls = []
|
| 314 |
+
all_failed_calls = []
|
| 315 |
+
all_failure_rates = []
|
| 316 |
+
aggregated_tool_counts = defaultdict(int)
|
| 317 |
+
aggregated_failure_counts = defaultdict(int)
|
| 318 |
+
aggregated_category_counts = defaultdict(int)
|
| 319 |
+
|
| 320 |
+
for scenario_id, trials in tqdm(scenario_data.items(), desc=f" {model_name} scenarios", leave=False):
|
| 321 |
+
for trial_idx, trial in enumerate(trials):
|
| 322 |
+
total = trial['total_tool_calls']
|
| 323 |
+
failed = trial['failed_tool_calls']
|
| 324 |
+
|
| 325 |
+
all_total_calls.append(total)
|
| 326 |
+
all_failed_calls.append(failed)
|
| 327 |
+
all_failure_rates.append(failed / total * 100 if total > 0 else 0)
|
| 328 |
+
|
| 329 |
+
for tool_name, count in trial['tool_call_counts'].items():
|
| 330 |
+
aggregated_tool_counts[tool_name] += count
|
| 331 |
+
|
| 332 |
+
for tool_name, count in trial['tool_failure_counts'].items():
|
| 333 |
+
aggregated_failure_counts[tool_name] += count
|
| 334 |
+
|
| 335 |
+
# Count failure categories
|
| 336 |
+
for failure in trial['failures']:
|
| 337 |
+
category = failure.get('category', 'other')
|
| 338 |
+
aggregated_category_counts[category] += 1
|
| 339 |
+
|
| 340 |
+
# Add to failure records
|
| 341 |
+
failure_records.append({
|
| 342 |
+
'model': model_name,
|
| 343 |
+
'scenario': scenario_id,
|
| 344 |
+
'trial': trial_idx,
|
| 345 |
+
'tool_name': failure.get('tool_name', ''),
|
| 346 |
+
'category': category,
|
| 347 |
+
'error': failure.get('error', ''),
|
| 348 |
+
'exit_code': failure.get('exit_code'),
|
| 349 |
+
'output_snippet': failure.get('output_snippet', '')[:100],
|
| 350 |
+
})
|
| 351 |
+
|
| 352 |
+
detail_records.append({
|
| 353 |
+
'model': model_name,
|
| 354 |
+
'scenario': scenario_id,
|
| 355 |
+
'trial': trial_idx,
|
| 356 |
+
'total_tool_calls': total,
|
| 357 |
+
'failed_tool_calls': failed,
|
| 358 |
+
'failure_rate_pct': failed / total * 100 if total > 0 else 0,
|
| 359 |
+
})
|
| 360 |
+
|
| 361 |
+
# Compute per-tool failure rates
|
| 362 |
+
tool_failure_rates = {}
|
| 363 |
+
for tool_name, total in aggregated_tool_counts.items():
|
| 364 |
+
failures = aggregated_failure_counts.get(tool_name, 0)
|
| 365 |
+
tool_failure_rates[tool_name] = {
|
| 366 |
+
'total': total,
|
| 367 |
+
'failures': failures,
|
| 368 |
+
'rate': failures / total * 100 if total > 0 else 0
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
summary_records.append({
|
| 372 |
+
'model': model_name,
|
| 373 |
+
'n_scenarios': n_scenarios_filtered,
|
| 374 |
+
'n_trials': len(all_total_calls),
|
| 375 |
+
'total_tool_calls': sum(all_total_calls),
|
| 376 |
+
'total_failed_calls': sum(all_failed_calls),
|
| 377 |
+
'avg_tool_calls_per_trial': np.mean(all_total_calls),
|
| 378 |
+
'avg_failed_calls_per_trial': np.mean(all_failed_calls),
|
| 379 |
+
'avg_failure_rate_pct': np.mean(all_failure_rates),
|
| 380 |
+
'std_failure_rate_pct': np.std(all_failure_rates),
|
| 381 |
+
'failure_categories': dict(aggregated_category_counts),
|
| 382 |
+
'tool_failure_rates': tool_failure_rates,
|
| 383 |
+
})
|
| 384 |
+
|
| 385 |
+
summary_df = pd.DataFrame(summary_records)
|
| 386 |
+
detail_df = pd.DataFrame(detail_records)
|
| 387 |
+
failures_df = pd.DataFrame(failure_records)
|
| 388 |
+
|
| 389 |
+
return summary_df, detail_df, failures_df
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def save_data(summary_df: pd.DataFrame, detail_df: pd.DataFrame, failures_df: pd.DataFrame):
|
| 393 |
+
"""Save extracted data to CSV files."""
|
| 394 |
+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
| 395 |
+
|
| 396 |
+
summary_path = OUTPUT_DIR / "tool_failures_summary.csv"
|
| 397 |
+
detail_path = OUTPUT_DIR / "tool_failures_detail.csv"
|
| 398 |
+
failures_path = OUTPUT_DIR / "tool_failures_individual.csv"
|
| 399 |
+
|
| 400 |
+
summary_df.to_csv(summary_path, index=False)
|
| 401 |
+
detail_df.to_csv(detail_path, index=False)
|
| 402 |
+
failures_df.to_csv(failures_path, index=False)
|
| 403 |
+
|
| 404 |
+
print(f"\nData saved to:")
|
| 405 |
+
print(f" - {summary_path}")
|
| 406 |
+
print(f" - {detail_path}")
|
| 407 |
+
print(f" - {failures_path}")
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def print_summary(summary_df: pd.DataFrame):
|
| 411 |
+
"""Print summary table."""
|
| 412 |
+
print("\n" + "="*100)
|
| 413 |
+
print("Tool Call Failure Summary")
|
| 414 |
+
print("="*100)
|
| 415 |
+
|
| 416 |
+
summary_df = summary_df.sort_values("avg_failure_rate_pct", ascending=False)
|
| 417 |
+
|
| 418 |
+
print(f"\n{'Model':<20} {'Trials':>8} {'Total Calls':>12} {'Failed':>10} {'Fail Rate':>10} {'Top Category':>20}")
|
| 419 |
+
print("-" * 85)
|
| 420 |
+
for _, row in summary_df.iterrows():
|
| 421 |
+
categories = row.get('failure_categories', {})
|
| 422 |
+
if categories:
|
| 423 |
+
top_cat = max(categories, key=categories.get)
|
| 424 |
+
top_cat_count = categories[top_cat]
|
| 425 |
+
else:
|
| 426 |
+
top_cat = '-'
|
| 427 |
+
top_cat_count = 0
|
| 428 |
+
|
| 429 |
+
print(f"{row['model']:<20} {row['n_trials']:>8} {row['total_tool_calls']:>12} "
|
| 430 |
+
f"{row['total_failed_calls']:>10} {row['avg_failure_rate_pct']:>9.2f}% "
|
| 431 |
+
f"{top_cat} ({top_cat_count})")
|
| 432 |
+
|
| 433 |
+
def plot_failure_rate_by_model(summary_df: pd.DataFrame):
|
| 434 |
+
"""
|
| 435 |
+
Figure 1: Overall failure rate per model (horizontal bar chart).
|
| 436 |
+
"""
|
| 437 |
+
plt.rcParams.update(PLOT_PARAMETERS)
|
| 438 |
+
|
| 439 |
+
fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 3.0))
|
| 440 |
+
|
| 441 |
+
data = summary_df.sort_values("avg_failure_rate_pct", ascending=True)
|
| 442 |
+
|
| 443 |
+
colors = get_color_palette(len(data))
|
| 444 |
+
|
| 445 |
+
bars = ax.barh(data["model"], data["avg_failure_rate_pct"],
|
| 446 |
+
color=colors, edgecolor='black', linewidth=0.5)
|
| 447 |
+
|
| 448 |
+
# Add error bars for std
|
| 449 |
+
ax.errorbar(data["avg_failure_rate_pct"], range(len(data)),
|
| 450 |
+
xerr=data["std_failure_rate_pct"], fmt='none',
|
| 451 |
+
color='black', capsize=2, linewidth=0.5)
|
| 452 |
+
|
| 453 |
+
# Add value labels
|
| 454 |
+
for i, (bar, val, std) in enumerate(zip(bars, data["avg_failure_rate_pct"], data["std_failure_rate_pct"])):
|
| 455 |
+
ax.text(val + std + 0.5, bar.get_y() + bar.get_height()/2,
|
| 456 |
+
f'{val:.1f}%', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1)
|
| 457 |
+
|
| 458 |
+
ax.set_xlabel("Average Failure Rate (%)")
|
| 459 |
+
ax.set_xlim(0, data["avg_failure_rate_pct"].max() + data["std_failure_rate_pct"].max() + 5)
|
| 460 |
+
|
| 461 |
+
plt.title("Tool Call Failure Rate")
|
| 462 |
+
|
| 463 |
+
plt.tight_layout()
|
| 464 |
+
plt.show()
|
| 465 |
+
fig.savefig(OUTPUT_DIR / "fig_failure_rate_by_model.png")
|
| 466 |
+
plt.close(fig)
|
| 467 |
+
print("Saved: fig_failure_rate_by_model.png")
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def parse_dict_column(col_str):
|
| 471 |
+
"""Parse a dictionary column stored as string."""
|
| 472 |
+
if pd.isna(col_str) or col_str == '{}':
|
| 473 |
+
return {}
|
| 474 |
+
try:
|
| 475 |
+
return ast.literal_eval(col_str)
|
| 476 |
+
except:
|
| 477 |
+
return {}
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def plot_failure_categories_stacked(summary_df: pd.DataFrame):
|
| 481 |
+
"""
|
| 482 |
+
Figure 2: Failure category breakdown per model (stacked bar).
|
| 483 |
+
"""
|
| 484 |
+
plt.rcParams.update(PLOT_PARAMETERS)
|
| 485 |
+
|
| 486 |
+
fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 3.0))
|
| 487 |
+
|
| 488 |
+
# Parse failure categories
|
| 489 |
+
data = summary_df.copy()
|
| 490 |
+
# data['failure_categories'] = data['failure_categories'].apply(parse_dict_column)
|
| 491 |
+
|
| 492 |
+
print(data['failure_categories'])
|
| 493 |
+
|
| 494 |
+
# Get all categories and sort by total
|
| 495 |
+
all_categories = defaultdict(int)
|
| 496 |
+
for cats in data['failure_categories']:
|
| 497 |
+
for cat, count in cats.items():
|
| 498 |
+
all_categories[cat] += count
|
| 499 |
+
|
| 500 |
+
CATEGORY_COLORS = {
|
| 501 |
+
'python_syntax': '#e41a1c',
|
| 502 |
+
'python_type': '#377eb8',
|
| 503 |
+
'python_name': '#4daf4a',
|
| 504 |
+
'file_not_found': '#984ea3',
|
| 505 |
+
'json_parse': '#ff7f00',
|
| 506 |
+
'shell_command': '#a65628',
|
| 507 |
+
'timeout': '#f781bf',
|
| 508 |
+
'memory': '#999999',
|
| 509 |
+
'other_python': '#66c2a5',
|
| 510 |
+
'other': '#8da0cb',
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
# Sort categories by total count
|
| 514 |
+
sorted_cats = sorted(all_categories.keys(), key=lambda x: all_categories[x], reverse=True)[:8]
|
| 515 |
+
|
| 516 |
+
# Build data for stacked bar
|
| 517 |
+
data = data.sort_values('total_failed_calls', ascending=True)
|
| 518 |
+
|
| 519 |
+
bottom = np.zeros(len(data))
|
| 520 |
+
|
| 521 |
+
for cat in sorted_cats:
|
| 522 |
+
values = [row['failure_categories'].get(cat, 0) for _, row in data.iterrows()]
|
| 523 |
+
color = CATEGORY_COLORS.get(cat, '#888888')
|
| 524 |
+
ax.barh(data['model'], values, left=bottom,
|
| 525 |
+
label=cat.replace('_', ' ').title(), color=color,
|
| 526 |
+
edgecolor='white', linewidth=0.3)
|
| 527 |
+
bottom += values
|
| 528 |
+
|
| 529 |
+
ax.set_xlabel("Number of Failed Tool Calls")
|
| 530 |
+
ax.legend(loc='lower right', ncol=2, fontsize=MIN_FONT_SIZE - 1,
|
| 531 |
+
framealpha=0.9, bbox_to_anchor=(1.0, 0.0))
|
| 532 |
+
|
| 533 |
+
plt.title("Tool Failure Category Distribution")
|
| 534 |
+
|
| 535 |
+
plt.tight_layout()
|
| 536 |
+
plt.show()
|
| 537 |
+
fig.savefig(OUTPUT_DIR / "fig_failure_categories_stacked.png")
|
| 538 |
+
plt.close(fig)
|
| 539 |
+
print("Saved: fig_failure_categories_stacked.png")
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def main():
|
| 543 |
+
print("Extracting tool call failure data for 'react with code' agents...")
|
| 544 |
+
print(f"Reading from directories: {LEADERBOARD_DIR}")
|
| 545 |
+
print(f"Output directory: {OUTPUT_DIR}")
|
| 546 |
+
|
| 547 |
+
summary_df, detail_df, failures_df = extract_all_data()
|
| 548 |
+
|
| 549 |
+
if len(summary_df) == 0:
|
| 550 |
+
print("No data extracted!")
|
| 551 |
+
return
|
| 552 |
+
|
| 553 |
+
save_data(summary_df, detail_df, failures_df)
|
| 554 |
+
print_summary(summary_df)
|
| 555 |
+
|
| 556 |
+
plot_failure_categories_stacked(summary_df)
|
| 557 |
+
|
| 558 |
+
if __name__ == "__main__":
|
| 559 |
+
main()
|
| 560 |
+
|
analysis_src/model_styles.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Shared model styling configuration for paper analysis figures.
|
| 4 |
+
|
| 5 |
+
Provides consistent colors, markers, and display names across all agents
|
| 6 |
+
(EOG, React with Code, and future agents).
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
from paper_analysis.common.model_styles import get_model_style, MODEL_DISPLAY_NAMES
|
| 10 |
+
|
| 11 |
+
style = get_model_style("GPT-5.1")
|
| 12 |
+
ax.scatter(x, y, c=style['color'], marker=style['marker'], ...)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import seaborn as sns
|
| 16 |
+
|
| 17 |
+
# =============================================================================
|
| 18 |
+
# MODEL DISPLAY NAMES
|
| 19 |
+
# Maps various raw names to standardized display names
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
MODEL_DISPLAY_NAMES = {
|
| 23 |
+
# OpenAI / Azure
|
| 24 |
+
"Azure_gpt-5.1-2025-11-13": "GPT-5.1",
|
| 25 |
+
"Azure_gpt-5.1-chat-2025-11-13": "GPT-5.1",
|
| 26 |
+
"Azure_o4-mini": "o4-mini",
|
| 27 |
+
"Azure_gpt-4o": "GPT-4o",
|
| 28 |
+
"openai_gpt-oss-120b": "GPT-OSS-120B",
|
| 29 |
+
"openai_gpt-oss-20b": "GPT-OSS-20B",
|
| 30 |
+
# Google / GCP
|
| 31 |
+
"GCP_gemini-2.5-pro": "Gemini 2.5 Pro",
|
| 32 |
+
"gemini-2.5-pro": "Gemini 2.5 Pro",
|
| 33 |
+
"Gemini-2.5-Pro": "Gemini 2.5 Pro",
|
| 34 |
+
"gcp_gemini-3-pro-preview": "Gemini 3 Pro",
|
| 35 |
+
"gemini-3-pro-preview": "Gemini 3 Pro",
|
| 36 |
+
"Gemini-3-Pro": "Gemini 3 Pro",
|
| 37 |
+
"gemini-3-flash-preview": "Gemini 3 Flash",
|
| 38 |
+
"Gemini-3-Flash": "Gemini 3 Flash",
|
| 39 |
+
"google_gemini-3-flash-preview": "Gemini 3 Flash",
|
| 40 |
+
# Moonshot AI
|
| 41 |
+
"moonshotai_kimi-k2-thinking": "Kimi K2",
|
| 42 |
+
"kimi-k2-thinking": "Kimi K2",
|
| 43 |
+
"Kimi-K2": "Kimi K2",
|
| 44 |
+
# Anthropic / AWS
|
| 45 |
+
"aws_claude-opus-4-5": "Claude Opus 4.5",
|
| 46 |
+
"Claude-Opus-4.5": "Claude Opus 4.5",
|
| 47 |
+
# Mistral AI
|
| 48 |
+
"mistralai_mistral-large-2512": "Mistral Large",
|
| 49 |
+
"Mistral-Large": "Mistral Large",
|
| 50 |
+
# Alibaba / Qwen
|
| 51 |
+
"qwen_qwen3-vl-32b-instruct": "Qwen3-VL-32B",
|
| 52 |
+
# ServiceNow
|
| 53 |
+
"ServiceNow-AI_Apriel-1.6-15b-Thinker": "Apriel-1.6-15B",
|
| 54 |
+
# Minimax
|
| 55 |
+
"minimax_minimax-m2.1": "Minimax M2.1",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
# =============================================================================
|
| 59 |
+
# MODEL STYLES
|
| 60 |
+
# Defines color and marker for each model (by display name)
|
| 61 |
+
# Colors are colorblind-friendly, markers provide redundant encoding
|
| 62 |
+
# =============================================================================
|
| 63 |
+
|
| 64 |
+
# Colorblind-friendly palette (based on IBM Design Library / Wong palette)
|
| 65 |
+
_COLORS = {
|
| 66 |
+
'blue': '#0072B2',
|
| 67 |
+
'orange': '#E69F00',
|
| 68 |
+
'green': '#009E73',
|
| 69 |
+
'pink': '#CC79A7',
|
| 70 |
+
'light_blue': '#56B4E9',
|
| 71 |
+
'yellow': '#F0E442',
|
| 72 |
+
'red': '#D55E00',
|
| 73 |
+
'gray': '#999999',
|
| 74 |
+
'purple': '#9467BD',
|
| 75 |
+
'brown': '#8C564B',
|
| 76 |
+
'teal': '#17BECF',
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
# Marker styles for redundant encoding
|
| 80 |
+
_MARKERS = {
|
| 81 |
+
'circle': 'o',
|
| 82 |
+
'square': 's',
|
| 83 |
+
'diamond': 'D',
|
| 84 |
+
'triangle_up': '^',
|
| 85 |
+
'triangle_down': 'v',
|
| 86 |
+
'pentagon': 'p',
|
| 87 |
+
'hexagon': 'h',
|
| 88 |
+
'star': '*',
|
| 89 |
+
'plus': 'P',
|
| 90 |
+
'x': 'X',
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
# Model style definitions (display_name -> {color, marker})
|
| 94 |
+
MODEL_STYLES = {
|
| 95 |
+
# Google models - shades of blue
|
| 96 |
+
"Gemini 3 Flash": {
|
| 97 |
+
'color': _COLORS['blue'],
|
| 98 |
+
'marker': _MARKERS['circle'],
|
| 99 |
+
},
|
| 100 |
+
"Gemini 3 Pro": {
|
| 101 |
+
'color': _COLORS['light_blue'],
|
| 102 |
+
'marker': _MARKERS['circle'],
|
| 103 |
+
},
|
| 104 |
+
"Gemini 2.5 Pro": {
|
| 105 |
+
'color': _COLORS['teal'],
|
| 106 |
+
'marker': _MARKERS['circle'],
|
| 107 |
+
},
|
| 108 |
+
|
| 109 |
+
# OpenAI models - shades of green/orange
|
| 110 |
+
"GPT-5.1": {
|
| 111 |
+
'color': _COLORS['green'],
|
| 112 |
+
'marker': _MARKERS['square'],
|
| 113 |
+
},
|
| 114 |
+
"GPT-4o": {
|
| 115 |
+
'color': _COLORS['green'],
|
| 116 |
+
'marker': _MARKERS['diamond'],
|
| 117 |
+
},
|
| 118 |
+
"o4-mini": {
|
| 119 |
+
'color': _COLORS['yellow'],
|
| 120 |
+
'marker': _MARKERS['square'],
|
| 121 |
+
},
|
| 122 |
+
"GPT-OSS-120B": {
|
| 123 |
+
'color': _COLORS['orange'],
|
| 124 |
+
'marker': _MARKERS['triangle_up'],
|
| 125 |
+
},
|
| 126 |
+
"GPT-OSS-20B": {
|
| 127 |
+
'color': _COLORS['brown'],
|
| 128 |
+
'marker': _MARKERS['triangle_down'],
|
| 129 |
+
},
|
| 130 |
+
|
| 131 |
+
# Anthropic models - pink
|
| 132 |
+
"Claude Opus 4.5": {
|
| 133 |
+
'color': _COLORS['pink'],
|
| 134 |
+
'marker': _MARKERS['diamond'],
|
| 135 |
+
},
|
| 136 |
+
|
| 137 |
+
# Moonshot AI - red
|
| 138 |
+
"Kimi K2": {
|
| 139 |
+
'color': _COLORS['red'],
|
| 140 |
+
'marker': _MARKERS['pentagon'],
|
| 141 |
+
},
|
| 142 |
+
|
| 143 |
+
# Mistral - purple
|
| 144 |
+
"Mistral Large": {
|
| 145 |
+
'color': _COLORS['purple'],
|
| 146 |
+
'marker': _MARKERS['hexagon'],
|
| 147 |
+
},
|
| 148 |
+
|
| 149 |
+
# Minimax - gray
|
| 150 |
+
"Minimax M2.1": {
|
| 151 |
+
'color': _COLORS['gray'],
|
| 152 |
+
'marker': _MARKERS['star'],
|
| 153 |
+
},
|
| 154 |
+
|
| 155 |
+
# Qwen - teal
|
| 156 |
+
"Qwen3-VL-32B": {
|
| 157 |
+
'color': _COLORS['teal'],
|
| 158 |
+
'marker': _MARKERS['plus'],
|
| 159 |
+
},
|
| 160 |
+
|
| 161 |
+
# ServiceNow - brown
|
| 162 |
+
"Apriel-1.6-15B": {
|
| 163 |
+
'color': _COLORS['brown'],
|
| 164 |
+
'marker': _MARKERS['x'],
|
| 165 |
+
},
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# Default style for unknown models
|
| 169 |
+
_DEFAULT_STYLE = {
|
| 170 |
+
'color': _COLORS['gray'],
|
| 171 |
+
'marker': _MARKERS['circle'],
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_display_name(raw_name: str) -> str:
|
| 176 |
+
"""Convert raw model name to display name."""
|
| 177 |
+
return MODEL_DISPLAY_NAMES.get(raw_name, raw_name)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def get_model_style(model_name: str) -> dict:
|
| 181 |
+
"""
|
| 182 |
+
Get the style (color, marker) for a model.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
model_name: Either raw name or display name
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Dict with 'color' and 'marker' keys
|
| 189 |
+
"""
|
| 190 |
+
# Try display name first
|
| 191 |
+
if model_name in MODEL_STYLES:
|
| 192 |
+
return MODEL_STYLES[model_name]
|
| 193 |
+
|
| 194 |
+
# Try converting from raw name
|
| 195 |
+
display_name = get_display_name(model_name)
|
| 196 |
+
if display_name in MODEL_STYLES:
|
| 197 |
+
return MODEL_STYLES[display_name]
|
| 198 |
+
|
| 199 |
+
# Return default
|
| 200 |
+
return _DEFAULT_STYLE
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def get_model_color(model_name: str) -> str:
|
| 204 |
+
"""Get just the color for a model."""
|
| 205 |
+
return get_model_style(model_name)['color']
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_model_marker(model_name: str) -> str:
|
| 209 |
+
"""Get just the marker for a model."""
|
| 210 |
+
return get_model_style(model_name)['marker']
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# For backward compatibility - create a color palette list
|
| 214 |
+
def get_color_palette(n_colors: int = 10):
|
| 215 |
+
"""Get a colorblind-friendly palette with n colors."""
|
| 216 |
+
palette_order = ['blue', 'orange', 'green', 'pink', 'light_blue',
|
| 217 |
+
'red', 'purple', 'brown', 'teal', 'yellow']
|
| 218 |
+
return [_COLORS[c] for c in palette_order[:n_colors]]
|
| 219 |
+
|
| 220 |
+
SINGLE_COLUMN_WIDTH = 3.25 # inches (ICML)
|
| 221 |
+
DOUBLE_COLUMN_WIDTH = 6.75 # inches (ICML)
|
| 222 |
+
MIN_FONT_SIZE = 10
|
| 223 |
+
|
| 224 |
+
PLOT_PARAMETERS = {
|
| 225 |
+
'font.size': MIN_FONT_SIZE,
|
| 226 |
+
'font.family': 'serif',
|
| 227 |
+
'axes.labelsize': MIN_FONT_SIZE + 1,
|
| 228 |
+
'axes.titlesize': MIN_FONT_SIZE + 2,
|
| 229 |
+
'xtick.labelsize': MIN_FONT_SIZE,
|
| 230 |
+
'ytick.labelsize': MIN_FONT_SIZE,
|
| 231 |
+
'legend.fontsize': MIN_FONT_SIZE,
|
| 232 |
+
'figure.titlesize': MIN_FONT_SIZE + 2,
|
| 233 |
+
'figure.dpi': 150,
|
| 234 |
+
'savefig.dpi': 300,
|
| 235 |
+
'savefig.bbox': 'tight',
|
| 236 |
+
'axes.spines.top': False,
|
| 237 |
+
'axes.spines.right': False,
|
| 238 |
+
'axes.linewidth': 0.8,
|
| 239 |
+
'lines.linewidth': 1.0,
|
| 240 |
+
'patch.linewidth': 0.5,
|
| 241 |
+
}
|
analysis_src/utils.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
# Model display names (short for figures)
|
| 5 |
+
# Follows ArtificialAnalysis.ai naming conventions
|
| 6 |
+
MODEL_DISPLAY_NAMES = {
|
| 7 |
+
# OpenAI / Azure
|
| 8 |
+
"Azure_gpt-5.1-2025-11-13": "GPT-5.1",
|
| 9 |
+
"Azure_gpt-5.1-chat-2025-11-13": "GPT-5.1",
|
| 10 |
+
"Azure_o4-mini": "o4-mini",
|
| 11 |
+
"Azure_gpt-4o": "GPT-4o",
|
| 12 |
+
"openai_gpt-oss-120b": "GPT-OSS-120B",
|
| 13 |
+
"openai_gpt-oss-20b": "GPT-OSS-20B",
|
| 14 |
+
# Google / GCP
|
| 15 |
+
"GCP_gemini-2.5-pro": "Gemini 2.5 Pro",
|
| 16 |
+
"gemini-2.5-pro": "Gemini 2.5 Pro",
|
| 17 |
+
"gcp_gemini-3-pro-preview": "Gemini 3 Pro",
|
| 18 |
+
"gemini-3-pro-preview": "Gemini 3 Pro",
|
| 19 |
+
"gemini-3-flash-preview": "Gemini 3 Flash",
|
| 20 |
+
"google_gemini-3-flash-preview": "Gemini 3 Flash",
|
| 21 |
+
# Moonshot AI
|
| 22 |
+
"moonshotai_kimi-k2-thinking": "Kimi K2",
|
| 23 |
+
"kimi-k2-thinking": "Kimi K2",
|
| 24 |
+
# Anthropic / AWS
|
| 25 |
+
"aws_claude-opus-4-5": "Claude Opus 4.5",
|
| 26 |
+
# Mistral AI
|
| 27 |
+
"mistralai_mistral-large-2512": "Mistral Large",
|
| 28 |
+
# Alibaba / Qwen
|
| 29 |
+
"qwen_qwen3-vl-32b-instruct": "Qwen3-VL-32B",
|
| 30 |
+
# ServiceNow
|
| 31 |
+
"ServiceNow-AI_Apriel-1.6-15b-Thinker": "Apriel-1.6-15B",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_model_name(dirname: str) -> str:
|
| 36 |
+
"""Extract model name from directory name."""
|
| 37 |
+
name = dirname.replace("react with code_", "").replace("_07ccdb1", "")
|
| 38 |
+
return MODEL_DISPLAY_NAMES.get(name, name)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def find_react_with_code_dirs(leaderboard_dir: Path) -> list[Path]:
|
| 42 |
+
"""Find all 'react with code' agent directories (non-backup)."""
|
| 43 |
+
dirs = []
|
| 44 |
+
for d in leaderboard_dir.iterdir():
|
| 45 |
+
if d.is_dir() and d.name.startswith("react with code_") and not d.name.startswith("backup_"):
|
| 46 |
+
dirs.append(d)
|
| 47 |
+
return sorted(dirs)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def read_judge_outputs_from_dir(agent_dir: Path) -> dict[str, list[dict]]:
|
| 51 |
+
"""
|
| 52 |
+
Read all judge_output.json files from an agent directory.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Dict mapping scenario_id -> list of judge outputs (one per trial)
|
| 56 |
+
"""
|
| 57 |
+
scenario_data = {}
|
| 58 |
+
|
| 59 |
+
for scenario_dir in agent_dir.iterdir():
|
| 60 |
+
if not scenario_dir.is_dir() or not scenario_dir.name.startswith("Scenario"):
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
scenario_id = scenario_dir.name
|
| 64 |
+
trials = []
|
| 65 |
+
|
| 66 |
+
# Look for trial subdirectories (1, 2, 3, etc.)
|
| 67 |
+
for trial_dir in sorted(scenario_dir.iterdir()):
|
| 68 |
+
if not trial_dir.is_dir():
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
judge_file = trial_dir / "judge_output.json"
|
| 72 |
+
if judge_file.exists():
|
| 73 |
+
try:
|
| 74 |
+
with open(judge_file) as f:
|
| 75 |
+
judge_data = json.load(f)
|
| 76 |
+
trials.append(judge_data)
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f" Warning: Error reading {judge_file}: {e}")
|
| 79 |
+
|
| 80 |
+
if trials:
|
| 81 |
+
scenario_data[scenario_id] = trials
|
| 82 |
+
|
| 83 |
+
return scenario_data
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def extract_trial_scores_from_judge_outputs(
|
| 87 |
+
scenario_data: dict[str, list[dict]],
|
| 88 |
+
metric: str
|
| 89 |
+
) -> dict[str, list[float]]:
|
| 90 |
+
"""
|
| 91 |
+
Extract per-trial scores for a given metric from judge outputs.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
scenario_data: Dict mapping scenario_id -> list of judge outputs
|
| 95 |
+
metric: The metric name to extract
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Dict mapping scenario_id -> list of trial scores
|
| 99 |
+
"""
|
| 100 |
+
scenario_trials = {}
|
| 101 |
+
|
| 102 |
+
for scenario_id, trials in scenario_data.items():
|
| 103 |
+
scores = []
|
| 104 |
+
for trial in trials:
|
| 105 |
+
flat_scores = trial.get("flat_scores", {})
|
| 106 |
+
score = flat_scores.get(metric)
|
| 107 |
+
|
| 108 |
+
# Handle None/null values
|
| 109 |
+
if score is None:
|
| 110 |
+
score = 0.0
|
| 111 |
+
scores.append(float(score))
|
| 112 |
+
|
| 113 |
+
if scores:
|
| 114 |
+
scenario_trials[scenario_id] = scores
|
| 115 |
+
|
| 116 |
+
return scenario_trials
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_runs_stats(scenario_data: dict[str, list], min_runs_required: int) -> tuple[int, int, int, int]:
|
| 120 |
+
"""Get run statistics: (n_scenarios, min_runs, max_runs, n_qualifying)."""
|
| 121 |
+
if not scenario_data:
|
| 122 |
+
return 0, 0, 0, 0
|
| 123 |
+
|
| 124 |
+
run_counts = [len(trials) for trials in scenario_data.values()]
|
| 125 |
+
n_qualifying = sum(1 for c in run_counts if c >= min_runs_required)
|
| 126 |
+
return len(scenario_data), min(run_counts), max(run_counts), n_qualifying
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def filter_scenarios_with_min_runs(scenario_data: dict[str, list], min_runs_required: int) -> dict[str, list]:
|
| 130 |
+
"""Filter to only include scenarios with >= min_runs_required runs."""
|
| 131 |
+
return {
|
| 132 |
+
scenario_id: trials
|
| 133 |
+
for scenario_id, trials in scenario_data.items()
|
| 134 |
+
if len(trials) >= min_runs_required
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def find_latest_rollout_file(trial_dir: Path) -> Path:
|
| 138 |
+
"""Find the latest rollout file in a trial's sessions directory."""
|
| 139 |
+
sessions_dir = trial_dir / "sessions"
|
| 140 |
+
if not sessions_dir.exists():
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
rollout_files = []
|
| 144 |
+
for rollout_file in sessions_dir.rglob("rollout-*.jsonl"):
|
| 145 |
+
rollout_files.append(rollout_file)
|
| 146 |
+
|
| 147 |
+
if not rollout_files:
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
# Sort by filename (which includes timestamp) and return the latest
|
| 151 |
+
# rollout_files.sort(key=lambda f: f.name, reverse=True)
|
| 152 |
+
# return rollout_files[0]
|
| 153 |
+
|
| 154 |
+
return max(rollout_files, key=lambda p: p.stat().st_mtime)
|
| 155 |
+
|
evaluation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|