import argparse import json import inspect import math import time from pathlib import Path from typing import Any, Dict, Optional, Tuple, List import torch import yaml from datasets import load_dataset, DatasetDict from huggingface_hub import snapshot_download from transformers import ( AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, TrainerCallback, EarlyStoppingCallback, set_seed, ) from transformers.trainer_utils import get_last_checkpoint from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, ) from trl import DPOTrainer, DPOConfig try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False wandb = None # -------------------------- # Helpers # -------------------------- def _dtype_from_str(s: str) -> torch.dtype: s = (s or "").lower() if s in ("float16", "fp16"): return torch.float16 if s in ("bfloat16", "bf16"): return torch.bfloat16 if s in ("float32", "fp32"): return torch.float32 raise ValueError(f"Unknown torch_dtype: {s}") def _now_iso() -> str: return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) def _safe_exp(x: float) -> float: x = min(float(x), 50.0) return float(math.exp(x)) def _ensure_dir(p: Path) -> Path: p.mkdir(parents=True, exist_ok=True) return p def _looks_like_model_dir(p: Path) -> bool: if not p.exists() or not p.is_dir(): return False if (p / "config.json").exists(): return True if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")): return True return False def _infer_target_modules(model) -> List[str]: names = set() for n, _ in model.named_modules(): names.add(n.split(".")[-1]) for group in [ ["q_proj", "k_proj", "v_proj", "o_proj"], ["Wqkv", "out_proj"], ["query_key_value", "dense"], ["c_attn", "c_proj"], ]: if all(x in names for x in group): return group fallback = [ x for x in [ "q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense", ] if x in names ] if fallback: return fallback raise ValueError( "Could not auto-infer target_modules. Set peft.target_modules explicitly." ) def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]: return cfg.get("model", {}).get("attn_implementation", None) # -------------------------- # Wandb Integration # -------------------------- def setup_wandb(cfg: Dict[str, Any], run_dir: Path): """Initialize Wandb if enabled in configuration.""" wandb_cfg = cfg.get("wandb", {}) if not wandb_cfg.get("enabled", False): print("Wandb logging disabled") return None if not WANDB_AVAILABLE: print("Wandb not available. Install with: pip install wandb") return None project = wandb_cfg.get("project", "dpo-training") entity = wandb_cfg.get("entity", None) name = wandb_cfg.get("name", None) tags = wandb_cfg.get("tags", []) notes = wandb_cfg.get("notes", None) try: wandb.init( project=project, entity=entity, name=name, tags=tags, notes=notes, dir=str(run_dir), config={ "model": cfg.get("model", {}), "data": cfg.get("data", {}), "peft": cfg.get("peft", {}), "dpo": cfg.get("dpo", {}), "train": cfg.get("train", {}), "run_dir": str(run_dir), } ) print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'") return wandb except Exception as e: print(f"Failed to initialize Wandb: {e}") return None def finish_wandb(): """Finish Wandb run if active.""" if WANDB_AVAILABLE and wandb.run is not None: wandb.finish() print("Wandb run finished") # -------------------------- # JSONL Logger Callback # -------------------------- class JsonlLoggerCallback(TrainerCallback): def __init__(self, run_dir: Path): self.run_dir = run_dir self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl" self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl" self.start_time = None def _eta(self, global_step: int, max_steps: int) -> Optional[str]: if self.start_time is None or global_step <= 0 or max_steps <= 0: return None elapsed = time.time() - self.start_time sec_per_step = elapsed / global_step remaining = max(0, max_steps - global_step) * sec_per_step h = int(remaining // 3600) m = int((remaining % 3600) // 60) s = int(remaining % 60) return f"{h:02d}:{m:02d}:{s:02d}" def on_train_begin(self, args, state, control, **kwargs): self.start_time = time.time() def on_log(self, args, state, control, logs=None, **kwargs): if not logs: return max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0 progress_pct = ( (100.0 * state.global_step / max_steps) if max_steps > 0 else None ) epoch_pct = None if ( state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0 ): epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs)) payload = { "ts": _now_iso(), "event": "train_log", "step": int(state.global_step), "epoch": round(float(state.epoch), 4) if state.epoch is not None else None, "progress_pct": ( round(progress_pct, 2) if progress_pct is not None else None ), "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None, "eta": self._eta(int(state.global_step), max_steps), "max_grad_norm": getattr(args, "max_grad_norm", None), **logs, } with self.train_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") def on_evaluate(self, args, state, control, metrics=None, **kwargs): if not metrics: return eval_loss = metrics.get("eval_loss", None) payload = { "ts": _now_iso(), "event": "eval", "step": int(state.global_step), "epoch": float(state.epoch) if state.epoch is not None else None, **metrics, } with self.eval_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") # -------------------------- # Custom Exceptions # -------------------------- class DataFormattingError(Exception): """Exception raised for errors in data formatting.""" pass class DataValidationError(Exception): """Exception raised for errors in data validation.""" pass # -------------------------- # Data Pipeline (DPO Format) # -------------------------- def format_dpo_example( example: Dict[str, Any], cfg: Dict[str, Any], tokenizer ) -> Dict[str, Any]: """ Format DPO data which requires prompt, chosen, and rejected completions. Returns formatted prompt, chosen, and rejected texts. Raises DataFormattingError if formatting fails. """ data_cfg = cfg["data"] format_type = data_cfg.get("format_type", "chatml") # Get field names from config prompt_field = data_cfg.get("prompt_field", "prompt") chosen_field = data_cfg.get("chosen_field", "chosen") rejected_field = data_cfg.get("rejected_field", "rejected") # Extract text from example prompt = example.get(prompt_field, "") chosen = example.get(chosen_field, "") rejected = example.get(rejected_field, "") # Validate required fields if not prompt: raise DataFormattingError(f"Empty prompt field: {prompt_field}") if not chosen: raise DataFormattingError(f"Empty chosen field: {chosen_field}") if not rejected: raise DataFormattingError(f"Empty rejected field: {rejected_field}") if format_type == "chatml": system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.") # Format prompt with system message messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": prompt}) # Apply chat template for prompt only (without assistant response) formatted_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Chosen and rejected are just the completions (will be added by DPOTrainer) formatted_chosen = chosen formatted_rejected = rejected # Add EOS token to completions if tokenizer.eos_token: if not formatted_chosen.endswith(tokenizer.eos_token): formatted_chosen += tokenizer.eos_token if not formatted_rejected.endswith(tokenizer.eos_token): formatted_rejected += tokenizer.eos_token elif format_type == "alpaca": # Alpaca format prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n" formatted_prompt = prefix formatted_chosen = chosen formatted_rejected = rejected if tokenizer.eos_token: if not formatted_chosen.endswith(tokenizer.eos_token): formatted_chosen += tokenizer.eos_token if not formatted_rejected.endswith(tokenizer.eos_token): formatted_rejected += tokenizer.eos_token elif format_type == "custom": # Custom template template = data_cfg.get("custom_template", "{prompt}") formatted_prompt = template.format(prompt=prompt) formatted_chosen = chosen formatted_rejected = rejected if tokenizer.eos_token: if not formatted_chosen.endswith(tokenizer.eos_token): formatted_chosen += tokenizer.eos_token if not formatted_rejected.endswith(tokenizer.eos_token): formatted_rejected += tokenizer.eos_token else: raise ValueError(f"Unsupported format_type: {format_type}") return { "prompt": formatted_prompt, "chosen": formatted_chosen, "rejected": formatted_rejected, } def validate_dpo_data(dataset, stage: str = "train") -> None: """ Validate DPO dataset has all required fields and proper structure. Args: dataset: Dataset to validate stage: Training stage ("train" or "eval") Raises: DataValidationError if validation fails """ required_fields = ["prompt", "chosen", "rejected"] # Check required fields exist for field in required_fields: if field not in dataset.column_names: raise DataValidationError( f"{stage} dataset missing required field: {field}. " f"Available fields: {dataset.column_names}" ) # Sample validation - check first example if len(dataset) > 0: sample = dataset[0] for field in required_fields: if not sample[field] or len(sample[field].strip()) == 0: logger.warning(f"{stage} dataset has empty {field} in first example") logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]: """ Build datasets for DPO training. Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."} Or with custom field names specified in config. """ data_cfg = cfg["data"] train_path = data_cfg["train_jsonl"] eval_path = data_cfg.get("eval_jsonl", None) split_ratio = float(data_cfg.get("eval_split_ratio", 0.0)) shuffle = bool(data_cfg.get("shuffle", True)) num_proc = int(data_cfg.get("num_proc", 4)) # Ensure tokenizer has pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load datasets ds = load_dataset("json", data_files={"train": train_path}) if eval_path: ds_eval = load_dataset("json", data_files={"eval": eval_path}) dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]}) else: if 0.0 < split_ratio < 1.0: split = ds["train"].train_test_split( test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)) ) dsd = DatasetDict({"train": split["train"], "eval": split["test"]}) else: dsd = DatasetDict({"train": ds["train"], "eval": None}) # Format DPO examples with error handling def format_fn(examples): prompts = [] chosen_list = [] rejected_list = [] errors = 0 for i in range(len(examples[list(examples.keys())[0]])): example = {k: examples[k][i] for k in examples.keys()} try: formatted = format_dpo_example(example, cfg, tokenizer) prompts.append(formatted["prompt"]) chosen_list.append(formatted["chosen"]) rejected_list.append(formatted["rejected"]) except (DataFormattingError, Exception) as e: errors += 1 if errors <= 5: # Log first 5 errors logger.warning(f"Failed to format example {i}: {e}") # Add empty placeholder to maintain batch structure prompts.append("") chosen_list.append("") rejected_list.append("") if errors > 0: logger.warning(f"Total formatting errors in batch: {errors}") return { "prompt": prompts, "chosen": chosen_list, "rejected": rejected_list, } logger.info("Formatting train DPO data...") formatted_train = dsd["train"].map( format_fn, batched=True, num_proc=num_proc, remove_columns=dsd["train"].column_names, desc="Formatting train DPO data", ) # Filter out failed examples (empty prompts) formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0) logger.info(f"Train dataset after filtering: {len(formatted_train)} examples") # Validate formatted data validate_dpo_data(formatted_train, "train") formatted_eval = None if dsd["eval"] is not None: logger.info("Formatting eval DPO data...") formatted_eval = dsd["eval"].map( format_fn, batched=True, num_proc=num_proc, remove_columns=dsd["eval"].column_names, desc="Formatting eval DPO data", ) formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0) logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples") validate_dpo_data(formatted_eval, "eval") if shuffle: formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42))) return formatted_train, formatted_eval # -------------------------- # Model Loading + PEFT # -------------------------- def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path): model_cfg = cfg["model"] trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) use_fast = bool(model_cfg.get("tokenizer_use_fast", True)) device_map = model_cfg.get("device_map", "auto") tokenizer = AutoTokenizer.from_pretrained( str(base_dir), use_fast=use_fast, trust_remote_code=trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) use_4bit = bool(model_cfg.get("use_4bit", False)) quant_cfg = None if use_4bit: quant_cfg = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")), bnb_4bit_use_double_quant=bool( model_cfg.get("bnb_4bit_use_double_quant", True) ), bnb_4bit_compute_dtype=_dtype_from_str( model_cfg.get("bnb_4bit_compute_dtype", "bfloat16") ), ) attn_impl = _choose_attn_impl(cfg) try: model = AutoModelForCausalLM.from_pretrained( str(base_dir), device_map=device_map, trust_remote_code=trust_remote_code, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, attn_implementation=attn_impl, ) except Exception as e: if attn_impl is not None: print(f"[warn] attn_implementation='{attn_impl}' failed: {e}") print("[warn] Falling back to default attention implementation.") model = AutoModelForCausalLM.from_pretrained( str(base_dir), device_map=device_map, trust_remote_code=trust_remote_code, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, ) return model, tokenizer def apply_peft(cfg: Dict[str, Any], model): peft_cfg = cfg["peft"] model_cfg = cfg["model"] tr_cfg = cfg["train"] if not bool(peft_cfg.get("enabled", True)): return model, None use_4bit = bool(model_cfg.get("use_4bit", False)) gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True)) if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable() if hasattr(model, "config"): model.config.use_cache = False if use_4bit: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=gradient_checkpointing, ) target_modules = peft_cfg.get("target_modules", "auto") if target_modules == "auto": target_modules = _infer_target_modules(model) lora_config = LoraConfig( r=int(peft_cfg.get("r", 16)), lora_alpha=int(peft_cfg.get("lora_alpha", 32)), lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)), bias=str(peft_cfg.get("bias", "none")), task_type="CAUSAL_LM", target_modules=target_modules, ) model = get_peft_model(model, lora_config) return model, lora_config # -------------------------- # Merge Logic # -------------------------- def merge_adapter( cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path ): print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---") model_cfg = cfg["model"] merge_cfg = cfg.get("merge", {}) trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16")) max_shard_size = str(merge_cfg.get("max_shard_size", "2GB")) try: base = AutoModelForCausalLM.from_pretrained( str(base_dir), torch_dtype=merged_dtype, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=trust_remote_code, ) merged = PeftModel.from_pretrained(base, str(adapter_dir)) merged = merged.merge_and_unload() # Clean up base model to free memory del base gc.collect() torch.cuda.empty_cache() _ensure_dir(final_dir) merged.save_pretrained( str(final_dir), safe_serialization=True, max_shard_size=max_shard_size ) # Clean up merged model del merged gc.collect() torch.cuda.empty_cache() tok = AutoTokenizer.from_pretrained( str(base_dir), trust_remote_code=trust_remote_code ) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.save_pretrained(str(final_dir)) print("--- Merge complete ---") except Exception as e: logger.error(f"Merge failed: {e}") raise # -------------------------- # Main # -------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--config", required=True, help="Path to YAML config") ap.add_argument( "--merge-only", action="store_true", help="Skip training, just merge adapter" ) args = ap.parse_args() with open(args.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) run_dir = _ensure_dir(Path(cfg["run"]["run_dir"])) _ensure_dir(run_dir / "logs") with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f: yaml.safe_dump(cfg, f, sort_keys=False) model_cfg = cfg["model"] repo_id = str(model_cfg["repo_id"]).strip() repo_path = Path(repo_id) # Local model path -> load directly if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path): base_dir = repo_path print(f"Using local model at: {base_dir}") elif repo_path.exists() and repo_path.is_dir(): raise ValueError( f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}" ) else: # HF repo_id -> download base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model")) if not _looks_like_model_dir(base_dir): print(f"Base model not found at {base_dir}, downloading from {repo_id} ...") snapshot_download( repo_id=repo_id, revision=model_cfg.get("revision", None), local_dir=str(base_dir), local_dir_use_symlinks=False, ) ckpt_dir = _ensure_dir(run_dir / "checkpoints") best_adapter_dir = _ensure_dir(run_dir / "best_adapter") merge_cfg = cfg.get("merge", {}) or {} if merge_cfg.get("output_dir"): od = Path(str(merge_cfg["output_dir"])) final_dir = od if od.is_absolute() else (run_dir / od) else: final_dir = run_dir / "final_model" # Merge-only if args.merge_only: if not _looks_like_model_dir(best_adapter_dir): raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}") merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) return # Initialize Wandb wandb_run = setup_wandb(cfg, run_dir) # Training set_seed(int(cfg["run"].get("seed", 42))) model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir) model, _ = apply_peft(cfg, model) # Load reference model for DPO (if using reference model) dpo_cfg = cfg.get("dpo", {}) use_reference_model = bool(dpo_cfg.get("use_reference_model", True)) reference_free = bool(dpo_cfg.get("reference_free", False)) ref_model = None if use_reference_model and not reference_free: print("Loading reference model (frozen copy)...") ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir) ref_model, _ = apply_peft(cfg, ref_model) # Freeze reference model for param in ref_model.parameters(): param.requires_grad = False ref_model.eval() print("Reference model loaded and frozen") train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer) tr_cfg = cfg["train"] dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) use_fp16 = dtype == torch.float16 use_bf16 = dtype == torch.bfloat16 max_steps = int(tr_cfg.get("max_steps", 0)) num_train_epochs = float(tr_cfg.get("num_train_epochs", 1)) # Dynamic evaluation strategy parameter handling ta_params = inspect.signature(TrainingArguments.__init__).parameters eval_key = ( "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy" ) # Setup reporting based on wandb availability report_to = [] if wandb_run is not None: report_to.append("wandb") # Validate and adjust training parameters max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0)) if max_grad_norm <= 0: logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0") max_grad_norm = 1.0 ta_kwargs = dict( output_dir=str(ckpt_dir), max_steps=max_steps if max_steps > 0 else -1, num_train_epochs=num_train_epochs, per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)), per_device_eval_batch_size=int( tr_cfg.get( "per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1), ) ), gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)), learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), weight_decay=float(tr_cfg.get("weight_decay", 0.0)), warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")), optim=str( tr_cfg.get( "optim", ( "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch" ), ) ), max_grad_norm=max_grad_norm, logging_steps=int(tr_cfg.get("logging_steps", 10)), save_strategy=str(tr_cfg.get("save_strategy", "steps")), save_steps=int(tr_cfg.get("save_steps", 200)), save_total_limit=int(tr_cfg.get("save_total_limit", 3)), eval_steps=int(tr_cfg.get("eval_steps", 50)), load_best_model_at_end=( bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False ), metric_for_best_model="eval_loss", greater_is_better=False, fp16=use_fp16, bf16=use_bf16, report_to=report_to, remove_unused_columns=False, ) # Set the correct argument name for this transformers version ta_kwargs[eval_key] = str( tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no") ) training_args = TrainingArguments(**ta_kwargs) # Setup callbacks callbacks = [JsonlLoggerCallback(run_dir)] # Add early stopping callback if enabled early_stopping_cfg = tr_cfg.get("early_stopping", {}) if early_stopping_cfg.get("enabled", False) and eval_ds is not None: early_stopping_callback = EarlyStoppingCallback( early_stopping_patience=int(early_stopping_cfg.get("patience", 3)), early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)), ) callbacks.append(early_stopping_callback) print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, " f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}") # DPO-specific parameters beta = float(dpo_cfg.get("beta", 0.1)) label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0)) loss_type = str(dpo_cfg.get("loss_type", "sigmoid")) max_length = int(cfg["data"].get("max_length", 2048)) max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2)) print(f"DPO Training with beta={beta}, loss_type={loss_type}") # Get evaluation strategy from config eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")) # Create DPOConfig with all training and DPO-specific parameters dpo_config = DPOConfig( output_dir=str(run_dir), num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)), per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)), per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)), gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)), learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), weight_decay=float(tr_cfg.get("weight_decay", 0.01)), adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)), adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)), adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)), max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)), lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")), warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), logging_steps=int(tr_cfg.get("logging_steps", 10)), save_steps=int(tr_cfg.get("save_steps", 100)), save_total_limit=int(tr_cfg.get("save_total_limit", 3)), eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None, eval_strategy=eval_strategy_val, save_strategy=str(tr_cfg.get("save_strategy", "steps")), load_best_model_at_end=( bool(tr_cfg.get("load_best_model_at_end", False)) if eval_ds is not None else False ), metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")), greater_is_better=bool(tr_cfg.get("greater_is_better", False)), fp16=use_fp16, bf16=use_bf16, report_to=report_to, remove_unused_columns=False, # DPO-specific parameters beta=beta, label_smoothing=label_smoothing, loss_type=loss_type, max_length=max_length, max_prompt_length=max_prompt_length, ) trainer = DPOTrainer( model=model, ref_model=ref_model, args=dpo_config, train_dataset=train_ds, eval_dataset=eval_ds, processing_class=tokenizer, callbacks=callbacks, ) # Resume resume_from = tr_cfg.get("resume_from_checkpoint", None) if resume_from == "auto": last = get_last_checkpoint(str(ckpt_dir)) resume_from = last if last else None if resume_from: print(f"Resuming from {resume_from}") print("Starting DPO training...") trainer.train(resume_from_checkpoint=resume_from) trainer.save_model(str(best_adapter_dir)) print(f"Saved best adapter -> {best_adapter_dir}") if eval_ds is not None: metrics = trainer.evaluate() with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) print(f"Final metrics: {metrics}") if bool(cfg.get("merge", {}).get("enabled", False)): del trainer, model if ref_model is not None: del ref_model torch.cuda.empty_cache() merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) else: print("Merge disabled. Run with --merge-only later if needed.") # Finish Wandb run finish_wandb() if __name__ == "__main__": main()