| | """ |
| | Ablation Runner Module |
| | |
| | Runs ablation studies to definitively answer: |
| | "Is Council-Lite's benefit from structure or just more prompting/tokens?" |
| | |
| | Four experimental conditions: |
| | 1. Direct: Raw prompt → generators (no planning) |
| | 2. Single Planner: 1 LLM call → plan → generators |
| | 3. Council-Lite: 3 LLM calls → merge → generators |
| | 4. Extended Prompt: 1 LLM call with 3× token budget (controls for compute) |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | import time |
| | from dataclasses import dataclass, field |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional |
| | from collections import defaultdict |
| |
|
| | from src.planner.single_planner import SinglePlanner, PlannerMetrics |
| | from src.planner.extended_prompt_planner import ExtendedPromptPlanner |
| | from src.planner.unified_planner import UnifiedPlanner |
| | from src.planner.schema import SemanticPlan |
| | from src.utils.seed import set_global_seed |
| |
|
| |
|
| | @dataclass |
| | class AblationCondition: |
| | """Definition of an ablation condition.""" |
| | name: str |
| | description: str |
| | planner_class: Optional[type] |
| | expected_llm_calls: int |
| | token_multiplier: float = 1.0 |
| |
|
| | def create_planner(self, **kwargs): |
| | """Create planner instance for this condition.""" |
| | if self.planner_class is None: |
| | return None |
| | return self.planner_class(**kwargs) |
| |
|
| |
|
| | |
| | ABLATION_CONDITIONS = { |
| | "direct": AblationCondition( |
| | name="direct", |
| | description="Raw prompt → generators (no planning)", |
| | planner_class=None, |
| | expected_llm_calls=0, |
| | token_multiplier=0.0, |
| | ), |
| | "single_planner": AblationCondition( |
| | name="single_planner", |
| | description="1 LLM call → plan → generators", |
| | planner_class=SinglePlanner, |
| | expected_llm_calls=1, |
| | token_multiplier=1.0, |
| | ), |
| | "council": AblationCondition( |
| | name="council", |
| | description="3 LLM calls (council) → merge → generators", |
| | planner_class=UnifiedPlanner, |
| | expected_llm_calls=3, |
| | token_multiplier=3.0, |
| | ), |
| | "extended_prompt": AblationCondition( |
| | name="extended_prompt", |
| | description="1 LLM call with 3× token budget", |
| | planner_class=ExtendedPromptPlanner, |
| | expected_llm_calls=1, |
| | token_multiplier=3.0, |
| | ), |
| | } |
| |
|
| |
|
| | @dataclass |
| | class AblationResult: |
| | """Result from a single ablation run.""" |
| | condition: str |
| | prompt: str |
| | seed: int |
| | success: bool |
| | msci: Optional[float] = None |
| | st_i: Optional[float] = None |
| | st_a: Optional[float] = None |
| | si_a: Optional[float] = None |
| | planner_metrics: Optional[Dict[str, Any]] = None |
| | generation_time_ms: float = 0.0 |
| | error: Optional[str] = None |
| | run_id: str = "" |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | """Convert to dictionary.""" |
| | return { |
| | "condition": self.condition, |
| | "prompt": self.prompt, |
| | "seed": self.seed, |
| | "success": self.success, |
| | "msci": self.msci, |
| | "st_i": self.st_i, |
| | "st_a": self.st_a, |
| | "si_a": self.si_a, |
| | "planner_metrics": self.planner_metrics, |
| | "generation_time_ms": self.generation_time_ms, |
| | "error": self.error, |
| | "run_id": self.run_id, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class AblationStudyConfig: |
| | """Configuration for ablation study.""" |
| | name: str = "council_lite_ablation" |
| | conditions: List[str] = field(default_factory=lambda: list(ABLATION_CONDITIONS.keys())) |
| | n_prompts: int = 50 |
| | n_seeds: int = 3 |
| | base_seed: int = 42 |
| | output_dir: str = "runs/ablation_study" |
| | use_ollama: bool = True |
| | deterministic: bool = True |
| |
|
| | @property |
| | def total_runs(self) -> int: |
| | """Total number of runs.""" |
| | return self.n_prompts * self.n_seeds * len(self.conditions) |
| |
|
| |
|
| | class AblationRunner: |
| | """ |
| | Runs ablation studies across planning conditions. |
| | |
| | Key controls: |
| | - Same prompts across all conditions |
| | - Same seeds for reproducibility |
| | - Token budget tracking |
| | - Compute time tracking |
| | """ |
| |
|
| | def __init__(self, config: AblationStudyConfig): |
| | self.config = config |
| | self.results: List[AblationResult] = [] |
| | self.results_by_condition: Dict[str, List[AblationResult]] = defaultdict(list) |
| |
|
| | def run_single( |
| | self, |
| | prompt: str, |
| | condition: str, |
| | seed: int, |
| | output_dir: Path, |
| | ) -> AblationResult: |
| | """ |
| | Run a single ablation condition. |
| | |
| | Args: |
| | prompt: Input prompt |
| | condition: Condition name from ABLATION_CONDITIONS |
| | seed: Random seed |
| | output_dir: Output directory for this run |
| | |
| | Returns: |
| | AblationResult |
| | """ |
| | from src.pipeline.generate_and_evaluate import generate_and_evaluate |
| |
|
| | if condition not in ABLATION_CONDITIONS: |
| | raise ValueError(f"Unknown condition: {condition}") |
| |
|
| | cond_def = ABLATION_CONDITIONS[condition] |
| |
|
| | if self.config.deterministic: |
| | set_global_seed(seed) |
| |
|
| | start_time = time.time() |
| | planner_metrics = None |
| |
|
| | try: |
| | |
| | mode = "direct" if condition == "direct" else "planner" |
| |
|
| | |
| | if cond_def.planner_class: |
| | planner = cond_def.create_planner() |
| | |
| | plan = planner.plan(prompt) |
| | if hasattr(planner, 'get_metrics'): |
| | metrics = planner.get_metrics() |
| | if metrics: |
| | planner_metrics = metrics.to_dict() |
| |
|
| | |
| | bundle = generate_and_evaluate( |
| | prompt=prompt, |
| | out_dir=str(output_dir), |
| | use_ollama=self.config.use_ollama, |
| | deterministic=self.config.deterministic, |
| | seed=seed, |
| | mode=mode, |
| | condition="baseline", |
| | ) |
| |
|
| | end_time = time.time() |
| |
|
| | return AblationResult( |
| | condition=condition, |
| | prompt=prompt, |
| | seed=seed, |
| | success=True, |
| | msci=bundle.scores.get("msci"), |
| | st_i=bundle.scores.get("st_i"), |
| | st_a=bundle.scores.get("st_a"), |
| | si_a=bundle.scores.get("si_a"), |
| | planner_metrics=planner_metrics, |
| | generation_time_ms=(end_time - start_time) * 1000, |
| | run_id=bundle.run_id, |
| | ) |
| |
|
| | except Exception as e: |
| | end_time = time.time() |
| | return AblationResult( |
| | condition=condition, |
| | prompt=prompt, |
| | seed=seed, |
| | success=False, |
| | error=str(e), |
| | generation_time_ms=(end_time - start_time) * 1000, |
| | ) |
| |
|
| | def run_study( |
| | self, |
| | prompts: List[str], |
| | ) -> Dict[str, Any]: |
| | """ |
| | Run complete ablation study. |
| | |
| | Args: |
| | prompts: List of prompts to test |
| | |
| | Returns: |
| | Complete study results |
| | """ |
| | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| | output_base = Path(self.config.output_dir) / f"{self.config.name}_{timestamp}" |
| | output_base.mkdir(parents=True, exist_ok=True) |
| |
|
| | print(f"\n{'=' * 70}") |
| | print(f"ABLATION STUDY: {self.config.name}") |
| | print(f"{'=' * 70}") |
| | print(f"Conditions: {self.config.conditions}") |
| | print(f"Prompts: {len(prompts)}") |
| | print(f"Seeds per prompt: {self.config.n_seeds}") |
| | print(f"Total runs: {self.config.total_runs}") |
| | print(f"Output: {output_base}") |
| | print(f"{'=' * 70}\n") |
| |
|
| | |
| | for prompt_idx, prompt in enumerate(prompts[:self.config.n_prompts]): |
| | print(f"\nPrompt {prompt_idx + 1}/{self.config.n_prompts}: {prompt[:50]}...") |
| |
|
| | for seed_offset in range(self.config.n_seeds): |
| | seed = self.config.base_seed + prompt_idx * 100 + seed_offset |
| |
|
| | for condition in self.config.conditions: |
| | print(f" [{condition}] seed={seed}...", end=" ") |
| |
|
| | result = self.run_single( |
| | prompt=prompt, |
| | condition=condition, |
| | seed=seed, |
| | output_dir=output_base / condition / f"prompt_{prompt_idx}_seed_{seed}", |
| | ) |
| |
|
| | self.results.append(result) |
| | self.results_by_condition[condition].append(result) |
| |
|
| | if result.success: |
| | print(f"MSCI={result.msci:.4f}") |
| | else: |
| | print(f"ERROR: {result.error[:40] if result.error else 'Unknown'}") |
| |
|
| | |
| | report = self._generate_report(timestamp, prompts) |
| |
|
| | |
| | results_path = output_base / "ablation_results.json" |
| | with results_path.open("w", encoding="utf-8") as f: |
| | json.dump(report, f, indent=2, ensure_ascii=False, default=str) |
| |
|
| | print(f"\nResults saved to: {results_path}") |
| | self._print_summary(report) |
| |
|
| | return report |
| |
|
| | def _generate_report( |
| | self, |
| | timestamp: str, |
| | prompts: List[str], |
| | ) -> Dict[str, Any]: |
| | """Generate comprehensive ablation report.""" |
| | import numpy as np |
| | from src.experiments.statistical_analysis import ( |
| | paired_ttest, |
| | compare_all_pairs, |
| | descriptive_stats, |
| | ) |
| |
|
| | |
| | condition_stats = {} |
| | msci_by_condition = {} |
| |
|
| | for condition, results in self.results_by_condition.items(): |
| | successful = [r for r in results if r.success] |
| | msci_scores = [r.msci for r in successful if r.msci is not None] |
| |
|
| | if msci_scores: |
| | msci_by_condition[condition] = msci_scores |
| | condition_stats[condition] = { |
| | "n_total": len(results), |
| | "n_successful": len(successful), |
| | "success_rate": len(successful) / len(results), |
| | "msci": descriptive_stats(msci_scores), |
| | "mean_time_ms": np.mean([r.generation_time_ms for r in successful]), |
| | } |
| |
|
| | |
| | token_results = [r for r in successful if r.planner_metrics] |
| | if token_results: |
| | total_tokens = [r.planner_metrics["total_tokens"] for r in token_results] |
| | condition_stats[condition]["mean_tokens"] = np.mean(total_tokens) |
| |
|
| | |
| | statistical_tests = {} |
| |
|
| | if len(msci_by_condition) >= 2: |
| | |
| | min_len = min(len(v) for v in msci_by_condition.values()) |
| |
|
| | if min_len >= 2: |
| | |
| | truncated = {k: v[:min_len] for k, v in msci_by_condition.items()} |
| | comparisons = compare_all_pairs(truncated, paired=True) |
| |
|
| | for key, result in comparisons.items(): |
| | statistical_tests[key] = result.to_dict() |
| |
|
| | |
| | ablation_analysis = self._analyze_ablation(msci_by_condition, condition_stats) |
| |
|
| | return { |
| | "config": { |
| | "name": self.config.name, |
| | "conditions": self.config.conditions, |
| | "n_prompts": self.config.n_prompts, |
| | "n_seeds": self.config.n_seeds, |
| | "base_seed": self.config.base_seed, |
| | }, |
| | "timestamp": timestamp, |
| | "n_prompts": len(prompts), |
| | "total_runs": len(self.results), |
| | "successful_runs": sum(1 for r in self.results if r.success), |
| | "condition_statistics": condition_stats, |
| | "statistical_tests": statistical_tests, |
| | "ablation_analysis": ablation_analysis, |
| | "raw_results": [r.to_dict() for r in self.results], |
| | } |
| |
|
| | def _analyze_ablation( |
| | self, |
| | msci_by_condition: Dict[str, List[float]], |
| | condition_stats: Dict[str, Dict], |
| | ) -> Dict[str, Any]: |
| | """ |
| | Perform ablation-specific analysis. |
| | |
| | Key questions: |
| | 1. Does single_planner improve over direct? |
| | 2. Does council improve over single_planner? |
| | 3. Does extended_prompt match council? (controls for tokens) |
| | """ |
| | import numpy as np |
| |
|
| | analysis = { |
| | "research_questions": {}, |
| | "conclusions": [], |
| | } |
| |
|
| | |
| | if "direct" in msci_by_condition and "single_planner" in msci_by_condition: |
| | direct_mean = np.mean(msci_by_condition["direct"]) |
| | single_mean = np.mean(msci_by_condition["single_planner"]) |
| | diff = single_mean - direct_mean |
| |
|
| | analysis["research_questions"]["planning_effect"] = { |
| | "comparison": "single_planner vs direct", |
| | "direct_mean": direct_mean, |
| | "single_planner_mean": single_mean, |
| | "difference": diff, |
| | "interpretation": "Planning improves MSCI" if diff > 0 else "No planning benefit", |
| | } |
| |
|
| | |
| | if "single_planner" in msci_by_condition and "council" in msci_by_condition: |
| | single_mean = np.mean(msci_by_condition["single_planner"]) |
| | council_mean = np.mean(msci_by_condition["council"]) |
| | diff = council_mean - single_mean |
| |
|
| | analysis["research_questions"]["council_structure"] = { |
| | "comparison": "council vs single_planner", |
| | "single_planner_mean": single_mean, |
| | "council_mean": council_mean, |
| | "difference": diff, |
| | "interpretation": "Multi-agent structure helps" if diff > 0 else "No structural benefit", |
| | } |
| |
|
| | |
| | if "extended_prompt" in msci_by_condition and "council" in msci_by_condition: |
| | extended_mean = np.mean(msci_by_condition["extended_prompt"]) |
| | council_mean = np.mean(msci_by_condition["council"]) |
| | diff = council_mean - extended_mean |
| |
|
| | analysis["research_questions"]["token_control"] = { |
| | "comparison": "council vs extended_prompt (same token budget)", |
| | "extended_prompt_mean": extended_mean, |
| | "council_mean": council_mean, |
| | "difference": diff, |
| | "interpretation": ( |
| | "Council benefit is from STRUCTURE (not just more tokens)" |
| | if diff > 0.01 else |
| | "Council benefit is from TOKENS (not structure)" |
| | if diff < -0.01 else |
| | "Council and extended_prompt are equivalent" |
| | ), |
| | } |
| |
|
| | |
| | if "token_control" in analysis["research_questions"]: |
| | tc = analysis["research_questions"]["token_control"] |
| | if tc["difference"] > 0.01: |
| | analysis["conclusions"].append( |
| | "Council-Lite's benefit comes from its multi-agent STRUCTURE, " |
| | "not just the increased token budget." |
| | ) |
| | elif tc["difference"] < -0.01: |
| | analysis["conclusions"].append( |
| | "Council-Lite's benefit is primarily from using more TOKENS. " |
| | "The multi-agent structure provides no additional benefit." |
| | ) |
| | else: |
| | analysis["conclusions"].append( |
| | "Council-Lite and extended single prompting produce equivalent results. " |
| | "The benefit is likely from increased compute/tokens." |
| | ) |
| |
|
| | return analysis |
| |
|
| | def _print_summary(self, report: Dict[str, Any]): |
| | """Print formatted summary.""" |
| | print(f"\n{'=' * 70}") |
| | print("ABLATION STUDY SUMMARY") |
| | print(f"{'=' * 70}") |
| |
|
| | stats = report.get("condition_statistics", {}) |
| | print("\nConditions ranked by mean MSCI:") |
| | sorted_stats = sorted(stats.items(), key=lambda x: x[1].get("msci", {}).get("mean", 0), reverse=True) |
| | for i, (cond, s) in enumerate(sorted_stats, 1): |
| | msci = s.get("msci", {}) |
| | tokens = s.get("mean_tokens", "N/A") |
| | print(f" {i}. {cond}: MSCI={msci.get('mean', 0):.4f}±{msci.get('std', 0):.4f}, tokens={tokens}") |
| |
|
| | |
| | ablation = report.get("ablation_analysis", {}) |
| | conclusions = ablation.get("conclusions", []) |
| | if conclusions: |
| | print("\n--- KEY FINDINGS ---") |
| | for conclusion in conclusions: |
| | print(f" • {conclusion}") |
| |
|
| | print(f"\n{'=' * 70}") |
| |
|