Spaces:
Running
Running
| """ | |
| Multi-seed training wrapper for LexiMind. | |
| Runs training across multiple seeds and aggregates results with mean ± std. | |
| This addresses the single-seed limitation identified in review feedback. | |
| Usage: | |
| python scripts/train_multiseed.py --seeds 17 42 123 --config training=full | |
| python scripts/train_multiseed.py --seeds 17 42 123 456 789 --config training=medium | |
| Author: Oliver Perrin | |
| Date: February 2026 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import numpy as np | |
| def run_single_seed(seed: int, config_overrides: str, base_dir: Path) -> Dict: | |
| """Run training for a single seed and return the training history.""" | |
| seed_dir = base_dir / f"seed_{seed}" | |
| seed_dir.mkdir(parents=True, exist_ok=True) | |
| cmd = [ | |
| sys.executable, | |
| "scripts/train.py", | |
| f"seed={seed}", | |
| f"checkpoint_out={seed_dir}/checkpoints/best.pt", | |
| f"history_out={seed_dir}/training_history.json", | |
| f"labels_out={seed_dir}/labels.json", | |
| ] | |
| if config_overrides: | |
| cmd.extend(config_overrides.split()) | |
| print(f"\n{'=' * 60}") | |
| print(f"Training seed {seed}") | |
| print(f"{'=' * 60}") | |
| print(f" Command: {' '.join(cmd)}") | |
| result = subprocess.run(cmd, capture_output=False) | |
| if result.returncode != 0: | |
| print(f" WARNING: Seed {seed} training failed (exit code {result.returncode})") | |
| return {} | |
| history_path = seed_dir / "training_history.json" | |
| if history_path.exists(): | |
| with open(history_path) as f: | |
| data: Dict = json.load(f) # type: ignore[no-any-return] | |
| return data | |
| return {} | |
| def run_evaluation(seed: int, base_dir: Path, extra_args: List[str] | None = None) -> Dict: | |
| """Run evaluation for a single seed and return results.""" | |
| seed_dir = base_dir / f"seed_{seed}" | |
| checkpoint = seed_dir / "checkpoints" / "best.pt" | |
| labels = seed_dir / "labels.json" | |
| output = seed_dir / "evaluation_report.json" | |
| if not checkpoint.exists(): | |
| print(f" Skipping eval for seed {seed}: no checkpoint found") | |
| return {} | |
| cmd = [ | |
| sys.executable, | |
| "scripts/evaluate.py", | |
| f"--checkpoint={checkpoint}", | |
| f"--labels={labels}", | |
| f"--output={output}", | |
| "--skip-bertscore", | |
| "--tune-thresholds", | |
| "--bootstrap", | |
| ] | |
| if extra_args: | |
| cmd.extend(extra_args) | |
| print(f"\n Evaluating seed {seed}...") | |
| result = subprocess.run(cmd, capture_output=False) | |
| if result.returncode != 0: | |
| print(f" WARNING: Seed {seed} evaluation failed") | |
| return {} | |
| if output.exists(): | |
| with open(output) as f: | |
| data: Dict = json.load(f) # type: ignore[no-any-return] | |
| return data | |
| return {} | |
| def aggregate_results(all_results: Dict[int, Dict]) -> Dict: | |
| """Aggregate evaluation results across seeds with mean ± std.""" | |
| if not all_results: | |
| return {} | |
| # Collect all metric paths | |
| metric_values: Dict[str, List[float]] = {} | |
| for _seed, results in all_results.items(): | |
| for task, task_metrics in results.items(): | |
| if not isinstance(task_metrics, dict): | |
| continue | |
| for metric_name, value in task_metrics.items(): | |
| if ( | |
| isinstance(value, (int, float)) | |
| and metric_name != "num_samples" | |
| and metric_name != "num_classes" | |
| ): | |
| key = f"{task}/{metric_name}" | |
| metric_values.setdefault(key, []).append(float(value)) | |
| aggregated: Dict[str, Dict[str, float]] = {} | |
| for key, values in sorted(metric_values.items()): | |
| arr = np.array(values) | |
| aggregated[key] = { | |
| "mean": float(arr.mean()), | |
| "std": float(arr.std()), | |
| "min": float(arr.min()), | |
| "max": float(arr.max()), | |
| "n_seeds": len(values), | |
| } | |
| return aggregated | |
| def print_summary(aggregated: Dict, seeds: List[int]) -> None: | |
| """Print human-readable summary of multi-seed results.""" | |
| print(f"\n{'=' * 70}") | |
| print(f"MULTI-SEED RESULTS SUMMARY ({len(seeds)} seeds: {seeds})") | |
| print(f"{'=' * 70}") | |
| # Group by task | |
| tasks: Dict[str, Dict[str, Dict]] = {} | |
| for key, stats in aggregated.items(): | |
| task, metric = key.split("/", 1) | |
| tasks.setdefault(task, {})[metric] = stats | |
| for task, metrics in sorted(tasks.items()): | |
| print(f"\n {task.upper()}:") | |
| for metric, stats in sorted(metrics.items()): | |
| mean = stats["mean"] | |
| std = stats["std"] | |
| # Format based on metric type | |
| if "accuracy" in metric: | |
| print(f" {metric:25s}: {mean * 100:.1f}% ± {std * 100:.1f}%") | |
| else: | |
| print(f" {metric:25s}: {mean:.4f} ± {std:.4f}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Multi-seed training for LexiMind") | |
| parser.add_argument( | |
| "--seeds", nargs="+", type=int, default=[17, 42, 123], help="Random seeds to train with" | |
| ) | |
| parser.add_argument( | |
| "--config", type=str, default="", help="Hydra config overrides (e.g., 'training=full')" | |
| ) | |
| parser.add_argument( | |
| "--output-dir", type=Path, default=Path("outputs/multiseed"), help="Base output directory" | |
| ) | |
| parser.add_argument( | |
| "--skip-training", | |
| action="store_true", | |
| help="Skip training, only aggregate existing results", | |
| ) | |
| parser.add_argument( | |
| "--skip-eval", | |
| action="store_true", | |
| help="Skip evaluation, only aggregate training histories", | |
| ) | |
| args = parser.parse_args() | |
| args.output_dir.mkdir(parents=True, exist_ok=True) | |
| # Training phase | |
| if not args.skip_training: | |
| for seed in args.seeds: | |
| run_single_seed(seed, args.config, args.output_dir) | |
| # Evaluation phase | |
| all_eval_results: Dict[int, Dict] = {} | |
| if not args.skip_eval: | |
| for seed in args.seeds: | |
| result = run_evaluation(seed, args.output_dir) | |
| if result: | |
| all_eval_results[seed] = result | |
| # Aggregate and save | |
| if all_eval_results: | |
| aggregated = aggregate_results(all_eval_results) | |
| print_summary(aggregated, args.seeds) | |
| # Save aggregated results | |
| output_path = args.output_dir / "aggregated_results.json" | |
| with open(output_path, "w") as f: | |
| json.dump( | |
| { | |
| "seeds": args.seeds, | |
| "per_seed": {str(k): v for k, v in all_eval_results.items()}, | |
| "aggregated": aggregated, | |
| }, | |
| f, | |
| indent=2, | |
| ) | |
| print(f"\n Saved to: {output_path}") | |
| else: | |
| print("\nNo evaluation results to aggregate.") | |
| if __name__ == "__main__": | |
| main() | |