import argparse import json import inspect import math import gc import time import logging from pathlib import Path from typing import Any, Dict, Optional, Tuple, List import torch import yaml from datasets import load_dataset, DatasetDict from huggingface_hub import snapshot_download from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, TrainingArguments, TrainerCallback, EarlyStoppingCallback, set_seed, ) from transformers.trainer_utils import get_last_checkpoint from peft import ( LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel, ) from trl import DPOTrainer, DPOConfig # Version check for TRL try: from packaging import version import trl if version.parse(trl.__version__) < version.parse("0.7.0"): logger.warning(f"TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.") except ImportError: logger.warning("Could not verify TRL version") try: import wandb WANDB_AVAILABLE = True except ImportError: WANDB_AVAILABLE = False wandb = None # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # -------------------------- # Custom Exceptions # -------------------------- class DataFormattingError(Exception): """Exception raised for errors in data formatting.""" pass class DataValidationError(Exception): """Exception raised for errors in data validation.""" pass # -------------------------- # Helpers # -------------------------- def _dtype_from_str(s: str) -> torch.dtype: s = (s or "").lower() if s in ("float16", "fp16"): return torch.float16 if s in ("bfloat16", "bf16"): return torch.bfloat16 if s in ("float32", "fp32"): return torch.float32 raise ValueError(f"Unknown torch_dtype: {s}") def _now_iso() -> str: return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) def _safe_exp(x: float) -> float: x = min(float(x), 50.0) return float(math.exp(x)) def _ensure_dir(p: Path) -> Path: p.mkdir(parents=True, exist_ok=True) return p def _looks_like_model_dir(p: Path) -> bool: if not p.exists() or not p.is_dir(): return False if (p / "config.json").exists(): return True if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")): return True return False def _infer_target_modules(model) -> List[str]: names = set() for n, _ in model.named_modules(): names.add(n.split(".")[-1]) for group in [ ["q_proj", "k_proj", "v_proj", "o_proj"], ["Wqkv", "out_proj"], ["query_key_value", "dense"], ["c_attn", "c_proj"], ]: if all(x in names for x in group): return group fallback = [ x for x in [ "q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense", ] if x in names ] if fallback: return fallback raise ValueError( "Could not auto-infer target_modules. Set peft.target_modules explicitly." ) def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]: return cfg.get("model", {}).get("attn_implementation", None) # -------------------------- # Wandb Integration # -------------------------- def setup_wandb(cfg: Dict[str, Any], run_dir: Path): """Initialize Wandb if enabled in configuration.""" wandb_cfg = cfg.get("wandb", {}) if not wandb_cfg.get("enabled", False): print("Wandb logging disabled") return None if not WANDB_AVAILABLE: print("Wandb not available. Install with: pip install wandb") return None project = wandb_cfg.get("project", "dpo-training") entity = wandb_cfg.get("entity", None) name = wandb_cfg.get("name", None) tags = wandb_cfg.get("tags", []) notes = wandb_cfg.get("notes", None) try: wandb.init( project=project, entity=entity, name=name, tags=tags, notes=notes, dir=str(run_dir), config={ "model": cfg.get("model", {}), "data": cfg.get("data", {}), "peft": cfg.get("peft", {}), "dpo": cfg.get("dpo", {}), "train": cfg.get("train", {}), "run_dir": str(run_dir), } ) print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'") return wandb except Exception as e: print(f"Failed to initialize Wandb: {e}") return None def finish_wandb(): """Finish Wandb run if active.""" if WANDB_AVAILABLE and wandb.run is not None: wandb.finish() print("Wandb run finished") # -------------------------- # JSONL Logger Callback # -------------------------- class JsonlLoggerCallback(TrainerCallback): def __init__(self, run_dir: Path): self.run_dir = run_dir self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl" self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl" self.start_time = None def _eta(self, global_step: int, max_steps: int) -> Optional[str]: if self.start_time is None or global_step <= 0 or max_steps <= 0: return None elapsed = time.time() - self.start_time sec_per_step = elapsed / global_step remaining = max(0, max_steps - global_step) * sec_per_step h = int(remaining // 3600) m = int((remaining % 3600) // 60) s = int(remaining % 60) return f"{h:02d}:{m:02d}:{s:02d}" def on_train_begin(self, args, state, control, **kwargs): self.start_time = time.time() def on_log(self, args, state, control, logs=None, **kwargs): if not logs: return max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0 progress_pct = ( (100.0 * state.global_step / max_steps) if max_steps > 0 else None ) epoch_pct = None if ( state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0 ): epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs)) payload = { "ts": _now_iso(), "event": "train_log", "step": int(state.global_step), "epoch": round(float(state.epoch), 4) if state.epoch is not None else None, "progress_pct": ( round(progress_pct, 2) if progress_pct is not None else None ), "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None, "eta": self._eta(int(state.global_step), max_steps), "max_grad_norm": getattr(args, "max_grad_norm", None), **logs, } with self.train_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") def on_evaluate(self, args, state, control, metrics=None, **kwargs): if not metrics: return eval_loss = metrics.get("eval_loss", None) payload = { "ts": _now_iso(), "event": "eval", "step": int(state.global_step), "epoch": float(state.epoch) if state.epoch is not None else None, **metrics, } with self.eval_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(payload, ensure_ascii=False) + "\n") # -------------------------- # Custom Exceptions # -------------------------- class DataFormattingError(Exception): """Exception raised for errors in data formatting.""" pass class DataValidationError(Exception): """Exception raised for errors in data validation.""" pass # -------------------------- # Data Pipeline (DPO Format) # -------------------------- def format_dpo_example( example: Dict[str, Any], cfg: Dict[str, Any], tokenizer ) -> Dict[str, Any]: """ Format DPO data which requires prompt, chosen, and rejected completions. Returns formatted prompt, chosen, and rejected texts. Raises DataFormattingError if formatting fails. """ data_cfg = cfg["data"] format_type = data_cfg.get("format_type", "chatml") # Get field names from config prompt_field = data_cfg.get("prompt_field", "prompt") chosen_field = data_cfg.get("chosen_field", "chosen") rejected_field = data_cfg.get("rejected_field", "rejected") # Extract text from example prompt = example.get(prompt_field, "") chosen = example.get(chosen_field, "") rejected = example.get(rejected_field, "") # Validate required fields if not prompt: raise DataFormattingError(f"Empty prompt field: {prompt_field}") if not chosen: raise DataFormattingError(f"Empty chosen field: {chosen_field}") if not rejected: raise DataFormattingError(f"Empty rejected field: {rejected_field}") if format_type == "chatml": # DPOTrainer will handle chat template internally, just pass raw text formatted_prompt = prompt formatted_chosen = chosen formatted_rejected = rejected elif format_type == "alpaca": # DPOTrainer will handle formatting, just pass raw text formatted_prompt = prompt formatted_chosen = chosen formatted_rejected = rejected elif format_type == "custom": # Custom template template = data_cfg.get("custom_template", "{prompt}") formatted_prompt = template.format(prompt=prompt) formatted_chosen = chosen formatted_rejected = rejected else: raise ValueError(f"Unsupported format_type: {format_type}") return { "prompt": formatted_prompt, "chosen": formatted_chosen, "rejected": formatted_rejected, } def validate_dpo_data(dataset, stage: str = "train") -> None: """ Validate DPO dataset has all required fields and proper structure. Args: dataset: Dataset to validate stage: Training stage ("train" or "eval") Raises: DataValidationError if validation fails """ required_fields = ["prompt", "chosen", "rejected"] # Check required fields exist for field in required_fields: if field not in dataset.column_names: raise DataValidationError( f"{stage} dataset missing required field: {field}. " f"Available fields: {dataset.column_names}" ) # Sample validation - check first example if len(dataset) > 0: sample = dataset[0] for field in required_fields: if not sample[field] or len(sample[field].strip()) == 0: logger.warning(f"{stage} dataset has empty {field} in first example") logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]: """ Build datasets for DPO training. Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."} Or with custom field names specified in config. """ data_cfg = cfg["data"] train_path = data_cfg["train_jsonl"] eval_path = data_cfg.get("eval_jsonl", None) split_ratio = float(data_cfg.get("eval_split_ratio", 0.0)) shuffle = bool(data_cfg.get("shuffle", True)) num_proc = int(data_cfg.get("num_proc", 4)) # Ensure tokenizer has pad token if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load datasets ds = load_dataset("json", data_files={"train": train_path}) if eval_path: ds_eval = load_dataset("json", data_files={"eval": eval_path}) dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]}) else: if 0.0 < split_ratio < 1.0: split = ds["train"].train_test_split( test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)) ) dsd = DatasetDict({"train": split["train"], "eval": split["test"]}) else: dsd = DatasetDict({"train": ds["train"], "eval": None}) # Format DPO examples with error handling def format_fn(examples): prompts = [] chosen_list = [] rejected_list = [] errors = 0 for i in range(len(examples[list(examples.keys())[0]])): example = {k: examples[k][i] for k in examples.keys()} try: formatted = format_dpo_example(example, cfg, tokenizer) prompts.append(formatted["prompt"]) chosen_list.append(formatted["chosen"]) rejected_list.append(formatted["rejected"]) except (DataFormattingError, Exception) as e: errors += 1 if errors <= 5: # Log first 5 errors logger.warning(f"Failed to format example {i}: {e}") # Add empty placeholder to maintain batch structure prompts.append("") chosen_list.append("") rejected_list.append("") if errors > 0: logger.warning(f"Total formatting errors in batch: {errors}") return { "prompt": prompts, "chosen": chosen_list, "rejected": rejected_list, } logger.info("Formatting train DPO data...") formatted_train = dsd["train"].map( format_fn, batched=True, num_proc=num_proc, remove_columns=dsd["train"].column_names, desc="Formatting train DPO data", ) # Filter out failed examples (empty prompts) formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0) logger.info(f"Train dataset after filtering: {len(formatted_train)} examples") # Validate formatted data validate_dpo_data(formatted_train, "train") formatted_eval = None if dsd["eval"] is not None: logger.info("Formatting eval DPO data...") formatted_eval = dsd["eval"].map( format_fn, batched=True, num_proc=num_proc, remove_columns=dsd["eval"].column_names, desc="Formatting eval DPO data", ) formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0) logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples") validate_dpo_data(formatted_eval, "eval") if shuffle: formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42))) return formatted_train, formatted_eval # -------------------------- # Model Loading + PEFT # -------------------------- def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path): model_cfg = cfg["model"] trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) use_fast = bool(model_cfg.get("tokenizer_use_fast", True)) device_map = model_cfg.get("device_map", "auto") tokenizer = AutoTokenizer.from_pretrained( str(base_dir), use_fast=use_fast, trust_remote_code=trust_remote_code, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) use_4bit = bool(model_cfg.get("use_4bit", False)) quant_cfg = None if use_4bit: quant_cfg = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")), bnb_4bit_use_double_quant=bool( model_cfg.get("bnb_4bit_use_double_quant", True) ), bnb_4bit_compute_dtype=_dtype_from_str( model_cfg.get("bnb_4bit_compute_dtype", "bfloat16") ), ) attn_impl = _choose_attn_impl(cfg) # First check the model type to determine loading strategy try: config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True) model_type = config.model_type architectures = getattr(config, 'architectures', []) # Handle Mistral3 (multimodal) models if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]): logger.info(f"Detected Mistral3 model architecture, loading with specific class") from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration try: model = Mistral3ForConditionalGeneration.from_pretrained( str(base_dir), config=config, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, attn_implementation=attn_impl, ) except Exception as e: if attn_impl is not None: logger.warning(f"attn_implementation='{attn_impl}' failed: {e}") logger.warning("Falling back to default attention implementation.") model = Mistral3ForConditionalGeneration.from_pretrained( str(base_dir), config=config, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, ) else: raise e else: # Standard AutoModelForCausalLM loading for other models try: model = AutoModelForCausalLM.from_pretrained( str(base_dir), device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, attn_implementation=attn_impl, ) except Exception as e: if attn_impl is not None: logger.warning(f"attn_implementation='{attn_impl}' failed: {e}") logger.warning("Falling back to default attention implementation.") model = AutoModelForCausalLM.from_pretrained( str(base_dir), device_map=device_map, trust_remote_code=True, low_cpu_mem_usage=True, torch_dtype=(torch_dtype if not use_4bit else None), quantization_config=quant_cfg, ) else: raise e except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Ensure all parameters are off meta device logger.info("Ensuring all parameters are materialized...") meta_params = [] for name, param in model.named_parameters(): if param.device.type == 'meta': meta_params.append(name) if meta_params: logger.warning(f"Found {len(meta_params)} parameters on meta device") # For multimodal models, freeze vision components if doing text-only training if hasattr(model, 'vision_tower'): logger.info("Freezing vision tower for text-only training") for param in model.vision_tower.parameters(): param.requires_grad = False return model, tokenizer def apply_peft(cfg: Dict[str, Any], model): peft_cfg = cfg["peft"] model_cfg = cfg["model"] tr_cfg = cfg["train"] if not bool(peft_cfg.get("enabled", True)): return model, None use_4bit = bool(model_cfg.get("use_4bit", False)) gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True)) if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable() if hasattr(model, "config"): model.config.use_cache = False if use_4bit: model = prepare_model_for_kbit_training( model, use_gradient_checkpointing=gradient_checkpointing, ) target_modules = peft_cfg.get("target_modules", "auto") if target_modules == "auto": target_modules = _infer_target_modules(model) lora_config = LoraConfig( r=int(peft_cfg.get("r", 16)), lora_alpha=int(peft_cfg.get("lora_alpha", 32)), lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)), bias=str(peft_cfg.get("bias", "none")), task_type="CAUSAL_LM", target_modules=target_modules, ) model = get_peft_model(model, lora_config) return model, lora_config # -------------------------- # Merge Logic # -------------------------- def merge_adapter( cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path ): logger.info(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---") model_cfg = cfg["model"] merge_cfg = cfg.get("merge", {}) trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16")) max_shard_size = str(merge_cfg.get("max_shard_size", "2GB")) try: base = AutoModelForCausalLM.from_pretrained( str(base_dir), torch_dtype=merged_dtype, device_map="cpu", low_cpu_mem_usage=True, trust_remote_code=trust_remote_code, ) merged = PeftModel.from_pretrained(base, str(adapter_dir)) merged = merged.merge_and_unload() # Clean up base model to free memory del base gc.collect() torch.cuda.empty_cache() _ensure_dir(final_dir) merged.save_pretrained( str(final_dir), safe_serialization=True, max_shard_size=max_shard_size ) # Clean up merged model del merged gc.collect() torch.cuda.empty_cache() tok = AutoTokenizer.from_pretrained( str(base_dir), trust_remote_code=trust_remote_code ) if tok.pad_token is None: tok.pad_token = tok.eos_token tok.save_pretrained(str(final_dir)) logger.info("--- Merge complete ---") except Exception as e: logger.error(f"Merge failed: {e}") raise # -------------------------- # Main # -------------------------- def main(): ap = argparse.ArgumentParser() ap.add_argument("--config", required=True, help="Path to YAML config") ap.add_argument( "--merge-only", action="store_true", help="Skip training, just merge adapter" ) args = ap.parse_args() with open(args.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) run_dir = _ensure_dir(Path(cfg["run"]["run_dir"])) _ensure_dir(run_dir / "logs") with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f: yaml.safe_dump(cfg, f, sort_keys=False) model_cfg = cfg["model"] repo_id = str(model_cfg["repo_id"]).strip() repo_path = Path(repo_id) # Local model path -> load directly if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path): base_dir = repo_path logger.info(f"Using local model at: {base_dir}") elif repo_path.exists() and repo_path.is_dir(): raise ValueError( f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}" ) else: # HF repo_id -> download base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model")) if not _looks_like_model_dir(base_dir): print(f"Base model not found at {base_dir}, downloading from {repo_id} ...") snapshot_download( repo_id=repo_id, revision=model_cfg.get("revision", None), local_dir=str(base_dir), local_dir_use_symlinks=False, ) ckpt_dir = _ensure_dir(run_dir / "checkpoints") best_adapter_dir = _ensure_dir(run_dir / "best_adapter") merge_cfg = cfg.get("merge", {}) or {} if merge_cfg.get("output_dir"): od = Path(str(merge_cfg["output_dir"])) final_dir = od if od.is_absolute() else (run_dir / od) else: final_dir = run_dir / "final_model" # Merge-only if args.merge_only: if not _looks_like_model_dir(best_adapter_dir): raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}") merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) return # Initialize Wandb wandb_run = setup_wandb(cfg, run_dir) # Training set_seed(int(cfg["run"].get("seed", 42))) model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir) model, _ = apply_peft(cfg, model) # Load reference model for DPO (if using reference model) dpo_cfg = cfg.get("dpo", {}) use_reference_model = bool(dpo_cfg.get("use_reference_model", True)) reference_free = bool(dpo_cfg.get("reference_free", False)) ref_model = None if use_reference_model and not reference_free: print("Loading reference model (frozen copy)...") ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir) ref_model, _ = apply_peft(cfg, ref_model) # Freeze reference model for param in ref_model.parameters(): param.requires_grad = False ref_model.eval() print("Reference model loaded and frozen") train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer) tr_cfg = cfg["train"] dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) use_fp16 = dtype == torch.float16 use_bf16 = dtype == torch.bfloat16 max_steps = int(tr_cfg.get("max_steps", 0)) num_train_epochs = float(tr_cfg.get("num_train_epochs", 1)) # Dynamic evaluation strategy parameter handling ta_params = inspect.signature(TrainingArguments.__init__).parameters eval_key = ( "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy" ) # Setup reporting based on wandb availability report_to = [] if wandb_run is not None: report_to.append("wandb") # Validate and adjust training parameters max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0)) if max_grad_norm <= 0: logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0") max_grad_norm = 1.0 ta_kwargs = dict( output_dir=str(ckpt_dir), max_steps=max_steps if max_steps > 0 else -1, num_train_epochs=num_train_epochs, per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)), per_device_eval_batch_size=int( tr_cfg.get( "per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1), ) ), gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)), learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), weight_decay=float(tr_cfg.get("weight_decay", 0.0)), warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")), optim=str( tr_cfg.get( "optim", ( "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch" ), ) ), max_grad_norm=max_grad_norm, logging_steps=int(tr_cfg.get("logging_steps", 10)), save_strategy=str(tr_cfg.get("save_strategy", "steps")), save_steps=int(tr_cfg.get("save_steps", 200)), save_total_limit=int(tr_cfg.get("save_total_limit", 3)), eval_steps=int(tr_cfg.get("eval_steps", 50)), load_best_model_at_end=( bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False ), metric_for_best_model="eval_loss", greater_is_better=False, fp16=use_fp16, bf16=use_bf16, report_to=report_to, remove_unused_columns=False, ) # Set the correct argument name for this transformers version ta_kwargs[eval_key] = str( tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no") ) training_args = TrainingArguments(**ta_kwargs) # Setup callbacks callbacks = [JsonlLoggerCallback(run_dir)] # Add early stopping callback if enabled early_stopping_cfg = tr_cfg.get("early_stopping", {}) if early_stopping_cfg.get("enabled", False) and eval_ds is not None: early_stopping_callback = EarlyStoppingCallback( early_stopping_patience=int(early_stopping_cfg.get("patience", 3)), early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)), ) callbacks.append(early_stopping_callback) print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, " f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}") # DPO-specific parameters beta = float(dpo_cfg.get("beta", 0.1)) label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0)) loss_type = str(dpo_cfg.get("loss_type", "sigmoid")) max_length = int(cfg["data"].get("max_length", 2048)) max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2)) logger.info(f"DPO Training with beta={beta}, loss_type={loss_type}") # Get evaluation strategy from config eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")) # Create DPOConfig with all training and DPO-specific parameters dpo_config = DPOConfig( output_dir=str(run_dir), model_init_kwargs={"trust_remote_code": True}, num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)), per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)), per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)), gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)), learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), weight_decay=float(tr_cfg.get("weight_decay", 0.01)), adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)), adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)), adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)), max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)), lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")), warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), logging_steps=int(tr_cfg.get("logging_steps", 10)), save_steps=int(tr_cfg.get("save_steps", 100)), save_total_limit=int(tr_cfg.get("save_total_limit", 3)), eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None, eval_strategy=eval_strategy_val, save_strategy=str(tr_cfg.get("save_strategy", "steps")), load_best_model_at_end=( bool(tr_cfg.get("load_best_model_at_end", False)) if eval_ds is not None else False ), metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")), greater_is_better=bool(tr_cfg.get("greater_is_better", False)), fp16=use_fp16, bf16=use_bf16, report_to=report_to, remove_unused_columns=False, # DPO-specific parameters beta=beta, label_smoothing=label_smoothing, loss_type=loss_type, max_length=max_length, max_prompt_length=max_prompt_length, ) # DPOTrainer # For text-only models, don't pass processing_class - let DPOTrainer handle it trainer = DPOTrainer( model=model, ref_model=ref_model, args=dpo_config, train_dataset=train_ds, eval_dataset=eval_ds, callbacks=callbacks, ) # Resume resume_from = tr_cfg.get("resume_from_checkpoint", None) if resume_from == "auto": last = get_last_checkpoint(str(ckpt_dir)) resume_from = last if last else None if resume_from: logger.info(f"Resuming from {resume_from}") logger.info("Starting DPO training...") trainer.train(resume_from_checkpoint=resume_from) trainer.save_model(str(best_adapter_dir)) logger.info(f"Saved best adapter -> {best_adapter_dir}") if eval_ds is not None: metrics = trainer.evaluate() with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) print(f"Final metrics: {metrics}") if bool(cfg.get("merge", {}).get("enabled", False)): del trainer, model if ref_model is not None: del ref_model torch.cuda.empty_cache() merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) else: print("Merge disabled. Run with --merge-only later if needed.") # Finish Wandb run finish_wandb() if __name__ == "__main__": main()