curriculum-cot-code / multi_output_cell_policy /sft_multi_output_train.py
Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import argparse
import hashlib
import json
import math
import os
import random
import sys
import time
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Dict, List
import torch
import torch.distributed as dist
from peft import LoraConfig, PeftModel, get_peft_model
from torch.optim import AdamW
from torch.utils.data import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.dirname(CURRENT_DIR)
if PARENT_DIR not in sys.path:
sys.path.insert(0, PARENT_DIR)
from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row
from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final
from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt
from multi_output_cell_policy.rewards import score_prediction_text
from multi_output_cell_policy.shared_multi_output_policy import (
batched_completion_ce_loss,
build_supervised_completion,
completion_ce_loss,
make_solved_grid_from_row,
stage_i_consistent_values,
)
try:
import wandb
except Exception:
wandb = None
@dataclass
class Args:
model_name: str
train_jsonl: str
eval_jsonl: str
output_dir: str
cache_dir: str
init_adapter_dir: str
seed: int
gpu_id: int
stage_i: int
total_empties_hint: int
per_device_train_batch_size: int
gradient_accumulation_steps: int
num_epochs: float
learning_rate: float
weight_decay: float
max_grad_norm: float
enable_gradient_checkpointing: bool
logging_steps: int
save_steps: int
eval_steps: int
eval_rows: int
max_completion_length: int
lora_r: int
lora_alpha: int
lora_dropout: float
use_wandb: bool
wandb_entity: str
wandb_project: str
wandb_run_name: str
wandb_mode: str
debug_print_limit: int
limit_train_rows: int
eval_exact_set_match_stop: float
eval_value_precision_stop: float
eval_value_recall_stop: float
eval_solve_rate_stop: float
min_steps_before_stop: int
max_wall_clock_seconds: int
max_steps: int
multi_value_oversample_factor: int
train_target_size_min: int
train_target_size_max: int
eval_target_size_min: int
eval_target_size_max: int
def configure_hf_cache(cache_dir: str) -> str:
cache_dir = os.path.abspath(os.path.expanduser(cache_dir))
hub_dir = os.path.join(cache_dir, "hub")
transformers_dir = os.path.join(cache_dir, "transformers")
os.makedirs(hub_dir, exist_ok=True)
os.makedirs(transformers_dir, exist_ok=True)
os.environ["HF_HOME"] = cache_dir
os.environ["HF_HUB_CACHE"] = hub_dir
os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir
os.environ["TRANSFORMERS_CACHE"] = transformers_dir
os.environ.setdefault("HF_HUB_DISABLE_XET", "1")
return cache_dir
def configure_wandb_dirs(output_dir: str) -> None:
wandb_dir = os.path.join(output_dir, "wandb_runtime")
os.makedirs(wandb_dir, exist_ok=True)
os.environ.setdefault("WANDB_DIR", wandb_dir)
os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir)
os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir)
def pick_dtype() -> torch.dtype:
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16
def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
if limit_rows > 0 and len(rows) >= limit_rows:
break
return rows
def target_size_allowed(target_size: int, min_size: int, max_size: int) -> bool:
if int(min_size) > 0 and int(target_size) < int(min_size):
return False
if int(max_size) > 0 and int(target_size) > int(max_size):
return False
return True
def build_training_examples(
rows: List[Dict[str, Any]],
*,
tokenizer: Any,
stage_i: int,
total_empties_hint: int,
progress_every_rows: int = 10,
progress_callback: Any = None,
) -> List[Dict[str, Any]]:
examples: List[Dict[str, Any]] = []
eos_text = getattr(tokenizer, "eos_token", None) or ""
for row_idx, row in enumerate(rows, start=1):
solved = make_solved_grid_from_row(row)
for ex in build_cell_examples_from_row(row):
target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i)
if not target_size_allowed(
len(target_values),
getattr(tokenizer, "_train_target_size_min", 0),
getattr(tokenizer, "_train_target_size_max", 0),
):
continue
prompt = build_multi_output_cell_prompt(
ex.grid,
target_cell=ex.target_cell,
stage_i=stage_i,
tokenizer=tokenizer,
turn_idx=ex.turn_idx,
total_turns=ex.total_turns,
prev_output_flag=None,
total_empties_hint=total_empties_hint,
)
target_text = build_supervised_completion(ex, stage_i=stage_i)
if eos_text:
target_text = target_text + eos_text
repeat_count = max(1, int(getattr(tokenizer, "_multi_value_oversample_factor", 1))) if len(target_values) > 1 else 1
for _ in range(repeat_count):
examples.append(
{
"prompt_text": prompt,
"completion_text": target_text,
"target_values": list(target_values),
"grid": ex.grid,
"solved": solved,
"target_cell": ex.target_cell,
}
)
if progress_callback is not None and (
row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0
):
progress_callback(row_idx, len(rows), len(examples))
return examples
def _prepared_data_dir(args: Args) -> str:
path = os.path.join(PARENT_DIR, "_prepared_data", "multi_output_cell_policy")
os.makedirs(path, exist_ok=True)
return path
def _prepared_sft_cache_path(args: Args) -> str:
payload = json.dumps(
{
"completion_format_version": 2,
"train_jsonl": os.path.abspath(args.train_jsonl),
"stage_i": int(args.stage_i),
"total_empties_hint": int(args.total_empties_hint),
"limit_train_rows": int(args.limit_train_rows),
"model_name": str(args.model_name),
"multi_value_oversample_factor": int(args.multi_value_oversample_factor),
"train_target_size_min": int(args.train_target_size_min),
"train_target_size_max": int(args.train_target_size_max),
},
sort_keys=True,
).encode("utf-8")
digest = hashlib.sha1(payload).hexdigest()[:20]
return os.path.join(_prepared_data_dir(args), f"sft_stage{int(args.stage_i):02d}_{digest}.jsonl")
def _to_jsonable(value: Any) -> Any:
if isinstance(value, dict):
return {k: _to_jsonable(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_to_jsonable(v) for v in value]
if hasattr(value, "tolist"):
return _to_jsonable(value.tolist())
return value
def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None:
tmp_path = f"{path}.tmp"
with open(tmp_path, "w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(_to_jsonable(row), separators=(",", ":")) + "\n")
os.replace(tmp_path, path)
def _read_jsonl(path: str) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None:
start = time.time()
while not os.path.exists(path):
if time.time() - start > timeout_s:
raise TimeoutError(f"Timed out waiting for prepared cache: {path}")
time.sleep(2.0)
def load_or_build_sft_examples(
args: Args,
*,
rows: List[Dict[str, Any]],
tokenizer: Any,
rank: int,
world_size: int,
progress_callback: Any = None,
) -> List[Dict[str, Any]]:
cache_path = _prepared_sft_cache_path(args)
if os.path.exists(cache_path):
return _read_jsonl(cache_path)
if rank == 0:
print(f"[dataset build][sft stage {args.stage_i}] building prepared cache: {cache_path}", flush=True)
examples = build_training_examples(
rows,
tokenizer=tokenizer,
stage_i=args.stage_i,
total_empties_hint=args.total_empties_hint,
progress_every_rows=10,
progress_callback=progress_callback,
)
_write_jsonl(cache_path, examples)
return examples
_wait_for_cache(cache_path)
return _read_jsonl(cache_path)
@torch.no_grad()
def run_eval(args: Args, rows: List[Dict[str, Any]], model: torch.nn.Module, tokenizer: Any, device: torch.device):
model.eval()
total_cells = 0
parse_ok = 0.0
canonical_ok = 0.0
exact_set_match = 0.0
includes_gt = 0.0
precision_sum = 0.0
recall_sum = 0.0
predicted_size_sum = 0.0
good_count_sum = 0.0
bad_count_sum = 0.0
solve_ok = 0
solve_rows = 0
printed = 0
for row in rows:
solved = make_solved_grid_from_row(row)
row_all_exact = True
row_has_eval_cell = False
row_debug_lines: List[str] = []
for ex in build_cell_examples_from_row(row):
target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=args.stage_i)
if not target_size_allowed(len(target_values), int(args.eval_target_size_min), int(args.eval_target_size_max)):
continue
row_has_eval_cell = True
prompt = build_multi_output_cell_prompt(
ex.grid,
target_cell=ex.target_cell,
stage_i=args.stage_i,
tokenizer=tokenizer,
turn_idx=ex.turn_idx,
total_turns=ex.total_turns,
prev_output_flag=None,
total_empties_hint=args.total_empties_hint,
)
enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
enc = {k: v.to(device) for k, v in enc.items()}
out = model.generate(
**enc,
max_new_tokens=max(1, int(args.max_completion_length)),
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip()
info = score_prediction_text(
text=pred_text,
grid=ex.grid,
solved=solved,
target_cell=ex.target_cell,
stage_i=args.stage_i,
reward_good_value=1.0,
penalty_bad_value=1.75,
penalty_malformed=4.0,
penalty_empty=0.5,
penalty_singleton=1.5,
)
total_cells += 1
parse_ok += float(info["parse_ok"])
canonical_ok += float(info["strict_canonical"])
exact_set_match += float(info["exact_set_match"])
includes_gt += float(info["includes_ground_truth"])
precision_sum += float(info["value_precision"])
recall_sum += float(info["value_recall"])
predicted_size_sum += float(info["num_predicted_values"])
good_count_sum += float(info["num_i_consistent_values"])
bad_count_sum += float(info["num_non_i_consistent_values"])
if float(info["exact_set_match"]) < 0.5:
row_all_exact = False
if printed < int(args.debug_print_limit):
row_debug_lines.append(
f"[baseline sft eval debug] true_values={info['target_values']} "
f"predicted_values={info['predicted_values']} output={pred_text!r}"
)
if row_has_eval_cell:
if printed < int(args.debug_print_limit) and row_debug_lines:
print("[baseline sft eval debug] puzzle_outputs_begin", flush=True)
for line in row_debug_lines:
print(line, flush=True)
print("[baseline sft eval debug] puzzle_outputs_end", flush=True)
printed += 1
solve_ok += int(row_all_exact)
solve_rows += 1
out = {
"parse_rate": float(parse_ok / max(1, total_cells)),
"strict_canonical_rate": float(canonical_ok / max(1, total_cells)),
"exact_set_match_rate": float(exact_set_match / max(1, total_cells)),
"includes_ground_truth_rate": float(includes_gt / max(1, total_cells)),
"value_precision": float(precision_sum / max(1, total_cells)),
"value_recall": float(recall_sum / max(1, total_cells)),
"avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)),
"avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)),
"avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)),
"solve_rate": float(solve_ok / max(1, solve_rows)),
}
print(
f"[baseline sft eval] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} "
f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} "
f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}",
flush=True,
)
model.train()
return out
def parse_args() -> Args:
p = argparse.ArgumentParser()
p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct")
p.add_argument("--train_jsonl", type=str, required=True)
p.add_argument("--eval_jsonl", type=str, default="")
p.add_argument("--output_dir", type=str, required=True)
p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum-CoT/.hf_cache")
p.add_argument("--init_adapter_dir", type=str, default="")
p.add_argument("--seed", type=int, default=0)
p.add_argument("--gpu_id", type=int, default=0)
p.add_argument("--stage_i", type=int, default=1)
p.add_argument("--total_empties_hint", type=int, default=10)
p.add_argument("--per_device_train_batch_size", type=int, default=1)
p.add_argument("--gradient_accumulation_steps", type=int, default=8)
p.add_argument("--num_epochs", type=float, default=1.0)
p.add_argument("--learning_rate", type=float, default=2e-4)
p.add_argument("--weight_decay", type=float, default=0.0)
p.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Clip global grad norm before each optimizer step (0 disables).",
)
p.add_argument("--enable_gradient_checkpointing", action="store_true")
p.add_argument("--logging_steps", type=int, default=10)
p.add_argument("--save_steps", type=int, default=100)
p.add_argument("--eval_steps", type=int, default=100)
p.add_argument("--eval_rows", type=int, default=20)
p.add_argument("--max_completion_length", type=int, default=24)
p.add_argument("--lora_r", type=int, default=16)
p.add_argument("--lora_alpha", type=int, default=32)
p.add_argument("--lora_dropout", type=float, default=0.05)
p.add_argument("--use_wandb", action="store_true")
p.add_argument("--wandb_entity", type=str, default="")
p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-sft")
p.add_argument("--wandb_run_name", type=str, default="")
p.add_argument("--wandb_mode", type=str, default="online")
p.add_argument("--debug_print_limit", type=int, default=3)
p.add_argument("--limit_train_rows", type=int, default=0)
p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0)
p.add_argument("--eval_value_precision_stop", type=float, default=0.0)
p.add_argument("--eval_value_recall_stop", type=float, default=0.0)
p.add_argument("--eval_solve_rate_stop", type=float, default=0.0)
p.add_argument("--min_steps_before_stop", type=int, default=0)
p.add_argument("--max_wall_clock_seconds", type=int, default=0)
p.add_argument("--max_steps", type=int, default=0)
p.add_argument("--multi_value_oversample_factor", type=int, default=1)
p.add_argument("--train_target_size_min", type=int, default=0)
p.add_argument("--train_target_size_max", type=int, default=0)
p.add_argument("--eval_target_size_min", type=int, default=0)
p.add_argument("--eval_target_size_max", type=int, default=0)
return Args(**vars(p.parse_args()))
def save_checkpoint(model: torch.nn.Module, tokenizer: Any, output_dir: str, step: int) -> None:
save_checkpoint_and_update_final(model, tokenizer, output_dir, f"checkpoint-step-{step:05d}")
def main() -> None:
args = parse_args()
preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip()
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
is_distributed = world_size > 1
is_main_process = rank == 0
if preset_visible_devices:
if is_main_process:
print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True)
elif int(args.gpu_id) >= 0:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id))
if is_distributed:
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", timeout=timedelta(hours=2))
set_seed(args.seed + rank)
os.makedirs(args.output_dir, exist_ok=True)
ensure_final_checkpoint_dir(args.output_dir)
cache_dir = configure_hf_cache(args.cache_dir)
configure_wandb_dirs(args.output_dir)
wb_run = None
if is_main_process and args.use_wandb and wandb is not None:
init_kwargs = {
"project": args.wandb_project,
"name": args.wandb_run_name or None,
"mode": args.wandb_mode,
}
if str(args.wandb_entity).strip():
init_kwargs["entity"] = args.wandb_entity
wb_run = wandb.init(**init_kwargs)
print(f"W&B run id: {wb_run.id}", flush=True)
print(f"W&B run URL: {wb_run.url}", flush=True)
wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0})
rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows)
eval_source = args.eval_jsonl if str(args.eval_jsonl).strip() else args.train_jsonl
eval_rows = load_jsonl_rows(eval_source, limit_rows=max(1, int(args.eval_rows)))
tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
tokenizer._multi_value_oversample_factor = max(1, int(args.multi_value_oversample_factor))
tokenizer._train_target_size_min = max(0, int(args.train_target_size_min))
tokenizer._train_target_size_max = max(0, int(args.train_target_size_max))
if torch.cuda.is_available():
device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}")
else:
device = torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(
args.model_name, cache_dir=cache_dir, torch_dtype=pick_dtype(), low_cpu_mem_usage=True
)
if str(args.init_adapter_dir).strip():
model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True)
else:
lora = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora)
if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if hasattr(model, "config"):
model.config.use_cache = False
model.to(device)
model.train()
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
def on_prep_progress(rows_done: int, total_rows: int, examples_built: int) -> None:
if is_main_process:
print(
f"[dataset build][sft stage {args.stage_i}] rows={rows_done}/{total_rows} examples={examples_built}",
flush=True,
)
if wb_run is not None:
wandb.log({"prep/rows_done": float(rows_done), "prep/examples_built": float(examples_built)})
train_examples = load_or_build_sft_examples(
args,
rows=rows,
tokenizer=tokenizer,
rank=rank,
world_size=world_size,
progress_callback=on_prep_progress,
)
if is_main_process and wb_run is not None:
wandb.log(
{
"prep/cache_hit": float(os.path.exists(_prepared_sft_cache_path(args))),
"prep/examples_final": float(len(train_examples)),
}
)
optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay)
denom = max(1, int(args.gradient_accumulation_steps)) * max(1, int(args.per_device_train_batch_size))
total_steps = max(1, math.ceil(len(train_examples) * args.num_epochs / denom))
if int(args.max_steps) > 0:
total_steps = min(total_steps, int(args.max_steps))
step = 0
start_time = time.time()
def average_scalar(value: float) -> float:
if not is_distributed or not dist.is_initialized():
return float(value)
tensor = torch.tensor(float(value), device=device, dtype=torch.float32)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return float((tensor / float(world_size)).item())
def all_reduce_gradients() -> None:
if not is_distributed or not dist.is_initialized():
return
for param in model.parameters():
if param.grad is None:
continue
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad.div_(float(world_size))
def sync_stop(local_stop: bool) -> bool:
if not is_distributed or not dist.is_initialized():
return bool(local_stop)
tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64)
dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
return bool(int(tensor.item()) > 0)
for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))):
if is_distributed:
sampler = DistributedSampler(
train_examples,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=args.seed,
drop_last=False,
)
sampler.set_epoch(epoch_idx)
order = list(iter(sampler))
else:
generator = torch.Generator()
generator.manual_seed(args.seed + epoch_idx)
order = torch.randperm(len(train_examples), generator=generator).tolist()
optimizer.zero_grad(set_to_none=True)
accum_count = 0
accum_ce_sum = 0.0
microbatch_size = max(1, int(args.per_device_train_batch_size))
for batch_start in range(0, len(order), microbatch_size):
batch_indices = order[batch_start : batch_start + microbatch_size]
batch_examples = [train_examples[ex_idx] for ex_idx in batch_indices]
ce_full = batched_completion_ce_loss(
model,
tokenizer,
[str(ex["prompt_text"]) for ex in batch_examples],
[str(ex["completion_text"]) for ex in batch_examples],
device,
)
loss = ce_full / max(1, int(args.gradient_accumulation_steps))
loss.backward()
accum_ce_sum += float(ce_full.detach().item())
accum_count += 1
if accum_count >= int(args.gradient_accumulation_steps):
all_reduce_gradients()
if float(args.max_grad_norm) > 0.0:
torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.max_grad_norm))
optimizer.step()
optimizer.zero_grad(set_to_none=True)
accum_count = 0
step += 1
mean_ce = accum_ce_sum / max(1, int(args.gradient_accumulation_steps))
accum_ce_sum = 0.0
if step % int(args.logging_steps) == 0:
loss_value = average_scalar(mean_ce)
if is_main_process:
print(f"[baseline sft train step {step:05d}] loss={loss_value:.4f}", flush=True)
if wb_run is not None:
wandb.log({"train/loss": loss_value, "step": step})
if step % int(args.eval_steps) == 0:
if is_distributed and dist.is_initialized():
dist.barrier()
should_stop_eval = False
if is_main_process:
ev = run_eval(args, eval_rows, model, tokenizer, device)
if wb_run is not None:
wandb.log({f"eval/{k}": float(v) for k, v in ev.items()} | {"step": step})
if (
args.eval_exact_set_match_stop > 0.0
and float(ev["exact_set_match_rate"]) >= args.eval_exact_set_match_stop
):
save_checkpoint(model, tokenizer, args.output_dir, step)
should_stop_eval = True
if (
not should_stop_eval
and step >= int(args.min_steps_before_stop)
and args.eval_value_precision_stop > 0.0
and args.eval_value_recall_stop > 0.0
and float(ev["value_precision"]) >= args.eval_value_precision_stop
and float(ev["value_recall"]) >= args.eval_value_recall_stop
):
save_checkpoint(model, tokenizer, args.output_dir, step)
should_stop_eval = True
if (
not should_stop_eval
and args.eval_solve_rate_stop > 0.0
and step >= int(args.min_steps_before_stop)
and float(ev["solve_rate"]) >= args.eval_solve_rate_stop
):
save_checkpoint(model, tokenizer, args.output_dir, step)
should_stop_eval = True
should_stop_eval = sync_stop(should_stop_eval)
if is_distributed and dist.is_initialized():
dist.barrier()
if should_stop_eval:
if is_main_process and wb_run is not None:
wb_run.finish()
if is_distributed and dist.is_initialized():
dist.destroy_process_group()
return
if step % int(args.save_steps) == 0:
if is_distributed and dist.is_initialized():
dist.barrier()
if is_main_process:
save_checkpoint(model, tokenizer, args.output_dir, step)
if is_distributed and dist.is_initialized():
dist.barrier()
reached_limit = step >= total_steps
exceeded_wall = bool(args.max_wall_clock_seconds) and (
time.time() - start_time >= float(args.max_wall_clock_seconds)
)
should_stop = sync_stop(reached_limit or exceeded_wall)
if should_stop:
break
if sync_stop(step >= total_steps):
break
if is_distributed and dist.is_initialized():
dist.barrier()
if is_main_process:
save_checkpoint(model, tokenizer, args.output_dir, step)
if wb_run is not None:
wb_run.finish()
if is_distributed and dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
main()