#!/usr/bin/env python3 """ Exploration Breadth Analysis by Diagnosis Correctness Creates a plot comparing exploration breadth between: - Correct diagnoses (recall > 0, i.e., root_cause_f1 > 0) - Incorrect diagnoses (recall = 0, i.e., root_cause_f1 == 0) Uses semantic entity grouping to avoid counting "frontend deployment" and "frontend service" as separate entities. """ import json import sys import re from pathlib import Path from dataclasses import dataclass from typing import Optional, List, Dict, Set, Tuple import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from tqdm import tqdm # Publication settings - ICML half column HALF_COLUMN_WIDTH = 3.25 # inches MIN_FONT_SIZE = 8 plt.rcParams.update({ 'font.size': MIN_FONT_SIZE, 'font.family': 'serif', 'axes.labelsize': MIN_FONT_SIZE, 'axes.titlesize': MIN_FONT_SIZE + 1, 'xtick.labelsize': MIN_FONT_SIZE, 'ytick.labelsize': MIN_FONT_SIZE, 'legend.fontsize': MIN_FONT_SIZE, 'figure.dpi': 150, 'savefig.dpi': 300, 'savefig.bbox': 'tight', 'axes.spines.top': False, 'axes.spines.right': False, }) # Paths PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from analysis_src.utils import find_latest_rollout_file from analysis_src.model_styles import ( get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS ) # Paths LEADERBOARD_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Trajectories" / "ReAct-Agent-Trajectories" GT_DIR = PROJECT_ROOT / "data" / "itbench-snapshots" OUTPUT_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Trajectories" / "output" / "discovery" # Regex for K8s entities K8S_ENTITY_PATTERN = re.compile( r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|' r'DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|' r'PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|' r'ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|' r'Node|Schedule|NetworkChaos|StressChaos|PodChaos)/([\w-]+)', re.IGNORECASE ) # Service name normalization patterns SERVICE_NORMALIZATIONS = { # Map specific variations to canonical names 'frontend-proxy': 'frontend-proxy', 'frontendproxy': 'frontend-proxy', 'frontend': 'frontend', 'checkout': 'checkout', 'checkoutservice': 'checkout', 'cart': 'cart', 'cartservice': 'cart', 'shipping': 'shipping', 'shippingservice': 'shipping', 'product-catalog': 'product-catalog', 'productcatalog': 'product-catalog', 'productcatalogservice': 'product-catalog', 'recommendation': 'recommendation', 'recommendationservice': 'recommendation', 'email': 'email', 'emailservice': 'email', 'payment': 'payment', 'paymentservice': 'payment', 'currency': 'currency', 'currencyservice': 'currency', 'ad': 'ad', 'adservice': 'ad', 'fraud-detection': 'fraud-detection', 'frauddetection': 'fraud-detection', 'frauddetectionservice': 'fraud-detection', 'load-generator': 'load-generator', 'loadgenerator': 'load-generator', 'flagd': 'flagd', 'otel-collector': 'otel-collector', 'valkey': 'valkey', 'valkey-cart': 'valkey', # valkey instance for cart 'redis': 'valkey', # alias 'kafka': 'kafka', 'quote': 'quote', 'quoteservice': 'quote', 'accounting': 'accounting', 'accountingservice': 'accounting', 'otel-demo': 'otel-demo', # namespace 'imageprovider': 'imageprovider', 'flagdui': 'flagdui', 'opensearch': 'opensearch', 'grafana': 'grafana', 'jaeger': 'jaeger', 'prometheus': 'prometheus', } # Model name mapping for cleaner labels MODEL_NAMES = { 'Azure_gpt-5.1-2025-11-13': 'GPT-5.1', 'Azure_o4-mini': 'o4-mini', 'GCP_gemini-2.5-pro': 'Gemini 2.5 Pro', 'gcp_gemini-3-pro-preview': 'Gemini 3 Pro', 'gemini-3-pro-preview': 'Gemini 3 Pro', 'gemini-3-flash-preview': 'Gemini 3 Flash', 'moonshotai_kimi-k2-thinking': 'Kimi K2', 'aws_claude-opus-4-5': 'Claude Opus 4.5', 'openai_gpt-oss-120b': 'GPT-OSS-120B', } def normalize_entity_to_logical(entity: str) -> str: """ Normalize an entity to its logical/canonical service name. e.g., "otel-demo/Deployment/frontend-abc123" -> "frontend" "otel-demo/Service/checkout" -> "checkout" "chaos-mesh/NetworkChaos/xyz" -> "chaos:NetworkChaos" """ parts = entity.lower().split('/') # Handle chaos-mesh specially if 'chaos-mesh' in parts[0] if parts else '': if len(parts) >= 2: return f"chaos:{parts[1]}" return "chaos" # Get the name part (last component) if len(parts) >= 3: name = parts[2] elif len(parts) >= 1: name = parts[-1] else: return entity.lower() # Strip pod suffixes (e.g., frontend-5d4f6b7c8d-xyz9a -> frontend) # Pattern: name followed by hash-like suffixes from ReplicaSets/Pods # ReplicaSet adds - and Pod adds - # e.g., frontend-5d4f6b7c8d-xyz9a -> strip -5d4f6b7c8d-xyz9a name = re.sub(r'-[a-f0-9]{8,10}-[a-z0-9]{5}$', '', name) # Pod suffix (RS hash + Pod hash) name = re.sub(r'-[a-f0-9]{8,10}$', '', name) # ReplicaSet suffix only (10-char hex hash) # Also strip numeric suffixes like -1, -2 from entity names name = re.sub(r'-\d+$', '', name) # First check for exact match (most reliable) if name in SERVICE_NORMALIZATIONS: return SERVICE_NORMALIZATIONS[name] # Try matching with service name variations # Sort by length descending so longer patterns match first (frontend-proxy before frontend) for pattern in sorted(SERVICE_NORMALIZATIONS.keys(), key=len, reverse=True): canonical = SERVICE_NORMALIZATIONS[pattern] # Exact match or name starts with pattern followed by typical suffixes if name == pattern: return canonical # e.g., "checkoutservice" starts with "checkout" if name.startswith(pattern) and ( len(name) == len(pattern) or name[len(pattern):].startswith('service') or name[len(pattern):].startswith('-') ): return canonical # Fallback: return cleaned name return name def extract_k8s_entities(text: str) -> List[str]: """Extract all K8s entities from text.""" matches = K8S_ENTITY_PATTERN.findall(text) entities = [] for m in matches: entity = f"{m[0]}/{m[1]}/{m[2]}" entities.append(entity) return entities def extract_logical_entities(text: str) -> Set[str]: """Extract and normalize entities to logical names.""" raw_entities = extract_k8s_entities(text) return {normalize_entity_to_logical(e) for e in raw_entities} def get_latest_rollout(trial_dir: Path) -> Optional[Path]: """Get the latest rollout file from a trial directory.""" sessions_dir = trial_dir / "sessions" if not sessions_dir.exists(): return None rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl")) if not rollout_files: return None return max(rollout_files, key=lambda p: p.stat().st_mtime) def get_judge_f1(trial_dir: Path) -> float: """Get root_cause_entity_f1 from judge output.""" judge_path = trial_dir / "judge_output.json" if not judge_path.exists(): return 0.0 try: with open(judge_path) as f: judge_data = json.load(f) return judge_data.get('flat_scores', {}).get('root_cause_entity_f1', 0.0) or 0.0 except: return 0.0 def count_semantic_entities_investigated(rollout_path: Path) -> int: """ Count unique semantic entity groups investigated in a rollout. Uses normalization to group similar entities: - otel-demo/Deployment/frontend and otel-demo/Service/frontend -> 1 entity ("frontend") - otel-demo/Pod/frontend-abc123 and otel-demo/Pod/frontend-xyz456 -> 1 entity ("frontend") """ investigated_logical = set() with open(rollout_path) as f: for line in f: try: obj = json.loads(line) except json.JSONDecodeError: continue if obj.get('type') != 'response_item': continue payload = obj.get('payload', {}) # Check tool arguments (investigation = active querying) if payload.get('type') == 'function_call': args = payload.get('arguments', {}) if isinstance(args, str): try: args = json.loads(args) except: args = {'raw': args} args_str = json.dumps(args) # Extract and normalize entities logical_entities = extract_logical_entities(args_str) investigated_logical.update(logical_entities) return len(investigated_logical) def analyze_all_trials() -> pd.DataFrame: """ Analyze all trials from react with code agents. Returns DataFrame with model, scenario, trial, f1_score, semantic_entities_investigated. """ results = [] # Find react with code agents model_dirs = [d for d in LEADERBOARD_DIR.iterdir() if d.is_dir() and d.name.startswith("react with code_")] print(f"Found {len(model_dirs)} agent models") for model_dir in tqdm(model_dirs, desc="Processing models"): model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0] print(f"Processing {model_name}...") scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")] for scenario_dir in tqdm(scenario_dirs, desc=f" {model_name} scenarios", leave=False): scenario = scenario_dir.name trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()] for trial_dir in tqdm(trial_dirs, desc=f" {scenario} trials", leave=False): trial_num = int(trial_dir.name) rollout_path = get_latest_rollout(trial_dir) if rollout_path is None: continue try: f1_score = get_judge_f1(trial_dir) semantic_count = count_semantic_entities_investigated(rollout_path) results.append({ 'model': model_name, 'scenario': scenario, 'trial': trial_num, 'root_cause_f1': f1_score, 'is_correct': f1_score > 0, 'semantic_entities_investigated': semantic_count }) except Exception as e: print(f" Error processing {model_name}/{scenario}/{trial_num}: {e}") return pd.DataFrame(results) def clean_model_name(name: str) -> str: return MODEL_NAMES.get(name, name) def plot_exploration_by_correctness(df: pd.DataFrame): """ Plot comparing exploration breadth between correct and incorrect diagnoses. Creates a grouped bar chart or box plot. """ # Aggregate by model and correctness agg = df.groupby(['model', 'is_correct']).agg({ 'semantic_entities_investigated': ['mean', 'std', 'count'] }).reset_index() agg.columns = ['model', 'is_correct', 'mean_entities', 'std_entities', 'n_trials'] # Pivot for easier plotting correct_df = agg[agg['is_correct'] == True].set_index('model') incorrect_df = agg[agg['is_correct'] == False].set_index('model') # Get all models that have both correct and incorrect trials models_both = set(correct_df.index) & set(incorrect_df.index) # Create comparison data comparison_data = [] for model in models_both: comparison_data.append({ 'model': model, 'model_clean': clean_model_name(model), 'correct_mean': correct_df.loc[model, 'mean_entities'], 'correct_std': correct_df.loc[model, 'std_entities'], 'correct_n': correct_df.loc[model, 'n_trials'], 'incorrect_mean': incorrect_df.loc[model, 'mean_entities'], 'incorrect_std': incorrect_df.loc[model, 'std_entities'], 'incorrect_n': incorrect_df.loc[model, 'n_trials'], }) comp_df = pd.DataFrame(comparison_data) comp_df = comp_df.sort_values('correct_mean', ascending=True) # === Figure 1: Grouped bar chart === fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 3.0)) y = np.arange(len(comp_df)) bar_height = 0.35 # Incorrect (red) and Correct (green) bars bars_incorrect = ax.barh(y - bar_height/2, comp_df['incorrect_mean'], height=bar_height, label='Incorrect (recall=0)', color='#d62728', edgecolor='black', linewidth=0.3, alpha=0.8) bars_correct = ax.barh(y + bar_height/2, comp_df['correct_mean'], height=bar_height, label='Correct (recall>0)', color='#2ca02c', edgecolor='black', linewidth=0.3, alpha=0.8) ax.set_yticks(y) ax.set_yticklabels(comp_df['model_clean']) ax.set_xlabel('Avg. Semantic Entity Groups Investigated') # Add value labels for i, (bar_i, bar_c) in enumerate(zip(bars_incorrect, bars_correct)): # Incorrect ax.text(bar_i.get_width() + 0.1, bar_i.get_y() + bar_i.get_height()/2, f'{bar_i.get_width():.1f}', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1, color='#d62728') # Correct ax.text(bar_c.get_width() + 0.1, bar_c.get_y() + bar_c.get_height()/2, f'{bar_c.get_width():.1f}', va='center', ha='left', fontsize=MIN_FONT_SIZE - 1, color='#2ca02c') ax.legend(loc='lower right', frameon=False, fontsize=MIN_FONT_SIZE) plt.tight_layout() fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.pdf") fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.png") plt.close(fig) print(f"Saved: fig_exploration_by_correctness.pdf/png") # === Figure 2: Box plot distribution === fig2, ax2 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 1.5, 3.5)) # Prepare data for box plot df['correctness'] = df['is_correct'].map({True: 'Correct\n(recall>0)', False: 'Incorrect\n(recall=0)'}) df['model_clean'] = df['model'].apply(clean_model_name) # Order models by overall median exploration model_order = df.groupby('model_clean')['semantic_entities_investigated'].median().sort_values().index.tolist() # Create box plot with hue sns.boxplot(data=df, x='model_clean', y='semantic_entities_investigated', hue='correctness', order=model_order, ax=ax2, palette={'Correct\n(recall>0)': '#2ca02c', 'Incorrect\n(recall=0)': '#d62728'}, linewidth=0.5, fliersize=2) ax2.set_xlabel('') ax2.set_ylabel('Semantic Entity Groups Investigated') ax2.tick_params(axis='x', rotation=45) ax2.legend(title='', loc='upper left', frameon=False, fontsize=MIN_FONT_SIZE) plt.tight_layout() fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.pdf") fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.png") plt.close(fig2) print(f"Saved: fig_exploration_by_correctness_boxplot.pdf/png") # === Figure 3: Aggregated across all models === fig3, ax3 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 0.8, 2.5)) correct_all = df[df['is_correct'] == True]['semantic_entities_investigated'] incorrect_all = df[df['is_correct'] == False]['semantic_entities_investigated'] # Violin plot for overall distribution parts = ax3.violinplot([incorrect_all, correct_all], positions=[0, 1], showmeans=True, showmedians=True) # Color the violins colors = ['#d62728', '#2ca02c'] for i, pc in enumerate(parts['bodies']): pc.set_facecolor(colors[i]) pc.set_alpha(0.7) # Style the other elements for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']: if partname in parts: parts[partname].set_edgecolor('black') parts[partname].set_linewidth(0.5) ax3.set_xticks([0, 1]) ax3.set_xticklabels(['Incorrect\n(recall=0)', 'Correct\n(recall>0)']) ax3.set_ylabel('Semantic Entities Investigated') # Add mean values as text ax3.text(0, incorrect_all.mean() + 0.5, f'μ={incorrect_all.mean():.1f}', ha='center', fontsize=MIN_FONT_SIZE, color='#d62728') ax3.text(1, correct_all.mean() + 0.5, f'μ={correct_all.mean():.1f}', ha='center', fontsize=MIN_FONT_SIZE, color='#2ca02c') # Add n counts ax3.text(0, ax3.get_ylim()[0] + 0.5, f'n={len(incorrect_all)}', ha='center', fontsize=MIN_FONT_SIZE - 1) ax3.text(1, ax3.get_ylim()[0] + 0.5, f'n={len(correct_all)}', ha='center', fontsize=MIN_FONT_SIZE - 1) plt.tight_layout() fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.pdf") fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.png") plt.close(fig3) print(f"Saved: fig_exploration_overall_correctness.pdf/png") # Print statistics print("\n" + "=" * 60) print("Exploration Breadth by Diagnosis Correctness") print("=" * 60) print(f"\nOverall Statistics:") print(f" Correct diagnoses (n={len(correct_all)}): mean={correct_all.mean():.2f}, median={correct_all.median():.1f}") print(f" Incorrect diagnoses (n={len(incorrect_all)}): mean={incorrect_all.mean():.2f}, median={incorrect_all.median():.1f}") # Statistical test from scipy import stats stat, pvalue = stats.mannwhitneyu(correct_all, incorrect_all, alternative='two-sided') print(f"\n Mann-Whitney U test: U={stat:.0f}, p={pvalue:.4f}") print(f"\nPer-Model Comparison:") print(f"{'Model':<20} {'Correct':>12} {'Incorrect':>12} {'Diff':>8}") print("-" * 55) for _, row in comp_df.sort_values('correct_mean', ascending=False).iterrows(): diff = row['correct_mean'] - row['incorrect_mean'] print(f"{row['model_clean']:<20} {row['correct_mean']:>10.1f} (n={int(row['correct_n'])}) " f"{row['incorrect_mean']:>10.1f} (n={int(row['incorrect_n'])}) {diff:>+7.1f}") return comp_df def plot_success_by_exploration_bins(df: pd.DataFrame): """ Plot showing success rate as a function of exploration breadth. This shows a clear dose-response relationship. """ # Create exploration bins bins = [0, 2, 4, 6, 8, 10, 100] labels = ['0-2', '3-4', '5-6', '7-8', '9-10', '11+'] df['exploration_bin'] = pd.cut(df['semantic_entities_investigated'], bins=bins, labels=labels) # Calculate success rate per bin bin_stats = [] for label in labels: subset = df[df['exploration_bin'] == label] if len(subset) > 0: success_rate = (subset['root_cause_f1'] > 0).mean() * 100 bin_stats.append({ 'bin': label, 'success_rate': success_rate, 'n': len(subset) }) stats_df = pd.DataFrame(bin_stats) # Create figure fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 2.5)) x = np.arange(len(stats_df)) bars = ax.bar(x, stats_df['success_rate'], color='#4a90d9', edgecolor='black', linewidth=0.5) ax.set_xticks(x) ax.set_xticklabels(stats_df['bin']) ax.set_xlabel('Semantic Entities Investigated') ax.set_ylabel('Correct Diagnosis Rate (%)') # Add value labels on bars for i, (bar, row) in enumerate(zip(bars, stats_df.itertuples())): height = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2, height + 1, f'{height:.0f}%', ha='center', va='bottom', fontsize=MIN_FONT_SIZE) ax.text(bar.get_x() + bar.get_width()/2, 2, f'n={row.n}', ha='center', va='bottom', fontsize=MIN_FONT_SIZE - 1, color='white') ax.set_ylim(0, 60) plt.tight_layout() fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.pdf") fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.png") plt.close(fig) print(f"Saved: fig_exploration_success_rate.pdf/png") # Also create a combined figure with both views fig2, axes = plt.subplots(1, 2, figsize=(HALF_COLUMN_WIDTH * 2 + 0.3, 2.5)) # Left: Success rate by exploration bins ax1 = axes[0] bars1 = ax1.bar(x, stats_df['success_rate'], color='#4a90d9', edgecolor='black', linewidth=0.5) ax1.set_xticks(x) ax1.set_xticklabels(stats_df['bin']) ax1.set_xlabel('Entities Investigated') ax1.set_ylabel('Correct Diagnosis Rate (%)') ax1.set_title('(a) Success vs Exploration', fontsize=MIN_FONT_SIZE + 1) for bar, row in zip(bars1, stats_df.itertuples()): ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{bar.get_height():.0f}%', ha='center', va='bottom', fontsize=MIN_FONT_SIZE - 1) ax1.set_ylim(0, 60) # Right: Exploration distribution by correctness (violin) ax2 = axes[1] correct = df[df['is_correct'] == True]['semantic_entities_investigated'] incorrect = df[df['is_correct'] == False]['semantic_entities_investigated'] parts = ax2.violinplot([incorrect, correct], positions=[0, 1], showmeans=True, showmedians=True) colors = ['#d62728', '#2ca02c'] for i, pc in enumerate(parts['bodies']): pc.set_facecolor(colors[i]) pc.set_alpha(0.7) for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']: if partname in parts: parts[partname].set_edgecolor('black') parts[partname].set_linewidth(0.5) ax2.set_xticks([0, 1]) ax2.set_xticklabels(['Incorrect', 'Correct']) ax2.set_ylabel('Entities Investigated') ax2.set_title('(b) Exploration by Outcome', fontsize=MIN_FONT_SIZE + 1) ax2.text(0, incorrect.mean() + 1, f'μ={incorrect.mean():.1f}', ha='center', fontsize=MIN_FONT_SIZE - 1, color='#d62728') ax2.text(1, correct.mean() + 1, f'μ={correct.mean():.1f}', ha='center', fontsize=MIN_FONT_SIZE - 1, color='#2ca02c') plt.tight_layout() fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.pdf") fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.png") plt.close(fig2) print(f"Saved: fig_exploration_combined.pdf/png") def main(): print("=" * 60) print("Exploration Breadth by Diagnosis Correctness Analysis") print("=" * 60) # Check if we can use cached data or need to re-extract cache_path = OUTPUT_DIR / "exploration_by_correctness.csv" if cache_path.exists(): print(f"\nLoading cached data from {cache_path}") df = pd.read_csv(cache_path) else: print("\nExtracting data from rollout files (this may take a while)...") df = analyze_all_trials() df.to_csv(cache_path, index=False) print(f"Saved cache to: {cache_path}") print(f"\nLoaded {len(df)} trials from {df['model'].nunique()} models") # Generate plots print("\nGenerating figures...") plot_exploration_by_correctness(df) plot_success_by_exploration_bins(df) # NEW: dose-response plot print(f"\nDone! Figures saved to: {OUTPUT_DIR}") if __name__ == "__main__": main()