SirajRLX's picture
Upload folder using huggingface_hub
4eae728 verified
import argparse
import json
import inspect
import math
import gc
import time
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, List
import torch
import yaml
from datasets import load_dataset, DatasetDict
from huggingface_hub import snapshot_download
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoConfig,
BitsAndBytesConfig,
TrainingArguments,
TrainerCallback,
EarlyStoppingCallback,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
)
from trl import DPOTrainer, DPOConfig
# Version check for TRL
try:
from packaging import version
import trl
if version.parse(trl.__version__) < version.parse("0.7.0"):
logger.warning(f"TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
except ImportError:
logger.warning("Could not verify TRL version")
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --------------------------
# Custom Exceptions
# --------------------------
class DataFormattingError(Exception):
"""Exception raised for errors in data formatting."""
pass
class DataValidationError(Exception):
"""Exception raised for errors in data validation."""
pass
# --------------------------
# Helpers
# --------------------------
def _dtype_from_str(s: str) -> torch.dtype:
s = (s or "").lower()
if s in ("float16", "fp16"):
return torch.float16
if s in ("bfloat16", "bf16"):
return torch.bfloat16
if s in ("float32", "fp32"):
return torch.float32
raise ValueError(f"Unknown torch_dtype: {s}")
def _now_iso() -> str:
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
def _safe_exp(x: float) -> float:
x = min(float(x), 50.0)
return float(math.exp(x))
def _ensure_dir(p: Path) -> Path:
p.mkdir(parents=True, exist_ok=True)
return p
def _looks_like_model_dir(p: Path) -> bool:
if not p.exists() or not p.is_dir():
return False
if (p / "config.json").exists():
return True
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
return True
return False
def _infer_target_modules(model) -> List[str]:
names = set()
for n, _ in model.named_modules():
names.add(n.split(".")[-1])
for group in [
["q_proj", "k_proj", "v_proj", "o_proj"],
["Wqkv", "out_proj"],
["query_key_value", "dense"],
["c_attn", "c_proj"],
]:
if all(x in names for x in group):
return group
fallback = [
x
for x in [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"c_attn",
"c_proj",
"out_proj",
"dense",
]
if x in names
]
if fallback:
return fallback
raise ValueError(
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
)
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
return cfg.get("model", {}).get("attn_implementation", None)
# --------------------------
# Wandb Integration
# --------------------------
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
"""Initialize Wandb if enabled in configuration."""
wandb_cfg = cfg.get("wandb", {})
if not wandb_cfg.get("enabled", False):
print("Wandb logging disabled")
return None
if not WANDB_AVAILABLE:
print("Wandb not available. Install with: pip install wandb")
return None
project = wandb_cfg.get("project", "dpo-training")
entity = wandb_cfg.get("entity", None)
name = wandb_cfg.get("name", None)
tags = wandb_cfg.get("tags", [])
notes = wandb_cfg.get("notes", None)
try:
wandb.init(
project=project,
entity=entity,
name=name,
tags=tags,
notes=notes,
dir=str(run_dir),
config={
"model": cfg.get("model", {}),
"data": cfg.get("data", {}),
"peft": cfg.get("peft", {}),
"dpo": cfg.get("dpo", {}),
"train": cfg.get("train", {}),
"run_dir": str(run_dir),
}
)
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
return wandb
except Exception as e:
print(f"Failed to initialize Wandb: {e}")
return None
def finish_wandb():
"""Finish Wandb run if active."""
if WANDB_AVAILABLE and wandb.run is not None:
wandb.finish()
print("Wandb run finished")
# --------------------------
# JSONL Logger Callback
# --------------------------
class JsonlLoggerCallback(TrainerCallback):
def __init__(self, run_dir: Path):
self.run_dir = run_dir
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
self.start_time = None
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
if self.start_time is None or global_step <= 0 or max_steps <= 0:
return None
elapsed = time.time() - self.start_time
sec_per_step = elapsed / global_step
remaining = max(0, max_steps - global_step) * sec_per_step
h = int(remaining // 3600)
m = int((remaining % 3600) // 60)
s = int(remaining % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
progress_pct = (
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
)
epoch_pct = None
if (
state.epoch is not None
and args.num_train_epochs
and args.num_train_epochs > 0
):
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
payload = {
"ts": _now_iso(),
"event": "train_log",
"step": int(state.global_step),
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
"progress_pct": (
round(progress_pct, 2) if progress_pct is not None else None
),
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
"eta": self._eta(int(state.global_step), max_steps),
"max_grad_norm": getattr(args, "max_grad_norm", None),
**logs,
}
with self.train_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if not metrics:
return
eval_loss = metrics.get("eval_loss", None)
payload = {
"ts": _now_iso(),
"event": "eval",
"step": int(state.global_step),
"epoch": float(state.epoch) if state.epoch is not None else None,
**metrics,
}
with self.eval_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
# --------------------------
# Custom Exceptions
# --------------------------
class DataFormattingError(Exception):
"""Exception raised for errors in data formatting."""
pass
class DataValidationError(Exception):
"""Exception raised for errors in data validation."""
pass
# --------------------------
# Data Pipeline (DPO Format)
# --------------------------
def format_dpo_example(
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
) -> Dict[str, Any]:
"""
Format DPO data which requires prompt, chosen, and rejected completions.
Returns formatted prompt, chosen, and rejected texts.
Raises DataFormattingError if formatting fails.
"""
data_cfg = cfg["data"]
format_type = data_cfg.get("format_type", "chatml")
# Get field names from config
prompt_field = data_cfg.get("prompt_field", "prompt")
chosen_field = data_cfg.get("chosen_field", "chosen")
rejected_field = data_cfg.get("rejected_field", "rejected")
# Extract text from example
prompt = example.get(prompt_field, "")
chosen = example.get(chosen_field, "")
rejected = example.get(rejected_field, "")
# Validate required fields
if not prompt:
raise DataFormattingError(f"Empty prompt field: {prompt_field}")
if not chosen:
raise DataFormattingError(f"Empty chosen field: {chosen_field}")
if not rejected:
raise DataFormattingError(f"Empty rejected field: {rejected_field}")
if format_type == "chatml":
# DPOTrainer will handle chat template internally, just pass raw text
formatted_prompt = prompt
formatted_chosen = chosen
formatted_rejected = rejected
elif format_type == "alpaca":
# DPOTrainer will handle formatting, just pass raw text
formatted_prompt = prompt
formatted_chosen = chosen
formatted_rejected = rejected
elif format_type == "custom":
# Custom template
template = data_cfg.get("custom_template", "{prompt}")
formatted_prompt = template.format(prompt=prompt)
formatted_chosen = chosen
formatted_rejected = rejected
else:
raise ValueError(f"Unsupported format_type: {format_type}")
return {
"prompt": formatted_prompt,
"chosen": formatted_chosen,
"rejected": formatted_rejected,
}
def validate_dpo_data(dataset, stage: str = "train") -> None:
"""
Validate DPO dataset has all required fields and proper structure.
Args:
dataset: Dataset to validate
stage: Training stage ("train" or "eval")
Raises:
DataValidationError if validation fails
"""
required_fields = ["prompt", "chosen", "rejected"]
# Check required fields exist
for field in required_fields:
if field not in dataset.column_names:
raise DataValidationError(
f"{stage} dataset missing required field: {field}. "
f"Available fields: {dataset.column_names}"
)
# Sample validation - check first example
if len(dataset) > 0:
sample = dataset[0]
for field in required_fields:
if not sample[field] or len(sample[field].strip()) == 0:
logger.warning(f"{stage} dataset has empty {field} in first example")
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples")
def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
"""
Build datasets for DPO training.
Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."}
Or with custom field names specified in config.
"""
data_cfg = cfg["data"]
train_path = data_cfg["train_jsonl"]
eval_path = data_cfg.get("eval_jsonl", None)
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
shuffle = bool(data_cfg.get("shuffle", True))
num_proc = int(data_cfg.get("num_proc", 4))
# Ensure tokenizer has pad token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load datasets
ds = load_dataset("json", data_files={"train": train_path})
if eval_path:
ds_eval = load_dataset("json", data_files={"eval": eval_path})
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
else:
if 0.0 < split_ratio < 1.0:
split = ds["train"].train_test_split(
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42))
)
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
else:
dsd = DatasetDict({"train": ds["train"], "eval": None})
# Format DPO examples with error handling
def format_fn(examples):
prompts = []
chosen_list = []
rejected_list = []
errors = 0
for i in range(len(examples[list(examples.keys())[0]])):
example = {k: examples[k][i] for k in examples.keys()}
try:
formatted = format_dpo_example(example, cfg, tokenizer)
prompts.append(formatted["prompt"])
chosen_list.append(formatted["chosen"])
rejected_list.append(formatted["rejected"])
except (DataFormattingError, Exception) as e:
errors += 1
if errors <= 5: # Log first 5 errors
logger.warning(f"Failed to format example {i}: {e}")
# Add empty placeholder to maintain batch structure
prompts.append("")
chosen_list.append("")
rejected_list.append("")
if errors > 0:
logger.warning(f"Total formatting errors in batch: {errors}")
return {
"prompt": prompts,
"chosen": chosen_list,
"rejected": rejected_list,
}
logger.info("Formatting train DPO data...")
formatted_train = dsd["train"].map(
format_fn,
batched=True,
num_proc=num_proc,
remove_columns=dsd["train"].column_names,
desc="Formatting train DPO data",
)
# Filter out failed examples (empty prompts)
formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0)
logger.info(f"Train dataset after filtering: {len(formatted_train)} examples")
# Validate formatted data
validate_dpo_data(formatted_train, "train")
formatted_eval = None
if dsd["eval"] is not None:
logger.info("Formatting eval DPO data...")
formatted_eval = dsd["eval"].map(
format_fn,
batched=True,
num_proc=num_proc,
remove_columns=dsd["eval"].column_names,
desc="Formatting eval DPO data",
)
formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0)
logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples")
validate_dpo_data(formatted_eval, "eval")
if shuffle:
formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
return formatted_train, formatted_eval
# --------------------------
# Model Loading + PEFT
# --------------------------
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
model_cfg = cfg["model"]
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
device_map = model_cfg.get("device_map", "auto")
tokenizer = AutoTokenizer.from_pretrained(
str(base_dir),
use_fast=use_fast,
trust_remote_code=trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
use_4bit = bool(model_cfg.get("use_4bit", False))
quant_cfg = None
if use_4bit:
quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
bnb_4bit_use_double_quant=bool(
model_cfg.get("bnb_4bit_use_double_quant", True)
),
bnb_4bit_compute_dtype=_dtype_from_str(
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")
),
)
attn_impl = _choose_attn_impl(cfg)
# First check the model type to determine loading strategy
try:
config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True)
model_type = config.model_type
architectures = getattr(config, 'architectures', [])
# Handle Mistral3 (multimodal) models
if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]):
logger.info(f"Detected Mistral3 model architecture, loading with specific class")
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration
try:
model = Mistral3ForConditionalGeneration.from_pretrained(
str(base_dir),
config=config,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
attn_implementation=attn_impl,
)
except Exception as e:
if attn_impl is not None:
logger.warning(f"attn_implementation='{attn_impl}' failed: {e}")
logger.warning("Falling back to default attention implementation.")
model = Mistral3ForConditionalGeneration.from_pretrained(
str(base_dir),
config=config,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
)
else:
raise e
else:
# Standard AutoModelForCausalLM loading for other models
try:
model = AutoModelForCausalLM.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
attn_implementation=attn_impl,
)
except Exception as e:
if attn_impl is not None:
logger.warning(f"attn_implementation='{attn_impl}' failed: {e}")
logger.warning("Falling back to default attention implementation.")
model = AutoModelForCausalLM.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
)
else:
raise e
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
# Ensure all parameters are off meta device
logger.info("Ensuring all parameters are materialized...")
meta_params = []
for name, param in model.named_parameters():
if param.device.type == 'meta':
meta_params.append(name)
if meta_params:
logger.warning(f"Found {len(meta_params)} parameters on meta device")
# For multimodal models, freeze vision components if doing text-only training
if hasattr(model, 'vision_tower'):
logger.info("Freezing vision tower for text-only training")
for param in model.vision_tower.parameters():
param.requires_grad = False
return model, tokenizer
def apply_peft(cfg: Dict[str, Any], model):
peft_cfg = cfg["peft"]
model_cfg = cfg["model"]
tr_cfg = cfg["train"]
if not bool(peft_cfg.get("enabled", True)):
return model, None
use_4bit = bool(model_cfg.get("use_4bit", False))
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
if hasattr(model, "config"):
model.config.use_cache = False
if use_4bit:
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=gradient_checkpointing,
)
target_modules = peft_cfg.get("target_modules", "auto")
if target_modules == "auto":
target_modules = _infer_target_modules(model)
lora_config = LoraConfig(
r=int(peft_cfg.get("r", 16)),
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
bias=str(peft_cfg.get("bias", "none")),
task_type="CAUSAL_LM",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
return model, lora_config
# --------------------------
# Merge Logic
# --------------------------
def merge_adapter(
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path
):
logger.info(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
model_cfg = cfg["model"]
merge_cfg = cfg.get("merge", {})
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
try:
base = AutoModelForCausalLM.from_pretrained(
str(base_dir),
torch_dtype=merged_dtype,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
merged = PeftModel.from_pretrained(base, str(adapter_dir))
merged = merged.merge_and_unload()
# Clean up base model to free memory
del base
gc.collect()
torch.cuda.empty_cache()
_ensure_dir(final_dir)
merged.save_pretrained(
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size
)
# Clean up merged model
del merged
gc.collect()
torch.cuda.empty_cache()
tok = AutoTokenizer.from_pretrained(
str(base_dir), trust_remote_code=trust_remote_code
)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.save_pretrained(str(final_dir))
logger.info("--- Merge complete ---")
except Exception as e:
logger.error(f"Merge failed: {e}")
raise
# --------------------------
# Main
# --------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True, help="Path to YAML config")
ap.add_argument(
"--merge-only", action="store_true", help="Skip training, just merge adapter"
)
args = ap.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
_ensure_dir(run_dir / "logs")
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
yaml.safe_dump(cfg, f, sort_keys=False)
model_cfg = cfg["model"]
repo_id = str(model_cfg["repo_id"]).strip()
repo_path = Path(repo_id)
# Local model path -> load directly
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path):
base_dir = repo_path
logger.info(f"Using local model at: {base_dir}")
elif repo_path.exists() and repo_path.is_dir():
raise ValueError(
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}"
)
else:
# HF repo_id -> download
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
if not _looks_like_model_dir(base_dir):
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
snapshot_download(
repo_id=repo_id,
revision=model_cfg.get("revision", None),
local_dir=str(base_dir),
local_dir_use_symlinks=False,
)
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
merge_cfg = cfg.get("merge", {}) or {}
if merge_cfg.get("output_dir"):
od = Path(str(merge_cfg["output_dir"]))
final_dir = od if od.is_absolute() else (run_dir / od)
else:
final_dir = run_dir / "final_model"
# Merge-only
if args.merge_only:
if not _looks_like_model_dir(best_adapter_dir):
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
return
# Initialize Wandb
wandb_run = setup_wandb(cfg, run_dir)
# Training
set_seed(int(cfg["run"].get("seed", 42)))
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
model, _ = apply_peft(cfg, model)
# Load reference model for DPO (if using reference model)
dpo_cfg = cfg.get("dpo", {})
use_reference_model = bool(dpo_cfg.get("use_reference_model", True))
reference_free = bool(dpo_cfg.get("reference_free", False))
ref_model = None
if use_reference_model and not reference_free:
print("Loading reference model (frozen copy)...")
ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir)
ref_model, _ = apply_peft(cfg, ref_model)
# Freeze reference model
for param in ref_model.parameters():
param.requires_grad = False
ref_model.eval()
print("Reference model loaded and frozen")
train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer)
tr_cfg = cfg["train"]
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
use_fp16 = dtype == torch.float16
use_bf16 = dtype == torch.bfloat16
max_steps = int(tr_cfg.get("max_steps", 0))
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
# Dynamic evaluation strategy parameter handling
ta_params = inspect.signature(TrainingArguments.__init__).parameters
eval_key = (
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
)
# Setup reporting based on wandb availability
report_to = []
if wandb_run is not None:
report_to.append("wandb")
# Validate and adjust training parameters
max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0))
if max_grad_norm <= 0:
logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0")
max_grad_norm = 1.0
ta_kwargs = dict(
output_dir=str(ckpt_dir),
max_steps=max_steps if max_steps > 0 else -1,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
per_device_eval_batch_size=int(
tr_cfg.get(
"per_device_eval_batch_size",
tr_cfg.get("per_device_train_batch_size", 1),
)
),
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)),
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
optim=str(
tr_cfg.get(
"optim",
(
"paged_adamw_8bit"
if bool(model_cfg.get("use_4bit", False))
else "adamw_torch"
),
)
),
max_grad_norm=max_grad_norm,
logging_steps=int(tr_cfg.get("logging_steps", 10)),
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
save_steps=int(tr_cfg.get("save_steps", 200)),
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
eval_steps=int(tr_cfg.get("eval_steps", 50)),
load_best_model_at_end=(
bool(tr_cfg.get("load_best_model_at_end", True))
if eval_ds is not None
else False
),
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=use_fp16,
bf16=use_bf16,
report_to=report_to,
remove_unused_columns=False,
)
# Set the correct argument name for this transformers version
ta_kwargs[eval_key] = str(
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")
)
training_args = TrainingArguments(**ta_kwargs)
# Setup callbacks
callbacks = [JsonlLoggerCallback(run_dir)]
# Add early stopping callback if enabled
early_stopping_cfg = tr_cfg.get("early_stopping", {})
if early_stopping_cfg.get("enabled", False) and eval_ds is not None:
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)),
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)),
)
callbacks.append(early_stopping_callback)
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, "
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}")
# DPO-specific parameters
beta = float(dpo_cfg.get("beta", 0.1))
label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0))
loss_type = str(dpo_cfg.get("loss_type", "sigmoid"))
max_length = int(cfg["data"].get("max_length", 2048))
max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2))
logger.info(f"DPO Training with beta={beta}, loss_type={loss_type}")
# Get evaluation strategy from config
eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
# Create DPOConfig with all training and DPO-specific parameters
dpo_config = DPOConfig(
output_dir=str(run_dir),
model_init_kwargs={"trust_remote_code": True},
num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)),
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)),
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)),
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)),
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)),
weight_decay=float(tr_cfg.get("weight_decay", 0.01)),
adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)),
adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)),
adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)),
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")),
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
logging_steps=int(tr_cfg.get("logging_steps", 10)),
save_steps=int(tr_cfg.get("save_steps", 100)),
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None,
eval_strategy=eval_strategy_val,
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
load_best_model_at_end=(
bool(tr_cfg.get("load_best_model_at_end", False))
if eval_ds is not None
else False
),
metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")),
greater_is_better=bool(tr_cfg.get("greater_is_better", False)),
fp16=use_fp16,
bf16=use_bf16,
report_to=report_to,
remove_unused_columns=False,
# DPO-specific parameters
beta=beta,
label_smoothing=label_smoothing,
loss_type=loss_type,
max_length=max_length,
max_prompt_length=max_prompt_length,
)
# DPOTrainer
# For text-only models, don't pass processing_class - let DPOTrainer handle it
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=dpo_config,
train_dataset=train_ds,
eval_dataset=eval_ds,
callbacks=callbacks,
)
# Resume
resume_from = tr_cfg.get("resume_from_checkpoint", None)
if resume_from == "auto":
last = get_last_checkpoint(str(ckpt_dir))
resume_from = last if last else None
if resume_from:
logger.info(f"Resuming from {resume_from}")
logger.info("Starting DPO training...")
trainer.train(resume_from_checkpoint=resume_from)
trainer.save_model(str(best_adapter_dir))
logger.info(f"Saved best adapter -> {best_adapter_dir}")
if eval_ds is not None:
metrics = trainer.evaluate()
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
print(f"Final metrics: {metrics}")
if bool(cfg.get("merge", {}).get("enabled", False)):
del trainer, model
if ref_model is not None:
del ref_model
torch.cuda.empty_cache()
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
else:
print("Merge disabled. Run with --merge-only later if needed.")
# Finish Wandb run
finish_wandb()
if __name__ == "__main__":
main()