Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Exploration Breadth Analysis by Diagnosis Correctness | |
| Creates a plot comparing exploration breadth between: | |
| - Correct diagnoses (recall > 0, i.e., root_cause_f1 > 0) | |
| - Incorrect diagnoses (recall = 0, i.e., root_cause_f1 == 0) | |
| Uses semantic entity grouping to avoid counting "frontend deployment" and | |
| "frontend service" as separate entities. | |
| """ | |
| import json | |
| import sys | |
| import re | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| from typing import Optional, List, Dict, Set, Tuple | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from tqdm import tqdm | |
| # Publication settings - ICML half column | |
| HALF_COLUMN_WIDTH = 3.25 # inches | |
| MIN_FONT_SIZE = 8 | |
| plt.rcParams.update({ | |
| 'font.size': MIN_FONT_SIZE, | |
| 'font.family': 'serif', | |
| 'axes.labelsize': MIN_FONT_SIZE, | |
| 'axes.titlesize': MIN_FONT_SIZE + 1, | |
| 'xtick.labelsize': MIN_FONT_SIZE, | |
| 'ytick.labelsize': MIN_FONT_SIZE, | |
| 'legend.fontsize': MIN_FONT_SIZE, | |
| 'figure.dpi': 150, | |
| 'savefig.dpi': 300, | |
| 'savefig.bbox': 'tight', | |
| 'axes.spines.top': False, | |
| 'axes.spines.right': False, | |
| }) | |
| # Paths | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from analysis_src.utils import find_latest_rollout_file | |
| from analysis_src.model_styles import ( | |
| get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS | |
| ) | |
| # Paths | |
| LEADERBOARD_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Trajectories" / "ReAct-Agent-Trajectories" | |
| GT_DIR = PROJECT_ROOT / "data" / "itbench-snapshots" | |
| OUTPUT_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Trajectories" / "output" / "discovery" | |
| # Regex for K8s entities | |
| K8S_ENTITY_PATTERN = re.compile( | |
| r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|' | |
| r'DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|' | |
| r'PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|' | |
| r'ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|' | |
| r'Node|Schedule|NetworkChaos|StressChaos|PodChaos)/([\w-]+)', | |
| re.IGNORECASE | |
| ) | |
| # Service name normalization patterns | |
| SERVICE_NORMALIZATIONS = { | |
| # Map specific variations to canonical names | |
| 'frontend-proxy': 'frontend-proxy', | |
| 'frontendproxy': 'frontend-proxy', | |
| 'frontend': 'frontend', | |
| 'checkout': 'checkout', | |
| 'checkoutservice': 'checkout', | |
| 'cart': 'cart', | |
| 'cartservice': 'cart', | |
| 'shipping': 'shipping', | |
| 'shippingservice': 'shipping', | |
| 'product-catalog': 'product-catalog', | |
| 'productcatalog': 'product-catalog', | |
| 'productcatalogservice': 'product-catalog', | |
| 'recommendation': 'recommendation', | |
| 'recommendationservice': 'recommendation', | |
| 'email': 'email', | |
| 'emailservice': 'email', | |
| 'payment': 'payment', | |
| 'paymentservice': 'payment', | |
| 'currency': 'currency', | |
| 'currencyservice': 'currency', | |
| 'ad': 'ad', | |
| 'adservice': 'ad', | |
| 'fraud-detection': 'fraud-detection', | |
| 'frauddetection': 'fraud-detection', | |
| 'frauddetectionservice': 'fraud-detection', | |
| 'load-generator': 'load-generator', | |
| 'loadgenerator': 'load-generator', | |
| 'flagd': 'flagd', | |
| 'otel-collector': 'otel-collector', | |
| 'valkey': 'valkey', | |
| 'valkey-cart': 'valkey', # valkey instance for cart | |
| 'redis': 'valkey', # alias | |
| 'kafka': 'kafka', | |
| 'quote': 'quote', | |
| 'quoteservice': 'quote', | |
| 'accounting': 'accounting', | |
| 'accountingservice': 'accounting', | |
| 'otel-demo': 'otel-demo', # namespace | |
| 'imageprovider': 'imageprovider', | |
| 'flagdui': 'flagdui', | |
| 'opensearch': 'opensearch', | |
| 'grafana': 'grafana', | |
| 'jaeger': 'jaeger', | |
| 'prometheus': 'prometheus', | |
| } | |
| # Model name mapping for cleaner labels | |
| MODEL_NAMES = { | |
| 'Azure_gpt-5.1-2025-11-13': 'GPT-5.1', | |
| 'Azure_o4-mini': 'o4-mini', | |
| 'GCP_gemini-2.5-pro': 'Gemini 2.5 Pro', | |
| 'gcp_gemini-3-pro-preview': 'Gemini 3 Pro', | |
| 'gemini-3-pro-preview': 'Gemini 3 Pro', | |
| 'gemini-3-flash-preview': 'Gemini 3 Flash', | |
| 'moonshotai_kimi-k2-thinking': 'Kimi K2', | |
| 'aws_claude-opus-4-5': 'Claude Opus 4.5', | |
| 'openai_gpt-oss-120b': 'GPT-OSS-120B', | |
| } | |
| def normalize_entity_to_logical(entity: str) -> str: | |
| """ | |
| Normalize an entity to its logical/canonical service name. | |
| e.g., "otel-demo/Deployment/frontend-abc123" -> "frontend" | |
| "otel-demo/Service/checkout" -> "checkout" | |
| "chaos-mesh/NetworkChaos/xyz" -> "chaos:NetworkChaos" | |
| """ | |
| parts = entity.lower().split('/') | |
| # Handle chaos-mesh specially | |
| if 'chaos-mesh' in parts[0] if parts else '': | |
| if len(parts) >= 2: | |
| return f"chaos:{parts[1]}" | |
| return "chaos" | |
| # Get the name part (last component) | |
| if len(parts) >= 3: | |
| name = parts[2] | |
| elif len(parts) >= 1: | |
| name = parts[-1] | |
| else: | |
| return entity.lower() | |
| # Strip pod suffixes (e.g., frontend-5d4f6b7c8d-xyz9a -> frontend) | |
| # Pattern: name followed by hash-like suffixes from ReplicaSets/Pods | |
| # ReplicaSet adds -<hash10> and Pod adds -<hash5> | |
| # e.g., frontend-5d4f6b7c8d-xyz9a -> strip -5d4f6b7c8d-xyz9a | |
| name = re.sub(r'-[a-f0-9]{8,10}-[a-z0-9]{5}$', '', name) # Pod suffix (RS hash + Pod hash) | |
| name = re.sub(r'-[a-f0-9]{8,10}$', '', name) # ReplicaSet suffix only (10-char hex hash) | |
| # Also strip numeric suffixes like -1, -2 from entity names | |
| name = re.sub(r'-\d+$', '', name) | |
| # First check for exact match (most reliable) | |
| if name in SERVICE_NORMALIZATIONS: | |
| return SERVICE_NORMALIZATIONS[name] | |
| # Try matching with service name variations | |
| # Sort by length descending so longer patterns match first (frontend-proxy before frontend) | |
| for pattern in sorted(SERVICE_NORMALIZATIONS.keys(), key=len, reverse=True): | |
| canonical = SERVICE_NORMALIZATIONS[pattern] | |
| # Exact match or name starts with pattern followed by typical suffixes | |
| if name == pattern: | |
| return canonical | |
| # e.g., "checkoutservice" starts with "checkout" | |
| if name.startswith(pattern) and ( | |
| len(name) == len(pattern) or | |
| name[len(pattern):].startswith('service') or | |
| name[len(pattern):].startswith('-') | |
| ): | |
| return canonical | |
| # Fallback: return cleaned name | |
| return name | |
| def extract_k8s_entities(text: str) -> List[str]: | |
| """Extract all K8s entities from text.""" | |
| matches = K8S_ENTITY_PATTERN.findall(text) | |
| entities = [] | |
| for m in matches: | |
| entity = f"{m[0]}/{m[1]}/{m[2]}" | |
| entities.append(entity) | |
| return entities | |
| def extract_logical_entities(text: str) -> Set[str]: | |
| """Extract and normalize entities to logical names.""" | |
| raw_entities = extract_k8s_entities(text) | |
| return {normalize_entity_to_logical(e) for e in raw_entities} | |
| def get_latest_rollout(trial_dir: Path) -> Optional[Path]: | |
| """Get the latest rollout file from a trial directory.""" | |
| sessions_dir = trial_dir / "sessions" | |
| if not sessions_dir.exists(): | |
| return None | |
| rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl")) | |
| if not rollout_files: | |
| return None | |
| return max(rollout_files, key=lambda p: p.stat().st_mtime) | |
| def get_judge_f1(trial_dir: Path) -> float: | |
| """Get root_cause_entity_f1 from judge output.""" | |
| judge_path = trial_dir / "judge_output.json" | |
| if not judge_path.exists(): | |
| return 0.0 | |
| try: | |
| with open(judge_path) as f: | |
| judge_data = json.load(f) | |
| return judge_data.get('flat_scores', {}).get('root_cause_entity_f1', 0.0) or 0.0 | |
| except: | |
| return 0.0 | |
| def count_semantic_entities_investigated(rollout_path: Path) -> int: | |
| """ | |
| Count unique semantic entity groups investigated in a rollout. | |
| Uses normalization to group similar entities: | |
| - otel-demo/Deployment/frontend and otel-demo/Service/frontend -> 1 entity ("frontend") | |
| - otel-demo/Pod/frontend-abc123 and otel-demo/Pod/frontend-xyz456 -> 1 entity ("frontend") | |
| """ | |
| investigated_logical = set() | |
| with open(rollout_path) as f: | |
| for line in f: | |
| try: | |
| obj = json.loads(line) | |
| except json.JSONDecodeError: | |
| continue | |
| if obj.get('type') != 'response_item': | |
| continue | |
| payload = obj.get('payload', {}) | |
| # Check tool arguments (investigation = active querying) | |
| if payload.get('type') == 'function_call': | |
| args = payload.get('arguments', {}) | |
| if isinstance(args, str): | |
| try: | |
| args = json.loads(args) | |
| except: | |
| args = {'raw': args} | |
| args_str = json.dumps(args) | |
| # Extract and normalize entities | |
| logical_entities = extract_logical_entities(args_str) | |
| investigated_logical.update(logical_entities) | |
| return len(investigated_logical) | |
| def analyze_all_trials() -> pd.DataFrame: | |
| """ | |
| Analyze all trials from react with code agents. | |
| Returns DataFrame with model, scenario, trial, f1_score, semantic_entities_investigated. | |
| """ | |
| results = [] | |
| # Find react with code agents | |
| model_dirs = [d for d in LEADERBOARD_DIR.iterdir() | |
| if d.is_dir() and d.name.startswith("react with code_")] | |
| print(f"Found {len(model_dirs)} agent models") | |
| for model_dir in tqdm(model_dirs, desc="Processing models"): | |
| model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0] | |
| print(f"Processing {model_name}...") | |
| scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")] | |
| for scenario_dir in tqdm(scenario_dirs, desc=f" {model_name} scenarios", leave=False): | |
| scenario = scenario_dir.name | |
| trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()] | |
| for trial_dir in tqdm(trial_dirs, desc=f" {scenario} trials", leave=False): | |
| trial_num = int(trial_dir.name) | |
| rollout_path = get_latest_rollout(trial_dir) | |
| if rollout_path is None: | |
| continue | |
| try: | |
| f1_score = get_judge_f1(trial_dir) | |
| semantic_count = count_semantic_entities_investigated(rollout_path) | |
| results.append({ | |
| 'model': model_name, | |
| 'scenario': scenario, | |
| 'trial': trial_num, | |
| 'root_cause_f1': f1_score, | |
| 'is_correct': f1_score > 0, | |
| 'semantic_entities_investigated': semantic_count | |
| }) | |
| except Exception as e: | |
| print(f" Error processing {model_name}/{scenario}/{trial_num}: {e}") | |
| return pd.DataFrame(results) | |
| def clean_model_name(name: str) -> str: | |
| return MODEL_NAMES.get(name, name) | |
| def plot_exploration_by_correctness(df: pd.DataFrame): | |
| """ | |
| Plot comparing exploration breadth between correct and incorrect diagnoses. | |
| Creates a grouped bar chart or box plot. | |
| """ | |
| # Aggregate by model and correctness | |
| agg = df.groupby(['model', 'is_correct']).agg({ | |
| 'semantic_entities_investigated': ['mean', 'std', 'count'] | |
| }).reset_index() | |
| agg.columns = ['model', 'is_correct', 'mean_entities', 'std_entities', 'n_trials'] | |
| # Pivot for easier plotting | |
| correct_df = agg[agg['is_correct'] == True].set_index('model') | |
| incorrect_df = agg[agg['is_correct'] == False].set_index('model') | |
| # Get all models that have both correct and incorrect trials | |
| models_both = set(correct_df.index) & set(incorrect_df.index) | |
| # Create comparison data | |
| comparison_data = [] | |
| for model in models_both: | |
| comparison_data.append({ | |
| 'model': model, | |
| 'model_clean': clean_model_name(model), | |
| 'correct_mean': correct_df.loc[model, 'mean_entities'], | |
| 'correct_std': correct_df.loc[model, 'std_entities'], | |
| 'correct_n': correct_df.loc[model, 'n_trials'], | |
| 'incorrect_mean': incorrect_df.loc[model, 'mean_entities'], | |
| 'incorrect_std': incorrect_df.loc[model, 'std_entities'], | |
| 'incorrect_n': incorrect_df.loc[model, 'n_trials'], | |
| }) | |
| comp_df = pd.DataFrame(comparison_data) | |
| comp_df = comp_df.sort_values('correct_mean', ascending=True) | |
| # === Figure 1: Grouped bar chart === | |
| fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 3.0)) | |
| y = np.arange(len(comp_df)) | |
| bar_height = 0.35 | |
| # Incorrect (red) and Correct (green) bars | |
| bars_incorrect = ax.barh(y - bar_height/2, comp_df['incorrect_mean'], | |
| height=bar_height, label='Incorrect (recall=0)', | |
| color='#d62728', edgecolor='black', linewidth=0.3, alpha=0.8) | |
| bars_correct = ax.barh(y + bar_height/2, comp_df['correct_mean'], | |
| height=bar_height, label='Correct (recall>0)', | |
| color='#2ca02c', edgecolor='black', linewidth=0.3, alpha=0.8) | |
| ax.set_yticks(y) | |
| ax.set_yticklabels(comp_df['model_clean']) | |
| ax.set_xlabel('Avg. Semantic Entity Groups Investigated') | |
| # Add value labels | |
| for i, (bar_i, bar_c) in enumerate(zip(bars_incorrect, bars_correct)): | |
| # Incorrect | |
| ax.text(bar_i.get_width() + 0.1, bar_i.get_y() + bar_i.get_height()/2, | |
| f'{bar_i.get_width():.1f}', va='center', ha='left', | |
| fontsize=MIN_FONT_SIZE - 1, color='#d62728') | |
| # Correct | |
| ax.text(bar_c.get_width() + 0.1, bar_c.get_y() + bar_c.get_height()/2, | |
| f'{bar_c.get_width():.1f}', va='center', ha='left', | |
| fontsize=MIN_FONT_SIZE - 1, color='#2ca02c') | |
| ax.legend(loc='lower right', frameon=False, fontsize=MIN_FONT_SIZE) | |
| plt.tight_layout() | |
| fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.pdf") | |
| fig.savefig(OUTPUT_DIR / "fig_exploration_by_correctness.png") | |
| plt.close(fig) | |
| print(f"Saved: fig_exploration_by_correctness.pdf/png") | |
| # === Figure 2: Box plot distribution === | |
| fig2, ax2 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 1.5, 3.5)) | |
| # Prepare data for box plot | |
| df['correctness'] = df['is_correct'].map({True: 'Correct\n(recall>0)', False: 'Incorrect\n(recall=0)'}) | |
| df['model_clean'] = df['model'].apply(clean_model_name) | |
| # Order models by overall median exploration | |
| model_order = df.groupby('model_clean')['semantic_entities_investigated'].median().sort_values().index.tolist() | |
| # Create box plot with hue | |
| sns.boxplot(data=df, x='model_clean', y='semantic_entities_investigated', | |
| hue='correctness', order=model_order, ax=ax2, | |
| palette={'Correct\n(recall>0)': '#2ca02c', 'Incorrect\n(recall=0)': '#d62728'}, | |
| linewidth=0.5, fliersize=2) | |
| ax2.set_xlabel('') | |
| ax2.set_ylabel('Semantic Entity Groups Investigated') | |
| ax2.tick_params(axis='x', rotation=45) | |
| ax2.legend(title='', loc='upper left', frameon=False, fontsize=MIN_FONT_SIZE) | |
| plt.tight_layout() | |
| fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.pdf") | |
| fig2.savefig(OUTPUT_DIR / "fig_exploration_by_correctness_boxplot.png") | |
| plt.close(fig2) | |
| print(f"Saved: fig_exploration_by_correctness_boxplot.pdf/png") | |
| # === Figure 3: Aggregated across all models === | |
| fig3, ax3 = plt.subplots(figsize=(HALF_COLUMN_WIDTH * 0.8, 2.5)) | |
| correct_all = df[df['is_correct'] == True]['semantic_entities_investigated'] | |
| incorrect_all = df[df['is_correct'] == False]['semantic_entities_investigated'] | |
| # Violin plot for overall distribution | |
| parts = ax3.violinplot([incorrect_all, correct_all], positions=[0, 1], | |
| showmeans=True, showmedians=True) | |
| # Color the violins | |
| colors = ['#d62728', '#2ca02c'] | |
| for i, pc in enumerate(parts['bodies']): | |
| pc.set_facecolor(colors[i]) | |
| pc.set_alpha(0.7) | |
| # Style the other elements | |
| for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']: | |
| if partname in parts: | |
| parts[partname].set_edgecolor('black') | |
| parts[partname].set_linewidth(0.5) | |
| ax3.set_xticks([0, 1]) | |
| ax3.set_xticklabels(['Incorrect\n(recall=0)', 'Correct\n(recall>0)']) | |
| ax3.set_ylabel('Semantic Entities Investigated') | |
| # Add mean values as text | |
| ax3.text(0, incorrect_all.mean() + 0.5, f'μ={incorrect_all.mean():.1f}', | |
| ha='center', fontsize=MIN_FONT_SIZE, color='#d62728') | |
| ax3.text(1, correct_all.mean() + 0.5, f'μ={correct_all.mean():.1f}', | |
| ha='center', fontsize=MIN_FONT_SIZE, color='#2ca02c') | |
| # Add n counts | |
| ax3.text(0, ax3.get_ylim()[0] + 0.5, f'n={len(incorrect_all)}', | |
| ha='center', fontsize=MIN_FONT_SIZE - 1) | |
| ax3.text(1, ax3.get_ylim()[0] + 0.5, f'n={len(correct_all)}', | |
| ha='center', fontsize=MIN_FONT_SIZE - 1) | |
| plt.tight_layout() | |
| fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.pdf") | |
| fig3.savefig(OUTPUT_DIR / "fig_exploration_overall_correctness.png") | |
| plt.close(fig3) | |
| print(f"Saved: fig_exploration_overall_correctness.pdf/png") | |
| # Print statistics | |
| print("\n" + "=" * 60) | |
| print("Exploration Breadth by Diagnosis Correctness") | |
| print("=" * 60) | |
| print(f"\nOverall Statistics:") | |
| print(f" Correct diagnoses (n={len(correct_all)}): mean={correct_all.mean():.2f}, median={correct_all.median():.1f}") | |
| print(f" Incorrect diagnoses (n={len(incorrect_all)}): mean={incorrect_all.mean():.2f}, median={incorrect_all.median():.1f}") | |
| # Statistical test | |
| from scipy import stats | |
| stat, pvalue = stats.mannwhitneyu(correct_all, incorrect_all, alternative='two-sided') | |
| print(f"\n Mann-Whitney U test: U={stat:.0f}, p={pvalue:.4f}") | |
| print(f"\nPer-Model Comparison:") | |
| print(f"{'Model':<20} {'Correct':>12} {'Incorrect':>12} {'Diff':>8}") | |
| print("-" * 55) | |
| for _, row in comp_df.sort_values('correct_mean', ascending=False).iterrows(): | |
| diff = row['correct_mean'] - row['incorrect_mean'] | |
| print(f"{row['model_clean']:<20} {row['correct_mean']:>10.1f} (n={int(row['correct_n'])}) " | |
| f"{row['incorrect_mean']:>10.1f} (n={int(row['incorrect_n'])}) {diff:>+7.1f}") | |
| return comp_df | |
| def plot_success_by_exploration_bins(df: pd.DataFrame): | |
| """ | |
| Plot showing success rate as a function of exploration breadth. | |
| This shows a clear dose-response relationship. | |
| """ | |
| # Create exploration bins | |
| bins = [0, 2, 4, 6, 8, 10, 100] | |
| labels = ['0-2', '3-4', '5-6', '7-8', '9-10', '11+'] | |
| df['exploration_bin'] = pd.cut(df['semantic_entities_investigated'], | |
| bins=bins, labels=labels) | |
| # Calculate success rate per bin | |
| bin_stats = [] | |
| for label in labels: | |
| subset = df[df['exploration_bin'] == label] | |
| if len(subset) > 0: | |
| success_rate = (subset['root_cause_f1'] > 0).mean() * 100 | |
| bin_stats.append({ | |
| 'bin': label, | |
| 'success_rate': success_rate, | |
| 'n': len(subset) | |
| }) | |
| stats_df = pd.DataFrame(bin_stats) | |
| # Create figure | |
| fig, ax = plt.subplots(figsize=(HALF_COLUMN_WIDTH, 2.5)) | |
| x = np.arange(len(stats_df)) | |
| bars = ax.bar(x, stats_df['success_rate'], | |
| color='#4a90d9', edgecolor='black', linewidth=0.5) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(stats_df['bin']) | |
| ax.set_xlabel('Semantic Entities Investigated') | |
| ax.set_ylabel('Correct Diagnosis Rate (%)') | |
| # Add value labels on bars | |
| for i, (bar, row) in enumerate(zip(bars, stats_df.itertuples())): | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2, height + 1, | |
| f'{height:.0f}%', ha='center', va='bottom', | |
| fontsize=MIN_FONT_SIZE) | |
| ax.text(bar.get_x() + bar.get_width()/2, 2, | |
| f'n={row.n}', ha='center', va='bottom', | |
| fontsize=MIN_FONT_SIZE - 1, color='white') | |
| ax.set_ylim(0, 60) | |
| plt.tight_layout() | |
| fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.pdf") | |
| fig.savefig(OUTPUT_DIR / "fig_exploration_success_rate.png") | |
| plt.close(fig) | |
| print(f"Saved: fig_exploration_success_rate.pdf/png") | |
| # Also create a combined figure with both views | |
| fig2, axes = plt.subplots(1, 2, figsize=(HALF_COLUMN_WIDTH * 2 + 0.3, 2.5)) | |
| # Left: Success rate by exploration bins | |
| ax1 = axes[0] | |
| bars1 = ax1.bar(x, stats_df['success_rate'], | |
| color='#4a90d9', edgecolor='black', linewidth=0.5) | |
| ax1.set_xticks(x) | |
| ax1.set_xticklabels(stats_df['bin']) | |
| ax1.set_xlabel('Entities Investigated') | |
| ax1.set_ylabel('Correct Diagnosis Rate (%)') | |
| ax1.set_title('(a) Success vs Exploration', fontsize=MIN_FONT_SIZE + 1) | |
| for bar, row in zip(bars1, stats_df.itertuples()): | |
| ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, | |
| f'{bar.get_height():.0f}%', ha='center', va='bottom', | |
| fontsize=MIN_FONT_SIZE - 1) | |
| ax1.set_ylim(0, 60) | |
| # Right: Exploration distribution by correctness (violin) | |
| ax2 = axes[1] | |
| correct = df[df['is_correct'] == True]['semantic_entities_investigated'] | |
| incorrect = df[df['is_correct'] == False]['semantic_entities_investigated'] | |
| parts = ax2.violinplot([incorrect, correct], positions=[0, 1], | |
| showmeans=True, showmedians=True) | |
| colors = ['#d62728', '#2ca02c'] | |
| for i, pc in enumerate(parts['bodies']): | |
| pc.set_facecolor(colors[i]) | |
| pc.set_alpha(0.7) | |
| for partname in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']: | |
| if partname in parts: | |
| parts[partname].set_edgecolor('black') | |
| parts[partname].set_linewidth(0.5) | |
| ax2.set_xticks([0, 1]) | |
| ax2.set_xticklabels(['Incorrect', 'Correct']) | |
| ax2.set_ylabel('Entities Investigated') | |
| ax2.set_title('(b) Exploration by Outcome', fontsize=MIN_FONT_SIZE + 1) | |
| ax2.text(0, incorrect.mean() + 1, f'μ={incorrect.mean():.1f}', | |
| ha='center', fontsize=MIN_FONT_SIZE - 1, color='#d62728') | |
| ax2.text(1, correct.mean() + 1, f'μ={correct.mean():.1f}', | |
| ha='center', fontsize=MIN_FONT_SIZE - 1, color='#2ca02c') | |
| plt.tight_layout() | |
| fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.pdf") | |
| fig2.savefig(OUTPUT_DIR / "fig_exploration_combined.png") | |
| plt.close(fig2) | |
| print(f"Saved: fig_exploration_combined.pdf/png") | |
| def main(): | |
| print("=" * 60) | |
| print("Exploration Breadth by Diagnosis Correctness Analysis") | |
| print("=" * 60) | |
| # Check if we can use cached data or need to re-extract | |
| cache_path = OUTPUT_DIR / "exploration_by_correctness.csv" | |
| if cache_path.exists(): | |
| print(f"\nLoading cached data from {cache_path}") | |
| df = pd.read_csv(cache_path) | |
| else: | |
| print("\nExtracting data from rollout files (this may take a while)...") | |
| df = analyze_all_trials() | |
| df.to_csv(cache_path, index=False) | |
| print(f"Saved cache to: {cache_path}") | |
| print(f"\nLoaded {len(df)} trials from {df['model'].nunique()} models") | |
| # Generate plots | |
| print("\nGenerating figures...") | |
| plot_exploration_by_correctness(df) | |
| plot_success_by_exploration_bins(df) # NEW: dose-response plot | |
| print(f"\nDone! Figures saved to: {OUTPUT_DIR}") | |
| if __name__ == "__main__": | |
| main() | |