import argparse import json import os import shutil import subprocess import sys # Keep optional TensorFlow and advisory logs out of experiment output. os.environ.setdefault("USE_TF", "0") os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3") os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0") os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") os.environ.setdefault("ACCELERATE_LOG_LEVEL", "error") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") os.environ.setdefault("VLLM_LOGGING_LEVEL", "WARNING") from copy import deepcopy from pathlib import Path import pandas as pd import yaml from transformers import set_seed from src.controllers import apply_decision, choose_decision, validate_decision from src.data import build_countdown_dataset from src.evaluate import evaluate_checkpoint from src.logging_utils import aggregate_train_logs from src.train_stage import train_continuous, train_stage def run_stage_worker(request_path): request_path = Path(request_path) request = json.loads(request_path.read_text()) config = request["config"] dataset = build_countdown_dataset( config["train_size"], config["seed"], config.get("num_numbers", 4), config.get("disable_thinking", False), ) checkpoint, metrics = train_stage( request["base_model_name"], request.get("previous_adapter"), dataset, config, request["stage_dir"], ) request_path.with_name("stage_result.json").write_text(json.dumps({ "checkpoint": checkpoint, "train_metrics": metrics, }, indent=2)) def train_stage_in_subprocess(base_model_name, previous_adapter, config, stage_dir): stage_dir = Path(stage_dir).resolve() request_path = stage_dir / "stage_request.json" result_path = stage_dir / "stage_result.json" request_path.write_text(json.dumps({ "base_model_name": base_model_name, "previous_adapter": previous_adapter, "config": config, "stage_dir": str(stage_dir), }, indent=2)) subprocess.run( [sys.executable, "-m", "src.run_experiment", "--stage-worker", str(request_path)], check=True, ) result = json.loads(result_path.read_text()) return result["checkpoint"], result["train_metrics"] def run_experiment(method, config, run_name, runs_dir="runs", overwrite=False): os.environ["WANDB_PROJECT"] = config.get("wandb_project", "llm-zero-lite") os.environ["WANDB_RUN_GROUP"] = run_name run_dir = Path(runs_dir) / run_name if run_dir.exists(): if not overwrite: raise FileExistsError(f"run already exists: {run_dir}. Pass --overwrite to replace it.") shutil.rmtree(run_dir) run_dir.mkdir(parents=True) set_seed(config["seed"]) train_data = build_countdown_dataset( config["train_size"], config["seed"], config.get("num_numbers", 4), config.get("disable_thinking", False), ) eval_data = build_countdown_dataset( config["eval_size"], config["seed"] + 10_000, config.get("num_numbers", 4), config.get("disable_thinking", False), ) (run_dir / "experiment_config.json").write_text(json.dumps(config, indent=2)) model_short = config["model_name"].split("/")[-1].lower().replace(".", "p") baseline_dir = Path(runs_dir) / ( f"_base_eval_{model_short}_n{config['num_numbers']}_eval{config['eval_size']}_" f"k{config.get('eval_num_samples', 4)}_seed{config['seed']}" ) baseline_metrics_path = baseline_dir / "metrics.json" if baseline_metrics_path.exists(): baseline_eval = json.loads(baseline_metrics_path.read_text()) else: baseline_dir.mkdir(parents=True, exist_ok=True) baseline_eval = evaluate_checkpoint( config["model_name"], None, eval_data, config, baseline_dir / "eval_samples.jsonl" ) baseline_metrics_path.write_text(json.dumps(baseline_eval, indent=2)) baseline_metrics = { "method": method, "run_name": run_name, "stage": -1, "global_train_steps": 0, **baseline_eval, **{key: config[key] for key in [ "learning_rate", "beta", "temperature", "max_completion_length", "num_generations" ]}, } (run_dir / "baseline_metrics.json").write_text(json.dumps(baseline_metrics, indent=2)) # Cached and uncached baseline evaluation must leave training with identical RNG state. set_seed(config["seed"]) history = [baseline_metrics] stage_config = deepcopy(config) previous_adapter = best_adapter = None best_accuracy = baseline_metrics["eval_accuracy"] if method == "continuous_grpo": checkpoints, log_path, total_wall_clock = train_continuous( config["model_name"], train_data, config, run_dir ) for stage, checkpoint in enumerate(checkpoints): stage_dir = run_dir / f"stage_{stage}" stage_dir.mkdir(exist_ok=True) (stage_dir / "config.json").write_text(json.dumps(config, indent=2)) eval_metrics = evaluate_checkpoint( config["model_name"], checkpoint, eval_data, config, stage_dir / "eval_samples.jsonl" ) step_max = (stage + 1) * config["steps_per_stage"] train_metrics = aggregate_train_logs( log_path, step_min=stage * config["steps_per_stage"], step_max=step_max ) metrics = { "method": method, "run_name": run_name, "stage": stage, "global_train_steps": step_max, **train_metrics, **eval_metrics, "wall_clock_seconds": total_wall_clock, **{key: config[key] for key in [ "learning_rate", "beta", "temperature", "max_completion_length", "num_generations" ]}, } (stage_dir / "metrics.json").write_text(json.dumps(metrics, indent=2)) history.append(metrics) pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False) print(json.dumps(metrics, indent=2)) return for stage in range(config["num_stages"]): stage_dir = run_dir / f"stage_{stage}" stage_dir.mkdir() (stage_dir / "config.json").write_text(json.dumps(stage_config, indent=2)) print(f"\n=== {run_name}: stage {stage} ===") checkpoint, train_metrics = train_stage_in_subprocess( config["model_name"], previous_adapter, stage_config, stage_dir ) eval_metrics = evaluate_checkpoint( config["model_name"], checkpoint, eval_data, stage_config, stage_dir / "eval_samples.jsonl" ) metrics = { "method": method, "run_name": run_name, "stage": stage, "global_train_steps": (stage + 1) * config["steps_per_stage"], **train_metrics, **eval_metrics, **{key: stage_config[key] for key in [ "learning_rate", "beta", "temperature", "max_completion_length", "num_generations" ]}, } (stage_dir / "metrics.json").write_text(json.dumps(metrics, indent=2)) history.append(metrics) pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False) print(json.dumps(metrics, indent=2)) if metrics["eval_accuracy"] > best_accuracy: best_accuracy, best_adapter = metrics["eval_accuracy"], checkpoint if stage == config["num_stages"] - 1: break decision = validate_decision(choose_decision(method, stage_config, metrics, history)) (stage_dir / "decision.json").write_text(json.dumps(decision, indent=2)) if decision["early_stop"]: break previous_adapter = best_adapter if decision["rollback_to_best_checkpoint"] else checkpoint stage_config = apply_decision(stage_config, decision) def main(): parser = argparse.ArgumentParser() parser.add_argument("--config") parser.add_argument("--method", choices=["continuous_grpo", "fixed_grpo", "rule_controller", "llm_controller"]) parser.add_argument("--run-name") parser.add_argument("--runs-dir", default="runs") parser.add_argument("--overwrite", action="store_true") parser.add_argument("--stage-worker") args = parser.parse_args() if args.stage_worker: run_stage_worker(args.stage_worker) return if not args.config or not args.method or not args.run_name: parser.error("--config, --method, and --run-name are required") with open(args.config) as file: config = yaml.safe_load(file) run_experiment(args.method, config, args.run_name, args.runs_dir, args.overwrite) if __name__ == "__main__": main()