| import argparse |
| import json |
| import inspect |
| import math |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple, List |
|
|
| import torch |
| import yaml |
| from datasets import load_dataset, DatasetDict |
| from huggingface_hub import snapshot_download |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| BitsAndBytesConfig, |
| TrainingArguments, |
| TrainerCallback, |
| EarlyStoppingCallback, |
| set_seed, |
| ) |
| from transformers.trainer_utils import get_last_checkpoint |
| from peft import ( |
| LoraConfig, |
| get_peft_model, |
| prepare_model_for_kbit_training, |
| PeftModel, |
| ) |
| from trl import DPOTrainer, DPOConfig |
|
|
| try: |
| import wandb |
| WANDB_AVAILABLE = True |
| except ImportError: |
| WANDB_AVAILABLE = False |
| wandb = None |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _dtype_from_str(s: str) -> torch.dtype: |
| s = (s or "").lower() |
| if s in ("float16", "fp16"): |
| return torch.float16 |
| if s in ("bfloat16", "bf16"): |
| return torch.bfloat16 |
| if s in ("float32", "fp32"): |
| return torch.float32 |
| raise ValueError(f"Unknown torch_dtype: {s}") |
|
|
|
|
| def _now_iso() -> str: |
| return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) |
|
|
|
|
| def _safe_exp(x: float) -> float: |
| x = min(float(x), 50.0) |
| return float(math.exp(x)) |
|
|
|
|
| def _ensure_dir(p: Path) -> Path: |
| p.mkdir(parents=True, exist_ok=True) |
| return p |
|
|
|
|
| def _looks_like_model_dir(p: Path) -> bool: |
| if not p.exists() or not p.is_dir(): |
| return False |
| if (p / "config.json").exists(): |
| return True |
| if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")): |
| return True |
| return False |
|
|
|
|
| def _infer_target_modules(model) -> List[str]: |
| names = set() |
| for n, _ in model.named_modules(): |
| names.add(n.split(".")[-1]) |
|
|
| for group in [ |
| ["q_proj", "k_proj", "v_proj", "o_proj"], |
| ["Wqkv", "out_proj"], |
| ["query_key_value", "dense"], |
| ["c_attn", "c_proj"], |
| ]: |
| if all(x in names for x in group): |
| return group |
|
|
| fallback = [ |
| x |
| for x in [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "c_attn", |
| "c_proj", |
| "out_proj", |
| "dense", |
| ] |
| if x in names |
| ] |
| if fallback: |
| return fallback |
|
|
| raise ValueError( |
| "Could not auto-infer target_modules. Set peft.target_modules explicitly." |
| ) |
|
|
|
|
| def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]: |
| return cfg.get("model", {}).get("attn_implementation", None) |
|
|
|
|
| |
| |
| |
|
|
| def setup_wandb(cfg: Dict[str, Any], run_dir: Path): |
| """Initialize Wandb if enabled in configuration.""" |
| wandb_cfg = cfg.get("wandb", {}) |
| |
| if not wandb_cfg.get("enabled", False): |
| print("Wandb logging disabled") |
| return None |
| |
| if not WANDB_AVAILABLE: |
| print("Wandb not available. Install with: pip install wandb") |
| return None |
| |
| project = wandb_cfg.get("project", "dpo-training") |
| entity = wandb_cfg.get("entity", None) |
| name = wandb_cfg.get("name", None) |
| tags = wandb_cfg.get("tags", []) |
| notes = wandb_cfg.get("notes", None) |
| |
| try: |
| wandb.init( |
| project=project, |
| entity=entity, |
| name=name, |
| tags=tags, |
| notes=notes, |
| dir=str(run_dir), |
| config={ |
| "model": cfg.get("model", {}), |
| "data": cfg.get("data", {}), |
| "peft": cfg.get("peft", {}), |
| "dpo": cfg.get("dpo", {}), |
| "train": cfg.get("train", {}), |
| "run_dir": str(run_dir), |
| } |
| ) |
| print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'") |
| return wandb |
| except Exception as e: |
| print(f"Failed to initialize Wandb: {e}") |
| return None |
|
|
|
|
| def finish_wandb(): |
| """Finish Wandb run if active.""" |
| if WANDB_AVAILABLE and wandb.run is not None: |
| wandb.finish() |
| print("Wandb run finished") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class JsonlLoggerCallback(TrainerCallback): |
| def __init__(self, run_dir: Path): |
| self.run_dir = run_dir |
| self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl" |
| self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl" |
| self.start_time = None |
|
|
| def _eta(self, global_step: int, max_steps: int) -> Optional[str]: |
| if self.start_time is None or global_step <= 0 or max_steps <= 0: |
| return None |
| elapsed = time.time() - self.start_time |
| sec_per_step = elapsed / global_step |
| remaining = max(0, max_steps - global_step) * sec_per_step |
| h = int(remaining // 3600) |
| m = int((remaining % 3600) // 60) |
| s = int(remaining % 60) |
| return f"{h:02d}:{m:02d}:{s:02d}" |
|
|
| def on_train_begin(self, args, state, control, **kwargs): |
| self.start_time = time.time() |
|
|
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if not logs: |
| return |
|
|
| max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0 |
| progress_pct = ( |
| (100.0 * state.global_step / max_steps) if max_steps > 0 else None |
| ) |
| epoch_pct = None |
| if ( |
| state.epoch is not None |
| and args.num_train_epochs |
| and args.num_train_epochs > 0 |
| ): |
| epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs)) |
|
|
| payload = { |
| "ts": _now_iso(), |
| "event": "train_log", |
| "step": int(state.global_step), |
| "epoch": round(float(state.epoch), 4) if state.epoch is not None else None, |
| "progress_pct": ( |
| round(progress_pct, 2) if progress_pct is not None else None |
| ), |
| "epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None, |
| "eta": self._eta(int(state.global_step), max_steps), |
| "max_grad_norm": getattr(args, "max_grad_norm", None), |
| **logs, |
| } |
|
|
| with self.train_log_path.open("a", encoding="utf-8") as f: |
| f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
|
|
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
| if not metrics: |
| return |
| eval_loss = metrics.get("eval_loss", None) |
|
|
| payload = { |
| "ts": _now_iso(), |
| "event": "eval", |
| "step": int(state.global_step), |
| "epoch": float(state.epoch) if state.epoch is not None else None, |
| **metrics, |
| } |
| with self.eval_log_path.open("a", encoding="utf-8") as f: |
| f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
|
|
|
|
| |
| |
| |
|
|
|
|
| class DataFormattingError(Exception): |
| """Exception raised for errors in data formatting.""" |
| pass |
|
|
|
|
| class DataValidationError(Exception): |
| """Exception raised for errors in data validation.""" |
| pass |
|
|
|
|
| |
| |
| |
|
|
|
|
| def format_dpo_example( |
| example: Dict[str, Any], cfg: Dict[str, Any], tokenizer |
| ) -> Dict[str, Any]: |
| """ |
| Format DPO data which requires prompt, chosen, and rejected completions. |
| Returns formatted prompt, chosen, and rejected texts. |
| Raises DataFormattingError if formatting fails. |
| """ |
| data_cfg = cfg["data"] |
| format_type = data_cfg.get("format_type", "chatml") |
|
|
| |
| prompt_field = data_cfg.get("prompt_field", "prompt") |
| chosen_field = data_cfg.get("chosen_field", "chosen") |
| rejected_field = data_cfg.get("rejected_field", "rejected") |
|
|
| |
| prompt = example.get(prompt_field, "") |
| chosen = example.get(chosen_field, "") |
| rejected = example.get(rejected_field, "") |
| |
| |
| if not prompt: |
| raise DataFormattingError(f"Empty prompt field: {prompt_field}") |
| if not chosen: |
| raise DataFormattingError(f"Empty chosen field: {chosen_field}") |
| if not rejected: |
| raise DataFormattingError(f"Empty rejected field: {rejected_field}") |
|
|
| if format_type == "chatml": |
| system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.") |
| |
| |
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| messages.append({"role": "user", "content": prompt}) |
| |
| |
| formatted_prompt = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| |
| |
| formatted_chosen = chosen |
| formatted_rejected = rejected |
| |
| |
| if tokenizer.eos_token: |
| if not formatted_chosen.endswith(tokenizer.eos_token): |
| formatted_chosen += tokenizer.eos_token |
| if not formatted_rejected.endswith(tokenizer.eos_token): |
| formatted_rejected += tokenizer.eos_token |
|
|
| elif format_type == "alpaca": |
| |
| prefix = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:\n" |
| formatted_prompt = prefix |
| formatted_chosen = chosen |
| formatted_rejected = rejected |
| |
| if tokenizer.eos_token: |
| if not formatted_chosen.endswith(tokenizer.eos_token): |
| formatted_chosen += tokenizer.eos_token |
| if not formatted_rejected.endswith(tokenizer.eos_token): |
| formatted_rejected += tokenizer.eos_token |
|
|
| elif format_type == "custom": |
| |
| template = data_cfg.get("custom_template", "{prompt}") |
| formatted_prompt = template.format(prompt=prompt) |
| formatted_chosen = chosen |
| formatted_rejected = rejected |
| |
| if tokenizer.eos_token: |
| if not formatted_chosen.endswith(tokenizer.eos_token): |
| formatted_chosen += tokenizer.eos_token |
| if not formatted_rejected.endswith(tokenizer.eos_token): |
| formatted_rejected += tokenizer.eos_token |
| else: |
| raise ValueError(f"Unsupported format_type: {format_type}") |
|
|
| return { |
| "prompt": formatted_prompt, |
| "chosen": formatted_chosen, |
| "rejected": formatted_rejected, |
| } |
|
|
|
|
| def validate_dpo_data(dataset, stage: str = "train") -> None: |
| """ |
| Validate DPO dataset has all required fields and proper structure. |
| |
| Args: |
| dataset: Dataset to validate |
| stage: Training stage ("train" or "eval") |
| |
| Raises: |
| DataValidationError if validation fails |
| """ |
| required_fields = ["prompt", "chosen", "rejected"] |
| |
| |
| for field in required_fields: |
| if field not in dataset.column_names: |
| raise DataValidationError( |
| f"{stage} dataset missing required field: {field}. " |
| f"Available fields: {dataset.column_names}" |
| ) |
| |
| |
| if len(dataset) > 0: |
| sample = dataset[0] |
| for field in required_fields: |
| if not sample[field] or len(sample[field].strip()) == 0: |
| logger.warning(f"{stage} dataset has empty {field} in first example") |
| |
| logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") |
|
|
|
|
| def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]: |
| """ |
| Build datasets for DPO training. |
| Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."} |
| Or with custom field names specified in config. |
| """ |
| data_cfg = cfg["data"] |
| train_path = data_cfg["train_jsonl"] |
| eval_path = data_cfg.get("eval_jsonl", None) |
| split_ratio = float(data_cfg.get("eval_split_ratio", 0.0)) |
| shuffle = bool(data_cfg.get("shuffle", True)) |
| num_proc = int(data_cfg.get("num_proc", 4)) |
|
|
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| ds = load_dataset("json", data_files={"train": train_path}) |
|
|
| if eval_path: |
| ds_eval = load_dataset("json", data_files={"eval": eval_path}) |
| dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]}) |
| else: |
| if 0.0 < split_ratio < 1.0: |
| split = ds["train"].train_test_split( |
| test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)) |
| ) |
| dsd = DatasetDict({"train": split["train"], "eval": split["test"]}) |
| else: |
| dsd = DatasetDict({"train": ds["train"], "eval": None}) |
|
|
| |
| def format_fn(examples): |
| prompts = [] |
| chosen_list = [] |
| rejected_list = [] |
| errors = 0 |
| |
| for i in range(len(examples[list(examples.keys())[0]])): |
| example = {k: examples[k][i] for k in examples.keys()} |
| try: |
| formatted = format_dpo_example(example, cfg, tokenizer) |
| prompts.append(formatted["prompt"]) |
| chosen_list.append(formatted["chosen"]) |
| rejected_list.append(formatted["rejected"]) |
| except (DataFormattingError, Exception) as e: |
| errors += 1 |
| if errors <= 5: |
| logger.warning(f"Failed to format example {i}: {e}") |
| |
| prompts.append("") |
| chosen_list.append("") |
| rejected_list.append("") |
| |
| if errors > 0: |
| logger.warning(f"Total formatting errors in batch: {errors}") |
| |
| return { |
| "prompt": prompts, |
| "chosen": chosen_list, |
| "rejected": rejected_list, |
| } |
|
|
| logger.info("Formatting train DPO data...") |
| formatted_train = dsd["train"].map( |
| format_fn, |
| batched=True, |
| num_proc=num_proc, |
| remove_columns=dsd["train"].column_names, |
| desc="Formatting train DPO data", |
| ) |
| |
| |
| formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0) |
| logger.info(f"Train dataset after filtering: {len(formatted_train)} examples") |
| |
| |
| validate_dpo_data(formatted_train, "train") |
|
|
| formatted_eval = None |
| if dsd["eval"] is not None: |
| logger.info("Formatting eval DPO data...") |
| formatted_eval = dsd["eval"].map( |
| format_fn, |
| batched=True, |
| num_proc=num_proc, |
| remove_columns=dsd["eval"].column_names, |
| desc="Formatting eval DPO data", |
| ) |
| formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0) |
| logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples") |
| validate_dpo_data(formatted_eval, "eval") |
|
|
| if shuffle: |
| formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42))) |
|
|
| return formatted_train, formatted_eval |
|
|
|
|
| |
| |
| |
|
|
|
|
| def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path): |
| model_cfg = cfg["model"] |
| trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) |
| use_fast = bool(model_cfg.get("tokenizer_use_fast", True)) |
| device_map = model_cfg.get("device_map", "auto") |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| str(base_dir), |
| use_fast=use_fast, |
| trust_remote_code=trust_remote_code, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) |
| use_4bit = bool(model_cfg.get("use_4bit", False)) |
|
|
| quant_cfg = None |
| if use_4bit: |
| quant_cfg = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")), |
| bnb_4bit_use_double_quant=bool( |
| model_cfg.get("bnb_4bit_use_double_quant", True) |
| ), |
| bnb_4bit_compute_dtype=_dtype_from_str( |
| model_cfg.get("bnb_4bit_compute_dtype", "bfloat16") |
| ), |
| ) |
|
|
| attn_impl = _choose_attn_impl(cfg) |
|
|
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| str(base_dir), |
| device_map=device_map, |
| trust_remote_code=trust_remote_code, |
| low_cpu_mem_usage=True, |
| torch_dtype=(torch_dtype if not use_4bit else None), |
| quantization_config=quant_cfg, |
| attn_implementation=attn_impl, |
| ) |
| except Exception as e: |
| if attn_impl is not None: |
| print(f"[warn] attn_implementation='{attn_impl}' failed: {e}") |
| print("[warn] Falling back to default attention implementation.") |
| model = AutoModelForCausalLM.from_pretrained( |
| str(base_dir), |
| device_map=device_map, |
| trust_remote_code=trust_remote_code, |
| low_cpu_mem_usage=True, |
| torch_dtype=(torch_dtype if not use_4bit else None), |
| quantization_config=quant_cfg, |
| ) |
|
|
| return model, tokenizer |
|
|
|
|
| def apply_peft(cfg: Dict[str, Any], model): |
| peft_cfg = cfg["peft"] |
| model_cfg = cfg["model"] |
| tr_cfg = cfg["train"] |
|
|
| if not bool(peft_cfg.get("enabled", True)): |
| return model, None |
|
|
| use_4bit = bool(model_cfg.get("use_4bit", False)) |
| gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True)) |
|
|
| if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): |
| model.gradient_checkpointing_enable() |
| if hasattr(model, "config"): |
| model.config.use_cache = False |
|
|
| if use_4bit: |
| model = prepare_model_for_kbit_training( |
| model, |
| use_gradient_checkpointing=gradient_checkpointing, |
| ) |
|
|
| target_modules = peft_cfg.get("target_modules", "auto") |
| if target_modules == "auto": |
| target_modules = _infer_target_modules(model) |
|
|
| lora_config = LoraConfig( |
| r=int(peft_cfg.get("r", 16)), |
| lora_alpha=int(peft_cfg.get("lora_alpha", 32)), |
| lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)), |
| bias=str(peft_cfg.get("bias", "none")), |
| task_type="CAUSAL_LM", |
| target_modules=target_modules, |
| ) |
| model = get_peft_model(model, lora_config) |
| return model, lora_config |
|
|
|
|
| |
| |
| |
|
|
|
|
| def merge_adapter( |
| cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path |
| ): |
| print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---") |
|
|
| model_cfg = cfg["model"] |
| merge_cfg = cfg.get("merge", {}) |
| trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) |
|
|
| merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16")) |
| max_shard_size = str(merge_cfg.get("max_shard_size", "2GB")) |
|
|
| try: |
| base = AutoModelForCausalLM.from_pretrained( |
| str(base_dir), |
| torch_dtype=merged_dtype, |
| device_map="cpu", |
| low_cpu_mem_usage=True, |
| trust_remote_code=trust_remote_code, |
| ) |
|
|
| merged = PeftModel.from_pretrained(base, str(adapter_dir)) |
| merged = merged.merge_and_unload() |
|
|
| |
| del base |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| _ensure_dir(final_dir) |
| merged.save_pretrained( |
| str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
| ) |
|
|
| |
| del merged |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| tok = AutoTokenizer.from_pretrained( |
| str(base_dir), trust_remote_code=trust_remote_code |
| ) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| tok.save_pretrained(str(final_dir)) |
|
|
| print("--- Merge complete ---") |
| except Exception as e: |
| logger.error(f"Merge failed: {e}") |
| raise |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--config", required=True, help="Path to YAML config") |
| ap.add_argument( |
| "--merge-only", action="store_true", help="Skip training, just merge adapter" |
| ) |
| args = ap.parse_args() |
|
|
| with open(args.config, "r", encoding="utf-8") as f: |
| cfg = yaml.safe_load(f) |
|
|
| run_dir = _ensure_dir(Path(cfg["run"]["run_dir"])) |
| _ensure_dir(run_dir / "logs") |
|
|
| with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f: |
| yaml.safe_dump(cfg, f, sort_keys=False) |
|
|
| model_cfg = cfg["model"] |
| repo_id = str(model_cfg["repo_id"]).strip() |
| repo_path = Path(repo_id) |
|
|
| |
| if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path): |
| base_dir = repo_path |
| print(f"Using local model at: {base_dir}") |
| elif repo_path.exists() and repo_path.is_dir(): |
| raise ValueError( |
| f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}" |
| ) |
| else: |
| |
| base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model")) |
| if not _looks_like_model_dir(base_dir): |
| print(f"Base model not found at {base_dir}, downloading from {repo_id} ...") |
| snapshot_download( |
| repo_id=repo_id, |
| revision=model_cfg.get("revision", None), |
| local_dir=str(base_dir), |
| local_dir_use_symlinks=False, |
| ) |
|
|
| ckpt_dir = _ensure_dir(run_dir / "checkpoints") |
| best_adapter_dir = _ensure_dir(run_dir / "best_adapter") |
|
|
| merge_cfg = cfg.get("merge", {}) or {} |
| if merge_cfg.get("output_dir"): |
| od = Path(str(merge_cfg["output_dir"])) |
| final_dir = od if od.is_absolute() else (run_dir / od) |
| else: |
| final_dir = run_dir / "final_model" |
|
|
| |
| if args.merge_only: |
| if not _looks_like_model_dir(best_adapter_dir): |
| raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}") |
| merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) |
| return |
|
|
| |
| wandb_run = setup_wandb(cfg, run_dir) |
|
|
| |
| set_seed(int(cfg["run"].get("seed", 42))) |
|
|
| model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir) |
| model, _ = apply_peft(cfg, model) |
|
|
| |
| dpo_cfg = cfg.get("dpo", {}) |
| use_reference_model = bool(dpo_cfg.get("use_reference_model", True)) |
| reference_free = bool(dpo_cfg.get("reference_free", False)) |
| |
| ref_model = None |
| if use_reference_model and not reference_free: |
| print("Loading reference model (frozen copy)...") |
| ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir) |
| ref_model, _ = apply_peft(cfg, ref_model) |
| |
| for param in ref_model.parameters(): |
| param.requires_grad = False |
| ref_model.eval() |
| print("Reference model loaded and frozen") |
|
|
| train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer) |
|
|
| tr_cfg = cfg["train"] |
|
|
| dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) |
| use_fp16 = dtype == torch.float16 |
| use_bf16 = dtype == torch.bfloat16 |
|
|
| max_steps = int(tr_cfg.get("max_steps", 0)) |
| num_train_epochs = float(tr_cfg.get("num_train_epochs", 1)) |
|
|
| |
| ta_params = inspect.signature(TrainingArguments.__init__).parameters |
| eval_key = ( |
| "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy" |
| ) |
|
|
| |
| report_to = [] |
| if wandb_run is not None: |
| report_to.append("wandb") |
|
|
| |
| max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0)) |
| if max_grad_norm <= 0: |
| logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0") |
| max_grad_norm = 1.0 |
| |
| ta_kwargs = dict( |
| output_dir=str(ckpt_dir), |
| max_steps=max_steps if max_steps > 0 else -1, |
| num_train_epochs=num_train_epochs, |
| per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)), |
| per_device_eval_batch_size=int( |
| tr_cfg.get( |
| "per_device_eval_batch_size", |
| tr_cfg.get("per_device_train_batch_size", 1), |
| ) |
| ), |
| gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)), |
| learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), |
| weight_decay=float(tr_cfg.get("weight_decay", 0.0)), |
| warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), |
| lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")), |
| optim=str( |
| tr_cfg.get( |
| "optim", |
| ( |
| "paged_adamw_8bit" |
| if bool(model_cfg.get("use_4bit", False)) |
| else "adamw_torch" |
| ), |
| ) |
| ), |
| max_grad_norm=max_grad_norm, |
| logging_steps=int(tr_cfg.get("logging_steps", 10)), |
| save_strategy=str(tr_cfg.get("save_strategy", "steps")), |
| save_steps=int(tr_cfg.get("save_steps", 200)), |
| save_total_limit=int(tr_cfg.get("save_total_limit", 3)), |
| eval_steps=int(tr_cfg.get("eval_steps", 50)), |
| load_best_model_at_end=( |
| bool(tr_cfg.get("load_best_model_at_end", True)) |
| if eval_ds is not None |
| else False |
| ), |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| fp16=use_fp16, |
| bf16=use_bf16, |
| report_to=report_to, |
| remove_unused_columns=False, |
| ) |
|
|
| |
| ta_kwargs[eval_key] = str( |
| tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no") |
| ) |
|
|
| training_args = TrainingArguments(**ta_kwargs) |
|
|
| |
| callbacks = [JsonlLoggerCallback(run_dir)] |
| |
| |
| early_stopping_cfg = tr_cfg.get("early_stopping", {}) |
| if early_stopping_cfg.get("enabled", False) and eval_ds is not None: |
| early_stopping_callback = EarlyStoppingCallback( |
| early_stopping_patience=int(early_stopping_cfg.get("patience", 3)), |
| early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)), |
| ) |
| callbacks.append(early_stopping_callback) |
| print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, " |
| f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}") |
|
|
| |
| beta = float(dpo_cfg.get("beta", 0.1)) |
| label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0)) |
| loss_type = str(dpo_cfg.get("loss_type", "sigmoid")) |
| max_length = int(cfg["data"].get("max_length", 2048)) |
| max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2)) |
|
|
| print(f"DPO Training with beta={beta}, loss_type={loss_type}") |
| |
| |
| eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")) |
| |
| |
| dpo_config = DPOConfig( |
| output_dir=str(run_dir), |
| num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)), |
| per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)), |
| per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)), |
| gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)), |
| learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), |
| weight_decay=float(tr_cfg.get("weight_decay", 0.01)), |
| adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)), |
| adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)), |
| adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)), |
| max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)), |
| lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")), |
| warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), |
| logging_steps=int(tr_cfg.get("logging_steps", 10)), |
| save_steps=int(tr_cfg.get("save_steps", 100)), |
| save_total_limit=int(tr_cfg.get("save_total_limit", 3)), |
| eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None, |
| eval_strategy=eval_strategy_val, |
| save_strategy=str(tr_cfg.get("save_strategy", "steps")), |
| load_best_model_at_end=( |
| bool(tr_cfg.get("load_best_model_at_end", False)) |
| if eval_ds is not None |
| else False |
| ), |
| metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")), |
| greater_is_better=bool(tr_cfg.get("greater_is_better", False)), |
| fp16=use_fp16, |
| bf16=use_bf16, |
| report_to=report_to, |
| remove_unused_columns=False, |
| |
| beta=beta, |
| label_smoothing=label_smoothing, |
| loss_type=loss_type, |
| max_length=max_length, |
| max_prompt_length=max_prompt_length, |
| ) |
| |
| trainer = DPOTrainer( |
| model=model, |
| ref_model=ref_model, |
| args=dpo_config, |
| train_dataset=train_ds, |
| eval_dataset=eval_ds, |
| processing_class=tokenizer, |
| callbacks=callbacks, |
| ) |
|
|
| |
| resume_from = tr_cfg.get("resume_from_checkpoint", None) |
| if resume_from == "auto": |
| last = get_last_checkpoint(str(ckpt_dir)) |
| resume_from = last if last else None |
| if resume_from: |
| print(f"Resuming from {resume_from}") |
|
|
| print("Starting DPO training...") |
| trainer.train(resume_from_checkpoint=resume_from) |
|
|
| trainer.save_model(str(best_adapter_dir)) |
| print(f"Saved best adapter -> {best_adapter_dir}") |
|
|
| if eval_ds is not None: |
| metrics = trainer.evaluate() |
| with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f: |
| json.dump(metrics, f, indent=2) |
| print(f"Final metrics: {metrics}") |
|
|
| if bool(cfg.get("merge", {}).get("enabled", False)): |
| del trainer, model |
| if ref_model is not None: |
| del ref_model |
| torch.cuda.empty_cache() |
| merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) |
| else: |
| print("Merge disabled. Run with --merge-only later if needed.") |
|
|
| |
| finish_wandb() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|