| """ |
| ORPO (Odds Ratio Preference Optimization) training script. |
| Uses TRL 0.29.0 ORPOTrainer/ORPOConfig (trl.experimental.orpo). |
| Optimized for 8x NVIDIA B200 GPUs (183GB VRAM each, ~1.47TB total). |
| |
| Usage: |
| # Full training (8 GPU DDP) |
| torchrun --nproc_per_node=8 train/orpo.py \ |
| --config configs/korean_3b_orpo.yaml |
| |
| # Quick test (200 steps) |
| python train/orpo.py --config configs/korean_3b_orpo.yaml --max_steps 200 |
| |
| # Single GPU test |
| python train/orpo.py --config configs/korean_3b_orpo.yaml --device cuda:0 |
| |
| Prerequisites: |
| pip install trl==0.29.0 transformers accelerate peft datasets |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import datetime |
| import json |
| import logging |
| import os |
| import signal as _signal_mod |
| import sys |
| import time |
| import traceback |
| from pathlib import Path |
|
|
| import torch |
| from datasets import Dataset, load_dataset |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| EarlyStoppingCallback, |
| TrainerCallback, |
| ) |
|
|
| |
| try: |
| from trl.experimental.orpo import ORPOConfig, ORPOTrainer |
| except ImportError: |
| print("ERROR: trl not installed or outdated. Run: pip install trl==0.29.0") |
| sys.exit(1) |
|
|
| |
| try: |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| from scripts.telegram_notify import send_telegram_safe |
| HAS_TELEGRAM = True |
| except ImportError: |
| HAS_TELEGRAM = False |
| def send_telegram_safe(msg, **kw): return False |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| log = logging.getLogger("orpo") |
|
|
|
|
| |
| |
| |
| class ORPOMonitorCallback(TrainerCallback): |
| """Monitors ORPO-specific metrics and sends alerts on anomalies.""" |
|
|
| def __init__(self, alert_fn=send_telegram_safe): |
| self.alert_fn = alert_fn |
| self.start_time = None |
| self.last_eval_loss = None |
| self.eval_loss_increases = 0 |
| self.negative_margin_streak = 0 |
|
|
| def on_train_begin(self, args, state, control, **kwargs): |
| self.start_time = time.time() |
| log.info("ORPO training begin -- monitoring active") |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if logs is None: |
| return |
| step = state.global_step |
|
|
| |
| margin = logs.get("rewards/margins") |
| if margin is not None: |
| if margin < 0: |
| self.negative_margin_streak += 1 |
| if self.negative_margin_streak >= 10: |
| msg = (f"[ORPO ALERT] rewards/margins negative for " |
| f"{self.negative_margin_streak} consecutive logs at step {step} " |
| f"(margin={margin:.4f})") |
| log.warning(msg) |
| self.alert_fn(msg) |
| else: |
| self.negative_margin_streak = 0 |
|
|
| |
| loss = logs.get("loss") |
| chosen = logs.get("rewards/chosen") |
| rejected = logs.get("rewards/rejected") |
| if loss is not None: |
| elapsed = time.time() - self.start_time if self.start_time else 0 |
| log.info( |
| f"step={step} loss={loss:.4f} " |
| f"margin={margin if margin is not None else 'N/A'} " |
| f"chosen={chosen if chosen is not None else 'N/A'} " |
| f"rejected={rejected if rejected is not None else 'N/A'} " |
| f"elapsed={elapsed/3600:.1f}h" |
| ) |
|
|
| |
| if loss is not None and (not isinstance(loss, (int, float)) or loss != loss): |
| msg = f"[ORPO CRITICAL] NaN/Inf loss detected at step {step}!" |
| log.error(msg) |
| self.alert_fn(msg) |
|
|
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
| if metrics is None: |
| return |
| eval_loss = metrics.get("eval_loss") |
| step = state.global_step |
|
|
| if eval_loss is not None: |
| log.info(f"[EVAL] step={step} eval_loss={eval_loss:.4f}") |
| if self.last_eval_loss is not None and eval_loss > self.last_eval_loss: |
| self.eval_loss_increases += 1 |
| log.warning( |
| f"[EVAL] eval_loss increased: {self.last_eval_loss:.4f} -> {eval_loss:.4f} " |
| f"({self.eval_loss_increases}/3 before early stop)" |
| ) |
| else: |
| self.eval_loss_increases = 0 |
| self.last_eval_loss = eval_loss |
|
|
| def on_train_end(self, args, state, control, **kwargs): |
| elapsed = time.time() - self.start_time if self.start_time else 0 |
| log.info(f"ORPO training ended -- total time: {elapsed/3600:.2f}h, " |
| f"total steps: {state.global_step}") |
|
|
| def on_save(self, args, state, control, **kwargs): |
| log.info(f"Checkpoint saved at step {state.global_step}") |
|
|
|
|
| class VRAMMonitorCallback(TrainerCallback): |
| """Measures peak VRAM usage across all GPUs during training.""" |
|
|
| def on_train_begin(self, args, state, control, **kwargs): |
| if torch.cuda.is_available(): |
| for i in range(torch.cuda.device_count()): |
| torch.cuda.reset_peak_memory_stats(i) |
| log.info("[VRAM] Peak memory stats reset for all GPUs") |
|
|
| def on_train_end(self, args, state, control, **kwargs): |
| if torch.cuda.is_available(): |
| for i in range(torch.cuda.device_count()): |
| peak_mb = torch.cuda.max_memory_allocated(i) / (1024**2) |
| log.info(f"[VRAM] GPU {i} peak: {peak_mb:.0f} MiB") |
|
|
|
|
| def load_hf_preference_dataset(dataset_name: str, token: str | None = None) -> Dataset: |
| """Load and normalize a HuggingFace preference dataset to {prompt, chosen, rejected}.""" |
| ds = load_dataset(dataset_name, split="train", token=token) |
|
|
| |
| if "question" in ds.column_names and "chosen" in ds.column_names: |
| def normalize(example): |
| prompt = example.get("system", "") + "\n" + example["question"] |
| return {"prompt": prompt.strip(), "chosen": example["chosen"], "rejected": example["rejected"]} |
| return ds.map(normalize, remove_columns=ds.column_names) |
|
|
| |
| if "orig_preference" in ds.column_names: |
| def normalize_pref(example): |
| prompt = example.get("orig_instruction", example.get("instruction", "")) |
| if example["orig_preference"] == "B": |
| return {"prompt": prompt, "chosen": example["orig_response_B"], "rejected": example["orig_response_A"]} |
| else: |
| return {"prompt": prompt, "chosen": example["orig_response_A"], "rejected": example["orig_response_B"]} |
| return ds.map(normalize_pref, remove_columns=ds.column_names) |
|
|
| |
| if all(c in ds.column_names for c in ["prompt", "chosen", "rejected"]): |
| return ds |
|
|
| raise ValueError(f"Unknown dataset format. Columns: {ds.column_names}") |
|
|
|
|
| def load_custom_jsonl(path: str) -> Dataset: |
| """Load custom JSONL with {prompt, chosen, rejected} fields.""" |
| data = [] |
| with open(path) as f: |
| for line in f: |
| data.append(json.loads(line)) |
| return Dataset.from_list(data) |
|
|
|
|
| def load_yaml_config(path: str) -> dict: |
| """Load YAML config and return as dict.""" |
| import yaml |
| with open(path) as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="ORPO Training (TRL 0.29.0 -- 8xB200 optimized)") |
| parser.add_argument("--config", type=str, default=None, help="YAML config file path") |
| parser.add_argument("--model_path", type=str, default=None, help="HF format model path") |
| parser.add_argument("--dataset", type=str, default="kuotient/orca-math-korean-dpo-pairs") |
| parser.add_argument("--custom_data_path", type=str, default=None, help="Custom JSONL preference data") |
| parser.add_argument("--output_dir", type=str, default="checkpoints/korean_3b_orpo") |
| parser.add_argument("--hf_token", type=str, default=None) |
| parser.add_argument("--epochs", type=int, default=3) |
| parser.add_argument("--lr", type=float, default=5e-6) |
| parser.add_argument("--beta", type=float, default=0.1, help="ORPO beta (odds ratio weight)") |
| parser.add_argument("--batch_size", type=int, default=4) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--max_length", type=int, default=1536) |
| parser.add_argument("--bf16", action="store_true", default=True) |
| parser.add_argument("--weight_decay", type=float, default=0.01) |
| parser.add_argument("--eval_split_ratio", type=float, default=0.05, help="Fraction of data for eval") |
| parser.add_argument("--eval_steps", type=int, default=500) |
| parser.add_argument("--early_stopping_patience", type=int, default=3) |
| parser.add_argument("--max_steps", type=int, default=-1, help="Override max steps (for quick test)") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--save_total_limit", type=int, default=5) |
| parser.add_argument("--warmup_ratio", type=float, default=0.05) |
| parser.add_argument("--lr_scheduler_type", type=str, default="cosine") |
| parser.add_argument("--logging_steps", type=int, default=10) |
| parser.add_argument("--save_steps", type=int, default=500) |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=True) |
| parser.add_argument("--report_to", type=str, default="none") |
| parser.add_argument("--dataset_num_proc", type=int, default=8, |
| help="Number of processes for parallel tokenization in ORPOTrainer") |
| parser.add_argument("--dataloader_num_workers", type=int, default=4, |
| help="Number of dataloader worker processes") |
| parser.add_argument("--no_load_best", action="store_true", default=False, |
| help="Disable load_best_model_at_end (for sweep/quick tests)") |
| parser.add_argument("--max_samples", type=int, default=0, |
| help="Limit dataset size (0=use all, >0=subset for benchmarking)") |
| parser.add_argument("--skip_filter", action="store_true", default=False, |
| help="Skip NaN-prevention filter (for benchmarking only)") |
| args = parser.parse_args() |
|
|
| |
| if args.config: |
| cfg = load_yaml_config(args.config) |
| for key, value in cfg.items(): |
| if hasattr(args, key): |
| setattr(args, key, value) |
|
|
| if not args.model_path: |
| parser.error("--model_path is required (or set model_path in YAML config)") |
|
|
| |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
| is_main = local_rank == 0 |
| if is_main: |
| log.info("=" * 70) |
| log.info("ORPO Training Configuration (8xB200 optimized)") |
| log.info("=" * 70) |
| for k, v in sorted(vars(args).items()): |
| log.info(f" {k}: {v}") |
| log.info("=" * 70) |
|
|
| |
| if torch.cuda.is_available(): |
| for i in range(torch.cuda.device_count()): |
| mem = torch.cuda.get_device_properties(i).total_memory / 1e9 |
| log.info(f" GPU {i}: {torch.cuda.get_device_name(i)} ({mem:.1f} GB)") |
|
|
| |
| if not Path(args.model_path).exists(): |
| raise FileNotFoundError(f"Model path not found: {args.model_path}") |
| if args.custom_data_path and not Path(args.custom_data_path).exists(): |
| raise FileNotFoundError(f"Data path not found: {args.custom_data_path}") |
|
|
| |
| if is_main: |
| log.info("--- DDP/NCCL Environment ---") |
| for env_key in ["RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_ADDR", "MASTER_PORT", |
| "NCCL_IB_DISABLE", "NCCL_BUFFSIZE", "NCCL_P2P_LEVEL", |
| "OMP_NUM_THREADS", "PYTORCH_CUDA_ALLOC_CONF"]: |
| log.info(f" {env_key}={os.environ.get(env_key, '(not set)')}") |
| log.info(f" torch.distributed.is_available={torch.distributed.is_available()}") |
| if torch.distributed.is_initialized(): |
| log.info(f" world_size={torch.distributed.get_world_size()}, " |
| f"rank={torch.distributed.get_rank()}") |
|
|
| |
| log.info(f"Loading model from {args.model_path}...") |
| t0 = time.time() |
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_path, |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| ) |
| except Exception as e: |
| log.error(f"Model loading failed: {e}") |
| send_telegram_safe(f"[ORPO FATAL] Model load failed: {e}") |
| raise |
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| if is_main: |
| n_params = sum(p.numel() for p in model.parameters()) |
| log.info(f"Model loaded: {n_params:,} params in {time.time()-t0:.1f}s") |
| log.info(f"Tokenizer: vocab_size={tokenizer.vocab_size}, " |
| f"pad_token='{tokenizer.pad_token}', eos_token='{tokenizer.eos_token}'") |
|
|
| |
| t0 = time.time() |
| try: |
| if args.custom_data_path: |
| log.info(f"Loading custom data from {args.custom_data_path}...") |
| fsize_mb = Path(args.custom_data_path).stat().st_size / 1e6 |
| log.info(f" File size: {fsize_mb:.1f} MB") |
| dataset = load_custom_jsonl(args.custom_data_path) |
| else: |
| log.info(f"Loading dataset {args.dataset}...") |
| dataset = load_hf_preference_dataset(args.dataset, token=args.hf_token) |
| except Exception as e: |
| log.error(f"Dataset loading failed: {e}") |
| send_telegram_safe(f"[ORPO FATAL] Data load failed: {e}") |
| raise |
|
|
| |
| if args.max_samples > 0 and len(dataset) > args.max_samples: |
| dataset = dataset.select(range(args.max_samples)) |
| if is_main: |
| log.info(f"[BENCH] Dataset subset: {args.max_samples:,} samples") |
|
|
| if is_main: |
| log.info(f"Dataset loaded: {len(dataset)} pairs in {time.time()-t0:.1f}s") |
| |
| sample = dataset[0] |
| log.info(f"Sample keys: {list(sample.keys())}") |
| for key in ["prompt", "chosen", "rejected"]: |
| if key not in sample: |
| raise ValueError(f"Dataset missing required column: {key}") |
| val = sample[key] |
| log.info(f" {key}: {str(val)[:100]}...") |
|
|
| |
| sample_size = min(1000, len(dataset)) |
| lengths = [len(str(dataset[i]["prompt"])) + max(len(str(dataset[i]["chosen"])), |
| len(str(dataset[i]["rejected"]))) for i in range(sample_size)] |
| avg_len = sum(lengths) / len(lengths) |
| max_len = max(lengths) |
| log.info(f" Char lengths (sample {sample_size}): avg={avg_len:.0f}, max={max_len}") |
|
|
| |
| |
| |
| |
| |
| |
| if args.skip_filter: |
| if is_main: |
| log.info("[BENCH] Skipping NaN-prevention filter (--skip_filter)") |
| else: |
| pre_filter = len(dataset) |
| def _has_response_room(example): |
| prompt_tok_len = len(tokenizer.encode(example["prompt"], add_special_tokens=False)) |
| chosen_tok_len = len(tokenizer.encode(example["chosen"], add_special_tokens=False)) |
| rejected_tok_len = len(tokenizer.encode(example["rejected"], add_special_tokens=False)) |
|
|
| |
| if prompt_tok_len + 16 > args.max_length: |
| return False |
|
|
| |
| |
| if prompt_tok_len + chosen_tok_len + 2 > args.max_length * 2: |
| return False |
| if prompt_tok_len + rejected_tok_len + 2 > args.max_length * 2: |
| return False |
|
|
| |
| |
| longer = max(chosen_tok_len, rejected_tok_len) |
| if longer >= args.max_length: |
| return False |
|
|
| return True |
| dataset = dataset.filter(_has_response_room, num_proc=min(args.dataset_num_proc, 32) if is_main else 1) |
| if is_main: |
| log.info(f"Filtered: {pre_filter:,} -> {len(dataset):,} " |
| f"(removed {pre_filter - len(dataset):,} samples with prompt > max_length-16 or TRL truncation risk)") |
|
|
| |
| split = dataset.train_test_split(test_size=args.eval_split_ratio, seed=args.seed) |
| train_dataset = split["train"] |
| eval_dataset = split["test"] |
| log.info(f"Train: {len(train_dataset):,}, Eval: {len(eval_dataset):,}") |
|
|
| |
| n_gpus = max(torch.cuda.device_count(), 1) if torch.cuda.is_available() else 1 |
| eff_batch = args.batch_size * args.gradient_accumulation_steps * n_gpus |
| steps_per_epoch = len(train_dataset) // eff_batch |
| total_steps = args.max_steps if args.max_steps > 0 else steps_per_epoch * args.epochs |
| computed_warmup_steps = int(total_steps * args.warmup_ratio) |
| if is_main: |
| log.info(f"Training plan: eff_batch={eff_batch}, steps/epoch={steps_per_epoch:,}, " |
| f"total={total_steps:,}, warmup={computed_warmup_steps}") |
|
|
| |
| |
| |
| |
| |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| if world_size > 1: |
| effective_num_proc = args.dataset_num_proc if is_main else 1 |
| if is_main: |
| log.info(f"DDP tokenization: rank 0 uses num_proc={args.dataset_num_proc}, " |
| f"other {world_size-1} ranks use num_proc=1 (cache)") |
| else: |
| effective_num_proc = args.dataset_num_proc |
|
|
| |
| orpo_config = ORPOConfig( |
| output_dir=args.output_dir, |
| num_train_epochs=args.epochs, |
| per_device_train_batch_size=args.batch_size, |
| per_device_eval_batch_size=args.batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| learning_rate=args.lr, |
| beta=args.beta, |
| lr_scheduler_type=args.lr_scheduler_type, |
| warmup_steps=computed_warmup_steps, |
| weight_decay=args.weight_decay, |
| bf16=args.bf16, |
| logging_steps=args.logging_steps, |
| save_steps=args.save_steps, |
| save_total_limit=args.save_total_limit, |
| max_length=args.max_length, |
| gradient_checkpointing=args.gradient_checkpointing, |
| report_to=args.report_to, |
| remove_unused_columns=False, |
| eval_strategy="steps", |
| eval_steps=args.eval_steps, |
| metric_for_best_model="eval_loss" if not args.no_load_best else None, |
| load_best_model_at_end=not args.no_load_best, |
| greater_is_better=False if not args.no_load_best else None, |
| max_steps=args.max_steps, |
| seed=args.seed, |
| |
| dataloader_num_workers=args.dataloader_num_workers, |
| dataloader_pin_memory=True, |
| ddp_find_unused_parameters=False, |
| ddp_timeout=7200, |
| dataset_num_proc=effective_num_proc, |
| ) |
|
|
| |
| log.info("Initializing ORPOTrainer (tokenization will happen here — may take a while)...") |
| t0 = time.time() |
| monitor = ORPOMonitorCallback() |
| try: |
| trainer = ORPOTrainer( |
| model=model, |
| args=orpo_config, |
| train_dataset=train_dataset, |
| eval_dataset=eval_dataset, |
| processing_class=tokenizer, |
| callbacks=[cb for cb in [ |
| EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience) |
| if not args.no_load_best else None, |
| monitor, |
| VRAMMonitorCallback(), |
| ] if cb is not None], |
| ) |
| except Exception as e: |
| log.error(f"ORPOTrainer init failed: {e}\n{traceback.format_exc()}") |
| send_telegram_safe(f"[ORPO FATAL] Trainer init failed: {e}") |
| raise |
| if is_main: |
| log.info(f"ORPOTrainer initialized in {time.time()-t0:.1f}s " |
| f"(dataset_num_proc={args.dataset_num_proc})") |
|
|
| |
| def _graceful_shutdown_handler(signum, frame): |
| sig_name = _signal_mod.Signals(signum).name |
| log.warning(f"Received {sig_name}. Saving emergency checkpoint...") |
| try: |
| emergency_path = os.path.join(args.output_dir, "emergency_checkpoint") |
| trainer.save_model(emergency_path) |
| log.info(f"Emergency checkpoint saved to {emergency_path}") |
| send_telegram_safe( |
| f"[ORPO] Signal {sig_name} received at step {trainer.state.global_step}. " |
| f"Emergency checkpoint saved." |
| ) |
| except Exception as e: |
| log.error(f"Emergency save failed: {e}") |
| send_telegram_safe(f"[ORPO] Emergency save FAILED after {sig_name}: {e}") |
| sys.exit(1) |
|
|
| for _sig in (_signal_mod.SIGHUP, _signal_mod.SIGTERM): |
| _signal_mod.signal(_sig, _graceful_shutdown_handler) |
|
|
| |
| if is_main and torch.cuda.is_available(): |
| torch.cuda.reset_peak_memory_stats() |
| alloc = torch.cuda.memory_allocated() / 1e9 |
| reserved = torch.cuda.memory_reserved() / 1e9 |
| log.info(f"Pre-train VRAM: allocated={alloc:.1f}GB, reserved={reserved:.1f}GB") |
|
|
| start_msg = ( |
| f"[ORPO] Training started\n" |
| f" model: {args.model_path}\n" |
| f" beta: {args.beta}, lr: {args.lr}\n" |
| f" train: {len(train_dataset):,}, eval: {len(eval_dataset):,}\n" |
| f" eff_batch: {eff_batch}, steps/epoch: {steps_per_epoch:,}, total: {total_steps:,}\n" |
| f" warmup: {computed_warmup_steps} steps ({args.warmup_ratio*100:.0f}%)\n" |
| f" max_length: {args.max_length}, max_steps: {args.max_steps}\n" |
| f" dataset_num_proc: {args.dataset_num_proc}, dl_workers: {args.dataloader_num_workers}" |
| ) |
| log.info(start_msg.replace("[ORPO] ", "")) |
| send_telegram_safe(start_msg) |
|
|
| try: |
| trainer.train() |
|
|
| |
| if is_main and torch.cuda.is_available(): |
| peak = torch.cuda.max_memory_allocated() / 1e9 |
| log.info(f"Peak VRAM usage: {peak:.1f}GB") |
|
|
| trainer.save_model(args.output_dir) |
| log.info(f"Model saved to {args.output_dir}") |
|
|
| |
| final_metrics = {} |
| for entry in reversed(trainer.state.log_history): |
| if "loss" in entry and "loss" not in final_metrics: |
| final_metrics["loss"] = entry["loss"] |
| if "eval_loss" in entry and "eval_loss" not in final_metrics: |
| final_metrics["eval_loss"] = entry["eval_loss"] |
| if "rewards/margins" in entry and "rewards/margins" not in final_metrics: |
| final_metrics["rewards/margins"] = entry["rewards/margins"] |
| if len(final_metrics) >= 3: |
| break |
|
|
| done_msg = ( |
| f"[ORPO] Training complete!\n" |
| f" output: {args.output_dir}\n" |
| f" steps: {trainer.state.global_step}\n" |
| f" final loss: {final_metrics.get('loss', 'N/A')}\n" |
| f" final eval_loss: {final_metrics.get('eval_loss', 'N/A')}\n" |
| f" final margins: {final_metrics.get('rewards/margins', 'N/A')}" |
| ) |
| log.info(done_msg.replace("[ORPO] ", "")) |
| send_telegram_safe(done_msg) |
|
|
| except KeyboardInterrupt: |
| log.warning("Training interrupted by user (KeyboardInterrupt)") |
| send_telegram_safe(f"[ORPO] Training interrupted at step {trainer.state.global_step}") |
| trainer.save_model(os.path.join(args.output_dir, "interrupted_checkpoint")) |
| log.info("Interrupted checkpoint saved.") |
|
|
| except Exception as e: |
| tb = traceback.format_exc() |
| error_msg = f"[ORPO] Training FAILED at step {trainer.state.global_step}: {e}" |
| log.error(f"{error_msg}\n{tb}") |
| send_telegram_safe(f"{error_msg}\n{tb[:500]}") |
| |
| try: |
| trainer.save_model(os.path.join(args.output_dir, "error_checkpoint")) |
| log.info("Error checkpoint saved.") |
| except Exception: |
| log.error("Error checkpoint save also failed.") |
| raise |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|