rohan-arora-ibm commited on
Commit
0b73078
·
unverified ·
1 Parent(s): 7225d57

bump: bringing in evaluation code from @ptoroisaza

Browse files
analysis_src/__pycache__/consistency.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
analysis_src/__pycache__/extract_consistency_data.cpython-310.pyc ADDED
Binary file (6.73 kB). View file
 
analysis_src/__pycache__/extract_discovery_trajectory.cpython-310.pyc ADDED
Binary file (21.9 kB). View file
 
analysis_src/__pycache__/extract_inference_data.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
analysis_src/__pycache__/extract_majority_vote_data.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
analysis_src/__pycache__/extract_tool_failures.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
analysis_src/__pycache__/model_styles.cpython-310.pyc ADDED
Binary file (4.52 kB). View file
 
analysis_src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.85 kB). View file
 
analysis_src/consistency.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ # TODO: Needs to be integrated into the itbench_leaderboard module
4
+ # This script calculates ICC (Intraclass Correlation Coefficient) and other
5
+ # consistency metrics for agent evaluation results.
6
+
7
+ Consistency Analysis for Agent Leaderboard Results.
8
+
9
+ Computes ICC (Intraclass Correlation Coefficient) to measure the reliability
10
+ and consistency of agent responses across multiple trials per scenario.
11
+
12
+ ICC answers: "Of all the variance observed, how much is due to actual scenario
13
+ difficulty (signal) vs. random model variability (noise/flakiness)?"
14
+
15
+ Interpretation:
16
+ ICC > 0.9: Excellent consistency
17
+ ICC 0.75-0.9: Good consistency
18
+ ICC 0.5-0.75: Moderate consistency
19
+ ICC < 0.5: Poor consistency (high flakiness)
20
+
21
+ Usage:
22
+ python -m itbench_leaderboard.consistency --results-dir leaderboard_results/results
23
+ python -m itbench_leaderboard.consistency --results-file path/to/results.json
24
+ """
25
+
26
+ import argparse
27
+ import json
28
+ import sys
29
+ from dataclasses import dataclass, field
30
+ from pathlib import Path
31
+ from typing import Optional
32
+
33
+ import numpy as np
34
+
35
+
36
+ @dataclass
37
+ class ConsistencyMetrics:
38
+ """Container for all consistency metrics."""
39
+
40
+ # Core ICC metrics
41
+ icc: float
42
+ flakiness_ratio: float # 1 - ICC
43
+
44
+ # ANOVA components
45
+ msb: float # Mean Square Between (between-scenario variance)
46
+ msw: float # Mean Square Within (within-scenario variance)
47
+
48
+ # Within-scenario consistency
49
+ mean_within_std: float
50
+ mean_agreement_rate: float
51
+ repeatability_coefficient: float
52
+
53
+ # Summary stats
54
+ n_scenarios: int
55
+ n_trials: int
56
+ n_flaky_scenarios: int
57
+ flaky_scenarios: list = field(default_factory=list)
58
+
59
+ # Per-scenario breakdown
60
+ scenario_details: dict = field(default_factory=dict)
61
+
62
+ def __str__(self) -> str:
63
+ return (
64
+ f"ICC: {self.icc:.4f} (flakiness: {self.flakiness_ratio:.4f})\n"
65
+ f"MSB (between): {self.msb:.4f}, MSW (within): {self.msw:.4f}\n"
66
+ f"Mean within-std: {self.mean_within_std:.4f}\n"
67
+ f"Agreement rate: {self.mean_agreement_rate:.4f}\n"
68
+ f"Repeatability coef: {self.repeatability_coefficient:.4f}\n"
69
+ f"Flaky scenarios: {self.n_flaky_scenarios}/{self.n_scenarios}"
70
+ )
71
+
72
+
73
+ def load_results(filepath: Path) -> dict:
74
+ """Load results JSON file."""
75
+ with open(filepath, "r") as f:
76
+ return json.load(f)
77
+
78
+
79
+ def extract_trial_scores(
80
+ results: dict,
81
+ metric: str = "root_cause_entity_f1"
82
+ ) -> dict[str, list[float]]:
83
+ """
84
+ Extract per-trial scores for a given metric from results.
85
+
86
+ Args:
87
+ results: Loaded JSON results
88
+ metric: The metric name to extract (default: root_cause_entity_f1)
89
+
90
+ Returns:
91
+ Dict mapping scenario_id -> list of trial scores
92
+ """
93
+ scenario_trials = {}
94
+
95
+ scenarios = results.get("scenarios", {})
96
+ for scenario_id, scenario_data in scenarios.items():
97
+ runs = scenario_data.get("runs", [])
98
+ trial_scores = []
99
+
100
+ for run in runs:
101
+ scores = run.get("scores", {})
102
+ score = scores.get(metric)
103
+
104
+ # Handle None/null values
105
+ if score is None:
106
+ score = 0.0
107
+ trial_scores.append(float(score))
108
+
109
+ if trial_scores:
110
+ scenario_trials[scenario_id] = trial_scores
111
+
112
+ return scenario_trials
113
+
114
+
115
+ def calculate_agreement_rate(trials: list[float], tolerance: float = 0.1) -> float:
116
+ """
117
+ Calculate agreement rate between trial pairs.
118
+
119
+ Args:
120
+ trials: List of trial scores
121
+ tolerance: Maximum difference to consider as "agreement"
122
+
123
+ Returns:
124
+ Fraction of trial pairs that agree (0-1)
125
+ """
126
+ from itertools import combinations
127
+
128
+ if len(trials) < 2:
129
+ return 1.0
130
+
131
+ pairs = list(combinations(trials, 2))
132
+ agreements = sum(1 for a, b in pairs if abs(a - b) <= tolerance)
133
+ return agreements / len(pairs)
134
+
135
+
136
+ def compute_icc(scenario_trials: dict[str, list[float]]) -> ConsistencyMetrics:
137
+ """
138
+ Compute ICC(1,1) - one-way random effects model.
139
+
140
+ The ICC formula:
141
+ ICC = (MSB - MSW) / (MSB + (k-1) * MSW)
142
+
143
+ Where:
144
+ MSB = k * Var(scenario_means) [between-scenario variance]
145
+ MSW = Mean(Var(trials per scenario)) [within-scenario variance]
146
+ k = number of trials per scenario
147
+
148
+ Args:
149
+ scenario_trials: Dict mapping scenario_id -> list of trial scores
150
+
151
+ Returns:
152
+ ConsistencyMetrics with ICC and related metrics
153
+ """
154
+ # Convert to numpy array
155
+ scenarios = list(scenario_trials.keys())
156
+
157
+ # Ensure all scenarios have same number of trials
158
+ n_trials_list = [len(trials) for trials in scenario_trials.values()]
159
+ if len(set(n_trials_list)) > 1:
160
+ # Pad or truncate to minimum
161
+ k = min(n_trials_list)
162
+ scores = np.array([scenario_trials[s][:k] for s in scenarios])
163
+ else:
164
+ k = n_trials_list[0] if n_trials_list else 0
165
+ scores = np.array([scenario_trials[s] for s in scenarios])
166
+
167
+ n_scenarios = len(scenarios)
168
+
169
+ if n_scenarios == 0 or k == 0:
170
+ return ConsistencyMetrics(
171
+ icc=float('nan'),
172
+ flakiness_ratio=float('nan'),
173
+ msb=0.0,
174
+ msw=0.0,
175
+ mean_within_std=0.0,
176
+ mean_agreement_rate=1.0,
177
+ repeatability_coefficient=0.0,
178
+ n_scenarios=0,
179
+ n_trials=0,
180
+ n_flaky_scenarios=0,
181
+ )
182
+
183
+ # Calculate scenario means
184
+ scenario_means = np.mean(scores, axis=1)
185
+
186
+ # Between-scenario variance (MSB)
187
+ # MSB = k * Var(scenario means)
188
+ msb = k * np.var(scenario_means, ddof=1) if n_scenarios > 1 else 0.0
189
+
190
+ # Within-scenario variance (MSW)
191
+ # MSW = average of within-scenario variances
192
+ within_vars = np.var(scores, axis=1, ddof=1) if k > 1 else np.zeros(n_scenarios)
193
+ msw = np.mean(within_vars)
194
+
195
+ # ICC(1,1) formula
196
+ denominator = msb + (k - 1) * msw
197
+ if denominator > 0:
198
+ icc = (msb - msw) / denominator
199
+ icc = max(0.0, icc) # ICC can be negative, clip to 0
200
+ else:
201
+ icc = float('nan') if msw == 0 and msb == 0 else 0.0
202
+
203
+ # Within-scenario standard deviations
204
+ within_stds = np.std(scores, axis=1, ddof=1) if k > 1 else np.zeros(n_scenarios)
205
+ mean_within_std = np.mean(within_stds)
206
+
207
+ # Agreement rates
208
+ agreement_rates = [
209
+ calculate_agreement_rate(scenario_trials[s])
210
+ for s in scenarios
211
+ ]
212
+ mean_agreement_rate = np.mean(agreement_rates)
213
+
214
+ # Repeatability coefficient (95% of repeat differences < RC)
215
+ rc = 1.96 * np.sqrt(2 * msw) if msw > 0 else 0.0
216
+
217
+ # Identify flaky scenarios (high within-variance)
218
+ flaky_threshold = 0.3
219
+ flaky_scenarios = [
220
+ (s, float(std))
221
+ for s, std in zip(scenarios, within_stds)
222
+ if std > flaky_threshold
223
+ ]
224
+
225
+ # Per-scenario details
226
+ scenario_details = {}
227
+ for i, s in enumerate(scenarios):
228
+ scenario_details[s] = {
229
+ "trials": scenario_trials[s],
230
+ "mean": float(scenario_means[i]),
231
+ "std": float(within_stds[i]),
232
+ "agreement_rate": agreement_rates[i],
233
+ "is_flaky": within_stds[i] > flaky_threshold,
234
+ }
235
+
236
+ return ConsistencyMetrics(
237
+ icc=float(icc),
238
+ flakiness_ratio=float(1 - icc) if not np.isnan(icc) else float('nan'),
239
+ msb=float(msb),
240
+ msw=float(msw),
241
+ mean_within_std=float(mean_within_std),
242
+ mean_agreement_rate=float(mean_agreement_rate),
243
+ repeatability_coefficient=float(rc),
244
+ n_scenarios=n_scenarios,
245
+ n_trials=k,
246
+ n_flaky_scenarios=len(flaky_scenarios),
247
+ flaky_scenarios=flaky_scenarios,
248
+ scenario_details=scenario_details,
249
+ )
250
+
251
+
252
+ def analyze_results_file(
253
+ filepath: Path,
254
+ metrics: list[str] | None = None,
255
+ ) -> dict[str, ConsistencyMetrics]:
256
+ """
257
+ Analyze a single results file for multiple metrics.
258
+
259
+ Args:
260
+ filepath: Path to the results JSON file
261
+ metrics: List of metrics to analyze. Defaults to common metrics.
262
+
263
+ Returns:
264
+ Dict mapping metric_name -> ConsistencyMetrics
265
+ """
266
+ if metrics is None:
267
+ metrics = [
268
+ "root_cause_entity_f1",
269
+ "root_cause_proximity_with_fp_f1",
270
+ "propagation_chain",
271
+ ]
272
+
273
+ results = load_results(filepath)
274
+
275
+ analysis = {}
276
+ for metric in metrics:
277
+ scenario_trials = extract_trial_scores(results, metric)
278
+ if scenario_trials:
279
+ analysis[metric] = compute_icc(scenario_trials)
280
+
281
+ return analysis
282
+
283
+
284
+ def compare_models(
285
+ results_dir: Path,
286
+ model_patterns: list[str],
287
+ metric: str = "root_cause_entity_f1",
288
+ ) -> dict[str, ConsistencyMetrics]:
289
+ """
290
+ Compare ICC across multiple models.
291
+
292
+ Args:
293
+ results_dir: Directory containing results JSON files
294
+ model_patterns: List of model name patterns to match
295
+ metric: The metric to analyze
296
+
297
+ Returns:
298
+ Dict mapping model_name -> ConsistencyMetrics
299
+ """
300
+ comparison = {}
301
+
302
+ for pattern in model_patterns:
303
+ # Find matching file
304
+ matches = list(results_dir.glob(f"*{pattern}*.json"))
305
+
306
+ if not matches:
307
+ print(f"Warning: No file found for pattern '{pattern}'", file=sys.stderr)
308
+ continue
309
+
310
+ filepath = matches[0]
311
+ print(f"Analyzing: {filepath.name}")
312
+
313
+ results = load_results(filepath)
314
+ scenario_trials = extract_trial_scores(results, metric)
315
+
316
+ if scenario_trials:
317
+ model_name = results.get("agent_name", filepath.stem)
318
+ comparison[model_name] = compute_icc(scenario_trials)
319
+
320
+ return comparison
321
+
322
+
323
+ def print_comparison_table(
324
+ comparison: dict[str, ConsistencyMetrics],
325
+ metric: str,
326
+ ) -> None:
327
+ """Print a formatted comparison table."""
328
+ print(f"\n{'='*80}")
329
+ print(f"ICC Comparison for metric: {metric}")
330
+ print(f"{'='*80}\n")
331
+
332
+ # Header
333
+ print(f"{'Model':<55} {'ICC':>8} {'Flaky%':>8} {'Std':>8} {'Agree%':>8}")
334
+ print("-" * 91)
335
+
336
+ # Sort by ICC descending
337
+ sorted_models = sorted(
338
+ comparison.items(),
339
+ key=lambda x: x[1].icc if not np.isnan(x[1].icc) else -1,
340
+ reverse=True
341
+ )
342
+
343
+ for model, metrics in sorted_models:
344
+ # Truncate model name if too long
345
+ display_name = model[:52] + "..." if len(model) > 55 else model
346
+
347
+ icc_str = f"{metrics.icc:.4f}" if not np.isnan(metrics.icc) else "N/A"
348
+ flaky_pct = f"{metrics.flakiness_ratio*100:.1f}%" if not np.isnan(metrics.flakiness_ratio) else "N/A"
349
+
350
+ print(
351
+ f"{display_name:<55} "
352
+ f"{icc_str:>8} "
353
+ f"{flaky_pct:>8} "
354
+ f"{metrics.mean_within_std:>8.4f} "
355
+ f"{metrics.mean_agreement_rate*100:>7.1f}%"
356
+ )
357
+
358
+ print("\nInterpretation:")
359
+ print(" ICC > 0.9: Excellent consistency")
360
+ print(" ICC 0.75-0.9: Good consistency")
361
+ print(" ICC 0.5-0.75: Moderate consistency")
362
+ print(" ICC < 0.5: Poor consistency (high flakiness)")
363
+
364
+
365
+ def print_detailed_report(
366
+ model_name: str,
367
+ metrics_analysis: dict[str, ConsistencyMetrics],
368
+ ) -> None:
369
+ """Print detailed report for a single model."""
370
+ print(f"\n{'='*80}")
371
+ print(f"Detailed Consistency Report: {model_name}")
372
+ print(f"{'='*80}\n")
373
+
374
+ for metric_name, cm in metrics_analysis.items():
375
+ print(f"\n--- {metric_name} ---")
376
+ print(cm)
377
+
378
+ if cm.flaky_scenarios:
379
+ print(f"\nFlaky scenarios (std > 0.3):")
380
+ for scenario, std in sorted(cm.flaky_scenarios, key=lambda x: -x[1])[:10]:
381
+ details = cm.scenario_details.get(scenario, {})
382
+ trials = details.get("trials", [])
383
+ print(f" {scenario}: std={std:.3f}, trials={trials}")
384
+
385
+
386
+ def main():
387
+ parser = argparse.ArgumentParser(
388
+ description="Calculate ICC and consistency metrics for leaderboard results",
389
+ formatter_class=argparse.RawDescriptionHelpFormatter,
390
+ epilog=__doc__,
391
+ )
392
+
393
+ parser.add_argument(
394
+ "--results-dir",
395
+ type=Path,
396
+ default=Path("leaderboard_results/results"),
397
+ help="Directory containing results JSON files",
398
+ )
399
+
400
+ parser.add_argument(
401
+ "--results-file",
402
+ type=Path,
403
+ help="Analyze a single results file",
404
+ )
405
+
406
+ parser.add_argument(
407
+ "--models",
408
+ nargs="+",
409
+ default=[
410
+ "react with code_Azure_o4-mini",
411
+ "react with code_Azure_gpt-5.1-2025-11-13",
412
+ "react with code_gcp_gemini-3-pro-preview",
413
+ "react with code_GCP_gemini-2.5-pro",
414
+ ],
415
+ help="Model name patterns to compare",
416
+ )
417
+
418
+ parser.add_argument(
419
+ "--metric",
420
+ type=str,
421
+ default="root_cause_entity_f1",
422
+ help="Metric to analyze (default: root_cause_entity_f1)",
423
+ )
424
+
425
+ parser.add_argument(
426
+ "--all-metrics",
427
+ action="store_true",
428
+ help="Analyze all common metrics",
429
+ )
430
+
431
+ parser.add_argument(
432
+ "--detailed",
433
+ action="store_true",
434
+ help="Show detailed per-scenario breakdown",
435
+ )
436
+
437
+ parser.add_argument(
438
+ "--output-json",
439
+ type=Path,
440
+ help="Save results to JSON file",
441
+ )
442
+
443
+ args = parser.parse_args()
444
+
445
+ # Determine metrics to analyze
446
+ if args.all_metrics:
447
+ metrics = [
448
+ "root_cause_entity_f1",
449
+ "root_cause_entity_precision",
450
+ "root_cause_entity_recall",
451
+ "root_cause_proximity_with_fp_f1",
452
+ "propagation_chain",
453
+ "fault_localization_component_identification",
454
+ ]
455
+ else:
456
+ metrics = [args.metric]
457
+
458
+ results_to_save = {}
459
+
460
+ if args.results_file:
461
+ # Single file analysis
462
+ print(f"Analyzing: {args.results_file}")
463
+ analysis = analyze_results_file(args.results_file, metrics)
464
+
465
+ results = load_results(args.results_file)
466
+ model_name = results.get("agent_name", args.results_file.stem)
467
+
468
+ print_detailed_report(model_name, analysis)
469
+
470
+ results_to_save[model_name] = {
471
+ m: {
472
+ "icc": cm.icc,
473
+ "flakiness_ratio": cm.flakiness_ratio,
474
+ "mean_within_std": cm.mean_within_std,
475
+ "mean_agreement_rate": cm.mean_agreement_rate,
476
+ "n_flaky_scenarios": cm.n_flaky_scenarios,
477
+ "n_scenarios": cm.n_scenarios,
478
+ }
479
+ for m, cm in analysis.items()
480
+ }
481
+ else:
482
+ # Multi-model comparison
483
+ for metric in metrics:
484
+ comparison = compare_models(args.results_dir, args.models, metric)
485
+ print_comparison_table(comparison, metric)
486
+
487
+ # Store results
488
+ for model, cm in comparison.items():
489
+ if model not in results_to_save:
490
+ results_to_save[model] = {}
491
+ results_to_save[model][metric] = {
492
+ "icc": cm.icc if not np.isnan(cm.icc) else None,
493
+ "flakiness_ratio": cm.flakiness_ratio if not np.isnan(cm.flakiness_ratio) else None,
494
+ "mean_within_std": cm.mean_within_std,
495
+ "mean_agreement_rate": cm.mean_agreement_rate,
496
+ "n_flaky_scenarios": cm.n_flaky_scenarios,
497
+ "n_scenarios": cm.n_scenarios,
498
+ }
499
+
500
+ if args.detailed:
501
+ for model, cm in comparison.items():
502
+ print_detailed_report(model, {metric: cm})
503
+
504
+ # Save to JSON if requested
505
+ if args.output_json:
506
+ with open(args.output_json, "w") as f:
507
+ json.dump(results_to_save, f, indent=2)
508
+ print(f"\nResults saved to: {args.output_json}")
509
+
510
+
511
+ if __name__ == "__main__":
512
+ main()
513
+
analysis_src/extract_consistency_data.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract consistency (ICC) and performance data for all 'react with code' agents.
4
+
5
+ This script reads directly from the run directories (not JSON result files)
6
+ to ensure all trials are captured.
7
+
8
+ Output is saved to paper_analysis/react with code/resources/figures/consistency/ as CSV files for plotting.
9
+ """
10
+
11
+ import json
12
+ import sys
13
+ from pathlib import Path
14
+ from dataclasses import dataclass, field
15
+ import numpy as np
16
+ import pandas as pd
17
+ from tqdm import tqdm
18
+
19
+ # Add project root to path
20
+ PROJECT_ROOT = Path(__file__).parent.parent
21
+ sys.path.insert(0, str(PROJECT_ROOT))
22
+
23
+ from src.consistency import (
24
+ compute_icc,
25
+ ConsistencyMetrics,
26
+ )
27
+ from src.utils import (
28
+ get_model_name,
29
+ find_react_with_code_dirs,
30
+ read_judge_outputs_from_dir,
31
+ extract_trial_scores_from_judge_outputs,
32
+ get_runs_stats,
33
+ filter_scenarios_with_min_runs,
34
+ )
35
+
36
+ # Paths
37
+ LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
38
+ RESULTS_JSON_DIR = LEADERBOARD_DIR / "results"
39
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "consistency"
40
+
41
+ # Minimum runs per scenario required for inclusion
42
+ MIN_RUNS_PER_SCENARIO = 3
43
+
44
+ # Minimum scenarios needed after filtering (must have at least this many with 3+ runs)
45
+ MIN_QUALIFYING_SCENARIOS = 20
46
+
47
+ # Metrics to analyze
48
+ METRICS = [
49
+ "root_cause_entity_f1",
50
+ "root_cause_entity_precision",
51
+ "root_cause_entity_recall",
52
+ "root_cause_proximity_with_fp_f1",
53
+ "propagation_chain",
54
+ "fault_localization_component_identification",
55
+ ]
56
+
57
+ # Short names for display
58
+ METRIC_SHORT_NAMES = {
59
+ "root_cause_entity_f1": "RC Entity F1",
60
+ "root_cause_entity_precision": "RC Entity Prec",
61
+ "root_cause_entity_recall": "RC Entity Rec",
62
+ "root_cause_proximity_with_fp_f1": "RC Proximity F1",
63
+ "propagation_chain": "Prop. Chain",
64
+ "fault_localization_component_identification": "Fault Loc.",
65
+ }
66
+
67
+ def extract_all_data() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
68
+ """
69
+ Extract ICC and performance data for all agents by reading from directories.
70
+
71
+ Returns:
72
+ - icc_df: ICC scores per model per metric
73
+ - perf_df: Performance averages per model per metric
74
+ - scenario_df: Per-scenario breakdown
75
+ """
76
+ agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
77
+ print(f"Found {len(agent_dirs)} 'react with code' agent directories:")
78
+ for d in agent_dirs:
79
+ print(f" - {d.name}")
80
+
81
+ icc_records = []
82
+ perf_records = []
83
+ scenario_records = []
84
+
85
+ valid_models = []
86
+ skipped_models = []
87
+
88
+ for agent_dir in tqdm(agent_dirs, desc="Reading agent data"):
89
+ model_name = get_model_name(agent_dir.name)
90
+
91
+ print(f"\nReading: {agent_dir.name}")
92
+ scenario_data = read_judge_outputs_from_dir(agent_dir)
93
+
94
+ n_scenarios, min_runs, max_runs, n_qualifying = get_runs_stats(scenario_data, MIN_RUNS_PER_SCENARIO)
95
+
96
+ if n_scenarios == 0:
97
+ print(f" SKIPPING {model_name}: No judge outputs found")
98
+ skipped_models.append((model_name, "No data", 0))
99
+ continue
100
+
101
+ if n_qualifying < MIN_QUALIFYING_SCENARIOS:
102
+ print(f" SKIPPING {model_name}: Only {n_qualifying}/{n_scenarios} scenarios have {MIN_RUNS_PER_SCENARIO}+ runs")
103
+ skipped_models.append((model_name, f"{n_qualifying}/{n_scenarios} qualifying", n_qualifying))
104
+ continue
105
+
106
+ # Filter to only include scenarios with enough runs
107
+ scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
108
+ n_scenarios_filtered = len(scenario_data)
109
+
110
+ print(f" Processing: {model_name} ({n_scenarios_filtered} scenarios with {MIN_RUNS_PER_SCENARIO}+ runs)")
111
+ valid_models.append(model_name)
112
+
113
+ for metric in tqdm(METRICS, desc=f" {model_name} metrics", leave=False):
114
+ # Extract trial scores
115
+ scenario_trials = extract_trial_scores_from_judge_outputs(scenario_data, metric)
116
+
117
+ if not scenario_trials:
118
+ continue
119
+
120
+ # Calculate performance average
121
+ all_scores = [s for trials in scenario_trials.values() for s in trials]
122
+ perf_avg = np.mean(all_scores) if all_scores else 0.0
123
+
124
+ perf_records.append({
125
+ "model": model_name,
126
+ "metric": METRIC_SHORT_NAMES.get(metric, metric),
127
+ "metric_raw": metric,
128
+ "performance": perf_avg,
129
+ })
130
+
131
+ # ICC calculation
132
+ try:
133
+ icc_metrics = compute_icc(scenario_trials)
134
+
135
+ icc_records.append({
136
+ "model": model_name,
137
+ "metric": METRIC_SHORT_NAMES.get(metric, metric),
138
+ "metric_raw": metric,
139
+ "icc": icc_metrics.icc if not np.isnan(icc_metrics.icc) else 0.0,
140
+ "flakiness": icc_metrics.flakiness_ratio if not np.isnan(icc_metrics.flakiness_ratio) else 1.0,
141
+ "within_std": icc_metrics.mean_within_std,
142
+ "agreement_rate": icc_metrics.mean_agreement_rate,
143
+ "n_flaky_scenarios": icc_metrics.n_flaky_scenarios,
144
+ "n_scenarios": icc_metrics.n_scenarios,
145
+ })
146
+
147
+ # Per-scenario data (only for root_cause_entity_f1)
148
+ if metric == "root_cause_entity_f1":
149
+ for scenario_id, details in icc_metrics.scenario_details.items():
150
+ scenario_records.append({
151
+ "model": model_name,
152
+ "scenario": scenario_id,
153
+ "mean": details["mean"],
154
+ "std": details["std"],
155
+ "trials": details["trials"],
156
+ "is_flaky": details["is_flaky"],
157
+ })
158
+ except Exception as e:
159
+ print(f" Error computing ICC for {metric}: {e}")
160
+ continue
161
+
162
+ if skipped_models:
163
+ print(f"\n⚠️ Skipped {len(skipped_models)} models:")
164
+ for name, reason, _ in skipped_models:
165
+ print(f" - {name}: {reason}")
166
+
167
+ print(f"\n✓ Included {len(valid_models)} models: {valid_models}")
168
+
169
+ icc_df = pd.DataFrame(icc_records)
170
+ perf_df = pd.DataFrame(perf_records)
171
+ scenario_df = pd.DataFrame(scenario_records)
172
+
173
+ return icc_df, perf_df, scenario_df
174
+
175
+
176
+ def save_data(icc_df: pd.DataFrame, perf_df: pd.DataFrame, scenario_df: pd.DataFrame):
177
+ """Save extracted data to CSV files."""
178
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
179
+
180
+ icc_path = OUTPUT_DIR / "icc_data.csv"
181
+ perf_path = OUTPUT_DIR / "performance_data.csv"
182
+ scenario_path = OUTPUT_DIR / "scenario_data.csv"
183
+
184
+ icc_df.to_csv(icc_path, index=False)
185
+ perf_df.to_csv(perf_path, index=False)
186
+ scenario_df.to_csv(scenario_path, index=False)
187
+
188
+ print(f"\nData saved to:")
189
+ print(f" - {icc_path}")
190
+ print(f" - {perf_path}")
191
+ print(f" - {scenario_path}")
192
+
193
+ # Also save a summary JSON
194
+ summary = {
195
+ "models": icc_df["model"].unique().tolist(),
196
+ "metrics": icc_df["metric"].unique().tolist(),
197
+ "n_scenarios": int(icc_df["n_scenarios"].max()) if len(icc_df) > 0 else 0,
198
+ "min_runs_required": MIN_RUNS_PER_SCENARIO,
199
+ }
200
+
201
+ summary_path = OUTPUT_DIR / "analysis_summary.json"
202
+ with open(summary_path, "w") as f:
203
+ json.dump(summary, f, indent=2)
204
+ print(f" - {summary_path}")
205
+
206
+
207
+ def print_summary(icc_df: pd.DataFrame, perf_df: pd.DataFrame):
208
+ """Print summary tables."""
209
+ print("\n" + "="*80)
210
+ print("ICC Summary (root_cause_entity_f1)")
211
+ print("="*80)
212
+
213
+ rc_icc = icc_df[icc_df["metric_raw"] == "root_cause_entity_f1"].copy()
214
+ rc_icc = rc_icc.sort_values("icc", ascending=False)
215
+
216
+ print(f"\n{'Model':<20} {'ICC':>8} {'Flaky%':>8} {'Std':>8} {'Agree%':>8}")
217
+ print("-" * 56)
218
+ for _, row in rc_icc.iterrows():
219
+ print(f"{row['model']:<20} {row['icc']:>8.4f} {row['flakiness']*100:>7.1f}% {row['within_std']:>8.4f} {row['agreement_rate']*100:>7.1f}%")
220
+
221
+ print("\n" + "="*80)
222
+ print("Performance Summary (root_cause_entity_f1)")
223
+ print("="*80)
224
+
225
+ rc_perf = perf_df[perf_df["metric_raw"] == "root_cause_entity_f1"].copy()
226
+ rc_perf = rc_perf.sort_values("performance", ascending=False)
227
+
228
+ print(f"\n{'Model':<20} {'Avg Score':>12}")
229
+ print("-" * 34)
230
+ for _, row in rc_perf.iterrows():
231
+ print(f"{row['model']:<20} {row['performance']:>12.4f}")
232
+
233
+
234
+ def main():
235
+ print("Extracting consistency data for 'react with code' agents...")
236
+ print(f"Reading from directories: {LEADERBOARD_DIR}")
237
+ print(f"Output directory: {OUTPUT_DIR}")
238
+ print(f"Minimum runs per scenario: {MIN_RUNS_PER_SCENARIO}")
239
+
240
+ icc_df, perf_df, scenario_df = extract_all_data()
241
+
242
+ if len(icc_df) == 0:
243
+ print("No data extracted!")
244
+ return
245
+
246
+ save_data(icc_df, perf_df, scenario_df)
247
+ print_summary(icc_df, perf_df)
248
+
249
+
250
+ if __name__ == "__main__":
251
+ main()
analysis_src/extract_discovery_trajectory.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Root Cause Discovery Trajectory Analysis
4
+
5
+ Analyzes how agents discover root cause entities:
6
+ - T_encounter: When GT entity first appears in tool output
7
+ - T_investigate: When agent actively queries GT entity
8
+ - T_assert: When agent asserts GT entity as root cause
9
+ - T_exonerate: When agent dismisses GT entity (if ever)
10
+ - T_recover: When agent corrects after exoneration
11
+
12
+ Metrics computed:
13
+ - Discovery efficiency (how early GT appears)
14
+ - Investigation delay (turns between seeing and investigating)
15
+ - Assertion delay (turns to confirm after investigating)
16
+ - Recovery rate (% of trials with successful recovery)
17
+ """
18
+
19
+ import json
20
+ import sys
21
+ import re
22
+ import yaml
23
+ from pathlib import Path
24
+ from dataclasses import dataclass, field, asdict
25
+ from typing import Optional, List, Dict, Any
26
+ import numpy as np
27
+ import pandas as pd
28
+ import matplotlib.pyplot as plt
29
+ import seaborn as sns
30
+ import plotly.graph_objects as go
31
+ from tqdm import tqdm
32
+
33
+ PROJECT_ROOT = Path(__file__).parent.parent
34
+ sys.path.insert(0, str(PROJECT_ROOT))
35
+
36
+ from src.utils import find_latest_rollout_file
37
+
38
+ from src.model_styles import (
39
+ get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
40
+ )
41
+
42
+ # Improved regex to capture standard K8s resource patterns: namespace/Kind/name
43
+ # Captures: (namespace, Kind, name)
44
+ K8S_ENTITY_PATTERN = re.compile(r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|Node)/([\w-]+)', re.IGNORECASE)
45
+
46
+ def extract_k8s_entities(text: str) -> List[str]:
47
+ """Extract all K8s entities matching the standard pattern."""
48
+ matches = K8S_ENTITY_PATTERN.findall(text)
49
+ entities = []
50
+ for m in matches:
51
+ # Normalize to namespace/Kind/name
52
+ entity = f"{m[0]}/{m[1]}/{m[2]}"
53
+ entities.append(entity)
54
+ return entities
55
+
56
+ # Paths
57
+ PROJECT_ROOT = Path(__file__).parent.parent
58
+ LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
59
+ GT_DIR = PROJECT_ROOT / "data" / "itbench-snapshots"
60
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "discovery"
61
+
62
+ @dataclass
63
+ class GroundTruth:
64
+ """Ground truth root cause entity info."""
65
+ scenario: str
66
+ entity_name: str
67
+ entity_kind: str
68
+ group_id: str
69
+ filters: List[str] # regex patterns to match entity
70
+ aliases: List[str] # related entity group IDs
71
+ propagation_entities: set = field(default_factory=set) # All entities involved in propagation
72
+ all_entities: list = field(default_factory=list) # All entities defined in the scenario
73
+ entity_filters: Dict[str, List[str]] = field(default_factory=dict) # group_id -> filters mapping for all entities
74
+
75
+
76
+ @dataclass
77
+ class EntityMention:
78
+ """A mention of an entity in the agent's trajectory."""
79
+ turn: int
80
+ mention_type: str # 'encounter', 'investigate', 'assert', 'exonerate'
81
+ context: str # 'tool_output', 'tool_args', 'reasoning', 'final_output'
82
+ text_snippet: str
83
+ sentiment: str # 'positive', 'negative', 'neutral'
84
+
85
+
86
+ @dataclass
87
+ class TrajectoryAnalysis:
88
+ """Analysis results for a single trial."""
89
+ model: str
90
+ scenario: str
91
+ trial: int
92
+ total_turns: int
93
+ gt_entity: str
94
+
95
+ # Key timestamps (turn numbers, None if not found)
96
+ t_encounter: Optional[int] = None
97
+ t_investigate: Optional[int] = None
98
+ t_assert: Optional[int] = None
99
+ t_exonerate: Optional[int] = None
100
+ t_recover: Optional[int] = None
101
+
102
+ # Final outcome (from judge scores if available)
103
+ final_success: bool = False # Did the final answer include GT?
104
+ root_cause_f1: Optional[float] = None
105
+
106
+ # Pipeline stage reached (for funnel analysis)
107
+ # 0=none, 1=encounter, 2=investigate, 3=assert, 4=success
108
+ max_stage_reached: int = 0
109
+
110
+ # All mentions for detailed analysis
111
+ mentions: List[EntityMention] = field(default_factory=list)
112
+
113
+ # Exploration metrics
114
+ total_entities_available: int = 0
115
+ unique_entities_encountered: int = 0
116
+ unique_entities_investigated: int = 0
117
+ exploration_ratio: float = 0.0 # investigated / available
118
+
119
+ # Coverage metrics
120
+ on_chain_investigated: int = 0
121
+ off_chain_investigated: int = 0 # Detoured
122
+ propagation_coverage: float = 0.0 # % of chain entities investigated
123
+ detour_rate: float = 0.0 # off_chain / total_investigated
124
+
125
+ # Computed metrics
126
+ discovery_efficiency: Optional[float] = None # t_encounter / total_turns
127
+ investigation_delay: Optional[int] = None # t_investigate - t_encounter
128
+ assertion_delay: Optional[int] = None # t_assert - t_investigate
129
+ had_recovery: bool = False
130
+
131
+
132
+ def check_entity_match(text: str, entity_info: Dict) -> bool:
133
+ """Check if text matches an arbitrary entity definition."""
134
+ text_lower = text.lower()
135
+ entity_name = entity_info.get('id', '').lower()
136
+
137
+ # Check direct name match
138
+ if entity_name and entity_name in text_lower:
139
+ return True
140
+
141
+ # Check filters
142
+ filters = entity_info.get('filter', [])
143
+ for pattern in filters:
144
+ search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '')
145
+ if search_term and search_term.lower() in text_lower:
146
+ return True
147
+
148
+ return False
149
+
150
+
151
+ def load_ground_truth(scenario: str) -> Optional[GroundTruth]:
152
+ """Load and parse ground truth YAML for a scenario."""
153
+ gt_path = GT_DIR / scenario / "ground_truth.yaml"
154
+ if not gt_path.exists():
155
+ return None
156
+
157
+ with open(gt_path) as f:
158
+ gt_data = yaml.safe_load(f)
159
+
160
+ # Find the root cause group
161
+ root_cause_group = None
162
+ all_groups = gt_data.get('groups', [])
163
+
164
+ for group in all_groups:
165
+ if group.get('root_cause', False):
166
+ root_cause_group = group
167
+ break
168
+
169
+ if not root_cause_group:
170
+ return None
171
+
172
+ # Get fault entity info
173
+ fault_list = gt_data.get('fault', [])
174
+ fault_info = fault_list[0] if fault_list else {}
175
+ entity_info = fault_info.get('entity', {})
176
+
177
+ # Collect all aliases
178
+ aliases = []
179
+ for alias_group in gt_data.get('aliases', []):
180
+ if root_cause_group['id'] in alias_group:
181
+ aliases.extend(alias_group)
182
+
183
+ # Collect all entities in propagation chain
184
+ propagation_entities = set()
185
+ for prop in gt_data.get('propagations', []):
186
+ if 'source' in prop:
187
+ propagation_entities.add(prop['source'])
188
+ if 'target' in prop:
189
+ propagation_entities.add(prop['target'])
190
+
191
+ # Add root cause itself if not already there (it should be as source)
192
+ propagation_entities.add(root_cause_group['id'])
193
+
194
+ # Build entity_filters mapping: group_id -> list of filter patterns
195
+ entity_filters = {}
196
+ for group in all_groups:
197
+ group_id = group.get('id', '')
198
+ filters = group.get('filter', [])
199
+ # Also use the group id itself and 'name' field as filters
200
+ name = group.get('name', '')
201
+ all_filters = list(filters) if filters else []
202
+ if group_id:
203
+ all_filters.append(group_id)
204
+ if name and name != group_id:
205
+ all_filters.append(name)
206
+ entity_filters[group_id] = all_filters
207
+
208
+ gt_obj = GroundTruth(
209
+ scenario=scenario,
210
+ entity_name=entity_info.get('name', root_cause_group['id']),
211
+ entity_kind=root_cause_group.get('kind', 'Unknown'),
212
+ group_id=root_cause_group['id'],
213
+ filters=root_cause_group.get('filter', []),
214
+ aliases=aliases,
215
+ propagation_entities=propagation_entities,
216
+ entity_filters=entity_filters
217
+ )
218
+
219
+ # Attach all entities for exploration analysis
220
+ gt_obj.all_entities = all_groups
221
+ return gt_obj
222
+
223
+
224
+ def entity_matches(text: str, gt: GroundTruth) -> bool:
225
+ """Check if text mentions the ground truth entity."""
226
+ text_lower = text.lower()
227
+
228
+ # Check direct name match
229
+ if gt.entity_name.lower() in text_lower:
230
+ return True
231
+
232
+ # Check group_id match
233
+ if gt.group_id.lower().replace('-', ' ') in text_lower.replace('-', ' '):
234
+ return True
235
+ if gt.group_id.lower().replace('-', '') in text_lower.replace('-', ''):
236
+ return True
237
+
238
+ # Check filter patterns
239
+ for pattern in gt.filters:
240
+ # Convert filter pattern to regex-friendly form
241
+ search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '')
242
+ if search_term.lower() in text_lower:
243
+ return True
244
+
245
+ # Check aliases
246
+ for alias in gt.aliases:
247
+ alias_term = alias.replace('-', ' ').lower()
248
+ if alias_term in text_lower.replace('-', ' '):
249
+ return True
250
+
251
+ return False
252
+
253
+
254
+ def is_entity_on_chain(entity_str: str, gt: GroundTruth) -> Optional[str]:
255
+ """
256
+ Check if an entity string matches any entity in the fault propagation chain.
257
+ Returns the matched group_id if on-chain, None if off-chain.
258
+
259
+ entity_str: e.g., "otel-demo/Pod/frontend-abc123" or just "frontend"
260
+ """
261
+ entity_lower = entity_str.lower()
262
+
263
+ # For each propagation entity, check if entity_str matches its filters
264
+ for group_id in gt.propagation_entities:
265
+ filters = gt.entity_filters.get(group_id, [])
266
+
267
+ # Check group_id itself
268
+ if group_id.lower() in entity_lower or entity_lower in group_id.lower():
269
+ return group_id
270
+
271
+ # Check filter patterns
272
+ for pattern in filters:
273
+ # Clean up the regex pattern for simple matching
274
+ search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '').replace('\\', '')
275
+ if search_term and len(search_term) > 2: # Avoid too short matches
276
+ if search_term.lower() in entity_lower:
277
+ return group_id
278
+
279
+ return None
280
+
281
+
282
+ def get_entity_group_match(entity_str: str, gt: GroundTruth) -> Optional[str]:
283
+ """
284
+ Check if an entity string matches any entity group in the scenario.
285
+ Returns the matched group_id if found, None otherwise.
286
+ """
287
+ entity_lower = entity_str.lower()
288
+
289
+ for group in gt.all_entities:
290
+ group_id = group.get('id', '')
291
+ filters = group.get('filter', [])
292
+ name = group.get('name', '')
293
+
294
+ # Check group_id
295
+ if group_id and (group_id.lower() in entity_lower or entity_lower in group_id.lower()):
296
+ return group_id
297
+
298
+ # Check name
299
+ if name and (name.lower() in entity_lower or entity_lower in name.lower()):
300
+ return group_id
301
+
302
+ # Check filter patterns
303
+ for pattern in filters:
304
+ search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '').replace('\\', '')
305
+ if search_term and len(search_term) > 2:
306
+ if search_term.lower() in entity_lower:
307
+ return group_id
308
+
309
+ return None
310
+
311
+
312
+ def classify_sentiment(text: str, gt: GroundTruth) -> str:
313
+ """Classify if mention is positive (asserting), negative (exonerating), or neutral."""
314
+ text_lower = text.lower()
315
+
316
+ # Find the sentence/context containing the entity
317
+ entity_term = gt.entity_name.lower()
318
+
319
+ # Positive indicators (asserting as root cause)
320
+ positive_patterns = [
321
+ r'root\s*cause',
322
+ r'is\s+the\s+cause',
323
+ r'caused\s+by',
324
+ r'source\s+of\s+(the\s+)?problem',
325
+ r'culprit',
326
+ r'responsible\s+for',
327
+ r'likely\s+cause',
328
+ r'appears\s+to\s+be\s+the\s+issue',
329
+ r'primary\s+issue',
330
+ r'main\s+issue',
331
+ ]
332
+
333
+ # Negative indicators (exonerating)
334
+ negative_patterns = [
335
+ r'not\s+the\s+(root\s*)?cause',
336
+ r'ruled\s+out',
337
+ r'is\s+not\s+responsible',
338
+ r'working\s+(correctly|normally|fine)',
339
+ r'healthy',
340
+ r'no\s+issues?\s+(found|detected)',
341
+ r'can\s+be\s+excluded',
342
+ r'unlikely\s+to\s+be',
343
+ ]
344
+
345
+ for pattern in positive_patterns:
346
+ if re.search(pattern, text_lower):
347
+ return 'positive'
348
+
349
+ for pattern in negative_patterns:
350
+ if re.search(pattern, text_lower):
351
+ return 'negative'
352
+
353
+ return 'neutral'
354
+
355
+
356
+ def get_latest_rollout(trial_dir: Path) -> Optional[Path]:
357
+ """Get the latest rollout file from a trial directory."""
358
+ sessions_dir = trial_dir / "sessions"
359
+ if not sessions_dir.exists():
360
+ return None
361
+
362
+ rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl"))
363
+ if not rollout_files:
364
+ return None
365
+
366
+ # Sort by modification time, get latest
367
+ return max(rollout_files, key=lambda p: p.stat().st_mtime)
368
+
369
+
370
+ def get_judge_score(trial_dir: Path) -> Optional[float]:
371
+ """Get root_cause_entity_f1 from judge output."""
372
+ judge_path = trial_dir / "judge_output.json"
373
+ if not judge_path.exists():
374
+ return None
375
+
376
+ try:
377
+ with open(judge_path) as f:
378
+ judge_data = json.load(f)
379
+ return judge_data.get('flat_scores', {}).get('root_cause_entity_f1')
380
+ except:
381
+ return None
382
+
383
+
384
+ def parse_rollout(rollout_path: Path, gt: GroundTruth) -> TrajectoryAnalysis:
385
+ """Parse a rollout file and extract entity mentions."""
386
+ mentions = []
387
+ turn_num = 0
388
+ total_turns = 0
389
+
390
+ t_encounter = None
391
+ t_investigate = None
392
+ t_assert = None
393
+ t_exonerate = None
394
+ t_recover = None
395
+
396
+ # Exploration tracking
397
+ # We want to track unique entities from the SCENARIO that were touched
398
+ # gt.filters contains patterns for the root cause.
399
+ # But we want patterns for ALL entities in the scenario.
400
+ # The GroundTruth class only has root cause info currently.
401
+ # We need to pass the full list of scenario entities.
402
+
403
+ # Wait, GroundTruth class needs updating first to hold all scenario entities.
404
+ # Currently it only holds root cause info.
405
+ # Let's assume the caller will update GroundTruth definition or pass a list of entities.
406
+
407
+ # Actually, let's update this function to work with the updated GroundTruth class
408
+ # which will be updated in the next step.
409
+
410
+ encountered_entities = set()
411
+ investigated_entities = set()
412
+
413
+ # Track which entity groups were investigated (on-chain vs off-chain)
414
+ on_chain_groups_investigated = set()
415
+ off_chain_groups_investigated = set()
416
+ all_groups_investigated = set()
417
+
418
+ with open(rollout_path) as f:
419
+ for line in f:
420
+ try:
421
+ obj = json.loads(line)
422
+ except json.JSONDecodeError:
423
+ continue
424
+
425
+ if obj.get('type') == 'turn_context':
426
+ turn_num += 1
427
+ total_turns = turn_num
428
+
429
+ if obj.get('type') != 'response_item':
430
+ continue
431
+
432
+ payload = obj.get('payload', {})
433
+
434
+ # Check tool outputs (encounter)
435
+ if payload.get('type') == 'function_call_output':
436
+ output = str(payload.get('output', ''))
437
+
438
+ # Check for root cause match
439
+ if entity_matches(output, gt):
440
+ sentiment = classify_sentiment(output, gt)
441
+ mentions.append(EntityMention(
442
+ turn=turn_num,
443
+ mention_type='encounter',
444
+ context='tool_output',
445
+ text_snippet=output[:200],
446
+ sentiment=sentiment
447
+ ))
448
+ if t_encounter is None:
449
+ t_encounter = turn_num
450
+
451
+ # Broad exploration check using Regex
452
+ found_entities = extract_k8s_entities(output)
453
+ for entity in found_entities:
454
+ encountered_entities.add(entity)
455
+
456
+ # Check tool arguments (investigate)
457
+ if payload.get('type') == 'function_call':
458
+ args = payload.get('arguments', {})
459
+ if isinstance(args, str):
460
+ try:
461
+ args = json.loads(args)
462
+ except:
463
+ args = {'raw': args}
464
+ args_str = json.dumps(args)
465
+
466
+ # Root cause check
467
+ if entity_matches(args_str, gt):
468
+ mentions.append(EntityMention(
469
+ turn=turn_num,
470
+ mention_type='investigate',
471
+ context='tool_args',
472
+ text_snippet=args_str[:200],
473
+ sentiment='neutral'
474
+ ))
475
+ if t_investigate is None:
476
+ t_investigate = turn_num
477
+
478
+ # Broad exploration check using Regex
479
+ found_entities = extract_k8s_entities(args_str)
480
+ for entity in found_entities:
481
+ investigated_entities.add(entity)
482
+
483
+ # Classify as on-chain or off-chain
484
+ on_chain_group = is_entity_on_chain(entity, gt)
485
+ if on_chain_group:
486
+ on_chain_groups_investigated.add(on_chain_group)
487
+ all_groups_investigated.add(on_chain_group)
488
+ else:
489
+ # Check if it matches any entity in scenario at all
490
+ any_group = get_entity_group_match(entity, gt)
491
+ if any_group:
492
+ off_chain_groups_investigated.add(any_group)
493
+ all_groups_investigated.add(any_group)
494
+
495
+ # Check update_plan for assertions/reasoning
496
+ if payload.get('name') == 'update_plan':
497
+ explanation = args.get('explanation', '')
498
+ if entity_matches(explanation, gt):
499
+ sentiment = classify_sentiment(explanation, gt)
500
+ mention_type = 'assert' if sentiment == 'positive' else ('exonerate' if sentiment == 'negative' else 'investigate')
501
+ mentions.append(EntityMention(
502
+ turn=turn_num,
503
+ mention_type=mention_type,
504
+ context='reasoning',
505
+ text_snippet=explanation[:200],
506
+ sentiment=sentiment
507
+ ))
508
+
509
+ if mention_type == 'assert' and t_assert is None:
510
+ t_assert = turn_num
511
+ elif mention_type == 'exonerate' and t_exonerate is None:
512
+ t_exonerate = turn_num
513
+
514
+ # Check shell commands for final output
515
+ if payload.get('name') == 'shell':
516
+ cmd = args.get('command', [])
517
+ cmd_str = ' '.join(cmd) if isinstance(cmd, list) else str(cmd)
518
+
519
+ # Look for output generation with root cause assertions
520
+ if ('output.json' in cmd_str or 'root_cause' in cmd_str.lower()) and entity_matches(cmd_str, gt):
521
+ sentiment = classify_sentiment(cmd_str, gt)
522
+ if sentiment == 'positive' or 'root_cause' in cmd_str.lower():
523
+ mentions.append(EntityMention(
524
+ turn=turn_num,
525
+ mention_type='assert',
526
+ context='final_output',
527
+ text_snippet=cmd_str[:300],
528
+ sentiment='positive'
529
+ ))
530
+ if t_assert is None:
531
+ t_assert = turn_num
532
+
533
+ # Check for recovery (exoneration followed by assertion)
534
+ had_recovery = False
535
+ if t_exonerate is not None and t_assert is not None and t_exonerate < t_assert:
536
+ had_recovery = True
537
+ t_recover = t_assert
538
+
539
+ # Compute metrics
540
+ discovery_efficiency = t_encounter / total_turns if t_encounter and total_turns > 0 else None
541
+ investigation_delay = t_investigate - t_encounter if t_investigate and t_encounter else None
542
+ assertion_delay = t_assert - t_investigate if t_assert and t_investigate else None
543
+
544
+ # Compute max stage reached (without final success - that comes from judge)
545
+ # 0=none, 1=encounter, 2=investigate, 3=assert
546
+ max_stage = 0
547
+ if t_encounter is not None:
548
+ max_stage = 1
549
+ if t_investigate is not None:
550
+ max_stage = 2
551
+ if t_assert is not None:
552
+ max_stage = 3
553
+
554
+ # Exploration metrics
555
+ # Note: total_entities_available is hard to define with regex approach as we don't know the universe.
556
+ # We will use the number of encountered entities as the denominator for "investigation ratio"
557
+ # or just report the raw counts.
558
+
559
+ num_encountered = len(encountered_entities)
560
+ num_investigated = len(investigated_entities)
561
+
562
+ # Ratio: What % of things seen were actually investigated?
563
+ expl_ratio = num_investigated / num_encountered if num_encountered > 0 else 0.0
564
+
565
+ # Coverage metrics: on-chain (fault propagation) vs off-chain (detoured)
566
+ n_on_chain = len(on_chain_groups_investigated)
567
+ n_off_chain = len(off_chain_groups_investigated)
568
+ total_investigated_groups = len(all_groups_investigated)
569
+
570
+ # Propagation coverage: what % of the fault propagation chain was investigated?
571
+ n_propagation_entities = len(gt.propagation_entities)
572
+ prop_coverage = n_on_chain / n_propagation_entities if n_propagation_entities > 0 else 0.0
573
+
574
+ # Detour rate: what % of investigated entities were off-chain (not in fault propagation)?
575
+ det_rate = n_off_chain / total_investigated_groups if total_investigated_groups > 0 else 0.0
576
+
577
+ return TrajectoryAnalysis(
578
+ model="", # Set by caller
579
+ scenario="", # Set by caller
580
+ trial=0, # Set by caller
581
+ total_turns=total_turns,
582
+ gt_entity=gt.entity_name,
583
+ t_encounter=t_encounter,
584
+ t_investigate=t_investigate,
585
+ t_assert=t_assert,
586
+ t_exonerate=t_exonerate,
587
+ t_recover=t_recover,
588
+ max_stage_reached=max_stage,
589
+ mentions=mentions,
590
+ total_entities_available=num_encountered, # Using encountered as the "available" set
591
+ unique_entities_encountered=num_encountered,
592
+ unique_entities_investigated=num_investigated,
593
+ exploration_ratio=expl_ratio,
594
+ # Coverage metrics (on-chain vs off-chain)
595
+ on_chain_investigated=n_on_chain,
596
+ off_chain_investigated=n_off_chain,
597
+ propagation_coverage=prop_coverage,
598
+ detour_rate=det_rate,
599
+ # Computed metrics
600
+ discovery_efficiency=discovery_efficiency,
601
+ investigation_delay=investigation_delay,
602
+ assertion_delay=assertion_delay,
603
+ had_recovery=had_recovery
604
+ )
605
+
606
+
607
+ def analyze_model(model_dir: Path, gt_cache: Dict[str, GroundTruth]) -> List[TrajectoryAnalysis]:
608
+ """Analyze all trials for a model."""
609
+ results = []
610
+ model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
611
+
612
+ scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")]
613
+ for scenario_dir in tqdm(scenario_dirs, desc=f" {model_name} scenarios"):
614
+ scenario = scenario_dir.name
615
+ gt = gt_cache.get(scenario)
616
+ if gt is None:
617
+ continue
618
+
619
+ trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()]
620
+ for trial_dir in tqdm(trial_dirs, desc=f" {scenario} trials"):
621
+ trial_num = int(trial_dir.name)
622
+ rollout_path = find_latest_rollout_file(trial_dir)
623
+
624
+ if rollout_path is None:
625
+ continue
626
+
627
+ try:
628
+ analysis = parse_rollout(rollout_path, gt)
629
+ analysis.model = model_name
630
+ analysis.scenario = scenario
631
+ analysis.trial = trial_num
632
+
633
+ # Get judge score to determine final success
634
+ f1_score = get_judge_score(trial_dir)
635
+ analysis.root_cause_f1 = f1_score
636
+ if f1_score is not None and f1_score > 0:
637
+ analysis.final_success = True
638
+ analysis.max_stage_reached = 4 # Success!
639
+
640
+ results.append(analysis)
641
+ except Exception as e:
642
+ print(f"Error processing {model_name}/{scenario}/{trial_num}: {e}")
643
+
644
+ return results
645
+
646
+
647
+ def plot_pipeline_funnel(summary_df: pd.DataFrame):
648
+ """
649
+ Figure 1: Stacked bar showing where trials drop off in the pipeline.
650
+
651
+ Pipeline stages:
652
+ - Encounter: GT entity appears in tool OUTPUT (passive - agent didn't ask for it)
653
+ - Investigate: GT entity appears in tool ARGUMENTS (active - agent explicitly queried it)
654
+ - Assert: Agent declares GT as root cause
655
+ - Success: Judge confirms correct answer
656
+ """
657
+ # Filter out mistral (no data) and prepare data
658
+ data = summary_df[summary_df['encounter_rate'] > 0].copy()
659
+ data['model_clean'] = data['model'].apply(get_display_name)
660
+ data = data.sort_values('success_rate', ascending=True)
661
+
662
+ # Stack: none, encounter_only, investigate_only, assert_only, success
663
+ # Normalize to percentages
664
+ n_trials = data['n_trials']
665
+
666
+ none_pct = data['n_stage_0_none'] / n_trials * 100
667
+ enc_pct = data['n_stage_1_encounter_only'] / n_trials * 100
668
+ inv_pct = data['n_stage_2_investigate_only'] / n_trials * 100
669
+ ass_pct = data['n_stage_3_assert_only'] / n_trials * 100
670
+ suc_pct = data['n_stage_4_success'] / n_trials * 100
671
+
672
+ n_models = len(data)
673
+ y = np.arange(n_models)
674
+ bar_height = 0.7
675
+
676
+ plt.rcParams.update(PLOT_PARAMETERS)
677
+
678
+ STAGE_COLORS = {
679
+ 'none': '#d73027', # Red - never encountered GT
680
+ 'encounter': '#fc8d59', # Orange - saw but didn't investigate
681
+ 'investigate': '#fee08b', # Yellow - investigated but didn't assert
682
+ 'assert': '#d9ef8b', # Light green - asserted but wrong final answer
683
+ 'success': '#1a9850', # Green - success
684
+ }
685
+
686
+ # Create figure sized to fill half column with legend
687
+ fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 2.5))
688
+
689
+ # Plot stacked bars with GT prefix labels
690
+ ax.barh(y, none_pct, height=bar_height, label='RC never seen', color=STAGE_COLORS['none'],
691
+ edgecolor='white', linewidth=0.3)
692
+ ax.barh(y, enc_pct, height=bar_height, left=none_pct, label='RC seen, not queried',
693
+ color=STAGE_COLORS['encounter'], edgecolor='white', linewidth=0.3)
694
+ ax.barh(y, inv_pct, height=bar_height, left=none_pct + enc_pct, label='RC queried, not asserted',
695
+ color=STAGE_COLORS['investigate'], edgecolor='white', linewidth=0.3)
696
+ ax.barh(y, ass_pct, height=bar_height, left=none_pct + enc_pct + inv_pct, label='RC asserted, not in output',
697
+ color=STAGE_COLORS['assert'], edgecolor='white', linewidth=0.3)
698
+ ax.barh(y, suc_pct, height=bar_height, left=none_pct + enc_pct + inv_pct + ass_pct, label='RC asserted, in output',
699
+ color=STAGE_COLORS['success'], edgecolor='white', linewidth=0.3)
700
+
701
+ # Add percentage labels to each stack
702
+ min_pct_threshold = 4 # Only show labels for segments >= 2%
703
+ label_fontsize = MIN_FONT_SIZE - 3
704
+
705
+ for i, model_idx in enumerate(y):
706
+ segments = [
707
+ (none_pct.iloc[i], none_pct.iloc[i] / 2),
708
+ (enc_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] / 2),
709
+ (inv_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] / 2),
710
+ (ass_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] + ass_pct.iloc[i] / 2),
711
+ (suc_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] + ass_pct.iloc[i] + suc_pct.iloc[i] / 2)
712
+ ]
713
+
714
+ for pct, x_pos in segments:
715
+ if pct >= min_pct_threshold:
716
+ ax.text(x_pos, model_idx, f'{pct:.0f}%',
717
+ ha='center', va='center', fontsize=label_fontsize,
718
+ color='black', weight='bold')
719
+
720
+ ax.set_yticks(y)
721
+ ax.set_yticklabels(data['model_clean'], fontsize=MIN_FONT_SIZE)
722
+ ax.set_xlabel('Trials (%)', fontsize=MIN_FONT_SIZE)
723
+ ax.set_xlim(0, 100)
724
+ ax.set_ylim(-0.5, n_models - 0.5)
725
+ ax.tick_params(axis='x', labelsize=MIN_FONT_SIZE)
726
+
727
+ # Legend below the plot - 2 columns, positioned below x-axis label
728
+ ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.18), ncol=2,
729
+ frameon=False, fontsize=MIN_FONT_SIZE, columnspacing=0.8,
730
+ handletextpad=0.3, handlelength=1.0)
731
+
732
+ # Tight margins - more bottom space for legend
733
+ fig.subplots_adjust(left=0.28, right=0.99, top=0.99, bottom=0.38)
734
+
735
+ plt.title("Root Cause Entity Discovery Funnel")
736
+ plt.show()
737
+ fig.savefig(OUTPUT_DIR / "fig_conversion_funnel.png")
738
+ plt.close(fig)
739
+ print("Saved: fig_conversion_funnel.png")
740
+
741
+
742
+ def extract_all_data():
743
+ # Create output directory
744
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
745
+
746
+ # Load all ground truths
747
+ print("\nLoading ground truth data...")
748
+ gt_cache = {}
749
+ scenario_dirs = [d for d in GT_DIR.iterdir() if d.is_dir() and d.name.startswith("Scenario-")]
750
+ for scenario_dir in tqdm(scenario_dirs, desc="Loading ground truths"):
751
+ gt = load_ground_truth(scenario_dir.name)
752
+ if gt:
753
+ gt_cache[scenario_dir.name] = gt
754
+ print(f"Loaded {len(gt_cache)} ground truth files")
755
+
756
+ # Find react with code agents
757
+ model_dirs = [d for d in LEADERBOARD_DIR.iterdir()
758
+ if d.is_dir() and d.name.startswith("react with code_")]
759
+ print(f"Found {len(model_dirs)} agent models")
760
+
761
+ # Analyze each model
762
+ all_results = []
763
+ for model_dir in tqdm(model_dirs, desc="Analyzing models"):
764
+ model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
765
+ print(f"\nAnalyzing {model_name}...")
766
+
767
+ results = analyze_model(model_dir, gt_cache)
768
+ all_results.extend(results)
769
+
770
+ # Summary stats
771
+ if results:
772
+ encounters = [r for r in results if r.t_encounter is not None]
773
+ asserts = [r for r in results if r.t_assert is not None]
774
+ recoveries = [r for r in results if r.had_recovery]
775
+
776
+ print(f" Trials: {len(results)}")
777
+ print(f" Encounters: {len(encounters)} ({100*len(encounters)/len(results):.1f}%)")
778
+ print(f" Assertions: {len(asserts)} ({100*len(asserts)/len(results):.1f}%)")
779
+ print(f" Recoveries: {len(recoveries)} ({100*len(recoveries)/len(results):.1f}%)")
780
+
781
+ # Convert to DataFrame
782
+ print("\n" + "=" * 60)
783
+ print("Generating output files...")
784
+
785
+ # Summary per trial
786
+ trial_data = []
787
+ for r in all_results:
788
+ trial_data.append({
789
+ 'model': r.model,
790
+ 'scenario': r.scenario,
791
+ 'trial': r.trial,
792
+ 'total_turns': r.total_turns,
793
+ 'gt_entity': r.gt_entity,
794
+ 't_encounter': r.t_encounter,
795
+ 't_investigate': r.t_investigate,
796
+ 't_assert': r.t_assert,
797
+ 't_exonerate': r.t_exonerate,
798
+ 't_recover': r.t_recover,
799
+ 'max_stage_reached': r.max_stage_reached,
800
+ 'final_success': r.final_success,
801
+ 'root_cause_f1': r.root_cause_f1,
802
+ 'discovery_efficiency': r.discovery_efficiency,
803
+ 'investigation_delay': r.investigation_delay,
804
+ 'assertion_delay': r.assertion_delay,
805
+ 'had_recovery': r.had_recovery,
806
+ 'n_mentions': len(r.mentions),
807
+ 'total_entities_available': r.total_entities_available,
808
+ 'unique_entities_encountered': r.unique_entities_encountered,
809
+ 'unique_entities_investigated': r.unique_entities_investigated,
810
+ 'exploration_ratio': r.exploration_ratio,
811
+ # Coverage metrics (on-chain vs off-chain)
812
+ 'on_chain_investigated': r.on_chain_investigated,
813
+ 'off_chain_investigated': r.off_chain_investigated,
814
+ 'propagation_coverage': r.propagation_coverage,
815
+ 'detour_rate': r.detour_rate
816
+ })
817
+
818
+ trial_df = pd.DataFrame(trial_data)
819
+ trial_df.to_csv(OUTPUT_DIR / "discovery_trials.csv", index=False)
820
+ print(f"Saved: {OUTPUT_DIR / 'discovery_trials.csv'}")
821
+
822
+ # Summary per model
823
+ model_summary = []
824
+ for model in trial_df['model'].unique():
825
+ model_data = trial_df[trial_df['model'] == model]
826
+ n_total = len(model_data)
827
+
828
+ # Funnel stages: count trials reaching each stage
829
+ # Stage 0: none, 1: encounter, 2: investigate, 3: assert, 4: success
830
+ stage_counts = model_data['max_stage_reached'].value_counts().to_dict()
831
+
832
+ # Cumulative: how many reached AT LEAST this stage
833
+ n_encounter = len(model_data[model_data['max_stage_reached'] >= 1])
834
+ n_investigate = len(model_data[model_data['max_stage_reached'] >= 2])
835
+ n_assert = len(model_data[model_data['max_stage_reached'] >= 3])
836
+ n_success = len(model_data[model_data['max_stage_reached'] >= 4])
837
+
838
+ # Filter to trials where we found something
839
+ with_encounter = model_data[model_data['t_encounter'].notna()]
840
+ with_assert = model_data[model_data['t_assert'].notna()]
841
+ with_recovery = model_data[model_data['had_recovery'] == True]
842
+ with_success = model_data[model_data['final_success'] == True]
843
+
844
+ model_summary.append({
845
+ 'model': model,
846
+ 'n_trials': n_total,
847
+ 'n_scenarios': model_data['scenario'].nunique(),
848
+ # Funnel rates (cumulative, relative to total trials)
849
+ 'encounter_rate': n_encounter / n_total if n_total > 0 else 0,
850
+ 'investigate_rate': n_investigate / n_total if n_total > 0 else 0,
851
+ 'assertion_rate': n_assert / n_total if n_total > 0 else 0,
852
+ 'success_rate': n_success / n_total if n_total > 0 else 0,
853
+ # Conversion rate: given encounter, did model declare it as root cause?
854
+ # This handles multi-root-cause scenarios better
855
+ 'conversion_rate': n_success / n_encounter if n_encounter > 0 else 0,
856
+ # Drop-off at each stage (exclusive counts)
857
+ 'n_stage_0_none': stage_counts.get(0, 0),
858
+ 'n_stage_1_encounter_only': stage_counts.get(1, 0),
859
+ 'n_stage_2_investigate_only': stage_counts.get(2, 0),
860
+ 'n_stage_3_assert_only': stage_counts.get(3, 0),
861
+ 'n_stage_4_success': stage_counts.get(4, 0),
862
+ # Legacy metrics
863
+ 'recovery_rate': len(with_recovery) / n_total if n_total > 0 else 0,
864
+ 'avg_t_encounter': with_encounter['t_encounter'].mean() if len(with_encounter) > 0 else None,
865
+ 'avg_t_assert': with_assert['t_assert'].mean() if len(with_assert) > 0 else None,
866
+ 'avg_total_turns': model_data['total_turns'].mean(),
867
+ 'avg_discovery_efficiency': with_encounter['discovery_efficiency'].mean() if len(with_encounter) > 0 else None,
868
+ 'avg_investigation_delay': with_encounter['investigation_delay'].mean() if len(with_encounter) > 0 else None,
869
+ 'avg_assertion_delay': with_assert['assertion_delay'].mean() if len(with_assert) > 0 else None,
870
+ 'avg_f1': with_success['root_cause_f1'].mean() if len(with_success) > 0 else None,
871
+ 'avg_exploration_ratio': model_data['exploration_ratio'].mean(),
872
+ 'avg_entities_investigated': model_data['unique_entities_investigated'].mean(),
873
+ # Coverage metrics (fault propagation coverage)
874
+ 'avg_on_chain_investigated': model_data['on_chain_investigated'].mean(),
875
+ 'avg_off_chain_investigated': model_data['off_chain_investigated'].mean(),
876
+ 'avg_propagation_coverage': model_data['propagation_coverage'].mean(),
877
+ 'avg_detour_rate': model_data['detour_rate'].mean()
878
+ })
879
+
880
+ summary_df = pd.DataFrame(model_summary)
881
+ summary_df.to_csv(OUTPUT_DIR / "discovery_summary.csv", index=False)
882
+ print(f"Saved: {OUTPUT_DIR / 'discovery_summary.csv'}")
883
+
884
+ trials_n = len(all_results)
885
+
886
+ return summary_df, trial_df, trials_n
887
+
888
+
889
+ def main():
890
+ print("=" * 60)
891
+ print("Root Cause Discovery Trajectory Analysis")
892
+ print("=" * 60)
893
+
894
+ summary_df, trial_df, trials_n = extract_all_data()
895
+
896
+ # Print summary table with funnel
897
+ print("\n" + "=" * 80)
898
+ print("Discovery Pipeline Funnel:")
899
+ print("-" * 80)
900
+ print(f"{'Model':<25} {'Trials':>7} {'Encntr':>8} {'Invest':>8} {'Assert':>8} {'Success':>8}")
901
+ print("-" * 80)
902
+ for _, row in summary_df.iterrows():
903
+ print(f"{row['model']:<25} {row['n_trials']:>7} "
904
+ f"{row['encounter_rate']*100:>7.0f}% "
905
+ f"{row['investigate_rate']*100:>7.0f}% "
906
+ f"{row['assertion_rate']*100:>7.0f}% "
907
+ f"{row['success_rate']*100:>7.0f}%")
908
+
909
+ print("\n" + "=" * 80)
910
+ print("Drop-off Analysis (where trials stopped):")
911
+ print("-" * 80)
912
+ print(f"{'Model':<25} {'None':>7} {'Enc→X':>7} {'Inv→X':>7} {'Ass→X':>7} {'✓':>7}")
913
+ print("-" * 80)
914
+ for _, row in summary_df.iterrows():
915
+ print(f"{row['model']:<25} "
916
+ f"{row['n_stage_0_none']:>7} "
917
+ f"{row['n_stage_1_encounter_only']:>7} "
918
+ f"{row['n_stage_2_investigate_only']:>7} "
919
+ f"{row['n_stage_3_assert_only']:>7} "
920
+ f"{row['n_stage_4_success']:>7}")
921
+
922
+ print(f"\nTotal trials analyzed: {trials_n}")
923
+ print(f"\nOutput saved to: {OUTPUT_DIR}")
924
+
925
+
926
+ if __name__ == "__main__":
927
+ main()
928
+
analysis_src/extract_exploration.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Exploration Breadth Analysis by Diagnosis Correctness
4
+
5
+ Creates a plot comparing exploration breadth between:
6
+ - Correct diagnoses (recall > 0, i.e., root_cause_f1 > 0)
7
+ - Incorrect diagnoses (recall = 0, i.e., root_cause_f1 == 0)
8
+
9
+ Uses semantic entity grouping to avoid counting "frontend deployment" and
10
+ "frontend service" as separate entities.
11
+ """
12
+
13
+ import json
14
+ import sys
15
+ import re
16
+ from pathlib import Path
17
+ from dataclasses import dataclass
18
+ from typing import Optional, List, Dict, Set, Tuple
19
+ import pandas as pd
20
+ import numpy as np
21
+ import matplotlib.pyplot as plt
22
+ import seaborn as sns
23
+ from tqdm import tqdm
24
+
25
+ # Publication settings - ICML half column
26
+ HALF_COLUMN_WIDTH = 3.25 # inches
27
+ MIN_FONT_SIZE = 8
28
+
29
+ plt.rcParams.update({
30
+ 'font.size': MIN_FONT_SIZE,
31
+ 'font.family': 'serif',
32
+ 'axes.labelsize': MIN_FONT_SIZE,
33
+ 'axes.titlesize': MIN_FONT_SIZE + 1,
34
+ 'xtick.labelsize': MIN_FONT_SIZE,
35
+ 'ytick.labelsize': MIN_FONT_SIZE,
36
+ 'legend.fontsize': MIN_FONT_SIZE,
37
+ 'figure.dpi': 150,
38
+ 'savefig.dpi': 300,
39
+ 'savefig.bbox': 'tight',
40
+ 'axes.spines.top': False,
41
+ 'axes.spines.right': False,
42
+ })
43
+
44
+ # Paths
45
+ PROJECT_ROOT = Path(__file__).parent.parent
46
+ sys.path.insert(0, str(PROJECT_ROOT))
47
+
48
+ from src.utils import find_latest_rollout_file
49
+
50
+ from src.model_styles import (
51
+ get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
52
+ )
53
+
54
+ # Paths
55
+ LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
56
+ GT_DIR = PROJECT_ROOT / "data" / "itbench-snapshots"
57
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "discovery"
58
+
59
+ # Regex for K8s entities
60
+ K8S_ENTITY_PATTERN = re.compile(
61
+ r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|'
62
+ r'DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|'
63
+ r'PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|'
64
+ r'ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|'
65
+ r'Node|Schedule|NetworkChaos|StressChaos|PodChaos)/([\w-]+)',
66
+ re.IGNORECASE
67
+ )
68
+
69
+ # Service name normalization patterns
70
+ SERVICE_NORMALIZATIONS = {
71
+ # Map specific variations to canonical names
72
+ 'frontend-proxy': 'frontend-proxy',
73
+ 'frontendproxy': 'frontend-proxy',
74
+ 'frontend': 'frontend',
75
+ 'checkout': 'checkout',
76
+ 'checkoutservice': 'checkout',
77
+ 'cart': 'cart',
78
+ 'cartservice': 'cart',
79
+ 'shipping': 'shipping',
80
+ 'shippingservice': 'shipping',
81
+ 'product-catalog': 'product-catalog',
82
+ 'productcatalog': 'product-catalog',
83
+ 'productcatalogservice': 'product-catalog',
84
+ 'recommendation': 'recommendation',
85
+ 'recommendationservice': 'recommendation',
86
+ 'email': 'email',
87
+ 'emailservice': 'email',
88
+ 'payment': 'payment',
89
+ 'paymentservice': 'payment',
90
+ 'currency': 'currency',
91
+ 'currencyservice': 'currency',
92
+ 'ad': 'ad',
93
+ 'adservice': 'ad',
94
+ 'fraud-detection': 'fraud-detection',
95
+ 'frauddetection': 'fraud-detection',
96
+ 'frauddetectionservice': 'fraud-detection',
97
+ 'load-generator': 'load-generator',
98
+ 'loadgenerator': 'load-generator',
99
+ 'flagd': 'flagd',
100
+ 'otel-collector': 'otel-collector',
101
+ 'valkey': 'valkey',
102
+ 'valkey-cart': 'valkey', # valkey instance for cart
103
+ 'redis': 'valkey', # alias
104
+ 'kafka': 'kafka',
105
+ 'quote': 'quote',
106
+ 'quoteservice': 'quote',
107
+ 'accounting': 'accounting',
108
+ 'accountingservice': 'accounting',
109
+ 'otel-demo': 'otel-demo', # namespace
110
+ 'imageprovider': 'imageprovider',
111
+ 'flagdui': 'flagdui',
112
+ 'opensearch': 'opensearch',
113
+ 'grafana': 'grafana',
114
+ 'jaeger': 'jaeger',
115
+ 'prometheus': 'prometheus',
116
+ }
117
+
118
+ # Model name mapping for cleaner labels
119
+ MODEL_NAMES = {
120
+ 'Azure_gpt-5.1-2025-11-13': 'GPT-5.1',
121
+ 'Azure_o4-mini': 'o4-mini',
122
+ 'GCP_gemini-2.5-pro': 'Gemini 2.5 Pro',
123
+ 'gcp_gemini-3-pro-preview': 'Gemini 3 Pro',
124
+ 'gemini-3-pro-preview': 'Gemini 3 Pro',
125
+ 'gemini-3-flash-preview': 'Gemini 3 Flash',
126
+ 'moonshotai_kimi-k2-thinking': 'Kimi K2',
127
+ 'aws_claude-opus-4-5': 'Claude Opus 4.5',
128
+ 'openai_gpt-oss-120b': 'GPT-OSS-120B',
129
+ }
130
+
131
+
132
+ def normalize_entity_to_logical(entity: str) -> str:
133
+ """
134
+ Normalize an entity to its logical/canonical service name.
135
+
136
+ e.g., "otel-demo/Deployment/frontend-abc123" -> "frontend"
137
+ "otel-demo/Service/checkout" -> "checkout"
138
+ "chaos-mesh/NetworkChaos/xyz" -> "chaos:NetworkChaos"
139
+ """
140
+ parts = entity.lower().split('/')
141
+
142
+ # Handle chaos-mesh specially
143
+ if 'chaos-mesh' in parts[0] if parts else '':
144
+ if len(parts) >= 2:
145
+ return f"chaos:{parts[1]}"
146
+ return "chaos"
147
+
148
+ # Get the name part (last component)
149
+ if len(parts) >= 3:
150
+ name = parts[2]
151
+ elif len(parts) >= 1:
152
+ name = parts[-1]
153
+ else:
154
+ return entity.lower()
155
+
156
+ # Strip pod suffixes (e.g., frontend-5d4f6b7c8d-xyz9a -> frontend)
157
+ # Pattern: name followed by hash-like suffixes from ReplicaSets/Pods
158
+ # ReplicaSet adds -<hash10> and Pod adds -<hash5>
159
+ # e.g., frontend-5d4f6b7c8d-xyz9a -> strip -5d4f6b7c8d-xyz9a
160
+ name = re.sub(r'-[a-f0-9]{8,10}-[a-z0-9]{5}$', '', name) # Pod suffix (RS hash + Pod hash)
161
+ name = re.sub(r'-[a-f0-9]{8,10}$', '', name) # ReplicaSet suffix only (10-char hex hash)
162
+
163
+ # Also strip numeric suffixes like -1, -2 from entity names
164
+ name = re.sub(r'-\d+$', '', name)
165
+
166
+ # First check for exact match (most reliable)
167
+ if name in SERVICE_NORMALIZATIONS:
168
+ return SERVICE_NORMALIZATIONS[name]
169
+
170
+ # Try matching with service name variations
171
+ # Sort by length descending so longer patterns match first (frontend-proxy before frontend)
172
+ for pattern in sorted(SERVICE_NORMALIZATIONS.keys(), key=len, reverse=True):
173
+ canonical = SERVICE_NORMALIZATIONS[pattern]
174
+ # Exact match or name starts with pattern followed by typical suffixes
175
+ if name == pattern:
176
+ return canonical
177
+ # e.g., "checkoutservice" starts with "checkout"
178
+ if name.startswith(pattern) and (
179
+ len(name) == len(pattern) or
180
+ name[len(pattern):].startswith('service') or
181
+ name[len(pattern):].startswith('-')
182
+ ):
183
+ return canonical
184
+
185
+ # Fallback: return cleaned name
186
+ return name
187
+
188
+
189
+ def extract_k8s_entities(text: str) -> List[str]:
190
+ """Extract all K8s entities from text."""
191
+ matches = K8S_ENTITY_PATTERN.findall(text)
192
+ entities = []
193
+ for m in matches:
194
+ entity = f"{m[0]}/{m[1]}/{m[2]}"
195
+ entities.append(entity)
196
+ return entities
197
+
198
+
199
+ def extract_logical_entities(text: str) -> Set[str]:
200
+ """Extract and normalize entities to logical names."""
201
+ raw_entities = extract_k8s_entities(text)
202
+ return {normalize_entity_to_logical(e) for e in raw_entities}
203
+
204
+
205
+ def get_latest_rollout(trial_dir: Path) -> Optional[Path]:
206
+ """Get the latest rollout file from a trial directory."""
207
+ sessions_dir = trial_dir / "sessions"
208
+ if not sessions_dir.exists():
209
+ return None
210
+
211
+ rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl"))
212
+ if not rollout_files:
213
+ return None
214
+
215
+ return max(rollout_files, key=lambda p: p.stat().st_mtime)
216
+
217
+
218
+ def get_judge_f1(trial_dir: Path) -> float:
219
+ """Get root_cause_entity_f1 from judge output."""
220
+ judge_path = trial_dir / "judge_output.json"
221
+ if not judge_path.exists():
222
+ return 0.0
223
+
224
+ try:
225
+ with open(judge_path) as f:
226
+ judge_data = json.load(f)
227
+ return judge_data.get('flat_scores', {}).get('root_cause_entity_f1', 0.0) or 0.0
228
+ except:
229
+ return 0.0
230
+
231
+
232
+ def count_semantic_entities_investigated(rollout_path: Path) -> int:
233
+ """
234
+ Count unique semantic entity groups investigated in a rollout.
235
+
236
+ Uses normalization to group similar entities:
237
+ - otel-demo/Deployment/frontend and otel-demo/Service/frontend -> 1 entity ("frontend")
238
+ - otel-demo/Pod/frontend-abc123 and otel-demo/Pod/frontend-xyz456 -> 1 entity ("frontend")
239
+ """
240
+ investigated_logical = set()
241
+
242
+ with open(rollout_path) as f:
243
+ for line in f:
244
+ try:
245
+ obj = json.loads(line)
246
+ except json.JSONDecodeError:
247
+ continue
248
+
249
+ if obj.get('type') != 'response_item':
250
+ continue
251
+
252
+ payload = obj.get('payload', {})
253
+
254
+ # Check tool arguments (investigation = active querying)
255
+ if payload.get('type') == 'function_call':
256
+ args = payload.get('arguments', {})
257
+ if isinstance(args, str):
258
+ try:
259
+ args = json.loads(args)
260
+ except:
261
+ args = {'raw': args}
262
+ args_str = json.dumps(args)
263
+
264
+ # Extract and normalize entities
265
+ logical_entities = extract_logical_entities(args_str)
266
+ investigated_logical.update(logical_entities)
267
+
268
+ return len(investigated_logical)
269
+
270
+
271
+ def analyze_all_trials() -> pd.DataFrame:
272
+ """
273
+ Analyze all trials from react with code agents.
274
+ Returns DataFrame with model, scenario, trial, f1_score, semantic_entities_investigated.
275
+ """
276
+ results = []
277
+
278
+ # Find react with code agents
279
+ model_dirs = [d for d in LEADERBOARD_DIR.iterdir()
280
+ if d.is_dir() and d.name.startswith("react with code_")]
281
+
282
+ print(f"Found {len(model_dirs)} agent models")
283
+
284
+ for model_dir in tqdm(model_dirs, desc="Processing models"):
285
+ model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
286
+ print(f"Processing {model_name}...")
287
+
288
+ scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")]
289
+ for scenario_dir in tqdm(scenario_dirs, desc=f" {model_name} scenarios", leave=False):
290
+ scenario = scenario_dir.name
291
+
292
+ trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()]
293
+ for trial_dir in tqdm(trial_dirs, desc=f" {scenario} trials", leave=False):
294
+ trial_num = int(trial_dir.name)
295
+ rollout_path = get_latest_rollout(trial_dir)
296
+
297
+ if rollout_path is None:
298
+ continue
299
+
300
+ try:
301
+ f1_score = get_judge_f1(trial_dir)
302
+ semantic_count = count_semantic_entities_investigated(rollout_path)
303
+
304
+ results.append({
305
+ 'model': model_name,
306
+ 'scenario': scenario,
307
+ 'trial': trial_num,
308
+ 'root_cause_f1': f1_score,
309
+ 'is_correct': f1_score > 0,
310
+ 'semantic_entities_investigated': semantic_count
311
+ })
312
+ except Exception as e:
313
+ print(f" Error processing {model_name}/{scenario}/{trial_num}: {e}")
314
+
315
+ return pd.DataFrame(results)
316
+
317
+
318
+ def clean_model_name(name: str) -> str:
319
+ return MODEL_NAMES.get(name, name)
320
+
321
+
322
+ def plot_exploration_by_correctness(df: pd.DataFrame):
323
+ """
324
+ Plot comparing exploration breadth between correct and incorrect diagnoses.
325
+ Creates a grouped bar chart or box plot.
326
+ """
327
+ # Aggregate by model and correctness
328
+ agg = df.groupby(['model', 'is_correct']).agg({
329
+ 'semantic_entities_investigated': ['mean', 'std', 'count']
330
+ }).reset_index()
331
+ agg.columns = ['model', 'is_correct', 'mean_entities', 'std_entities', 'n_trials']
332
+
333
+ # Pivot for easier plotting
334
+ correct_df = agg[agg['is_correct'] == True].set_index('model')
335
+ incorrect_df = agg[agg['is_correct'] == False].set_index('model')
336
+
337
+ # Get all models that have both correct and incorrect trials
338
+ models_both = set(correct_df.index) & set(incorrect_df.index)
339
+
340
+ # Create comparison data
341
+ comparison_data = []
342
+ for model in models_both:
343
+ comparison_data.append({
344
+ 'model': model,
345
+ 'model_clean': clean_model_name(model),
346
+ 'correct_mean': correct_df.loc[model, 'mean_entities'],
347
+ 'correct_std': correct_df.loc[model, 'std_entities'],
348
+ 'correct_n': correct_df.loc[model, 'n_trials'],
349
+ 'incorrect_mean': incorrect_df.loc[model, 'mean_entities'],
350
+ 'incorrect_std': incorrect_df.loc[model, 'std_entities'],
351
+ 'incorrect_n': incorrect_df.loc[model, 'n_trials'],
352
+ })
353
+
354
+ comp_df = pd.DataFrame(comparison_data)
355
+ comp_df = comp_df.sort_values('correct_mean', ascending=True)
356
+
357
+ # === Figure 1: Grouped bar chart ===
358
+ fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 3.0))
359
+
360
+ y = np.arange(len(comp_df))
361
+ bar_height = 0.35
362
+
363
+ # Incorrect (red) and Correct (green) bars
364
+ bars_incorrect = ax.barh(y - bar_height/2, comp_df['incorrect_mean'],
365
+ height=bar_height, label='Incorrect (recall=0)',
366
+ color='#d62728', edgecolor='black', linewidth=0.3, alpha=0.8)
367
+ bars_correct = ax.barh(y + bar_height/2, comp_df['correct_mean'],
368
+ height=bar_height, label='Correct (recall>0)',
369
+ color='#2ca02c', edgecolor='black', linewidth=0.3, alpha=0.8)
370
+
371
+ ax.set_yticks(y)
372
+ ax.set_yticklabels(comp_df['model_clean'])
373
+ ax.set_xlabel('Avg. Semantic Entity Groups Investigated')
374
+
375
+ # Add value labels
376
+ for i, (bar_i, bar_c) in enumerate(zip(bars_incorrect, bars_correct)):
377
+ # Incorrect
378
+ ax.text(bar_i.get_width() + 0.1, bar_i.get_y() + bar_i.get_height()/2,
379
+ f'{bar_i.get_width():.1f}', va='center', ha='left',
380
+ fontsize=MIN_FONT_SIZE - 1, color='#d62728')
381
+ # Correct
382
+ ax.text(bar_c.get_width() + 0.1, bar_c.get_y() + bar_c.get_height()/2,
383
+ f'{bar_c.get_width():.1f}', va='center', ha='left',
384
+ fontsize=MIN_FONT_SIZE - 1, color='#2ca02c')
385
+
386
+ ax.legend(loc='lower right', frameon=False, fontsize=MIN_FONT_SIZE)
387
+
388
+ plt.tight_layout()
389
+ fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.pdf")
390
+ fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.png")
391
+ plt.close(fig)
392
+ print(f"Saved: fig_exploration_by_correctness.pdf/png")
393
+
394
+ # === Figure 2: Box plot distribution ===
395
+ fig2, ax2 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 1.5, 3.5))
396
+
397
+ # Prepare data for box plot
398
+ df['correctness'] = df['is_correct'].map({True: 'Correct\n(recall>0)', False: 'Incorrect\n(recall=0)'})
399
+ df['model_clean'] = df['model'].apply(clean_model_name)
400
+
401
+ # Order models by overall median exploration
402
+ model_order = df.groupby('model_clean')['semantic_entities_investigated'].median().sort_values().index.tolist()
403
+
404
+ # Create box plot with hue
405
+ sns.boxplot(data=df, x='model_clean', y='semantic_entities_investigated',
406
+ hue='correctness', order=model_order, ax=ax2,
407
+ palette={'Correct\n(recall>0)': '#2ca02c', 'Incorrect\n(recall=0)': '#d62728'},
408
+ linewidth=0.5, fliersize=2)
409
+
410
+ ax2.set_xlabel('')
411
+ ax2.set_ylabel('Semantic Entity Groups Investigated')
412
+ ax2.tick_params(axis='x', rotation=45)
413
+ ax2.legend(title='', loc='upper left', frameon=False, fontsize=MIN_FONT_SIZE)
414
+
415
+ plt.tight_layout()
416
+ fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.pdf")
417
+ fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.png")
418
+ plt.close(fig2)
419
+ print(f"Saved: fig_exploration_by_correctness_boxplot.pdf/png")
420
+
421
+ # === Figure 3: Aggregated across all models ===
422
+ fig3, ax3 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 0.8, 2.5))
423
+
424
+ correct_all = df[df['is_correct'] == True]['semantic_entities_investigated']
425
+ incorrect_all = df[df['is_correct'] == False]['semantic_entities_investigated']
426
+
427
+ # Violin plot for overall distribution
428
+ parts = ax3.violinplot([incorrect_all, correct_all], positions=[0, 1],
429
+ showmeans=True, showmedians=True)
430
+
431
+ # Color the violins
432
+ colors = ['#d62728', '#2ca02c']
433
+ for i, pc in enumerate(parts['bodies']):
434
+ pc.set_facecolor(colors[i])
435
+ pc.set_alpha(0.7)
436
+
437
+ # Style the other elements
438
+ for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']:
439
+ if partname in parts:
440
+ parts[partname].set_edgecolor('black')
441
+ parts[partname].set_linewidth(0.5)
442
+
443
+ ax3.set_xticks([0, 1])
444
+ ax3.set_xticklabels(['Incorrect\n(recall=0)', 'Correct\n(recall>0)'])
445
+ ax3.set_ylabel('Semantic Entities Investigated')
446
+
447
+ # Add mean values as text
448
+ ax3.text(0, incorrect_all.mean() + 0.5, f'μ={incorrect_all.mean():.1f}',
449
+ ha='center', fontsize=MIN_FONT_SIZE, color='#d62728')
450
+ ax3.text(1, correct_all.mean() + 0.5, f'μ={correct_all.mean():.1f}',
451
+ ha='center', fontsize=MIN_FONT_SIZE, color='#2ca02c')
452
+
453
+ # Add n counts
454
+ ax3.text(0, ax3.get_ylim()[0] + 0.5, f'n={len(incorrect_all)}',
455
+ ha='center', fontsize=MIN_FONT_SIZE - 1)
456
+ ax3.text(1, ax3.get_ylim()[0] + 0.5, f'n={len(correct_all)}',
457
+ ha='center', fontsize=MIN_FONT_SIZE - 1)
458
+
459
+ plt.tight_layout()
460
+ fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.pdf")
461
+ fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.png")
462
+ plt.close(fig3)
463
+ print(f"Saved: fig_exploration_overall_correctness.pdf/png")
464
+
465
+ # Print statistics
466
+ print("\n" + "=" * 60)
467
+ print("Exploration Breadth by Diagnosis Correctness")
468
+ print("=" * 60)
469
+ print(f"\nOverall Statistics:")
470
+ print(f" Correct diagnoses (n={len(correct_all)}): mean={correct_all.mean():.2f}, median={correct_all.median():.1f}")
471
+ print(f" Incorrect diagnoses (n={len(incorrect_all)}): mean={incorrect_all.mean():.2f}, median={incorrect_all.median():.1f}")
472
+
473
+ # Statistical test
474
+ from scipy import stats
475
+ stat, pvalue = stats.mannwhitneyu(correct_all, incorrect_all, alternative='two-sided')
476
+ print(f"\n Mann-Whitney U test: U={stat:.0f}, p={pvalue:.4f}")
477
+
478
+ print(f"\nPer-Model Comparison:")
479
+ print(f"{'Model':<20} {'Correct':>12} {'Incorrect':>12} {'Diff':>8}")
480
+ print("-" * 55)
481
+ for _, row in comp_df.sort_values('correct_mean', ascending=False).iterrows():
482
+ diff = row['correct_mean'] - row['incorrect_mean']
483
+ print(f"{row['model_clean']:<20} {row['correct_mean']:>10.1f} (n={int(row['correct_n'])}) "
484
+ f"{row['incorrect_mean']:>10.1f} (n={int(row['incorrect_n'])}) {diff:>+7.1f}")
485
+
486
+ return comp_df
487
+
488
+
489
+ def plot_success_by_exploration_bins(df: pd.DataFrame):
490
+ """
491
+ Plot showing success rate as a function of exploration breadth.
492
+ This shows a clear dose-response relationship.
493
+ """
494
+ # Create exploration bins
495
+ bins = [0, 2, 4, 6, 8, 10, 100]
496
+ labels = ['0-2', '3-4', '5-6', '7-8', '9-10', '11+']
497
+ df['exploration_bin'] = pd.cut(df['semantic_entities_investigated'],
498
+ bins=bins, labels=labels)
499
+
500
+ # Calculate success rate per bin
501
+ bin_stats = []
502
+ for label in labels:
503
+ subset = df[df['exploration_bin'] == label]
504
+ if len(subset) > 0:
505
+ success_rate = (subset['root_cause_f1'] > 0).mean() * 100
506
+ bin_stats.append({
507
+ 'bin': label,
508
+ 'success_rate': success_rate,
509
+ 'n': len(subset)
510
+ })
511
+
512
+ stats_df = pd.DataFrame(bin_stats)
513
+
514
+ # Create figure
515
+ fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 2.5))
516
+
517
+ x = np.arange(len(stats_df))
518
+ bars = ax.bar(x, stats_df['success_rate'],
519
+ color='#4a90d9', edgecolor='black', linewidth=0.5)
520
+
521
+ ax.set_xticks(x)
522
+ ax.set_xticklabels(stats_df['bin'])
523
+ ax.set_xlabel('Semantic Entities Investigated')
524
+ ax.set_ylabel('Correct Diagnosis Rate (%)')
525
+
526
+ # Add value labels on bars
527
+ for i, (bar, row) in enumerate(zip(bars, stats_df.itertuples())):
528
+ height = bar.get_height()
529
+ ax.text(bar.get_x() + bar.get_width()/2, height + 1,
530
+ f'{height:.0f}%', ha='center', va='bottom',
531
+ fontsize=MIN_FONT_SIZE)
532
+ ax.text(bar.get_x() + bar.get_width()/2, 2,
533
+ f'n={row.n}', ha='center', va='bottom',
534
+ fontsize=MIN_FONT_SIZE - 1, color='white')
535
+
536
+ ax.set_ylim(0, 60)
537
+
538
+ plt.tight_layout()
539
+ fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.pdf")
540
+ fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.png")
541
+ plt.close(fig)
542
+ print(f"Saved: fig_exploration_success_rate.pdf/png")
543
+
544
+ # Also create a combined figure with both views
545
+ fig2, axes = plt.subplots(1, 2, figsize=(HALF_COLUMN_WIDTH * 2 + 0.3, 2.5))
546
+
547
+ # Left: Success rate by exploration bins
548
+ ax1 = axes[0]
549
+ bars1 = ax1.bar(x, stats_df['success_rate'],
550
+ color='#4a90d9', edgecolor='black', linewidth=0.5)
551
+ ax1.set_xticks(x)
552
+ ax1.set_xticklabels(stats_df['bin'])
553
+ ax1.set_xlabel('Entities Investigated')
554
+ ax1.set_ylabel('Correct Diagnosis Rate (%)')
555
+ ax1.set_title('(a) Success vs Exploration', fontsize=MIN_FONT_SIZE + 1)
556
+ for bar, row in zip(bars1, stats_df.itertuples()):
557
+ ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
558
+ f'{bar.get_height():.0f}%', ha='center', va='bottom',
559
+ fontsize=MIN_FONT_SIZE - 1)
560
+ ax1.set_ylim(0, 60)
561
+
562
+ # Right: Exploration distribution by correctness (violin)
563
+ ax2 = axes[1]
564
+ correct = df[df['is_correct'] == True]['semantic_entities_investigated']
565
+ incorrect = df[df['is_correct'] == False]['semantic_entities_investigated']
566
+
567
+ parts = ax2.violinplot([incorrect, correct], positions=[0, 1],
568
+ showmeans=True, showmedians=True)
569
+ colors = ['#d62728', '#2ca02c']
570
+ for i, pc in enumerate(parts['bodies']):
571
+ pc.set_facecolor(colors[i])
572
+ pc.set_alpha(0.7)
573
+ for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']:
574
+ if partname in parts:
575
+ parts[partname].set_edgecolor('black')
576
+ parts[partname].set_linewidth(0.5)
577
+
578
+ ax2.set_xticks([0, 1])
579
+ ax2.set_xticklabels(['Incorrect', 'Correct'])
580
+ ax2.set_ylabel('Entities Investigated')
581
+ ax2.set_title('(b) Exploration by Outcome', fontsize=MIN_FONT_SIZE + 1)
582
+ ax2.text(0, incorrect.mean() + 1, f'μ={incorrect.mean():.1f}',
583
+ ha='center', fontsize=MIN_FONT_SIZE - 1, color='#d62728')
584
+ ax2.text(1, correct.mean() + 1, f'μ={correct.mean():.1f}',
585
+ ha='center', fontsize=MIN_FONT_SIZE - 1, color='#2ca02c')
586
+
587
+ plt.tight_layout()
588
+ fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.pdf")
589
+ fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.png")
590
+ plt.close(fig2)
591
+ print(f"Saved: fig_exploration_combined.pdf/png")
592
+
593
+
594
+ def main():
595
+ print("=" * 60)
596
+ print("Exploration Breadth by Diagnosis Correctness Analysis")
597
+ print("=" * 60)
598
+
599
+ # Check if we can use cached data or need to re-extract
600
+ cache_path = OUTPUT_DIR / "exploration_by_correctness.csv"
601
+
602
+ if cache_path.exists():
603
+ print(f"\nLoading cached data from {cache_path}")
604
+ df = pd.read_csv(cache_path)
605
+ else:
606
+ print("\nExtracting data from rollout files (this may take a while)...")
607
+ df = analyze_all_trials()
608
+ df.to_csv(cache_path, index=False)
609
+ print(f"Saved cache to: {cache_path}")
610
+
611
+ print(f"\nLoaded {len(df)} trials from {df['model'].nunique()} models")
612
+
613
+ # Generate plots
614
+ print("\nGenerating figures...")
615
+ plot_exploration_by_correctness(df)
616
+ plot_success_by_exploration_bins(df) # NEW: dose-response plot
617
+
618
+ print(f"\nDone! Figures saved to: {OUTPUT_DIR}")
619
+
620
+
621
+ if __name__ == "__main__":
622
+ main()
623
+
analysis_src/extract_inference_data.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract inference request and token usage data for all 'react with code' agents.
4
+
5
+ This script reads session.jsonl files to count inference requests and estimate token usage.
6
+ """
7
+
8
+ import json
9
+ import sys
10
+ from pathlib import Path
11
+ from dataclasses import dataclass
12
+ import ast
13
+ import numpy as np
14
+ import pandas as pd
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
+ from tqdm import tqdm
18
+
19
+
20
+ PROJECT_ROOT = Path(__file__).parent.parent
21
+ sys.path.insert(0, str(PROJECT_ROOT))
22
+
23
+ from src.utils import (
24
+ get_model_name,
25
+ find_react_with_code_dirs,
26
+ get_runs_stats,
27
+ filter_scenarios_with_min_runs,
28
+ find_latest_rollout_file
29
+ )
30
+
31
+ from src.model_styles import (
32
+ get_model_style, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, get_color_palette, PLOT_PARAMETERS
33
+ )
34
+
35
+ # Paths
36
+ LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
37
+ RESULTS_JSON_DIR = LEADERBOARD_DIR / "results"
38
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "inferences"
39
+
40
+ # Minimum runs per scenario required
41
+ MIN_RUNS_PER_SCENARIO = 3
42
+ MIN_QUALIFYING_SCENARIOS = 20
43
+
44
+ # Token estimation factor (chars per token)
45
+ CHARS_PER_TOKEN = 4
46
+
47
+ def extract_tokens_from_rollout(rollout_file: Path) -> dict:
48
+ """
49
+ Extract token counts and tool usage from a rollout file.
50
+
51
+ Counts:
52
+ - INPUT: system prompt + user messages + tool outputs
53
+ - OUTPUT: assistant messages + tool call arguments
54
+ - TOOLS: counts by tool name, including code execution
55
+ """
56
+ system_prompt_chars = 0
57
+ user_input_chars = 0
58
+ assistant_output_chars = 0
59
+ tool_call_chars = 0
60
+ tool_output_chars = 0
61
+
62
+ assistant_msg_count = 0
63
+ tool_call_count = 0
64
+ tool_counts = {} # tool_name -> count
65
+ code_execution_count = 0 # Specifically track code/python execution
66
+
67
+ # Tool names that indicate code execution
68
+ CODE_TOOLS = ['execute_python', 'run_python', 'python', 'execute_code',
69
+ 'run_code', 'shell', 'bash', 'terminal', 'exec']
70
+
71
+ try:
72
+ with open(rollout_file) as f:
73
+ for line in f:
74
+ try:
75
+ d = json.loads(line)
76
+ msg_type = d.get('type', '')
77
+ payload = d.get('payload', {})
78
+
79
+ if msg_type == 'session_meta':
80
+ # System prompt
81
+ instructions = payload.get('instructions', '')
82
+ system_prompt_chars += len(str(instructions))
83
+
84
+ elif msg_type == 'response_item':
85
+ item_type = payload.get('type', '')
86
+ role = payload.get('role', '')
87
+
88
+ if item_type == 'message':
89
+ content = payload.get('content', [])
90
+ if isinstance(content, list):
91
+ text = ' '.join([
92
+ c.get('text', '') if isinstance(c, dict) else str(c)
93
+ for c in content
94
+ ])
95
+ else:
96
+ text = str(content)
97
+
98
+ if role == 'user':
99
+ user_input_chars += len(text)
100
+ elif role == 'assistant':
101
+ assistant_output_chars += len(text)
102
+ assistant_msg_count += 1
103
+
104
+ elif item_type == 'function_call':
105
+ # Tool call (output)
106
+ name = payload.get('name', '')
107
+ arguments = payload.get('arguments', '')
108
+ tool_call_chars += len(str(name)) + len(str(arguments))
109
+ tool_call_count += 1
110
+
111
+ # Track tool usage
112
+ tool_counts[name] = tool_counts.get(name, 0) + 1
113
+
114
+ # Check if it's code execution
115
+ name_lower = name.lower()
116
+ if any(code_tool in name_lower for code_tool in CODE_TOOLS):
117
+ code_execution_count += 1
118
+ # Also check if arguments contain python code patterns
119
+ args_str = str(arguments).lower()
120
+ if 'python' in name_lower or ('def ' in args_str or 'import ' in args_str):
121
+ code_execution_count += 1
122
+
123
+ elif item_type == 'function_call_output':
124
+ # Tool output (input to model)
125
+ output = payload.get('output', '')
126
+ tool_output_chars += len(str(output))
127
+
128
+ except json.JSONDecodeError:
129
+ continue
130
+ except Exception as e:
131
+ return None
132
+
133
+ # INPUT = system + user + tool outputs (fed back to model)
134
+ input_chars = system_prompt_chars + user_input_chars + tool_output_chars
135
+ # OUTPUT = assistant responses + tool call arguments
136
+ output_chars = assistant_output_chars + tool_call_chars
137
+
138
+ return {
139
+ 'system_prompt_chars': system_prompt_chars,
140
+ 'user_input_chars': user_input_chars,
141
+ 'assistant_output_chars': assistant_output_chars,
142
+ 'tool_call_chars': tool_call_chars,
143
+ 'tool_output_chars': tool_output_chars,
144
+ 'input_chars': input_chars,
145
+ 'output_chars': output_chars,
146
+ 'input_tokens': input_chars // CHARS_PER_TOKEN,
147
+ 'output_tokens': output_chars // CHARS_PER_TOKEN,
148
+ 'assistant_msg_count': assistant_msg_count,
149
+ 'tool_call_count': tool_call_count,
150
+ 'tool_counts': tool_counts,
151
+ 'code_execution_count': code_execution_count,
152
+ }
153
+
154
+
155
+ def extract_session_stats(session_file: Path) -> dict:
156
+ """
157
+ Extract inference stats from session.jsonl and rollout files.
158
+
159
+ Uses the latest rollout file for accurate token counting.
160
+ """
161
+ if not session_file.exists():
162
+ return None
163
+
164
+ trial_dir = session_file.parent
165
+
166
+ # Count inference requests from session.jsonl
167
+ inference_count = 0
168
+ try:
169
+ with open(session_file) as f:
170
+ for line in f:
171
+ try:
172
+ d = json.loads(line)
173
+ if d.get('type') == 'response_item':
174
+ inference_count += 1
175
+ except json.JSONDecodeError:
176
+ continue
177
+ except Exception as e:
178
+ print(f" Warning: Error reading {session_file}: {e}")
179
+ return None
180
+
181
+ # First check stdout.log for real token counts (OpenAI models)
182
+ stdout_log = trial_dir / "traces" / "stdout.log"
183
+ has_real_tokens = False
184
+ input_tokens = 0
185
+ output_tokens = 0
186
+ cached_input_tokens = 0
187
+
188
+ if stdout_log.exists():
189
+ try:
190
+ with open(stdout_log) as f:
191
+ for line in f:
192
+ try:
193
+ d = json.loads(line)
194
+ if d.get('type') == 'turn.completed':
195
+ usage = d.get('usage', {})
196
+ input_tokens = usage.get('input_tokens', 0)
197
+ output_tokens = usage.get('output_tokens', 0)
198
+ cached_input_tokens = usage.get('cached_input_tokens', 0)
199
+ if input_tokens > 0 or output_tokens > 0:
200
+ has_real_tokens = True
201
+ break
202
+ except json.JSONDecodeError:
203
+ continue
204
+ except Exception:
205
+ pass
206
+
207
+ # Extract from latest rollout file for tokens (if needed) and tool counts
208
+ tool_call_count = 0
209
+ tool_counts = {}
210
+ code_execution_count = 0
211
+
212
+ latest_rollout = find_latest_rollout_file(trial_dir)
213
+ if latest_rollout:
214
+ rollout_stats = extract_tokens_from_rollout(latest_rollout)
215
+ if rollout_stats:
216
+ # Use rollout tokens if no real API token data
217
+ if not has_real_tokens:
218
+ input_tokens = rollout_stats['input_tokens']
219
+ output_tokens = rollout_stats['output_tokens']
220
+
221
+ # Always use rollout for tool counts
222
+ tool_call_count = rollout_stats['tool_call_count']
223
+ tool_counts = rollout_stats['tool_counts']
224
+ code_execution_count = rollout_stats['code_execution_count']
225
+
226
+ return {
227
+ 'inference_count': inference_count,
228
+ 'input_tokens': input_tokens,
229
+ 'cached_input_tokens': cached_input_tokens,
230
+ 'output_tokens': output_tokens,
231
+ 'total_tokens': input_tokens + output_tokens,
232
+ 'has_real_tokens': has_real_tokens,
233
+ 'tool_call_count': tool_call_count,
234
+ 'tool_counts': tool_counts,
235
+ 'code_execution_count': code_execution_count,
236
+ }
237
+
238
+
239
+ def read_agent_stats(agent_dir: Path) -> dict[str, list[dict]]:
240
+ """
241
+ Read session stats from all scenarios/trials for an agent.
242
+
243
+ Returns:
244
+ Dict mapping scenario_id -> list of stats (one per trial)
245
+ """
246
+ scenario_data = {}
247
+
248
+ for scenario_dir in agent_dir.iterdir():
249
+ if not scenario_dir.is_dir() or not scenario_dir.name.startswith("Scenario"):
250
+ continue
251
+
252
+ scenario_id = scenario_dir.name
253
+ trials = []
254
+
255
+ for trial_dir in sorted(scenario_dir.iterdir()):
256
+ if not trial_dir.is_dir():
257
+ continue
258
+
259
+ session_file = trial_dir / "session.jsonl"
260
+ stats = extract_session_stats(session_file)
261
+ if stats:
262
+ trials.append(stats)
263
+
264
+ if trials:
265
+ scenario_data[scenario_id] = trials
266
+
267
+ return scenario_data
268
+
269
+ def load_performance_data() -> pd.DataFrame:
270
+ """Load performance data from the consistency analysis."""
271
+ perf_file = PROJECT_ROOT / "data" / "output" / "consistency" / "performance_data.csv"
272
+ if perf_file.exists():
273
+ df = pd.read_csv(perf_file)
274
+ return df[df["metric_raw"] == "root_cause_entity_f1"][["model", "performance"]]
275
+ return pd.DataFrame()
276
+
277
+ def extract_all_data() -> tuple[pd.DataFrame, pd.DataFrame]:
278
+ """
279
+ Extract inference data for all agents.
280
+
281
+ Returns:
282
+ - summary_df: Aggregated stats per model
283
+ - detail_df: Per-scenario stats
284
+ """
285
+ agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
286
+ print(f"Found {len(agent_dirs)} 'react with code' agent directories")
287
+
288
+ summary_records = []
289
+ detail_records = []
290
+
291
+ for agent_dir in tqdm(agent_dirs, desc="Processing agents"):
292
+ model_name = get_model_name(agent_dir.name)
293
+
294
+ print(f"\nReading: {agent_dir.name}")
295
+ scenario_data = read_agent_stats(agent_dir)
296
+
297
+ n_scenarios, min_runs, max_runs, n_qualifying = get_runs_stats(scenario_data, MIN_RUNS_PER_SCENARIO)
298
+
299
+ if n_scenarios == 0:
300
+ print(f" SKIPPING {model_name}: No session data found")
301
+ continue
302
+
303
+ if n_qualifying < MIN_QUALIFYING_SCENARIOS:
304
+ print(f" SKIPPING {model_name}: Only {n_qualifying}/{n_scenarios} scenarios have {MIN_RUNS_PER_SCENARIO}+ runs")
305
+ continue
306
+
307
+ # Filter scenarios
308
+ scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
309
+ n_scenarios_filtered = len(scenario_data)
310
+
311
+ print(f" Processing: {model_name} ({n_scenarios_filtered} scenarios)")
312
+
313
+ # Aggregate across all scenarios and trials
314
+ all_inference_counts = []
315
+ all_input_tokens = []
316
+ all_output_tokens = []
317
+ all_total_tokens = []
318
+ all_cached_tokens = []
319
+ all_tool_call_counts = []
320
+ all_code_execution_counts = []
321
+ aggregated_tool_counts = {}
322
+
323
+ for scenario_id, trials in tqdm(scenario_data.items(), desc=f" {model_name} scenarios", leave=False):
324
+ for trial in trials:
325
+ all_inference_counts.append(trial['inference_count'])
326
+ all_input_tokens.append(trial['input_tokens'])
327
+ all_output_tokens.append(trial['output_tokens'])
328
+ all_total_tokens.append(trial['total_tokens'])
329
+ all_cached_tokens.append(trial.get('cached_input_tokens', 0))
330
+ all_tool_call_counts.append(trial.get('tool_call_count', 0))
331
+ all_code_execution_counts.append(trial.get('code_execution_count', 0))
332
+
333
+ # Aggregate tool counts
334
+ for tool_name, count in trial.get('tool_counts', {}).items():
335
+ aggregated_tool_counts[tool_name] = aggregated_tool_counts.get(tool_name, 0) + count
336
+
337
+ detail_records.append({
338
+ 'model': model_name,
339
+ 'scenario': scenario_id,
340
+ 'inference_count': trial['inference_count'],
341
+ 'input_tokens': trial['input_tokens'],
342
+ 'cached_input_tokens': trial.get('cached_input_tokens', 0),
343
+ 'output_tokens': trial['output_tokens'],
344
+ 'total_tokens': trial['total_tokens'],
345
+ 'tool_call_count': trial.get('tool_call_count', 0),
346
+ 'code_execution_count': trial.get('code_execution_count', 0),
347
+ })
348
+
349
+ # Summary stats
350
+ summary_records.append({
351
+ 'model': model_name,
352
+ 'n_scenarios': n_scenarios_filtered,
353
+ 'n_trials': len(all_inference_counts),
354
+ 'avg_inference_count': np.mean(all_inference_counts),
355
+ 'std_inference_count': np.std(all_inference_counts),
356
+ 'avg_input_tokens': np.mean(all_input_tokens),
357
+ 'avg_cached_tokens': np.mean(all_cached_tokens),
358
+ 'avg_output_tokens': np.mean(all_output_tokens),
359
+ 'avg_total_tokens': np.mean(all_total_tokens),
360
+ 'total_inference_count': sum(all_inference_counts),
361
+ 'total_tokens': sum(all_total_tokens),
362
+ 'avg_tool_call_count': np.mean(all_tool_call_counts) if all_tool_call_counts else 0,
363
+ 'total_tool_calls': sum(all_tool_call_counts),
364
+ 'avg_code_execution_count': np.mean(all_code_execution_counts) if all_code_execution_counts else 0,
365
+ 'total_code_executions': sum(all_code_execution_counts),
366
+ 'top_tools': dict(sorted(aggregated_tool_counts.items(), key=lambda x: -x[1])[:10]),
367
+ })
368
+
369
+ summary_df = pd.DataFrame(summary_records)
370
+ detail_df = pd.DataFrame(detail_records)
371
+
372
+ # Merge with performance data
373
+ perf_df = load_performance_data()
374
+ if len(perf_df) > 0:
375
+ summary_df = pd.merge(summary_df, perf_df, on='model', how='left')
376
+
377
+ return summary_df, detail_df
378
+
379
+
380
+ def save_data(summary_df: pd.DataFrame, detail_df: pd.DataFrame):
381
+ """Save extracted data to CSV files."""
382
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
383
+
384
+ summary_path = OUTPUT_DIR / "inference_summary.csv"
385
+ detail_path = OUTPUT_DIR / "inference_detail.csv"
386
+
387
+ summary_df.to_csv(summary_path, index=False)
388
+ detail_df.to_csv(detail_path, index=False)
389
+
390
+ print(f"\nData saved to:")
391
+ print(f" - {summary_path}")
392
+ print(f" - {detail_path}")
393
+
394
+
395
+ def print_summary(summary_df: pd.DataFrame):
396
+ """Print summary table."""
397
+ print("\n" + "="*80)
398
+ print("Inference Summary")
399
+ print("="*80)
400
+
401
+ summary_df = summary_df.sort_values("avg_inference_count", ascending=False)
402
+
403
+ print(f"\n{'Model':<25} {'Avg Infer':>10} {'Avg Tokens':>12} {'Avg In':>10} {'Avg Out':>10}")
404
+ print("-" * 70)
405
+ for _, row in summary_df.iterrows():
406
+ print(f"{row['model']:<25} {row['avg_inference_count']:>10.1f} {row['avg_total_tokens']:>12.0f} {row['avg_input_tokens']:>10.0f} {row['avg_output_tokens']:>10.0f}")
407
+
408
+ def plot_tool_usage(summary_df: pd.DataFrame):
409
+ """
410
+ Figure: Tool usage per model - total tool calls and code execution.
411
+ """
412
+ plt.rcParams.update(PLOT_PARAMETERS)
413
+
414
+ if 'avg_tool_call_count' not in summary_df.columns:
415
+ print("Skipping tool usage: no tool data")
416
+ return
417
+
418
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(SINGLE_COLUMN_WIDTH * 2, 2.5))
419
+
420
+ data = summary_df.sort_values("avg_tool_call_count", ascending=True)
421
+
422
+ color_palette = get_color_palette(len(data))
423
+ colors = [color_palette[i % len(color_palette)] for i in range(len(data))]
424
+
425
+ # Left: Total tool calls
426
+ bars1 = ax1.barh(data["model"], data["avg_tool_call_count"], color=colors,
427
+ edgecolor='black', linewidth=0.5)
428
+ ax1.set_xlabel("Avg. Tool Calls per Scenario")
429
+
430
+ for bar, val in zip(bars1, data["avg_tool_call_count"]):
431
+ ax1.text(val + 1, bar.get_y() + bar.get_height()/2,
432
+ f'{val:.0f}', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1)
433
+
434
+ ax1.set_xlim(0, data["avg_tool_call_count"].max() * 1.15)
435
+
436
+ # Right: Code executions
437
+ bars2 = ax2.barh(data["model"], data["avg_code_execution_count"], color=colors,
438
+ edgecolor='black', linewidth=0.5)
439
+ ax2.set_xlabel("Avg. Code Executions per Scenario")
440
+
441
+ for bar, val in zip(bars2, data["avg_code_execution_count"]):
442
+ if val > 0:
443
+ ax2.text(val + 0.5, bar.get_y() + bar.get_height()/2,
444
+ f'{val:.0f}', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1)
445
+
446
+ ax2.set_xlim(0, max(data["avg_code_execution_count"].max() * 1.3, 1))
447
+ ax2.set_yticklabels([])
448
+
449
+ plt.title("Tool Call Distribution")
450
+
451
+ plt.tight_layout()
452
+ plt.show()
453
+ fig.savefig(OUTPUT_DIR / "fig_tool_usage.png")
454
+ plt.close(fig)
455
+ print("Saved: fig_tool_usage.png")
456
+
457
+ def plot_inference_vs_performance(summary_df: pd.DataFrame):
458
+ """
459
+ Figure 3: Inference count vs Performance scatter.
460
+ """
461
+ if 'performance' not in summary_df.columns:
462
+ print("Skipping inference vs performance: no performance data")
463
+ return
464
+
465
+ plt.rcParams.update(PLOT_PARAMETERS)
466
+
467
+ fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH))
468
+
469
+ data = summary_df.dropna(subset=['performance'])
470
+
471
+ # Manual label offsets to avoid overlap
472
+ label_offsets = {
473
+ "GPT-5.1": (-5, -8, "right", "top"),
474
+ "o4-mini": (5, -8, "left", "top"),
475
+ "GPT-OSS-120B": (5, 3, "left", "bottom"),
476
+ "Gemini-2.5-Pro": (-5, 3, "right", "bottom"),
477
+ "Gemini-3-Flash": (5, 3, "left", "bottom"),
478
+ "gemini-3-pro-preview": (5, 3, "left", "bottom"),
479
+ "Kimi-K2": (5, 3, "left", "bottom"),
480
+ }
481
+
482
+ # Get color palette
483
+ color_palette = get_color_palette(len(data))
484
+
485
+ # Scatter plot
486
+ for i, (_, row) in enumerate(data.iterrows()):
487
+ ax.scatter(row["avg_inference_count"], row["performance"],
488
+ c=[color_palette[i % len(color_palette)]], s=60, edgecolors='black',
489
+ linewidth=0.5, zorder=10)
490
+
491
+ # Label with custom offset
492
+ offset = label_offsets.get(row["model"], (5, 3, "left", "bottom"))
493
+ ax.annotate(row["model"],
494
+ (row["avg_inference_count"], row["performance"]),
495
+ xytext=(offset[0], offset[1]), textcoords='offset points',
496
+ fontsize=MIN_FONT_SIZE - 1, ha=offset[2], va=offset[3])
497
+
498
+ ax.set_xlabel("Avg. Inference Requests")
499
+ ax.set_ylabel("Performance (RC Entity F1)")
500
+ ax.set_xlim(0, data["avg_inference_count"].max() * 1.2)
501
+ ax.set_ylim(0, 0.7)
502
+
503
+ plt.title("Inference Requests vs. Performance")
504
+
505
+ plt.tight_layout()
506
+ plt.show()
507
+ fig.savefig(OUTPUT_DIR / "fig_inference_vs_performance.png")
508
+ plt.close(fig)
509
+ print("Saved: fig_inference_vs_performance.png")
510
+
511
+ def plot_tool_breakdown_heatmap(summary_df: pd.DataFrame):
512
+ """
513
+ Generate a heatmap showing which tools each agent uses most.
514
+ """
515
+
516
+ # Parse the stringified dict of top_tools
517
+ tool_usage = []
518
+
519
+ for _, row in summary_df.iterrows():
520
+ if pd.isna(row.get('top_tools')):
521
+ print("pd.isna")
522
+ continue
523
+
524
+ tools = row['top_tools']
525
+ total_calls = row['total_tool_calls']
526
+ if total_calls == 0:
527
+ print("No tool calls")
528
+ continue
529
+
530
+ for tool, count in tools.items():
531
+ tool_usage.append({
532
+ 'model': row['model'],
533
+ 'tool': tool,
534
+ 'count': count,
535
+ 'avg_per_scenario': count / row['n_scenarios']
536
+ })
537
+
538
+
539
+ df = pd.DataFrame(tool_usage)
540
+ if len(df) == 0:
541
+ print("No tool usage data found")
542
+ return
543
+
544
+ # Pivot for heatmap
545
+ pivot_df = df.pivot(index='model', columns='tool', values='avg_per_scenario').fillna(0)
546
+
547
+ # Filter to top 10 most used tools across all models
548
+ # top_tools = pivot_df.sum().sort_values(ascending=False).head(10).index
549
+ top_tools = pivot_df.sum().sort_values(ascending=False).index
550
+ pivot_df = pivot_df[top_tools]
551
+
552
+ # Sort models by total tool usage
553
+ pivot_df['total'] = pivot_df.sum(axis=1)
554
+ pivot_df = pivot_df.sort_values('total', ascending=False).drop('total', axis=1)
555
+
556
+ # Plot
557
+ PLOT_PARAMETERS['font.size'] = 8
558
+ plt.rcParams.update(PLOT_PARAMETERS)
559
+
560
+ fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH * 2, 4))
561
+
562
+ sns.heatmap(pivot_df, annot=True, fmt='.1f', cmap='YlOrRd', ax=ax,
563
+ cbar_kws={'label': 'Avg. Calls per Scenario'})
564
+
565
+ ax.set_xlabel("")
566
+ ax.set_ylabel("")
567
+ plt.xticks(rotation=45, ha='right')
568
+ plt.yticks(rotation=0)
569
+
570
+ plt.title("Tool Call Distribution")
571
+
572
+ plt.tight_layout()
573
+ plt.show()
574
+ fig.savefig(OUTPUT_DIR / "fig_tool_usage_heatmap.png")
575
+ plt.close(fig)
576
+ print("Saved: fig_tool_usage_heatmap.png")
577
+
578
+ def main():
579
+ print("Extracting inference data for 'react with code' agents...")
580
+ print(f"Reading from directories: {LEADERBOARD_DIR}")
581
+ print(f"Output directory: {OUTPUT_DIR}")
582
+
583
+ summary_df, detail_df = extract_all_data()
584
+
585
+ if len(summary_df) == 0:
586
+ print("No data extracted!")
587
+ return
588
+
589
+ save_data(summary_df, detail_df)
590
+ print_summary(summary_df)
591
+
592
+
593
+ if __name__ == "__main__":
594
+ main()
595
+
analysis_src/extract_majority_vote_data.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract Majority Vote and consistency breakdown data for all 'react with code' agents.
4
+
5
+ This script computes:
6
+ - Pass@k: At least 1 trial succeeds
7
+ - Majority@k: Majority of trials succeed
8
+ - All@k: All trials succeed
9
+ - Consistency breakdown: Consistent Correct, Consistent Wrong, Inconsistent
10
+
11
+ Output is saved to paper_analysis/react with code/resources/figures/consistency/ as CSV files.
12
+ """
13
+
14
+ import json
15
+ import sys
16
+ from pathlib import Path
17
+ from itertools import combinations
18
+ import numpy as np
19
+ import pandas as pd
20
+ import matplotlib.pyplot as plt
21
+ import seaborn as sns
22
+ from tqdm import tqdm
23
+
24
+ # Add project root to path
25
+ PROJECT_ROOT = Path(__file__).parent.parent
26
+ sys.path.insert(0, str(PROJECT_ROOT))
27
+
28
+ from src.utils import (
29
+ get_model_name,
30
+ find_react_with_code_dirs,
31
+ read_judge_outputs_from_dir,
32
+ extract_trial_scores_from_judge_outputs,
33
+ filter_scenarios_with_min_runs,
34
+ )
35
+
36
+ from src.model_styles import (
37
+ get_model_style, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
38
+ )
39
+
40
+ # Paths
41
+ LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
42
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "consistency"
43
+
44
+ # Minimum runs per scenario required for inclusion
45
+ MIN_RUNS_PER_SCENARIO = 2
46
+
47
+ # Minimum scenarios needed after filtering
48
+ MIN_QUALIFYING_SCENARIOS = 20
49
+
50
+ # Success threshold for binary classification
51
+ SUCCESS_THRESHOLD = 0.5
52
+
53
+ def compute_majority_vote_metrics(
54
+ scenario_trials: dict[str, list[float]],
55
+ success_threshold: float = SUCCESS_THRESHOLD
56
+ ) -> dict:
57
+ """
58
+ Compute majority vote and consistency metrics.
59
+
60
+ Returns dict with:
61
+ - pass_at_k: At least 1 trial succeeds
62
+ - majority_at_k: Majority of trials succeed
63
+ - all_at_k: All trials succeed
64
+ - consistent_correct: All trials succeed
65
+ - consistent_wrong: All trials fail
66
+ - inconsistent: Mixed results
67
+ """
68
+ scenarios = list(scenario_trials.keys())
69
+ n_trials_list = [len(trials) for trials in scenario_trials.values()]
70
+
71
+ if not n_trials_list:
72
+ return None
73
+
74
+ k = min(n_trials_list)
75
+ n_scenarios = len(scenarios)
76
+
77
+ if n_scenarios == 0 or k < 1:
78
+ return None
79
+
80
+ pass_at_k = 0
81
+ majority_at_k = 0
82
+ all_at_k = 0
83
+ consistent_correct = 0
84
+ consistent_wrong = 0
85
+ inconsistent = 0
86
+
87
+ scenario_details = []
88
+ all_scores = []
89
+
90
+ for s in scenarios:
91
+ trials = scenario_trials[s][:k]
92
+ all_scores.extend(trials)
93
+ successes = [1 if t >= success_threshold else 0 for t in trials]
94
+ n_success = sum(successes)
95
+
96
+ if n_success >= 1:
97
+ pass_at_k += 1
98
+
99
+ if n_success > k / 2:
100
+ majority_at_k += 1
101
+
102
+ if n_success == k:
103
+ all_at_k += 1
104
+ consistent_correct += 1
105
+ consistency_type = "correct"
106
+ elif n_success == 0:
107
+ consistent_wrong += 1
108
+ consistency_type = "wrong"
109
+ else:
110
+ inconsistent += 1
111
+ consistency_type = "inconsistent"
112
+
113
+ scenario_details.append({
114
+ "scenario": s,
115
+ "n_success": n_success,
116
+ "n_trials": k,
117
+ "majority_correct": n_success > k / 2,
118
+ "consistency_type": consistency_type,
119
+ "mean_score": np.mean(trials),
120
+ "std_score": np.std(trials) if len(trials) > 1 else 0,
121
+ })
122
+
123
+ return {
124
+ "n_scenarios": n_scenarios,
125
+ "n_trials": k,
126
+ "threshold": success_threshold,
127
+ "pass_at_k": pass_at_k / n_scenarios,
128
+ "majority_at_k": majority_at_k / n_scenarios,
129
+ "all_at_k": all_at_k / n_scenarios,
130
+ "consistent_correct": consistent_correct / n_scenarios,
131
+ "consistent_wrong": consistent_wrong / n_scenarios,
132
+ "inconsistent": inconsistent / n_scenarios,
133
+ "n_pass": pass_at_k,
134
+ "n_majority": majority_at_k,
135
+ "n_all": all_at_k,
136
+ "n_consistent_correct": consistent_correct,
137
+ "n_consistent_wrong": consistent_wrong,
138
+ "n_inconsistent": inconsistent,
139
+ "overall_mean": np.mean(all_scores),
140
+ "overall_std": np.std(all_scores),
141
+ "scenario_details": scenario_details,
142
+ }
143
+
144
+
145
+ # Metrics to extract
146
+ METRICS = [
147
+ ("root_cause_entity_f1", "F1"),
148
+ ("root_cause_entity_precision", "Precision"),
149
+ ("root_cause_entity_recall", "Recall"),
150
+ ]
151
+
152
+
153
+ def extract_all_data() -> dict[str, tuple[pd.DataFrame, pd.DataFrame]]:
154
+ """
155
+ Extract majority vote data for all agents, for multiple metrics.
156
+
157
+ Returns:
158
+ - dict mapping metric_name -> (summary_df, scenario_df)
159
+ """
160
+ agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
161
+ print(f"Found {len(agent_dirs)} 'react with code' agent directories:")
162
+ for d in agent_dirs:
163
+ print(f" - {d.name}")
164
+
165
+ # Read all judge outputs once
166
+ agent_data = {}
167
+ valid_models = []
168
+ skipped_models = []
169
+
170
+ for agent_dir in tqdm(agent_dirs, desc="Reading agent data"):
171
+ model_name = get_model_name(agent_dir.name)
172
+
173
+ print(f"\nReading: {agent_dir.name}")
174
+ scenario_data = read_judge_outputs_from_dir(agent_dir)
175
+
176
+ if not scenario_data:
177
+ print(f" SKIPPING {model_name}: No judge outputs found")
178
+ skipped_models.append((model_name, "No data"))
179
+ continue
180
+
181
+ # Filter scenarios with minimum runs
182
+ scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
183
+ n_qualifying = len(scenario_data)
184
+
185
+ if n_qualifying < MIN_QUALIFYING_SCENARIOS:
186
+ print(f" SKIPPING {model_name}: Only {n_qualifying} scenarios with {MIN_RUNS_PER_SCENARIO}+ runs")
187
+ skipped_models.append((model_name, f"{n_qualifying} qualifying"))
188
+ continue
189
+
190
+ print(f" Processing: {model_name} ({n_qualifying} scenarios)")
191
+ valid_models.append(model_name)
192
+ agent_data[model_name] = scenario_data
193
+
194
+ if skipped_models:
195
+ print(f"\n⚠️ Skipped {len(skipped_models)} models:")
196
+ for name, reason in skipped_models:
197
+ print(f" - {name}: {reason}")
198
+
199
+ print(f"\n✓ Included {len(valid_models)} models: {valid_models}")
200
+
201
+ # Extract for each metric
202
+ results = {}
203
+
204
+ for metric_key, metric_label in tqdm(METRICS, desc="Processing metrics"):
205
+ print(f"\n--- Extracting for metric: {metric_label} ({metric_key}) ---")
206
+
207
+ summary_records = []
208
+ scenario_records = []
209
+
210
+ for model_name, scenario_data in tqdm(agent_data.items(), desc=f" {metric_label}", leave=False):
211
+ # Extract scores for this metric
212
+ scenario_trials = extract_trial_scores_from_judge_outputs(scenario_data, metric_key)
213
+
214
+ # Compute majority vote metrics
215
+ metrics = compute_majority_vote_metrics(scenario_trials)
216
+
217
+ if metrics is None:
218
+ continue
219
+
220
+ # Add to summary
221
+ summary_records.append({
222
+ "model": model_name,
223
+ "metric": metric_label,
224
+ "n_scenarios": metrics["n_scenarios"],
225
+ "n_trials": metrics["n_trials"],
226
+ "pass_at_k": metrics["pass_at_k"],
227
+ "majority_at_k": metrics["majority_at_k"],
228
+ "all_at_k": metrics["all_at_k"],
229
+ "consistent_correct": metrics["consistent_correct"],
230
+ "consistent_wrong": metrics["consistent_wrong"],
231
+ "inconsistent": metrics["inconsistent"],
232
+ "overall_mean": metrics["overall_mean"],
233
+ "overall_std": metrics["overall_std"],
234
+ })
235
+
236
+ # Add per-scenario data
237
+ for detail in metrics["scenario_details"]:
238
+ scenario_records.append({
239
+ "model": model_name,
240
+ "metric": metric_label,
241
+ "scenario": detail["scenario"],
242
+ "n_success": detail["n_success"],
243
+ "n_trials": detail["n_trials"],
244
+ "majority_correct": detail["majority_correct"],
245
+ "consistency_type": detail["consistency_type"],
246
+ "mean_score": detail["mean_score"],
247
+ "std_score": detail["std_score"],
248
+ })
249
+
250
+ summary_df = pd.DataFrame(summary_records)
251
+ scenario_df = pd.DataFrame(scenario_records)
252
+ results[metric_label] = (summary_df, scenario_df)
253
+
254
+ return results
255
+
256
+
257
+ def save_data(results: dict[str, tuple[pd.DataFrame, pd.DataFrame]]):
258
+ """Save extracted data to CSV files for each metric."""
259
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
260
+
261
+ # Also save combined data for backward compatibility
262
+ all_summaries = []
263
+ all_scenarios = []
264
+
265
+ for metric_label, (summary_df, scenario_df) in results.items():
266
+ metric_suffix = metric_label.lower()
267
+
268
+ summary_path = OUTPUT_DIR / f"majority_vote_data_{metric_suffix}.csv"
269
+ scenario_path = OUTPUT_DIR / f"majority_vote_scenarios_{metric_suffix}.csv"
270
+
271
+ summary_df.to_csv(summary_path, index=False)
272
+ scenario_df.to_csv(scenario_path, index=False)
273
+
274
+ print(f"\nData saved for {metric_label}:")
275
+ print(f" - {summary_path}")
276
+ print(f" - {scenario_path}")
277
+
278
+ all_summaries.append(summary_df)
279
+ all_scenarios.append(scenario_df)
280
+
281
+ # Save combined (default to F1 for backward compatibility)
282
+ if "F1" in results:
283
+ f1_summary, f1_scenario = results["F1"]
284
+ # Save without metric column for backward compat
285
+ f1_summary_compat = f1_summary.drop(columns=["metric"], errors="ignore")
286
+ f1_scenario_compat = f1_scenario.drop(columns=["metric"], errors="ignore")
287
+ f1_summary_compat.to_csv(OUTPUT_DIR / "majority_vote_data.csv", index=False)
288
+ f1_scenario_compat.to_csv(OUTPUT_DIR / "majority_vote_scenarios.csv", index=False)
289
+ print(f"\nBackward-compatible files (F1) saved to:")
290
+ print(f" - {OUTPUT_DIR / 'majority_vote_data.csv'}")
291
+ print(f" - {OUTPUT_DIR / 'majority_vote_scenarios.csv'}")
292
+
293
+
294
+ def print_summary(results: dict[str, tuple[pd.DataFrame, pd.DataFrame]]):
295
+ """Print summary table for each metric."""
296
+ for metric_label, (summary_df, _) in results.items():
297
+ print("\n" + "="*80)
298
+ print(f"Majority Vote Summary ({metric_label}, threshold={SUCCESS_THRESHOLD})")
299
+ print("="*80)
300
+
301
+ df = summary_df.sort_values("majority_at_k", ascending=False)
302
+
303
+ print(f"\n{'Model':<20} {'Maj@k':>8} {'Pass@k':>8} {'All@k':>8} {'Cons✓':>8} {'Cons✗':>8} {'Incons':>8}")
304
+ print("-" * 80)
305
+ for _, row in df.iterrows():
306
+ print(f"{row['model']:<20} "
307
+ f"{row['majority_at_k']*100:>7.1f}% "
308
+ f"{row['pass_at_k']*100:>7.1f}% "
309
+ f"{row['all_at_k']*100:>7.1f}% "
310
+ f"{row['consistent_correct']*100:>7.1f}% "
311
+ f"{row['consistent_wrong']*100:>7.1f}% "
312
+ f"{row['inconsistent']*100:>7.1f}%")
313
+
314
+ def plot_majority_vs_performance(df: pd.DataFrame):
315
+ """
316
+ Figure: Majority@k vs Performance scatter plot.
317
+ """
318
+ plt.rcParams.update({PLOT_PARAMETERS})
319
+
320
+ fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH))
321
+
322
+ # Axis limits
323
+ x_min, x_max = 0, 1.0
324
+ y_min, y_max = 0, 100
325
+
326
+ # Gradient shading toward top-right (ideal)
327
+ for i in range(5):
328
+ alpha = 0.02 + i * 0.02
329
+ x_start = 0.1 + i * 0.15
330
+ y_start = 10 + i * 15
331
+ rect = plt.Rectangle((x_start, y_start), x_max - x_start, y_max - y_start,
332
+ color='#2ecc71', alpha=alpha, zorder=0)
333
+ ax.add_patch(rect)
334
+
335
+ # Arrow pointing to ideal
336
+ ax.annotate('', xy=(0.85, 85), xytext=(0.55, 55),
337
+ arrowprops=dict(arrowstyle='->', color='#27ae60', alpha=0.7, lw=1.5),
338
+ zorder=2)
339
+ ax.text(0.58, 58, 'better', fontsize=MIN_FONT_SIZE, style='italic',
340
+ color='#27ae60', alpha=0.8, rotation=45, zorder=2)
341
+
342
+ # Mark ideal corner
343
+ ax.scatter([1.0], [100], marker='*', s=100, c='#27ae60', alpha=0.5, zorder=2)
344
+ ax.text(0.92, 95, 'ideal', fontsize=MIN_FONT_SIZE - 1, color='#27ae60',
345
+ alpha=0.7, ha='right')
346
+
347
+ # Scatter points with model-specific colors and markers
348
+ for _, row in df.iterrows():
349
+ style = get_model_style(row["model"])
350
+ ax.scatter(row["overall_mean"], row["majority_at_k"] * 100,
351
+ c=style['color'], marker=style['marker'],
352
+ s=80, edgecolors='black', linewidth=0.5, zorder=10)
353
+
354
+ # Labels with smart positioning
355
+ for _, row in df.iterrows():
356
+ model = row["model"]
357
+ x_pos = row["overall_mean"]
358
+ y_pos = row["majority_at_k"] * 100
359
+
360
+ dx, dy = 0.03, 2
361
+ ha, va = "left", "center"
362
+
363
+ if x_pos > 0.7:
364
+ dx = -0.03
365
+ ha = "right"
366
+ if y_pos > 80:
367
+ dy = -3
368
+ va = "top"
369
+
370
+ ax.text(x_pos + dx, y_pos + dy, model, fontsize=MIN_FONT_SIZE - 1,
371
+ ha=ha, va=va, zorder=11)
372
+
373
+ ax.set_xlabel("Performance (RC Entity F1)")
374
+ ax.set_ylabel("Majority@k (%)")
375
+ ax.set_xlim(x_min, x_max)
376
+ ax.set_ylim(y_min, y_max)
377
+
378
+ plt.tight_layout()
379
+ plt.show()
380
+ fig.savefig(OUTPUT_DIR / "fig_majority_vs_performance.pdf")
381
+ fig.savefig(OUTPUT_DIR / "fig_majority_vs_performance.png")
382
+ plt.close(fig)
383
+ print("Saved: fig_majority_vs_performance.pdf/png")
384
+
385
+
386
+ def plot_pass_vs_majority(df: pd.DataFrame, metric: str = "F1", suffix: str = ""):
387
+ """
388
+ Figure: Scatter plot of Pass@k (x-axis) vs Majority@k (y-axis).
389
+
390
+ Args:
391
+ df: DataFrame with pass_at_k and majority_at_k columns
392
+ metric: Name of metric for labeling (F1, Precision, Recall)
393
+ suffix: Suffix for output filename (e.g., "_precision")
394
+ """
395
+ fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH))
396
+
397
+ ax_min, ax_max = 0, 100
398
+
399
+ # Diagonal line
400
+ ax.plot([ax_min, ax_max], [ax_min, ax_max], color='#444444', linestyle='--',
401
+ linewidth=1.5, alpha=0.6, zorder=1)
402
+
403
+ # Consistency region labels
404
+ ax.text(8, 92, 'more\nconsistent', fontsize=MIN_FONT_SIZE + 1, color='#333333',
405
+ ha='left', va='top', style='italic')
406
+ ax.text(92, 8, 'less\nconsistent', fontsize=MIN_FONT_SIZE + 1, color='#333333',
407
+ ha='right', va='bottom', style='italic')
408
+
409
+ # Collect and plot points
410
+ points = {}
411
+ for _, row in df.iterrows():
412
+ style = get_model_style(row["model"])
413
+ x = row["pass_at_k"] * 100
414
+ y = row["majority_at_k"] * 100
415
+ ax.scatter(x, y, c=style['color'], marker=style['marker'],
416
+ s=50, edgecolors='black', linewidth=0.5, zorder=10)
417
+ points[row["model"]] = {'x': x, 'y': y}
418
+
419
+ line_color = '#444444'
420
+ line_width = 1.2
421
+
422
+ # Place labels with manual positioning
423
+ for model, p in points.items():
424
+ x, y = p['x'], p['y']
425
+
426
+ if 'GPT-OSS-120B' in model:
427
+ # Label to the right, slightly below
428
+ ax.text(x + 3, y - 2, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
429
+
430
+ elif 'Gemini 2.5 Pro' in model:
431
+ # TEAL CIRCLE: label slightly below and to the right
432
+ ax.text(x + 3, y + 2, model, fontsize=MIN_FONT_SIZE, ha='left', va='bottom', zorder=11)
433
+
434
+ elif 'o4-mini' in model:
435
+ # YELLOW SQUARE: shorter line goes right then to label
436
+ label_x = x + 12
437
+ label_y = y
438
+ # Horizontal line right (shorter)
439
+ ax.plot([x, label_x], [y, y], color=line_color, linewidth=line_width, alpha=0.8, zorder=5)
440
+ ax.text(label_x + 1, label_y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
441
+
442
+ elif 'GPT-5.1' in model:
443
+ # GREEN SQUARE: line from left edge, goes left then up
444
+ label_x = 5
445
+ label_y = 25
446
+ start_x = x - 2 # Left edge of the square marker
447
+ # Horizontal line left from left edge midpoint
448
+ ax.plot([start_x, label_x], [y, y], color=line_color, linewidth=line_width, alpha=0.8, zorder=5)
449
+ # Vertical line up to label height
450
+ ax.plot([label_x, label_x], [y, label_y], color=line_color, linewidth=line_width, alpha=0.8, zorder=5)
451
+ ax.text(label_x, label_y + 1, model, fontsize=MIN_FONT_SIZE, ha='left', va='bottom', zorder=11)
452
+
453
+ elif 'Claude Opus' in model:
454
+ # Label to the right
455
+ ax.text(x + 5, y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
456
+
457
+ elif 'Gemini 3 Pro' in model:
458
+ # Label BELOW the circle, offset left
459
+ ax.text(x - 18, y - 6, model, fontsize=MIN_FONT_SIZE, ha='left', va='top', zorder=11)
460
+
461
+ elif 'Gemini 3 Flash' in model:
462
+ # Label at x=95 to avoid diagonal line
463
+ ax.text(105, y + 4, model, fontsize=MIN_FONT_SIZE, ha='right', va='bottom', zorder=11)
464
+
465
+ elif 'Kimi K2' in model:
466
+ # Label to the right
467
+ ax.text(x + 3, y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
468
+
469
+ else:
470
+ # Default: label to the right
471
+ ax.text(x + 3, y, model, fontsize=MIN_FONT_SIZE, ha='left', va='center', zorder=11)
472
+
473
+ ax.set_xlabel(f"Pass@k (%) [{metric}]")
474
+ ax.set_ylabel(f"Majority@k (%) [{metric}]")
475
+ ax.set_xlim(ax_min, ax_max)
476
+ ax.set_ylim(ax_min, ax_max)
477
+ ax.set_aspect('equal')
478
+
479
+ plt.title("Consistency: Pass@k vs. Majority@k")
480
+ plt.tight_layout()
481
+ plt.show()
482
+ filename = f"fig_pass_vs_majority{suffix}"
483
+ fig.savefig(OUTPUT_DIR / f"{filename}.png")
484
+ plt.close(fig)
485
+ print(f"Saved: {filename}.png")
486
+
487
+ def main():
488
+ print("Extracting majority vote data for 'react with code' agents...")
489
+ print(f"Reading from directories: {LEADERBOARD_DIR}")
490
+ print(f"Output directory: {OUTPUT_DIR}")
491
+ print(f"Success threshold: {SUCCESS_THRESHOLD}")
492
+ print(f"Minimum runs per scenario: {MIN_RUNS_PER_SCENARIO}")
493
+ print(f"Metrics: {[m[1] for m in METRICS]}")
494
+
495
+ results = extract_all_data()
496
+
497
+ if not results:
498
+ print("No data extracted!")
499
+ return
500
+
501
+ save_data(results)
502
+ print_summary(results)
503
+
504
+
505
+ if __name__ == "__main__":
506
+ main()
507
+
analysis_src/extract_tool_failures.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract tool call failure data for all 'react with code' agents.
4
+
5
+ This script reads rollout JSONL files to identify and categorize tool call failures.
6
+ """
7
+
8
+ import json
9
+ import re
10
+ import sys
11
+ import ast
12
+ from pathlib import Path
13
+ from dataclasses import dataclass, field
14
+ from collections import defaultdict
15
+ import numpy as np
16
+ import pandas as pd
17
+ import matplotlib.pyplot as plt
18
+ import seaborn as sns
19
+ from tqdm import tqdm
20
+
21
+ # Add project root to path
22
+ PROJECT_ROOT = Path(__file__).parent.parent
23
+ sys.path.insert(0, str(PROJECT_ROOT))
24
+
25
+ from src.utils import (
26
+ get_model_name,
27
+ find_react_with_code_dirs,
28
+ get_runs_stats,
29
+ filter_scenarios_with_min_runs,
30
+ find_latest_rollout_file
31
+ )
32
+
33
+ from src.model_styles import (
34
+ get_model_style, get_color_palette, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, _COLORS, PLOT_PARAMETERS
35
+ )
36
+
37
+ # Paths
38
+ LEADERBOARD_DIR = PROJECT_ROOT / "data" / "trajectories"
39
+ RESULTS_JSON_DIR = LEADERBOARD_DIR / "results"
40
+ OUTPUT_DIR = PROJECT_ROOT / "data" / "output" / "tool_failures"
41
+
42
+ # Minimum runs per scenario required
43
+ MIN_RUNS_PER_SCENARIO = 3
44
+ MIN_QUALIFYING_SCENARIOS = 20
45
+
46
+ # Failure type patterns
47
+ FAILURE_PATTERNS = {
48
+ 'python_syntax': [
49
+ r'SyntaxError',
50
+ r'IndentationError',
51
+ r'TabError',
52
+ ],
53
+ 'python_type': [
54
+ r'TypeError',
55
+ r'AttributeError',
56
+ r'ValueError',
57
+ r'KeyError',
58
+ r'IndexError',
59
+ ],
60
+ 'python_name': [
61
+ r'NameError',
62
+ r'UnboundLocalError',
63
+ r'ModuleNotFoundError',
64
+ r'ImportError',
65
+ ],
66
+ 'file_not_found': [
67
+ r'FileNotFoundError',
68
+ r'No such file or directory',
69
+ r'ENOENT',
70
+ r'path does not exist',
71
+ ],
72
+ 'permission_denied': [
73
+ r'PermissionError',
74
+ r'Permission denied',
75
+ r'EACCES',
76
+ ],
77
+ 'json_parse': [
78
+ r'JSONDecodeError',
79
+ r'json\.decoder\.JSONDecodeError',
80
+ r'Expecting value',
81
+ r'Invalid JSON',
82
+ ],
83
+ 'timeout': [
84
+ r'TimeoutError',
85
+ r'timeout',
86
+ r'Timed out',
87
+ r'deadline exceeded',
88
+ ],
89
+ 'memory': [
90
+ r'MemoryError',
91
+ r'out of memory',
92
+ r'OOM',
93
+ r'Cannot allocate memory',
94
+ ],
95
+ 'connection': [
96
+ r'ConnectionError',
97
+ r'ConnectionRefusedError',
98
+ r'Connection refused',
99
+ r'ECONNREFUSED',
100
+ ],
101
+ 'shell_command': [
102
+ r'command not found',
103
+ r'No such command',
104
+ r'not recognized as',
105
+ ],
106
+ 'assertion': [
107
+ r'AssertionError',
108
+ ],
109
+ 'runtime': [
110
+ r'RuntimeError',
111
+ r'Exception',
112
+ r'Error:',
113
+ ],
114
+ }
115
+
116
+
117
+ def classify_failure(output: str) -> tuple[str, str]:
118
+ """
119
+ Classify a failure based on the output string.
120
+
121
+ Returns:
122
+ (category, specific_error)
123
+ """
124
+ for category, patterns in FAILURE_PATTERNS.items():
125
+ for pattern in patterns:
126
+ if re.search(pattern, output, re.IGNORECASE):
127
+ # Extract the specific error type
128
+ match = re.search(pattern, output, re.IGNORECASE)
129
+ return (category, match.group(0) if match else pattern)
130
+
131
+ # Check for generic traceback
132
+ if 'Traceback' in output:
133
+ return ('other_python', 'Unknown Python Error')
134
+
135
+ return ('other', 'Unknown Error')
136
+
137
+
138
+ def extract_tool_calls_from_rollout(rollout_file: Path) -> dict:
139
+ """
140
+ Extract all tool calls and their outcomes from a rollout file.
141
+
142
+ Returns dict with:
143
+ - total_tool_calls: int
144
+ - failed_tool_calls: int
145
+ - failures: list of failure details
146
+ - tool_call_counts: dict of tool_name -> count
147
+ - tool_failure_counts: dict of tool_name -> failure_count
148
+ """
149
+ tool_calls = {} # call_id -> {name, arguments}
150
+ total_calls = 0
151
+ failed_calls = 0
152
+ failures = []
153
+ tool_call_counts = defaultdict(int)
154
+ tool_failure_counts = defaultdict(int)
155
+
156
+ try:
157
+ with open(rollout_file) as f:
158
+ for line in f:
159
+ try:
160
+ d = json.loads(line)
161
+ if d.get('type') != 'response_item':
162
+ continue
163
+
164
+ payload = d.get('payload', {})
165
+ payload_type = payload.get('type', '')
166
+
167
+ if payload_type == 'function_call':
168
+ call_id = payload.get('call_id', '')
169
+ name = payload.get('name', '')
170
+ arguments = payload.get('arguments', '')
171
+ tool_calls[call_id] = {
172
+ 'name': name,
173
+ 'arguments': arguments,
174
+ 'timestamp': d.get('timestamp', ''),
175
+ }
176
+ total_calls += 1
177
+ tool_call_counts[name] += 1
178
+
179
+ elif payload_type == 'function_call_output':
180
+ call_id = payload.get('call_id', '')
181
+ output = payload.get('output', '')
182
+
183
+ # Check if this is a failure
184
+ is_failure = False
185
+ failure_info = None
186
+
187
+ # Parse the output if it's JSON
188
+ try:
189
+ output_data = json.loads(output)
190
+ if isinstance(output_data, dict):
191
+ exit_code = output_data.get('metadata', {}).get('exit_code', 0)
192
+ output_text = output_data.get('output', '')
193
+ if exit_code != 0:
194
+ is_failure = True
195
+ category, error = classify_failure(output_text)
196
+ failure_info = {
197
+ 'exit_code': exit_code,
198
+ 'category': category,
199
+ 'error': error,
200
+ 'output_snippet': output_text[:300] if output_text else '',
201
+ }
202
+ except json.JSONDecodeError:
203
+ # Not JSON, check for error patterns in raw output
204
+ if 'Error' in output or 'error' in output or 'Traceback' in output:
205
+ is_failure = True
206
+ category, error = classify_failure(output)
207
+ failure_info = {
208
+ 'exit_code': None,
209
+ 'category': category,
210
+ 'error': error,
211
+ 'output_snippet': output[:300],
212
+ }
213
+
214
+ if is_failure and call_id in tool_calls:
215
+ failed_calls += 1
216
+ tool_name = tool_calls[call_id]['name']
217
+ tool_failure_counts[tool_name] += 1
218
+ failures.append({
219
+ 'tool_name': tool_name,
220
+ 'arguments': tool_calls[call_id]['arguments'][:200],
221
+ 'timestamp': tool_calls[call_id]['timestamp'],
222
+ **failure_info,
223
+ })
224
+
225
+ except json.JSONDecodeError:
226
+ continue
227
+ except Exception as e:
228
+ print(f" Warning: Error reading {rollout_file}: {e}")
229
+ return None
230
+
231
+ return {
232
+ 'total_tool_calls': total_calls,
233
+ 'failed_tool_calls': failed_calls,
234
+ 'failures': failures,
235
+ 'tool_call_counts': dict(tool_call_counts),
236
+ 'tool_failure_counts': dict(tool_failure_counts),
237
+ }
238
+
239
+
240
+ def read_agent_stats(agent_dir: Path) -> dict[str, list[dict]]:
241
+ """
242
+ Read tool call stats from all scenarios/trials for an agent.
243
+
244
+ Returns:
245
+ Dict mapping scenario_id -> list of stats (one per trial)
246
+ """
247
+ scenario_data = {}
248
+
249
+ for scenario_dir in agent_dir.iterdir():
250
+ if not scenario_dir.is_dir() or not scenario_dir.name.startswith("Scenario"):
251
+ continue
252
+
253
+ scenario_id = scenario_dir.name
254
+ trials = []
255
+
256
+ for trial_dir in sorted(scenario_dir.iterdir()):
257
+ if not trial_dir.is_dir():
258
+ continue
259
+
260
+ rollout_file = find_latest_rollout_file(trial_dir)
261
+ if rollout_file:
262
+ stats = extract_tool_calls_from_rollout(rollout_file)
263
+ if stats:
264
+ trials.append(stats)
265
+
266
+ if trials:
267
+ scenario_data[scenario_id] = trials
268
+
269
+ return scenario_data
270
+
271
+
272
+
273
+
274
+ def extract_all_data() -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
275
+ """
276
+ Extract tool failure data for all agents.
277
+
278
+ Returns:
279
+ - summary_df: Aggregated stats per model
280
+ - detail_df: Per-trial failure stats
281
+ - failures_df: Individual failure details
282
+ """
283
+ agent_dirs = find_react_with_code_dirs(LEADERBOARD_DIR)
284
+ print(f"Found {len(agent_dirs)} 'react with code' agent directories")
285
+
286
+ summary_records = []
287
+ detail_records = []
288
+ failure_records = []
289
+
290
+ for agent_dir in tqdm(agent_dirs, desc="Processing agents"):
291
+ model_name = get_model_name(agent_dir.name)
292
+
293
+ print(f"\nProcessing: {agent_dir.name}")
294
+ scenario_data = read_agent_stats(agent_dir)
295
+
296
+ n_scenarios, min_runs, max_runs, n_qualifying = get_runs_stats(scenario_data, MIN_RUNS_PER_SCENARIO)
297
+
298
+ if n_scenarios == 0:
299
+ print(f" SKIPPING {model_name}: No rollout data found")
300
+ continue
301
+
302
+ if n_qualifying < MIN_QUALIFYING_SCENARIOS:
303
+ print(f" SKIPPING {model_name}: Only {n_qualifying}/{n_scenarios} scenarios have {MIN_RUNS_PER_SCENARIO}+ runs")
304
+ continue
305
+
306
+ # Filter scenarios
307
+ scenario_data = filter_scenarios_with_min_runs(scenario_data, MIN_RUNS_PER_SCENARIO)
308
+ n_scenarios_filtered = len(scenario_data)
309
+
310
+ print(f" Processing: {model_name} ({n_scenarios_filtered} scenarios)")
311
+
312
+ # Aggregate across all scenarios and trials
313
+ all_total_calls = []
314
+ all_failed_calls = []
315
+ all_failure_rates = []
316
+ aggregated_tool_counts = defaultdict(int)
317
+ aggregated_failure_counts = defaultdict(int)
318
+ aggregated_category_counts = defaultdict(int)
319
+
320
+ for scenario_id, trials in tqdm(scenario_data.items(), desc=f" {model_name} scenarios", leave=False):
321
+ for trial_idx, trial in enumerate(trials):
322
+ total = trial['total_tool_calls']
323
+ failed = trial['failed_tool_calls']
324
+
325
+ all_total_calls.append(total)
326
+ all_failed_calls.append(failed)
327
+ all_failure_rates.append(failed / total * 100 if total > 0 else 0)
328
+
329
+ for tool_name, count in trial['tool_call_counts'].items():
330
+ aggregated_tool_counts[tool_name] += count
331
+
332
+ for tool_name, count in trial['tool_failure_counts'].items():
333
+ aggregated_failure_counts[tool_name] += count
334
+
335
+ # Count failure categories
336
+ for failure in trial['failures']:
337
+ category = failure.get('category', 'other')
338
+ aggregated_category_counts[category] += 1
339
+
340
+ # Add to failure records
341
+ failure_records.append({
342
+ 'model': model_name,
343
+ 'scenario': scenario_id,
344
+ 'trial': trial_idx,
345
+ 'tool_name': failure.get('tool_name', ''),
346
+ 'category': category,
347
+ 'error': failure.get('error', ''),
348
+ 'exit_code': failure.get('exit_code'),
349
+ 'output_snippet': failure.get('output_snippet', '')[:100],
350
+ })
351
+
352
+ detail_records.append({
353
+ 'model': model_name,
354
+ 'scenario': scenario_id,
355
+ 'trial': trial_idx,
356
+ 'total_tool_calls': total,
357
+ 'failed_tool_calls': failed,
358
+ 'failure_rate_pct': failed / total * 100 if total > 0 else 0,
359
+ })
360
+
361
+ # Compute per-tool failure rates
362
+ tool_failure_rates = {}
363
+ for tool_name, total in aggregated_tool_counts.items():
364
+ failures = aggregated_failure_counts.get(tool_name, 0)
365
+ tool_failure_rates[tool_name] = {
366
+ 'total': total,
367
+ 'failures': failures,
368
+ 'rate': failures / total * 100 if total > 0 else 0
369
+ }
370
+
371
+ summary_records.append({
372
+ 'model': model_name,
373
+ 'n_scenarios': n_scenarios_filtered,
374
+ 'n_trials': len(all_total_calls),
375
+ 'total_tool_calls': sum(all_total_calls),
376
+ 'total_failed_calls': sum(all_failed_calls),
377
+ 'avg_tool_calls_per_trial': np.mean(all_total_calls),
378
+ 'avg_failed_calls_per_trial': np.mean(all_failed_calls),
379
+ 'avg_failure_rate_pct': np.mean(all_failure_rates),
380
+ 'std_failure_rate_pct': np.std(all_failure_rates),
381
+ 'failure_categories': dict(aggregated_category_counts),
382
+ 'tool_failure_rates': tool_failure_rates,
383
+ })
384
+
385
+ summary_df = pd.DataFrame(summary_records)
386
+ detail_df = pd.DataFrame(detail_records)
387
+ failures_df = pd.DataFrame(failure_records)
388
+
389
+ return summary_df, detail_df, failures_df
390
+
391
+
392
+ def save_data(summary_df: pd.DataFrame, detail_df: pd.DataFrame, failures_df: pd.DataFrame):
393
+ """Save extracted data to CSV files."""
394
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
395
+
396
+ summary_path = OUTPUT_DIR / "tool_failures_summary.csv"
397
+ detail_path = OUTPUT_DIR / "tool_failures_detail.csv"
398
+ failures_path = OUTPUT_DIR / "tool_failures_individual.csv"
399
+
400
+ summary_df.to_csv(summary_path, index=False)
401
+ detail_df.to_csv(detail_path, index=False)
402
+ failures_df.to_csv(failures_path, index=False)
403
+
404
+ print(f"\nData saved to:")
405
+ print(f" - {summary_path}")
406
+ print(f" - {detail_path}")
407
+ print(f" - {failures_path}")
408
+
409
+
410
+ def print_summary(summary_df: pd.DataFrame):
411
+ """Print summary table."""
412
+ print("\n" + "="*100)
413
+ print("Tool Call Failure Summary")
414
+ print("="*100)
415
+
416
+ summary_df = summary_df.sort_values("avg_failure_rate_pct", ascending=False)
417
+
418
+ print(f"\n{'Model':<20} {'Trials':>8} {'Total Calls':>12} {'Failed':>10} {'Fail Rate':>10} {'Top Category':>20}")
419
+ print("-" * 85)
420
+ for _, row in summary_df.iterrows():
421
+ categories = row.get('failure_categories', {})
422
+ if categories:
423
+ top_cat = max(categories, key=categories.get)
424
+ top_cat_count = categories[top_cat]
425
+ else:
426
+ top_cat = '-'
427
+ top_cat_count = 0
428
+
429
+ print(f"{row['model']:<20} {row['n_trials']:>8} {row['total_tool_calls']:>12} "
430
+ f"{row['total_failed_calls']:>10} {row['avg_failure_rate_pct']:>9.2f}% "
431
+ f"{top_cat} ({top_cat_count})")
432
+
433
+ def plot_failure_rate_by_model(summary_df: pd.DataFrame):
434
+ """
435
+ Figure 1: Overall failure rate per model (horizontal bar chart).
436
+ """
437
+ plt.rcParams.update(PLOT_PARAMETERS)
438
+
439
+ fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 3.0))
440
+
441
+ data = summary_df.sort_values("avg_failure_rate_pct", ascending=True)
442
+
443
+ colors = get_color_palette(len(data))
444
+
445
+ bars = ax.barh(data["model"], data["avg_failure_rate_pct"],
446
+ color=colors, edgecolor='black', linewidth=0.5)
447
+
448
+ # Add error bars for std
449
+ ax.errorbar(data["avg_failure_rate_pct"], range(len(data)),
450
+ xerr=data["std_failure_rate_pct"], fmt='none',
451
+ color='black', capsize=2, linewidth=0.5)
452
+
453
+ # Add value labels
454
+ for i, (bar, val, std) in enumerate(zip(bars, data["avg_failure_rate_pct"], data["std_failure_rate_pct"])):
455
+ ax.text(val + std + 0.5, bar.get_y() + bar.get_height()/2,
456
+ f'{val:.1f}%', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1)
457
+
458
+ ax.set_xlabel("Average Failure Rate (%)")
459
+ ax.set_xlim(0, data["avg_failure_rate_pct"].max() + data["std_failure_rate_pct"].max() + 5)
460
+
461
+ plt.title("Tool Call Failure Rate")
462
+
463
+ plt.tight_layout()
464
+ plt.show()
465
+ fig.savefig(OUTPUT_DIR / "fig_failure_rate_by_model.png")
466
+ plt.close(fig)
467
+ print("Saved: fig_failure_rate_by_model.png")
468
+
469
+
470
+ def parse_dict_column(col_str):
471
+ """Parse a dictionary column stored as string."""
472
+ if pd.isna(col_str) or col_str == '{}':
473
+ return {}
474
+ try:
475
+ return ast.literal_eval(col_str)
476
+ except:
477
+ return {}
478
+
479
+
480
+ def plot_failure_categories_stacked(summary_df: pd.DataFrame):
481
+ """
482
+ Figure 2: Failure category breakdown per model (stacked bar).
483
+ """
484
+ plt.rcParams.update(PLOT_PARAMETERS)
485
+
486
+ fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 3.0))
487
+
488
+ # Parse failure categories
489
+ data = summary_df.copy()
490
+ # data['failure_categories'] = data['failure_categories'].apply(parse_dict_column)
491
+
492
+ print(data['failure_categories'])
493
+
494
+ # Get all categories and sort by total
495
+ all_categories = defaultdict(int)
496
+ for cats in data['failure_categories']:
497
+ for cat, count in cats.items():
498
+ all_categories[cat] += count
499
+
500
+ CATEGORY_COLORS = {
501
+ 'python_syntax': '#e41a1c',
502
+ 'python_type': '#377eb8',
503
+ 'python_name': '#4daf4a',
504
+ 'file_not_found': '#984ea3',
505
+ 'json_parse': '#ff7f00',
506
+ 'shell_command': '#a65628',
507
+ 'timeout': '#f781bf',
508
+ 'memory': '#999999',
509
+ 'other_python': '#66c2a5',
510
+ 'other': '#8da0cb',
511
+ }
512
+
513
+ # Sort categories by total count
514
+ sorted_cats = sorted(all_categories.keys(), key=lambda x: all_categories[x], reverse=True)[:8]
515
+
516
+ # Build data for stacked bar
517
+ data = data.sort_values('total_failed_calls', ascending=True)
518
+
519
+ bottom = np.zeros(len(data))
520
+
521
+ for cat in sorted_cats:
522
+ values = [row['failure_categories'].get(cat, 0) for _, row in data.iterrows()]
523
+ color = CATEGORY_COLORS.get(cat, '#888888')
524
+ ax.barh(data['model'], values, left=bottom,
525
+ label=cat.replace('_', ' ').title(), color=color,
526
+ edgecolor='white', linewidth=0.3)
527
+ bottom += values
528
+
529
+ ax.set_xlabel("Number of Failed Tool Calls")
530
+ ax.legend(loc='lower right', ncol=2, fontsize=MIN_FONT_SIZE - 1,
531
+ framealpha=0.9, bbox_to_anchor=(1.0, 0.0))
532
+
533
+ plt.title("Tool Failure Category Distribution")
534
+
535
+ plt.tight_layout()
536
+ plt.show()
537
+ fig.savefig(OUTPUT_DIR / "fig_failure_categories_stacked.png")
538
+ plt.close(fig)
539
+ print("Saved: fig_failure_categories_stacked.png")
540
+
541
+
542
+ def main():
543
+ print("Extracting tool call failure data for 'react with code' agents...")
544
+ print(f"Reading from directories: {LEADERBOARD_DIR}")
545
+ print(f"Output directory: {OUTPUT_DIR}")
546
+
547
+ summary_df, detail_df, failures_df = extract_all_data()
548
+
549
+ if len(summary_df) == 0:
550
+ print("No data extracted!")
551
+ return
552
+
553
+ save_data(summary_df, detail_df, failures_df)
554
+ print_summary(summary_df)
555
+
556
+ plot_failure_categories_stacked(summary_df)
557
+
558
+ if __name__ == "__main__":
559
+ main()
560
+
analysis_src/model_styles.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Shared model styling configuration for paper analysis figures.
4
+
5
+ Provides consistent colors, markers, and display names across all agents
6
+ (EOG, React with Code, and future agents).
7
+
8
+ Usage:
9
+ from paper_analysis.common.model_styles import get_model_style, MODEL_DISPLAY_NAMES
10
+
11
+ style = get_model_style("GPT-5.1")
12
+ ax.scatter(x, y, c=style['color'], marker=style['marker'], ...)
13
+ """
14
+
15
+ import seaborn as sns
16
+
17
+ # =============================================================================
18
+ # MODEL DISPLAY NAMES
19
+ # Maps various raw names to standardized display names
20
+ # =============================================================================
21
+
22
+ MODEL_DISPLAY_NAMES = {
23
+ # OpenAI / Azure
24
+ "Azure_gpt-5.1-2025-11-13": "GPT-5.1",
25
+ "Azure_gpt-5.1-chat-2025-11-13": "GPT-5.1",
26
+ "Azure_o4-mini": "o4-mini",
27
+ "Azure_gpt-4o": "GPT-4o",
28
+ "openai_gpt-oss-120b": "GPT-OSS-120B",
29
+ "openai_gpt-oss-20b": "GPT-OSS-20B",
30
+ # Google / GCP
31
+ "GCP_gemini-2.5-pro": "Gemini 2.5 Pro",
32
+ "gemini-2.5-pro": "Gemini 2.5 Pro",
33
+ "Gemini-2.5-Pro": "Gemini 2.5 Pro",
34
+ "gcp_gemini-3-pro-preview": "Gemini 3 Pro",
35
+ "gemini-3-pro-preview": "Gemini 3 Pro",
36
+ "Gemini-3-Pro": "Gemini 3 Pro",
37
+ "gemini-3-flash-preview": "Gemini 3 Flash",
38
+ "Gemini-3-Flash": "Gemini 3 Flash",
39
+ "google_gemini-3-flash-preview": "Gemini 3 Flash",
40
+ # Moonshot AI
41
+ "moonshotai_kimi-k2-thinking": "Kimi K2",
42
+ "kimi-k2-thinking": "Kimi K2",
43
+ "Kimi-K2": "Kimi K2",
44
+ # Anthropic / AWS
45
+ "aws_claude-opus-4-5": "Claude Opus 4.5",
46
+ "Claude-Opus-4.5": "Claude Opus 4.5",
47
+ # Mistral AI
48
+ "mistralai_mistral-large-2512": "Mistral Large",
49
+ "Mistral-Large": "Mistral Large",
50
+ # Alibaba / Qwen
51
+ "qwen_qwen3-vl-32b-instruct": "Qwen3-VL-32B",
52
+ # ServiceNow
53
+ "ServiceNow-AI_Apriel-1.6-15b-Thinker": "Apriel-1.6-15B",
54
+ # Minimax
55
+ "minimax_minimax-m2.1": "Minimax M2.1",
56
+ }
57
+
58
+ # =============================================================================
59
+ # MODEL STYLES
60
+ # Defines color and marker for each model (by display name)
61
+ # Colors are colorblind-friendly, markers provide redundant encoding
62
+ # =============================================================================
63
+
64
+ # Colorblind-friendly palette (based on IBM Design Library / Wong palette)
65
+ _COLORS = {
66
+ 'blue': '#0072B2',
67
+ 'orange': '#E69F00',
68
+ 'green': '#009E73',
69
+ 'pink': '#CC79A7',
70
+ 'light_blue': '#56B4E9',
71
+ 'yellow': '#F0E442',
72
+ 'red': '#D55E00',
73
+ 'gray': '#999999',
74
+ 'purple': '#9467BD',
75
+ 'brown': '#8C564B',
76
+ 'teal': '#17BECF',
77
+ }
78
+
79
+ # Marker styles for redundant encoding
80
+ _MARKERS = {
81
+ 'circle': 'o',
82
+ 'square': 's',
83
+ 'diamond': 'D',
84
+ 'triangle_up': '^',
85
+ 'triangle_down': 'v',
86
+ 'pentagon': 'p',
87
+ 'hexagon': 'h',
88
+ 'star': '*',
89
+ 'plus': 'P',
90
+ 'x': 'X',
91
+ }
92
+
93
+ # Model style definitions (display_name -> {color, marker})
94
+ MODEL_STYLES = {
95
+ # Google models - shades of blue
96
+ "Gemini 3 Flash": {
97
+ 'color': _COLORS['blue'],
98
+ 'marker': _MARKERS['circle'],
99
+ },
100
+ "Gemini 3 Pro": {
101
+ 'color': _COLORS['light_blue'],
102
+ 'marker': _MARKERS['circle'],
103
+ },
104
+ "Gemini 2.5 Pro": {
105
+ 'color': _COLORS['teal'],
106
+ 'marker': _MARKERS['circle'],
107
+ },
108
+
109
+ # OpenAI models - shades of green/orange
110
+ "GPT-5.1": {
111
+ 'color': _COLORS['green'],
112
+ 'marker': _MARKERS['square'],
113
+ },
114
+ "GPT-4o": {
115
+ 'color': _COLORS['green'],
116
+ 'marker': _MARKERS['diamond'],
117
+ },
118
+ "o4-mini": {
119
+ 'color': _COLORS['yellow'],
120
+ 'marker': _MARKERS['square'],
121
+ },
122
+ "GPT-OSS-120B": {
123
+ 'color': _COLORS['orange'],
124
+ 'marker': _MARKERS['triangle_up'],
125
+ },
126
+ "GPT-OSS-20B": {
127
+ 'color': _COLORS['brown'],
128
+ 'marker': _MARKERS['triangle_down'],
129
+ },
130
+
131
+ # Anthropic models - pink
132
+ "Claude Opus 4.5": {
133
+ 'color': _COLORS['pink'],
134
+ 'marker': _MARKERS['diamond'],
135
+ },
136
+
137
+ # Moonshot AI - red
138
+ "Kimi K2": {
139
+ 'color': _COLORS['red'],
140
+ 'marker': _MARKERS['pentagon'],
141
+ },
142
+
143
+ # Mistral - purple
144
+ "Mistral Large": {
145
+ 'color': _COLORS['purple'],
146
+ 'marker': _MARKERS['hexagon'],
147
+ },
148
+
149
+ # Minimax - gray
150
+ "Minimax M2.1": {
151
+ 'color': _COLORS['gray'],
152
+ 'marker': _MARKERS['star'],
153
+ },
154
+
155
+ # Qwen - teal
156
+ "Qwen3-VL-32B": {
157
+ 'color': _COLORS['teal'],
158
+ 'marker': _MARKERS['plus'],
159
+ },
160
+
161
+ # ServiceNow - brown
162
+ "Apriel-1.6-15B": {
163
+ 'color': _COLORS['brown'],
164
+ 'marker': _MARKERS['x'],
165
+ },
166
+ }
167
+
168
+ # Default style for unknown models
169
+ _DEFAULT_STYLE = {
170
+ 'color': _COLORS['gray'],
171
+ 'marker': _MARKERS['circle'],
172
+ }
173
+
174
+
175
+ def get_display_name(raw_name: str) -> str:
176
+ """Convert raw model name to display name."""
177
+ return MODEL_DISPLAY_NAMES.get(raw_name, raw_name)
178
+
179
+
180
+ def get_model_style(model_name: str) -> dict:
181
+ """
182
+ Get the style (color, marker) for a model.
183
+
184
+ Args:
185
+ model_name: Either raw name or display name
186
+
187
+ Returns:
188
+ Dict with 'color' and 'marker' keys
189
+ """
190
+ # Try display name first
191
+ if model_name in MODEL_STYLES:
192
+ return MODEL_STYLES[model_name]
193
+
194
+ # Try converting from raw name
195
+ display_name = get_display_name(model_name)
196
+ if display_name in MODEL_STYLES:
197
+ return MODEL_STYLES[display_name]
198
+
199
+ # Return default
200
+ return _DEFAULT_STYLE
201
+
202
+
203
+ def get_model_color(model_name: str) -> str:
204
+ """Get just the color for a model."""
205
+ return get_model_style(model_name)['color']
206
+
207
+
208
+ def get_model_marker(model_name: str) -> str:
209
+ """Get just the marker for a model."""
210
+ return get_model_style(model_name)['marker']
211
+
212
+
213
+ # For backward compatibility - create a color palette list
214
+ def get_color_palette(n_colors: int = 10):
215
+ """Get a colorblind-friendly palette with n colors."""
216
+ palette_order = ['blue', 'orange', 'green', 'pink', 'light_blue',
217
+ 'red', 'purple', 'brown', 'teal', 'yellow']
218
+ return [_COLORS[c] for c in palette_order[:n_colors]]
219
+
220
+ SINGLE_COLUMN_WIDTH = 3.25 # inches (ICML)
221
+ DOUBLE_COLUMN_WIDTH = 6.75 # inches (ICML)
222
+ MIN_FONT_SIZE = 10
223
+
224
+ PLOT_PARAMETERS = {
225
+ 'font.size': MIN_FONT_SIZE,
226
+ 'font.family': 'serif',
227
+ 'axes.labelsize': MIN_FONT_SIZE + 1,
228
+ 'axes.titlesize': MIN_FONT_SIZE + 2,
229
+ 'xtick.labelsize': MIN_FONT_SIZE,
230
+ 'ytick.labelsize': MIN_FONT_SIZE,
231
+ 'legend.fontsize': MIN_FONT_SIZE,
232
+ 'figure.titlesize': MIN_FONT_SIZE + 2,
233
+ 'figure.dpi': 150,
234
+ 'savefig.dpi': 300,
235
+ 'savefig.bbox': 'tight',
236
+ 'axes.spines.top': False,
237
+ 'axes.spines.right': False,
238
+ 'axes.linewidth': 0.8,
239
+ 'lines.linewidth': 1.0,
240
+ 'patch.linewidth': 0.5,
241
+ }
analysis_src/utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+
4
+ # Model display names (short for figures)
5
+ # Follows ArtificialAnalysis.ai naming conventions
6
+ MODEL_DISPLAY_NAMES = {
7
+ # OpenAI / Azure
8
+ "Azure_gpt-5.1-2025-11-13": "GPT-5.1",
9
+ "Azure_gpt-5.1-chat-2025-11-13": "GPT-5.1",
10
+ "Azure_o4-mini": "o4-mini",
11
+ "Azure_gpt-4o": "GPT-4o",
12
+ "openai_gpt-oss-120b": "GPT-OSS-120B",
13
+ "openai_gpt-oss-20b": "GPT-OSS-20B",
14
+ # Google / GCP
15
+ "GCP_gemini-2.5-pro": "Gemini 2.5 Pro",
16
+ "gemini-2.5-pro": "Gemini 2.5 Pro",
17
+ "gcp_gemini-3-pro-preview": "Gemini 3 Pro",
18
+ "gemini-3-pro-preview": "Gemini 3 Pro",
19
+ "gemini-3-flash-preview": "Gemini 3 Flash",
20
+ "google_gemini-3-flash-preview": "Gemini 3 Flash",
21
+ # Moonshot AI
22
+ "moonshotai_kimi-k2-thinking": "Kimi K2",
23
+ "kimi-k2-thinking": "Kimi K2",
24
+ # Anthropic / AWS
25
+ "aws_claude-opus-4-5": "Claude Opus 4.5",
26
+ # Mistral AI
27
+ "mistralai_mistral-large-2512": "Mistral Large",
28
+ # Alibaba / Qwen
29
+ "qwen_qwen3-vl-32b-instruct": "Qwen3-VL-32B",
30
+ # ServiceNow
31
+ "ServiceNow-AI_Apriel-1.6-15b-Thinker": "Apriel-1.6-15B",
32
+ }
33
+
34
+
35
+ def get_model_name(dirname: str) -> str:
36
+ """Extract model name from directory name."""
37
+ name = dirname.replace("react with code_", "").replace("_07ccdb1", "")
38
+ return MODEL_DISPLAY_NAMES.get(name, name)
39
+
40
+
41
+ def find_react_with_code_dirs(leaderboard_dir: Path) -> list[Path]:
42
+ """Find all 'react with code' agent directories (non-backup)."""
43
+ dirs = []
44
+ for d in leaderboard_dir.iterdir():
45
+ if d.is_dir() and d.name.startswith("react with code_") and not d.name.startswith("backup_"):
46
+ dirs.append(d)
47
+ return sorted(dirs)
48
+
49
+
50
+ def read_judge_outputs_from_dir(agent_dir: Path) -> dict[str, list[dict]]:
51
+ """
52
+ Read all judge_output.json files from an agent directory.
53
+
54
+ Returns:
55
+ Dict mapping scenario_id -> list of judge outputs (one per trial)
56
+ """
57
+ scenario_data = {}
58
+
59
+ for scenario_dir in agent_dir.iterdir():
60
+ if not scenario_dir.is_dir() or not scenario_dir.name.startswith("Scenario"):
61
+ continue
62
+
63
+ scenario_id = scenario_dir.name
64
+ trials = []
65
+
66
+ # Look for trial subdirectories (1, 2, 3, etc.)
67
+ for trial_dir in sorted(scenario_dir.iterdir()):
68
+ if not trial_dir.is_dir():
69
+ continue
70
+
71
+ judge_file = trial_dir / "judge_output.json"
72
+ if judge_file.exists():
73
+ try:
74
+ with open(judge_file) as f:
75
+ judge_data = json.load(f)
76
+ trials.append(judge_data)
77
+ except Exception as e:
78
+ print(f" Warning: Error reading {judge_file}: {e}")
79
+
80
+ if trials:
81
+ scenario_data[scenario_id] = trials
82
+
83
+ return scenario_data
84
+
85
+
86
+ def extract_trial_scores_from_judge_outputs(
87
+ scenario_data: dict[str, list[dict]],
88
+ metric: str
89
+ ) -> dict[str, list[float]]:
90
+ """
91
+ Extract per-trial scores for a given metric from judge outputs.
92
+
93
+ Args:
94
+ scenario_data: Dict mapping scenario_id -> list of judge outputs
95
+ metric: The metric name to extract
96
+
97
+ Returns:
98
+ Dict mapping scenario_id -> list of trial scores
99
+ """
100
+ scenario_trials = {}
101
+
102
+ for scenario_id, trials in scenario_data.items():
103
+ scores = []
104
+ for trial in trials:
105
+ flat_scores = trial.get("flat_scores", {})
106
+ score = flat_scores.get(metric)
107
+
108
+ # Handle None/null values
109
+ if score is None:
110
+ score = 0.0
111
+ scores.append(float(score))
112
+
113
+ if scores:
114
+ scenario_trials[scenario_id] = scores
115
+
116
+ return scenario_trials
117
+
118
+
119
+ def get_runs_stats(scenario_data: dict[str, list], min_runs_required: int) -> tuple[int, int, int, int]:
120
+ """Get run statistics: (n_scenarios, min_runs, max_runs, n_qualifying)."""
121
+ if not scenario_data:
122
+ return 0, 0, 0, 0
123
+
124
+ run_counts = [len(trials) for trials in scenario_data.values()]
125
+ n_qualifying = sum(1 for c in run_counts if c >= min_runs_required)
126
+ return len(scenario_data), min(run_counts), max(run_counts), n_qualifying
127
+
128
+
129
+ def filter_scenarios_with_min_runs(scenario_data: dict[str, list], min_runs_required: int) -> dict[str, list]:
130
+ """Filter to only include scenarios with >= min_runs_required runs."""
131
+ return {
132
+ scenario_id: trials
133
+ for scenario_id, trials in scenario_data.items()
134
+ if len(trials) >= min_runs_required
135
+ }
136
+
137
+ def find_latest_rollout_file(trial_dir: Path) -> Path:
138
+ """Find the latest rollout file in a trial's sessions directory."""
139
+ sessions_dir = trial_dir / "sessions"
140
+ if not sessions_dir.exists():
141
+ return None
142
+
143
+ rollout_files = []
144
+ for rollout_file in sessions_dir.rglob("rollout-*.jsonl"):
145
+ rollout_files.append(rollout_file)
146
+
147
+ if not rollout_files:
148
+ return None
149
+
150
+ # Sort by filename (which includes timestamp) and return the latest
151
+ # rollout_files.sort(key=lambda f: f.name, reverse=True)
152
+ # return rollout_files[0]
153
+
154
+ return max(rollout_files, key=lambda p: p.stat().st_mtime)
155
+
evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff