from __future__ import annotations import argparse import hashlib import json import math import os import random import sys import time from dataclasses import dataclass from datetime import timedelta from typing import Any, Dict, List import torch import torch.distributed as dist from peft import LoraConfig, PeftModel, get_peft_model from torch.optim import AdamW from torch.utils.data import DistributedSampler from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) PARENT_DIR = os.path.dirname(CURRENT_DIR) if PARENT_DIR not in sys.path: sys.path.insert(0, PARENT_DIR) from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row from checkpoint_utils import ensure_final_checkpoint_dir, save_checkpoint_and_update_final from multi_output_cell_policy.prompt_builder import build_multi_output_cell_prompt from multi_output_cell_policy.rewards import score_prediction_text from multi_output_cell_policy.shared_multi_output_policy import ( batched_completion_ce_loss, build_supervised_completion, completion_ce_loss, make_solved_grid_from_row, stage_i_consistent_values, ) try: import wandb except Exception: wandb = None @dataclass class Args: model_name: str train_jsonl: str eval_jsonl: str output_dir: str cache_dir: str init_adapter_dir: str seed: int gpu_id: int stage_i: int total_empties_hint: int per_device_train_batch_size: int gradient_accumulation_steps: int num_epochs: float learning_rate: float weight_decay: float max_grad_norm: float enable_gradient_checkpointing: bool logging_steps: int save_steps: int eval_steps: int eval_rows: int max_completion_length: int lora_r: int lora_alpha: int lora_dropout: float use_wandb: bool wandb_entity: str wandb_project: str wandb_run_name: str wandb_mode: str debug_print_limit: int limit_train_rows: int eval_exact_set_match_stop: float eval_value_precision_stop: float eval_value_recall_stop: float eval_solve_rate_stop: float min_steps_before_stop: int max_wall_clock_seconds: int max_steps: int multi_value_oversample_factor: int train_target_size_min: int train_target_size_max: int eval_target_size_min: int eval_target_size_max: int def configure_hf_cache(cache_dir: str) -> str: cache_dir = os.path.abspath(os.path.expanduser(cache_dir)) hub_dir = os.path.join(cache_dir, "hub") transformers_dir = os.path.join(cache_dir, "transformers") os.makedirs(hub_dir, exist_ok=True) os.makedirs(transformers_dir, exist_ok=True) os.environ["HF_HOME"] = cache_dir os.environ["HF_HUB_CACHE"] = hub_dir os.environ["HUGGINGFACE_HUB_CACHE"] = hub_dir os.environ["TRANSFORMERS_CACHE"] = transformers_dir os.environ.setdefault("HF_HUB_DISABLE_XET", "1") return cache_dir def configure_wandb_dirs(output_dir: str) -> None: wandb_dir = os.path.join(output_dir, "wandb_runtime") os.makedirs(wandb_dir, exist_ok=True) os.environ.setdefault("WANDB_DIR", wandb_dir) os.environ.setdefault("WANDB_CACHE_DIR", wandb_dir) os.environ.setdefault("WANDB_CONFIG_DIR", wandb_dir) def pick_dtype() -> torch.dtype: if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 def load_jsonl_rows(path: str, limit_rows: int = 0) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue rows.append(json.loads(line)) if limit_rows > 0 and len(rows) >= limit_rows: break return rows def target_size_allowed(target_size: int, min_size: int, max_size: int) -> bool: if int(min_size) > 0 and int(target_size) < int(min_size): return False if int(max_size) > 0 and int(target_size) > int(max_size): return False return True def build_training_examples( rows: List[Dict[str, Any]], *, tokenizer: Any, stage_i: int, total_empties_hint: int, progress_every_rows: int = 10, progress_callback: Any = None, ) -> List[Dict[str, Any]]: examples: List[Dict[str, Any]] = [] eos_text = getattr(tokenizer, "eos_token", None) or "" for row_idx, row in enumerate(rows, start=1): solved = make_solved_grid_from_row(row) for ex in build_cell_examples_from_row(row): target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=stage_i) if not target_size_allowed( len(target_values), getattr(tokenizer, "_train_target_size_min", 0), getattr(tokenizer, "_train_target_size_max", 0), ): continue prompt = build_multi_output_cell_prompt( ex.grid, target_cell=ex.target_cell, stage_i=stage_i, tokenizer=tokenizer, turn_idx=ex.turn_idx, total_turns=ex.total_turns, prev_output_flag=None, total_empties_hint=total_empties_hint, ) target_text = build_supervised_completion(ex, stage_i=stage_i) if eos_text: target_text = target_text + eos_text repeat_count = max(1, int(getattr(tokenizer, "_multi_value_oversample_factor", 1))) if len(target_values) > 1 else 1 for _ in range(repeat_count): examples.append( { "prompt_text": prompt, "completion_text": target_text, "target_values": list(target_values), "grid": ex.grid, "solved": solved, "target_cell": ex.target_cell, } ) if progress_callback is not None and ( row_idx == 1 or row_idx == len(rows) or row_idx % max(1, int(progress_every_rows)) == 0 ): progress_callback(row_idx, len(rows), len(examples)) return examples def _prepared_data_dir(args: Args) -> str: path = os.path.join(PARENT_DIR, "_prepared_data", "multi_output_cell_policy") os.makedirs(path, exist_ok=True) return path def _prepared_sft_cache_path(args: Args) -> str: payload = json.dumps( { "completion_format_version": 2, "train_jsonl": os.path.abspath(args.train_jsonl), "stage_i": int(args.stage_i), "total_empties_hint": int(args.total_empties_hint), "limit_train_rows": int(args.limit_train_rows), "model_name": str(args.model_name), "multi_value_oversample_factor": int(args.multi_value_oversample_factor), "train_target_size_min": int(args.train_target_size_min), "train_target_size_max": int(args.train_target_size_max), }, sort_keys=True, ).encode("utf-8") digest = hashlib.sha1(payload).hexdigest()[:20] return os.path.join(_prepared_data_dir(args), f"sft_stage{int(args.stage_i):02d}_{digest}.jsonl") def _to_jsonable(value: Any) -> Any: if isinstance(value, dict): return {k: _to_jsonable(v) for k, v in value.items()} if isinstance(value, (list, tuple)): return [_to_jsonable(v) for v in value] if hasattr(value, "tolist"): return _to_jsonable(value.tolist()) return value def _write_jsonl(path: str, rows: List[Dict[str, Any]]) -> None: tmp_path = f"{path}.tmp" with open(tmp_path, "w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(_to_jsonable(row), separators=(",", ":")) + "\n") os.replace(tmp_path, path) def _read_jsonl(path: str) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: rows.append(json.loads(line)) return rows def _wait_for_cache(path: str, timeout_s: float = 7200.0) -> None: start = time.time() while not os.path.exists(path): if time.time() - start > timeout_s: raise TimeoutError(f"Timed out waiting for prepared cache: {path}") time.sleep(2.0) def load_or_build_sft_examples( args: Args, *, rows: List[Dict[str, Any]], tokenizer: Any, rank: int, world_size: int, progress_callback: Any = None, ) -> List[Dict[str, Any]]: cache_path = _prepared_sft_cache_path(args) if os.path.exists(cache_path): return _read_jsonl(cache_path) if rank == 0: print(f"[dataset build][sft stage {args.stage_i}] building prepared cache: {cache_path}", flush=True) examples = build_training_examples( rows, tokenizer=tokenizer, stage_i=args.stage_i, total_empties_hint=args.total_empties_hint, progress_every_rows=10, progress_callback=progress_callback, ) _write_jsonl(cache_path, examples) return examples _wait_for_cache(cache_path) return _read_jsonl(cache_path) @torch.no_grad() def run_eval(args: Args, rows: List[Dict[str, Any]], model: torch.nn.Module, tokenizer: Any, device: torch.device): model.eval() total_cells = 0 parse_ok = 0.0 canonical_ok = 0.0 exact_set_match = 0.0 includes_gt = 0.0 precision_sum = 0.0 recall_sum = 0.0 predicted_size_sum = 0.0 good_count_sum = 0.0 bad_count_sum = 0.0 solve_ok = 0 solve_rows = 0 printed = 0 for row in rows: solved = make_solved_grid_from_row(row) row_all_exact = True row_has_eval_cell = False row_debug_lines: List[str] = [] for ex in build_cell_examples_from_row(row): target_values = stage_i_consistent_values(ex.grid, target_cell=ex.target_cell, stage_i=args.stage_i) if not target_size_allowed(len(target_values), int(args.eval_target_size_min), int(args.eval_target_size_max)): continue row_has_eval_cell = True prompt = build_multi_output_cell_prompt( ex.grid, target_cell=ex.target_cell, stage_i=args.stage_i, tokenizer=tokenizer, turn_idx=ex.turn_idx, total_turns=ex.total_turns, prev_output_flag=None, total_empties_hint=args.total_empties_hint, ) enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) enc = {k: v.to(device) for k, v in enc.items()} out = model.generate( **enc, max_new_tokens=max(1, int(args.max_completion_length)), do_sample=False, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) pred_text = tokenizer.decode(out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True).strip() info = score_prediction_text( text=pred_text, grid=ex.grid, solved=solved, target_cell=ex.target_cell, stage_i=args.stage_i, reward_good_value=1.0, penalty_bad_value=1.75, penalty_malformed=4.0, penalty_empty=0.5, penalty_singleton=1.5, ) total_cells += 1 parse_ok += float(info["parse_ok"]) canonical_ok += float(info["strict_canonical"]) exact_set_match += float(info["exact_set_match"]) includes_gt += float(info["includes_ground_truth"]) precision_sum += float(info["value_precision"]) recall_sum += float(info["value_recall"]) predicted_size_sum += float(info["num_predicted_values"]) good_count_sum += float(info["num_i_consistent_values"]) bad_count_sum += float(info["num_non_i_consistent_values"]) if float(info["exact_set_match"]) < 0.5: row_all_exact = False if printed < int(args.debug_print_limit): row_debug_lines.append( f"[baseline sft eval debug] true_values={info['target_values']} " f"predicted_values={info['predicted_values']} output={pred_text!r}" ) if row_has_eval_cell: if printed < int(args.debug_print_limit) and row_debug_lines: print("[baseline sft eval debug] puzzle_outputs_begin", flush=True) for line in row_debug_lines: print(line, flush=True) print("[baseline sft eval debug] puzzle_outputs_end", flush=True) printed += 1 solve_ok += int(row_all_exact) solve_rows += 1 out = { "parse_rate": float(parse_ok / max(1, total_cells)), "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), "value_precision": float(precision_sum / max(1, total_cells)), "value_recall": float(recall_sum / max(1, total_cells)), "avg_predicted_set_size": float(predicted_size_sum / max(1, total_cells)), "avg_num_i_consistent_values": float(good_count_sum / max(1, total_cells)), "avg_num_non_i_consistent_values": float(bad_count_sum / max(1, total_cells)), "solve_rate": float(solve_ok / max(1, solve_rows)), } print( f"[baseline sft eval] parse={out['parse_rate']:.3f} canonical={out['strict_canonical_rate']:.3f} " f"exact={out['exact_set_match_rate']:.3f} precision={out['value_precision']:.3f} " f"recall={out['value_recall']:.3f} solve={out['solve_rate']:.3f}", flush=True, ) model.train() return out def parse_args() -> Args: p = argparse.ArgumentParser() p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-7B-Instruct") p.add_argument("--train_jsonl", type=str, required=True) p.add_argument("--eval_jsonl", type=str, default="") p.add_argument("--output_dir", type=str, required=True) p.add_argument("--cache_dir", type=str, default="/home/ubuntu/curriculum-CoT/.hf_cache") p.add_argument("--init_adapter_dir", type=str, default="") p.add_argument("--seed", type=int, default=0) p.add_argument("--gpu_id", type=int, default=0) p.add_argument("--stage_i", type=int, default=1) p.add_argument("--total_empties_hint", type=int, default=10) p.add_argument("--per_device_train_batch_size", type=int, default=1) p.add_argument("--gradient_accumulation_steps", type=int, default=8) p.add_argument("--num_epochs", type=float, default=1.0) p.add_argument("--learning_rate", type=float, default=2e-4) p.add_argument("--weight_decay", type=float, default=0.0) p.add_argument( "--max_grad_norm", type=float, default=1.0, help="Clip global grad norm before each optimizer step (0 disables).", ) p.add_argument("--enable_gradient_checkpointing", action="store_true") p.add_argument("--logging_steps", type=int, default=10) p.add_argument("--save_steps", type=int, default=100) p.add_argument("--eval_steps", type=int, default=100) p.add_argument("--eval_rows", type=int, default=20) p.add_argument("--max_completion_length", type=int, default=24) p.add_argument("--lora_r", type=int, default=16) p.add_argument("--lora_alpha", type=int, default=32) p.add_argument("--lora_dropout", type=float, default=0.05) p.add_argument("--use_wandb", action="store_true") p.add_argument("--wandb_entity", type=str, default="") p.add_argument("--wandb_project", type=str, default="sudoku-multi-output-sft") p.add_argument("--wandb_run_name", type=str, default="") p.add_argument("--wandb_mode", type=str, default="online") p.add_argument("--debug_print_limit", type=int, default=3) p.add_argument("--limit_train_rows", type=int, default=0) p.add_argument("--eval_exact_set_match_stop", type=float, default=0.0) p.add_argument("--eval_value_precision_stop", type=float, default=0.0) p.add_argument("--eval_value_recall_stop", type=float, default=0.0) p.add_argument("--eval_solve_rate_stop", type=float, default=0.0) p.add_argument("--min_steps_before_stop", type=int, default=0) p.add_argument("--max_wall_clock_seconds", type=int, default=0) p.add_argument("--max_steps", type=int, default=0) p.add_argument("--multi_value_oversample_factor", type=int, default=1) p.add_argument("--train_target_size_min", type=int, default=0) p.add_argument("--train_target_size_max", type=int, default=0) p.add_argument("--eval_target_size_min", type=int, default=0) p.add_argument("--eval_target_size_max", type=int, default=0) return Args(**vars(p.parse_args())) def save_checkpoint(model: torch.nn.Module, tokenizer: Any, output_dir: str, step: int) -> None: save_checkpoint_and_update_final(model, tokenizer, output_dir, f"checkpoint-step-{step:05d}") def main() -> None: args = parse_args() preset_visible_devices = str(os.environ.get("CUDA_VISIBLE_DEVICES", "")).strip() rank = int(os.environ.get("RANK", "0")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) is_distributed = world_size > 1 is_main_process = rank == 0 if preset_visible_devices: if is_main_process: print(f"Respecting pre-set CUDA_VISIBLE_DEVICES={preset_visible_devices}", flush=True) elif int(args.gpu_id) >= 0: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_id)) if is_distributed: if torch.cuda.is_available(): torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl", timeout=timedelta(hours=2)) set_seed(args.seed + rank) os.makedirs(args.output_dir, exist_ok=True) ensure_final_checkpoint_dir(args.output_dir) cache_dir = configure_hf_cache(args.cache_dir) configure_wandb_dirs(args.output_dir) wb_run = None if is_main_process and args.use_wandb and wandb is not None: init_kwargs = { "project": args.wandb_project, "name": args.wandb_run_name or None, "mode": args.wandb_mode, } if str(args.wandb_entity).strip(): init_kwargs["entity"] = args.wandb_entity wb_run = wandb.init(**init_kwargs) print(f"W&B run id: {wb_run.id}", flush=True) print(f"W&B run URL: {wb_run.url}", flush=True) wandb.log({"prep/rows_done": 0.0, "prep/examples_built": 0.0, "prep/cache_hit": 0.0}) rows = load_jsonl_rows(args.train_jsonl, limit_rows=args.limit_train_rows) eval_source = args.eval_jsonl if str(args.eval_jsonl).strip() else args.train_jsonl eval_rows = load_jsonl_rows(eval_source, limit_rows=max(1, int(args.eval_rows))) tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=cache_dir, use_fast=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" tokenizer._multi_value_oversample_factor = max(1, int(args.multi_value_oversample_factor)) tokenizer._train_target_size_min = max(0, int(args.train_target_size_min)) tokenizer._train_target_size_max = max(0, int(args.train_target_size_max)) if torch.cuda.is_available(): device = torch.device(f"cuda:{local_rank}" if is_distributed else f"cuda:{max(0, int(args.gpu_id))}") else: device = torch.device("cpu") model = AutoModelForCausalLM.from_pretrained( args.model_name, cache_dir=cache_dir, torch_dtype=pick_dtype(), low_cpu_mem_usage=True ) if str(args.init_adapter_dir).strip(): model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True) else: lora = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, lora) if args.enable_gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) if args.enable_gradient_checkpointing and hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() if hasattr(model, "config"): model.config.use_cache = False model.to(device) model.train() if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True def on_prep_progress(rows_done: int, total_rows: int, examples_built: int) -> None: if is_main_process: print( f"[dataset build][sft stage {args.stage_i}] rows={rows_done}/{total_rows} examples={examples_built}", flush=True, ) if wb_run is not None: wandb.log({"prep/rows_done": float(rows_done), "prep/examples_built": float(examples_built)}) train_examples = load_or_build_sft_examples( args, rows=rows, tokenizer=tokenizer, rank=rank, world_size=world_size, progress_callback=on_prep_progress, ) if is_main_process and wb_run is not None: wandb.log( { "prep/cache_hit": float(os.path.exists(_prepared_sft_cache_path(args))), "prep/examples_final": float(len(train_examples)), } ) optimizer = AdamW((p for p in model.parameters() if p.requires_grad), lr=args.learning_rate, weight_decay=args.weight_decay) denom = max(1, int(args.gradient_accumulation_steps)) * max(1, int(args.per_device_train_batch_size)) total_steps = max(1, math.ceil(len(train_examples) * args.num_epochs / denom)) if int(args.max_steps) > 0: total_steps = min(total_steps, int(args.max_steps)) step = 0 start_time = time.time() def average_scalar(value: float) -> float: if not is_distributed or not dist.is_initialized(): return float(value) tensor = torch.tensor(float(value), device=device, dtype=torch.float32) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) return float((tensor / float(world_size)).item()) def all_reduce_gradients() -> None: if not is_distributed or not dist.is_initialized(): return for param in model.parameters(): if param.grad is None: continue dist.all_reduce(param.grad, op=dist.ReduceOp.SUM) param.grad.div_(float(world_size)) def sync_stop(local_stop: bool) -> bool: if not is_distributed or not dist.is_initialized(): return bool(local_stop) tensor = torch.tensor(1 if local_stop else 0, device=device, dtype=torch.int64) dist.all_reduce(tensor, op=dist.ReduceOp.MAX) return bool(int(tensor.item()) > 0) for epoch_idx in range(max(1, int(math.ceil(args.num_epochs)))): if is_distributed: sampler = DistributedSampler( train_examples, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed, drop_last=False, ) sampler.set_epoch(epoch_idx) order = list(iter(sampler)) else: generator = torch.Generator() generator.manual_seed(args.seed + epoch_idx) order = torch.randperm(len(train_examples), generator=generator).tolist() optimizer.zero_grad(set_to_none=True) accum_count = 0 accum_ce_sum = 0.0 microbatch_size = max(1, int(args.per_device_train_batch_size)) for batch_start in range(0, len(order), microbatch_size): batch_indices = order[batch_start : batch_start + microbatch_size] batch_examples = [train_examples[ex_idx] for ex_idx in batch_indices] ce_full = batched_completion_ce_loss( model, tokenizer, [str(ex["prompt_text"]) for ex in batch_examples], [str(ex["completion_text"]) for ex in batch_examples], device, ) loss = ce_full / max(1, int(args.gradient_accumulation_steps)) loss.backward() accum_ce_sum += float(ce_full.detach().item()) accum_count += 1 if accum_count >= int(args.gradient_accumulation_steps): all_reduce_gradients() if float(args.max_grad_norm) > 0.0: torch.nn.utils.clip_grad_norm_(model.parameters(), float(args.max_grad_norm)) optimizer.step() optimizer.zero_grad(set_to_none=True) accum_count = 0 step += 1 mean_ce = accum_ce_sum / max(1, int(args.gradient_accumulation_steps)) accum_ce_sum = 0.0 if step % int(args.logging_steps) == 0: loss_value = average_scalar(mean_ce) if is_main_process: print(f"[baseline sft train step {step:05d}] loss={loss_value:.4f}", flush=True) if wb_run is not None: wandb.log({"train/loss": loss_value, "step": step}) if step % int(args.eval_steps) == 0: if is_distributed and dist.is_initialized(): dist.barrier() should_stop_eval = False if is_main_process: ev = run_eval(args, eval_rows, model, tokenizer, device) if wb_run is not None: wandb.log({f"eval/{k}": float(v) for k, v in ev.items()} | {"step": step}) if ( args.eval_exact_set_match_stop > 0.0 and float(ev["exact_set_match_rate"]) >= args.eval_exact_set_match_stop ): save_checkpoint(model, tokenizer, args.output_dir, step) should_stop_eval = True if ( not should_stop_eval and step >= int(args.min_steps_before_stop) and args.eval_value_precision_stop > 0.0 and args.eval_value_recall_stop > 0.0 and float(ev["value_precision"]) >= args.eval_value_precision_stop and float(ev["value_recall"]) >= args.eval_value_recall_stop ): save_checkpoint(model, tokenizer, args.output_dir, step) should_stop_eval = True if ( not should_stop_eval and args.eval_solve_rate_stop > 0.0 and step >= int(args.min_steps_before_stop) and float(ev["solve_rate"]) >= args.eval_solve_rate_stop ): save_checkpoint(model, tokenizer, args.output_dir, step) should_stop_eval = True should_stop_eval = sync_stop(should_stop_eval) if is_distributed and dist.is_initialized(): dist.barrier() if should_stop_eval: if is_main_process and wb_run is not None: wb_run.finish() if is_distributed and dist.is_initialized(): dist.destroy_process_group() return if step % int(args.save_steps) == 0: if is_distributed and dist.is_initialized(): dist.barrier() if is_main_process: save_checkpoint(model, tokenizer, args.output_dir, step) if is_distributed and dist.is_initialized(): dist.barrier() reached_limit = step >= total_steps exceeded_wall = bool(args.max_wall_clock_seconds) and ( time.time() - start_time >= float(args.max_wall_clock_seconds) ) should_stop = sync_stop(reached_limit or exceeded_wall) if should_stop: break if sync_stop(step >= total_steps): break if is_distributed and dist.is_initialized(): dist.barrier() if is_main_process: save_checkpoint(model, tokenizer, args.output_dir, step) if wb_run is not None: wb_run.finish() if is_distributed and dist.is_initialized(): dist.barrier() dist.destroy_process_group() if __name__ == "__main__": main()