curriculum-cot-code / small_model_20empty /run_small_baseline_pipeline.py
Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import argparse
import json
import os
import re
import subprocess
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional
CURRENT_DIR = Path(__file__).resolve().parent
PARENT_DIR = CURRENT_DIR.parent
if str(PARENT_DIR) not in sys.path:
sys.path.insert(0, str(PARENT_DIR))
from checkpoint_utils import final_checkpoint_root, normalize_to_final_checkpoint_root
DEFAULT_CHECKPOINT_ROOT = Path(final_checkpoint_root("small_model_20empty", "baseline"))
DEFAULT_CACHE_DIR = Path("/home/ubuntu/curriculum-CoT/.hf_cache")
DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
DEFAULT_WANDB_GROUP = "small_model_20empty_baseline_pipeline"
DEFAULT_SFT_PROJECT = "sudoku-small-20empty-baseline-sft"
DEFAULT_GRPO_PROJECT = "sudoku-small-20empty-baseline-grpo"
SFT_SCRIPT = PARENT_DIR / "multi_output_cell_policy" / "sft_multi_output_train.py"
GRPO_SCRIPT = PARENT_DIR / "multi_output_cell_policy" / "grpo_multi_output_train.py"
STAGE_COMPLETE_MARKER = "_stage_complete.json"
@dataclass
class Artifact:
path: str
stage: int
phase: str
step: int
mtime: float
source_dir: str
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--python_executable", type=str, default=sys.executable)
p.add_argument("--checkpoint_root", type=str, default=str(DEFAULT_CHECKPOINT_ROOT))
p.add_argument("--output_root", type=str, default="")
p.add_argument("--run_tag", type=str, default="")
p.add_argument("--train_jsonl", type=str, default="")
p.add_argument("--cache_dir", type=str, default=str(DEFAULT_CACHE_DIR))
p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--total_empties_hint", type=int, default=20)
p.add_argument("--min_stage", type=int, default=1)
p.add_argument("--max_stage", type=int, default=4)
p.add_argument("--sft_gpu_id", type=int, default=0)
p.add_argument("--grpo_gpu_id", type=int, default=1)
p.add_argument("--stage1_init_adapter_dir", type=str, default="")
p.add_argument("--bootstrap_adapter_dir", type=str, default="")
p.add_argument("--distributed_gpu_ids", type=str, default="")
p.add_argument("--sft_num_processes", type=int, default=1)
p.add_argument("--grpo_num_processes", type=int, default=1)
p.add_argument("--wandb_mode", type=str, default="online")
p.add_argument("--use_wandb", action="store_true")
p.add_argument("--wandb_entity", type=str, default="")
p.add_argument("--wandb_group", type=str, default=DEFAULT_WANDB_GROUP)
p.add_argument("--wandb_sft_project", type=str, default=DEFAULT_SFT_PROJECT)
p.add_argument("--wandb_grpo_project", type=str, default=DEFAULT_GRPO_PROJECT)
p.add_argument("--sft_num_epochs", type=float, default=1.0)
p.add_argument("--sft_learning_rate_stage1", type=float, default=2e-4)
p.add_argument("--sft_learning_rate_later", type=float, default=5e-5)
p.add_argument("--sft_gradient_accumulation_steps", type=int, default=8)
p.add_argument("--sft_enable_gradient_checkpointing", action="store_true")
p.add_argument("--sft_logging_steps", type=int, default=10)
p.add_argument("--sft_eval_steps", type=int, default=100)
p.add_argument("--sft_save_steps", type=int, default=100)
p.add_argument("--sft_eval_rows", type=int, default=20)
p.add_argument("--sft_max_completion_length", type=int, default=24)
p.add_argument("--grpo_num_train_epochs", type=float, default=0.5)
p.add_argument("--grpo_learning_rate", type=float, default=1e-6)
p.add_argument("--grpo_per_device_train_batch_size", type=int, default=2)
p.add_argument("--grpo_gradient_accumulation_steps", type=int, default=4)
p.add_argument("--grpo_enable_gradient_checkpointing", action="store_true")
p.add_argument("--grpo_logging_steps", type=int, default=5)
p.add_argument("--grpo_eval_steps", type=int, default=25)
p.add_argument("--grpo_save_steps", type=int, default=25)
p.add_argument("--grpo_eval_rows", type=int, default=20)
p.add_argument("--grpo_num_generations", type=int, default=2)
p.add_argument("--grpo_max_prompt_length", type=int, default=1024)
p.add_argument("--grpo_max_completion_length", type=int, default=24)
p.add_argument("--grpo_beta", type=float, default=0.0)
p.add_argument("--phase_max_wall_clock_seconds", type=int, default=21600)
p.add_argument("--limit_train_rows", type=int, default=0)
p.add_argument("--sft_stage_max_steps", type=str, default="")
p.add_argument("--grpo_stage_max_steps", type=str, default="")
p.add_argument("--dry_run", action="store_true")
return p.parse_args()
def stage_dir_pattern(stage: int, phase: str, empties: int) -> str:
return f"stage{stage:02d}_{phase}_i{stage}_{empties}empty*"
def extract_numeric_suffix(name: str, prefix: str) -> Optional[int]:
match = re.fullmatch(rf"{re.escape(prefix)}(\d+)", name)
return int(match.group(1)) if match else None
def stage_complete_path(stage_dir: Path) -> Path:
return stage_dir / STAGE_COMPLETE_MARKER
def is_stage_complete(stage_dir: Path) -> bool:
return stage_complete_path(stage_dir).is_file()
def output_root_has_stage_artifacts(path: Path) -> bool:
if not path.exists():
return False
if (path / "pipeline_state.json").exists():
return True
return any(path.glob("stage[0-9][0-9]_*"))
def latest_sft_checkpoint(stage_dir: Path) -> Optional[Artifact]:
best: Optional[Artifact] = None
for child in stage_dir.iterdir():
if not child.is_dir():
continue
step = extract_numeric_suffix(child.name, "checkpoint-step-")
if step is None:
continue
artifact = Artifact(
path=str(child),
stage=-1,
phase="sft",
step=step,
mtime=child.stat().st_mtime,
source_dir=str(stage_dir),
)
if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step):
best = artifact
return best
def latest_grpo_artifact(stage_dir: Path) -> Optional[Artifact]:
best: Optional[Artifact] = None
root_adapter = stage_dir / "adapter_model.safetensors"
if root_adapter.exists():
best = Artifact(
path=str(stage_dir),
stage=-1,
phase="grpo",
step=10**9,
mtime=stage_dir.stat().st_mtime,
source_dir=str(stage_dir),
)
for child in stage_dir.iterdir():
if not child.is_dir():
continue
step = extract_numeric_suffix(child.name, "checkpoint-")
if step is None:
continue
adapter = child / "adapter_model.safetensors"
if not adapter.exists():
continue
artifact = Artifact(
path=str(child),
stage=-1,
phase="grpo",
step=step,
mtime=child.stat().st_mtime,
source_dir=str(stage_dir),
)
if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step):
best = artifact
return best
def discover_latest_artifact(
checkpoint_root: Path,
*,
stage: int,
phase: str,
empties: int,
require_complete: bool = True,
) -> Optional[Artifact]:
best: Optional[Artifact] = None
for stage_dir in checkpoint_root.rglob(stage_dir_pattern(stage, phase, empties)):
if not stage_dir.is_dir():
continue
if require_complete and not is_stage_complete(stage_dir):
continue
artifact = latest_sft_checkpoint(stage_dir) if phase == "sft" else latest_grpo_artifact(stage_dir)
if artifact is None:
continue
artifact.stage = stage
artifact.phase = phase
if best is None or (artifact.mtime, artifact.step) > (best.mtime, best.step):
best = artifact
return best
def choose_output_root(args: argparse.Namespace, checkpoint_root: Path) -> Path:
if str(args.output_root).strip():
requested_root = Path(
normalize_to_final_checkpoint_root(args.output_root, "small_model_20empty", "baseline")
).resolve()
if output_root_has_stage_artifacts(requested_root):
run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S")
return requested_root / run_tag
return requested_root
run_tag = str(args.run_tag).strip() or time.strftime("%Y%m%d_%H%M%S")
return checkpoint_root / run_tag / f"baseline_pipeline_{args.total_empties_hint}empty_{args.max_stage}stage_small"
def default_train_jsonl(args: argparse.Namespace) -> Path:
if str(args.train_jsonl).strip():
return Path(args.train_jsonl).resolve()
return (PARENT_DIR / "data" / f"sudoku_t3_{int(args.total_empties_hint)}empty_value_qwen_text.jsonl").resolve()
def phase_output_dir(output_root: Path, *, stage: int, phase: str, empties: int) -> Path:
return output_root / f"stage{stage:02d}_{phase}_i{stage}_{empties}empty"
def run_command(command: List[str], *, env: Dict[str, str], dry_run: bool) -> None:
print("")
print("Running command:")
print(" ".join(subprocess.list2cmdline([part]) for part in command))
if dry_run:
print("Dry run enabled; command not executed.")
return
subprocess.run(command, env=env, check=True)
def parse_stage_int_map(raw: str) -> Dict[int, int]:
mapping: Dict[int, int] = {}
text = str(raw or "").strip()
if not text:
return mapping
for part in text.split(","):
item = part.strip()
if not item:
continue
stage_text, value_text = item.split(":", 1)
mapping[int(stage_text.strip())] = int(value_text.strip())
return mapping
def resolve_stage_value(mapping: Dict[int, int], stage: int) -> int:
return int(mapping.get(int(stage), 0))
def make_env(*, gpu_id: int, wandb_mode: str, gpu_ids: str, num_processes: int) -> Dict[str, str]:
env = os.environ.copy()
env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
requested = [part.strip() for part in str(gpu_ids or "").split(",") if part.strip()]
if int(num_processes) > 1:
if requested:
env["CUDA_VISIBLE_DEVICES"] = ",".join(requested[: int(num_processes)])
else:
env["CUDA_VISIBLE_DEVICES"] = str(requested[0] if requested else int(gpu_id))
env["WANDB__SERVICE_WAIT"] = "300"
env["WANDB_MODE"] = str(wandb_mode)
return env
def build_sft_command(
args: argparse.Namespace,
*,
train_jsonl: Path,
output_dir: Path,
stage: int,
init_adapter_dir: Optional[str],
stage_max_steps: int,
) -> List[str]:
num_processes = max(1, int(args.sft_num_processes))
if num_processes > 1:
command = [
args.python_executable,
"-m",
"torch.distributed.run",
"--standalone",
"--nproc_per_node",
str(num_processes),
str(SFT_SCRIPT),
]
else:
command = [args.python_executable, "-u", str(SFT_SCRIPT)]
command.extend(
[
"--model_name",
args.model_name,
"--train_jsonl",
str(train_jsonl),
"--output_dir",
str(output_dir),
"--cache_dir",
args.cache_dir,
"--seed",
str(int(args.seed)),
"--gpu_id",
str(0 if num_processes > 1 else int(args.sft_gpu_id)),
"--stage_i",
str(int(stage)),
"--total_empties_hint",
str(int(args.total_empties_hint)),
"--num_epochs",
str(float(args.sft_num_epochs)),
"--learning_rate",
str(float(args.sft_learning_rate_stage1 if stage <= 1 else args.sft_learning_rate_later)),
"--gradient_accumulation_steps",
str(int(args.sft_gradient_accumulation_steps)),
"--enable_gradient_checkpointing" if args.sft_enable_gradient_checkpointing else "",
"--logging_steps",
str(int(args.sft_logging_steps)),
"--eval_steps",
str(int(args.sft_eval_steps)),
"--save_steps",
str(int(args.sft_save_steps)),
"--eval_rows",
str(int(args.sft_eval_rows)),
"--max_completion_length",
str(int(args.sft_max_completion_length)),
"--max_wall_clock_seconds",
str(int(args.phase_max_wall_clock_seconds)),
]
)
command = [part for part in command if part != ""]
if int(args.limit_train_rows) > 0:
command.extend(["--limit_train_rows", str(int(args.limit_train_rows))])
if int(stage_max_steps) > 0:
command.extend(["--max_steps", str(int(stage_max_steps))])
if args.use_wandb:
command.extend(["--use_wandb"])
if str(args.wandb_entity).strip():
command.extend(["--wandb_entity", args.wandb_entity])
command.extend(
[
"--wandb_project",
args.wandb_sft_project,
"--wandb_run_name",
f"small_baseline_stage{stage:02d}_sft_i{stage}_{args.total_empties_hint}empty",
"--wandb_mode",
args.wandb_mode,
]
)
if init_adapter_dir:
command.extend(["--init_adapter_dir", str(init_adapter_dir)])
return command
def build_grpo_command(
args: argparse.Namespace,
*,
train_jsonl: Path,
output_dir: Path,
stage: int,
init_adapter_dir: str,
stage_max_steps: int,
) -> List[str]:
num_processes = max(1, int(args.grpo_num_processes))
if num_processes > 1:
command = [
args.python_executable,
"-m",
"torch.distributed.run",
"--standalone",
"--nproc_per_node",
str(num_processes),
str(GRPO_SCRIPT),
]
else:
command = [args.python_executable, "-u", str(GRPO_SCRIPT)]
command.extend(
[
"--model_name",
args.model_name,
"--train_jsonl",
str(train_jsonl),
"--output_dir",
str(output_dir),
"--cache_dir",
args.cache_dir,
"--init_adapter_dir",
str(init_adapter_dir),
"--seed",
str(int(args.seed)),
"--gpu_id",
str(0 if num_processes > 1 else int(args.grpo_gpu_id)),
"--stage_i",
str(int(stage)),
"--total_empties_hint",
str(int(args.total_empties_hint)),
"--per_device_train_batch_size",
str(int(args.grpo_per_device_train_batch_size)),
"--gradient_accumulation_steps",
str(int(args.grpo_gradient_accumulation_steps)),
"--enable_gradient_checkpointing" if args.grpo_enable_gradient_checkpointing else "",
"--num_train_epochs",
str(float(args.grpo_num_train_epochs)),
"--learning_rate",
str(float(args.grpo_learning_rate)),
"--logging_steps",
str(int(args.grpo_logging_steps)),
"--save_steps",
str(int(args.grpo_save_steps)),
"--eval_steps",
str(int(args.grpo_eval_steps)),
"--eval_rows",
str(int(args.grpo_eval_rows)),
"--num_generations",
str(int(args.grpo_num_generations)),
"--max_prompt_length",
str(int(args.grpo_max_prompt_length)),
"--max_completion_length",
str(int(args.grpo_max_completion_length)),
"--beta",
str(float(args.grpo_beta)),
"--max_wall_clock_seconds",
str(int(args.phase_max_wall_clock_seconds)),
"--wandb_group",
args.wandb_group,
]
)
command = [part for part in command if part != ""]
if int(args.limit_train_rows) > 0:
command.extend(["--limit_train_rows", str(int(args.limit_train_rows))])
if int(stage_max_steps) > 0:
command.extend(["--max_steps", str(int(stage_max_steps))])
if args.use_wandb:
command.extend(["--use_wandb"])
if str(args.wandb_entity).strip():
command.extend(["--wandb_entity", args.wandb_entity])
command.extend(
[
"--wandb_project",
args.wandb_grpo_project,
"--wandb_run_name",
f"small_baseline_stage{stage:02d}_grpo_i{stage}_{args.total_empties_hint}empty",
"--wandb_mode",
args.wandb_mode,
]
)
return command
def write_state(path: Path, payload: Dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, sort_keys=True)
def mark_stage_complete(stage_dir: Path, artifact: Artifact) -> None:
write_state(
stage_complete_path(stage_dir),
{
"completed_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"artifact": asdict(artifact),
},
)
def main() -> None:
args = parse_args()
checkpoint_root = Path(
normalize_to_final_checkpoint_root(args.checkpoint_root, "small_model_20empty", "baseline")
).resolve()
output_root = choose_output_root(args, checkpoint_root)
train_jsonl = default_train_jsonl(args)
state_path = output_root / "pipeline_state.json"
sft_stage_max_steps = parse_stage_int_map(args.sft_stage_max_steps)
grpo_stage_max_steps = parse_stage_int_map(args.grpo_stage_max_steps)
output_root.mkdir(parents=True, exist_ok=True)
if not train_jsonl.exists():
raise FileNotFoundError(f"Missing train_jsonl: {train_jsonl}")
state: Dict[str, Any] = {
"updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
"train_jsonl": str(train_jsonl),
"checkpoint_root": str(checkpoint_root),
"output_root": str(output_root),
"min_stage": int(args.min_stage),
"max_stage": int(args.max_stage),
"total_empties_hint": int(args.total_empties_hint),
"model_name": str(args.model_name),
"stages": [],
}
previous_grpo: Optional[Artifact] = None
for stage in range(int(args.min_stage), int(args.max_stage) + 1):
stage_record: Dict[str, Any] = {"stage": stage}
existing_sft = discover_latest_artifact(
output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint)
)
existing_grpo = discover_latest_artifact(
output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint)
)
if existing_grpo is not None:
previous_grpo = existing_grpo
stage_record["status"] = "using_existing_grpo"
stage_record["grpo_artifact"] = asdict(existing_grpo)
if existing_sft is not None:
stage_record["sft_artifact"] = asdict(existing_sft)
state["stages"].append(stage_record)
write_state(state_path, state)
print(f"Stage {stage}: using existing GRPO artifact {existing_grpo.path}")
continue
if existing_sft is None:
sft_output_dir = phase_output_dir(output_root, stage=stage, phase="sft", empties=int(args.total_empties_hint))
if stage == int(args.min_stage) and str(args.bootstrap_adapter_dir).strip():
init_adapter_dir = str(args.bootstrap_adapter_dir).strip()
elif stage == 1:
init_adapter_dir = str(args.stage1_init_adapter_dir).strip() or None
else:
if previous_grpo is None:
raise RuntimeError(f"Missing previous GRPO artifact needed to launch baseline stage {stage} SFT.")
init_adapter_dir = previous_grpo.path
print(f"Stage {stage}: launching SFT into {sft_output_dir}")
run_command(
build_sft_command(
args,
train_jsonl=train_jsonl,
output_dir=sft_output_dir,
stage=stage,
init_adapter_dir=init_adapter_dir,
stage_max_steps=resolve_stage_value(sft_stage_max_steps, stage),
),
env=make_env(
gpu_id=int(args.sft_gpu_id),
wandb_mode=args.wandb_mode,
gpu_ids=args.distributed_gpu_ids,
num_processes=int(args.sft_num_processes),
),
dry_run=bool(args.dry_run),
)
existing_sft = discover_latest_artifact(
output_root,
stage=stage,
phase="sft",
empties=int(args.total_empties_hint),
require_complete=False,
)
if existing_sft is None and not args.dry_run:
raise RuntimeError(f"Could not locate SFT checkpoint for stage {stage} after running SFT.")
if existing_sft is not None:
mark_stage_complete(Path(existing_sft.source_dir), existing_sft)
stage_record["sft_artifact"] = asdict(existing_sft)
else:
stage_record["sft_artifact"] = asdict(existing_sft)
print(f"Stage {stage}: using existing SFT artifact {existing_sft.path}")
if existing_sft is None:
stage_record["status"] = "dry_run_pending_grpo"
state["stages"].append(stage_record)
write_state(state_path, state)
break
grpo_output_dir = phase_output_dir(output_root, stage=stage, phase="grpo", empties=int(args.total_empties_hint))
print(f"Stage {stage}: launching GRPO into {grpo_output_dir}")
run_command(
build_grpo_command(
args,
train_jsonl=train_jsonl,
output_dir=grpo_output_dir,
stage=stage,
init_adapter_dir=existing_sft.path,
stage_max_steps=resolve_stage_value(grpo_stage_max_steps, stage),
),
env=make_env(
gpu_id=int(args.grpo_gpu_id),
wandb_mode=args.wandb_mode,
gpu_ids=args.distributed_gpu_ids,
num_processes=int(args.grpo_num_processes),
),
dry_run=bool(args.dry_run),
)
existing_grpo = discover_latest_artifact(
output_root,
stage=stage,
phase="grpo",
empties=int(args.total_empties_hint),
require_complete=False,
)
if existing_grpo is None and not args.dry_run:
raise RuntimeError(f"Could not locate GRPO artifact for stage {stage} after running GRPO.")
if existing_grpo is not None:
mark_stage_complete(Path(existing_grpo.source_dir), existing_grpo)
previous_grpo = existing_grpo
stage_record["grpo_artifact"] = asdict(existing_grpo)
stage_record["status"] = "launched"
state["stages"].append(stage_record)
write_state(state_path, state)
state["updated_at"] = time.strftime("%Y-%m-%d %H:%M:%S")
write_state(state_path, state)
print("")
print(f"Pipeline state written to: {state_path}")
if __name__ == "__main__":
main()