llm-zero-lite-experiments / src /run_experiment.py
kishan51's picture
Add files using upload-large-folder tool
4f99f73 verified
Raw
History Blame Contribute Delete
8.86 kB
import argparse
import json
import os
import shutil
import subprocess
import sys
# Keep optional TensorFlow and advisory logs out of experiment output.
os.environ.setdefault("USE_TF", "0")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
os.environ.setdefault("ACCELERATE_LOG_LEVEL", "error")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("VLLM_LOGGING_LEVEL", "WARNING")
from copy import deepcopy
from pathlib import Path
import pandas as pd
import yaml
from transformers import set_seed
from src.controllers import apply_decision, choose_decision, validate_decision
from src.data import build_countdown_dataset
from src.evaluate import evaluate_checkpoint
from src.logging_utils import aggregate_train_logs
from src.train_stage import train_continuous, train_stage
def run_stage_worker(request_path):
request_path = Path(request_path)
request = json.loads(request_path.read_text())
config = request["config"]
dataset = build_countdown_dataset(
config["train_size"],
config["seed"],
config.get("num_numbers", 4),
config.get("disable_thinking", False),
)
checkpoint, metrics = train_stage(
request["base_model_name"],
request.get("previous_adapter"),
dataset,
config,
request["stage_dir"],
)
request_path.with_name("stage_result.json").write_text(json.dumps({
"checkpoint": checkpoint,
"train_metrics": metrics,
}, indent=2))
def train_stage_in_subprocess(base_model_name, previous_adapter, config, stage_dir):
stage_dir = Path(stage_dir).resolve()
request_path = stage_dir / "stage_request.json"
result_path = stage_dir / "stage_result.json"
request_path.write_text(json.dumps({
"base_model_name": base_model_name,
"previous_adapter": previous_adapter,
"config": config,
"stage_dir": str(stage_dir),
}, indent=2))
subprocess.run(
[sys.executable, "-m", "src.run_experiment", "--stage-worker", str(request_path)],
check=True,
)
result = json.loads(result_path.read_text())
return result["checkpoint"], result["train_metrics"]
def run_experiment(method, config, run_name, runs_dir="runs", overwrite=False):
os.environ["WANDB_PROJECT"] = config.get("wandb_project", "llm-zero-lite")
os.environ["WANDB_RUN_GROUP"] = run_name
run_dir = Path(runs_dir) / run_name
if run_dir.exists():
if not overwrite:
raise FileExistsError(f"run already exists: {run_dir}. Pass --overwrite to replace it.")
shutil.rmtree(run_dir)
run_dir.mkdir(parents=True)
set_seed(config["seed"])
train_data = build_countdown_dataset(
config["train_size"],
config["seed"],
config.get("num_numbers", 4),
config.get("disable_thinking", False),
)
eval_data = build_countdown_dataset(
config["eval_size"],
config["seed"] + 10_000,
config.get("num_numbers", 4),
config.get("disable_thinking", False),
)
(run_dir / "experiment_config.json").write_text(json.dumps(config, indent=2))
model_short = config["model_name"].split("/")[-1].lower().replace(".", "p")
baseline_dir = Path(runs_dir) / (
f"_base_eval_{model_short}_n{config['num_numbers']}_eval{config['eval_size']}_"
f"k{config.get('eval_num_samples', 4)}_seed{config['seed']}"
)
baseline_metrics_path = baseline_dir / "metrics.json"
if baseline_metrics_path.exists():
baseline_eval = json.loads(baseline_metrics_path.read_text())
else:
baseline_dir.mkdir(parents=True, exist_ok=True)
baseline_eval = evaluate_checkpoint(
config["model_name"], None, eval_data, config, baseline_dir / "eval_samples.jsonl"
)
baseline_metrics_path.write_text(json.dumps(baseline_eval, indent=2))
baseline_metrics = {
"method": method,
"run_name": run_name,
"stage": -1,
"global_train_steps": 0,
**baseline_eval,
**{key: config[key] for key in [
"learning_rate", "beta", "temperature", "max_completion_length", "num_generations"
]},
}
(run_dir / "baseline_metrics.json").write_text(json.dumps(baseline_metrics, indent=2))
# Cached and uncached baseline evaluation must leave training with identical RNG state.
set_seed(config["seed"])
history = [baseline_metrics]
stage_config = deepcopy(config)
previous_adapter = best_adapter = None
best_accuracy = baseline_metrics["eval_accuracy"]
if method == "continuous_grpo":
checkpoints, log_path, total_wall_clock = train_continuous(
config["model_name"], train_data, config, run_dir
)
for stage, checkpoint in enumerate(checkpoints):
stage_dir = run_dir / f"stage_{stage}"
stage_dir.mkdir(exist_ok=True)
(stage_dir / "config.json").write_text(json.dumps(config, indent=2))
eval_metrics = evaluate_checkpoint(
config["model_name"], checkpoint, eval_data, config, stage_dir / "eval_samples.jsonl"
)
step_max = (stage + 1) * config["steps_per_stage"]
train_metrics = aggregate_train_logs(
log_path, step_min=stage * config["steps_per_stage"], step_max=step_max
)
metrics = {
"method": method,
"run_name": run_name,
"stage": stage,
"global_train_steps": step_max,
**train_metrics,
**eval_metrics,
"wall_clock_seconds": total_wall_clock,
**{key: config[key] for key in [
"learning_rate", "beta", "temperature", "max_completion_length", "num_generations"
]},
}
(stage_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
history.append(metrics)
pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)
print(json.dumps(metrics, indent=2))
return
for stage in range(config["num_stages"]):
stage_dir = run_dir / f"stage_{stage}"
stage_dir.mkdir()
(stage_dir / "config.json").write_text(json.dumps(stage_config, indent=2))
print(f"\n=== {run_name}: stage {stage} ===")
checkpoint, train_metrics = train_stage_in_subprocess(
config["model_name"], previous_adapter, stage_config, stage_dir
)
eval_metrics = evaluate_checkpoint(
config["model_name"], checkpoint, eval_data, stage_config, stage_dir / "eval_samples.jsonl"
)
metrics = {
"method": method,
"run_name": run_name,
"stage": stage,
"global_train_steps": (stage + 1) * config["steps_per_stage"],
**train_metrics,
**eval_metrics,
**{key: stage_config[key] for key in [
"learning_rate", "beta", "temperature", "max_completion_length", "num_generations"
]},
}
(stage_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
history.append(metrics)
pd.DataFrame(history).to_csv(run_dir / "history.csv", index=False)
print(json.dumps(metrics, indent=2))
if metrics["eval_accuracy"] > best_accuracy:
best_accuracy, best_adapter = metrics["eval_accuracy"], checkpoint
if stage == config["num_stages"] - 1:
break
decision = validate_decision(choose_decision(method, stage_config, metrics, history))
(stage_dir / "decision.json").write_text(json.dumps(decision, indent=2))
if decision["early_stop"]:
break
previous_adapter = best_adapter if decision["rollback_to_best_checkpoint"] else checkpoint
stage_config = apply_decision(stage_config, decision)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config")
parser.add_argument("--method", choices=["continuous_grpo", "fixed_grpo", "rule_controller", "llm_controller"])
parser.add_argument("--run-name")
parser.add_argument("--runs-dir", default="runs")
parser.add_argument("--overwrite", action="store_true")
parser.add_argument("--stage-worker")
args = parser.parse_args()
if args.stage_worker:
run_stage_worker(args.stage_worker)
return
if not args.config or not args.method or not args.run_name:
parser.error("--config, --method, and --run-name are required")
with open(args.config) as file:
config = yaml.safe_load(file)
run_experiment(args.method, config, args.run_name, args.runs_dir, args.overwrite)
if __name__ == "__main__":
main()