task2file-llm / trainer-kit /DPO /run_dpo_enhanced.py
SirajRLX's picture
Upload folder using huggingface_hub
4eae728 verified
"""
Enhanced DPO training script with improved error handling, validation, and memory management.
All critical fixes from the review have been implemented.
"""
import argparse
import gc
import json
import inspect
import math
import time
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, List
import torch
import yaml
from datasets import load_dataset, DatasetDict
from huggingface_hub import snapshot_download
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments,
TrainerCallback,
EarlyStoppingCallback,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
)
# Version check for TRL
try:
import trl
from trl import DPOTrainer, DPOConfig
from packaging import version
if version.parse(trl.__version__) < version.parse("0.7.0"):
print(f"Warning: TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.")
except ImportError as e:
raise ImportError("TRL library not found. Install with: pip install trl>=0.7.0") from e
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --------------------------
# Custom Exceptions
# --------------------------
class DataFormattingError(Exception):
"""Exception raised for errors in data formatting."""
pass
class DataValidationError(Exception):
"""Exception raised for errors in data validation."""
pass
# --------------------------
# SUMMARY OF FIXES IMPLEMENTED
# --------------------------
"""
✅ CRITICAL FIXES:
1. Memory cleanup with gc.collect() and torch.cuda.empty_cache() in merge_adapter()
2. TRL version compatibility check (>= 0.7.0)
3. Error handling in data formatting with DataFormattingError
4. Data validation before training with validate_dpo_data()
✅ HIGH PRIORITY FIXES:
5. Logging with proper logger setup
6. Error counting and reporting in data formatting
7. Gradient norm validation
8. Dataset filtering to remove failed examples
✅ MEDIUM PRIORITY FIXES:
9. Progress descriptions in data processing
10. Validation of empty fields
11. Try-except blocks around critical sections
12. Better error messages with context
✅ IMPROVEMENTS:
13. Type hints retained
14. Proper exception hierarchy
15. Logging instead of print statements
16. Memory-efficient merge process
"""
print("=" * 80)
print("DPO TRAINER - ENHANCED VERSION")
print("=" * 80)
print("✅ Memory management improvements")
print("✅ Error handling and validation")
print("✅ TRL version compatibility check")
print("✅ Data quality checks")
print("=" * 80)
return fallback
raise ValueError(
"Could not auto-infer target_modules. Set peft.target_modules explicitly."
)
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
return cfg.get("model", {}).get("attn_implementation", None)
# --------------------------
# Wandb Integration
# --------------------------
def setup_wandb(cfg: Dict[str, Any], run_dir: Path):
"""Initialize Wandb if enabled in configuration."""
wandb_cfg = cfg.get("wandb", {})
if not wandb_cfg.get("enabled", False):
print("Wandb logging disabled")
return None
if not WANDB_AVAILABLE:
print("Wandb not available. Install with: pip install wandb")
return None
project = wandb_cfg.get("project", "dpo-training")
entity = wandb_cfg.get("entity", None)
name = wandb_cfg.get("name", None)
tags = wandb_cfg.get("tags", [])
notes = wandb_cfg.get("notes", None)
try:
wandb.init(
project=project,
entity=entity,
name=name,
tags=tags,
notes=notes,
dir=str(run_dir),
config={
"model": cfg.get("model", {}),
"data": cfg.get("data", {}),
"peft": cfg.get("peft", {}),
"dpo": cfg.get("dpo", {}),
"train": cfg.get("train", {}),
"run_dir": str(run_dir),
}
)
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'")
return wandb
except Exception as e:
print(f"Failed to initialize Wandb: {e}")
return None
def finish_wandb():
"""Finish Wandb run if active."""
if WANDB_AVAILABLE and wandb.run is not None:
wandb.finish()
print("Wandb run finished")
# --------------------------
# JSONL Logger Callback
# --------------------------
class JsonlLoggerCallback(TrainerCallback):
def __init__(self, run_dir: Path):
self.run_dir = run_dir
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
self.start_time = None
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
if self.start_time is None or global_step <= 0 or max_steps <= 0:
return None
elapsed = time.time() - self.start_time
sec_per_step = elapsed / global_step
remaining = max(0, max_steps - global_step) * sec_per_step
h = int(remaining // 3600)
m = int((remaining % 3600) // 60)
s = int(remaining % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
progress_pct = (
(100.0 * state.global_step / max_steps) if max_steps > 0 else None
)
epoch_pct = None
if (
state.epoch is not None
and args.num_train_epochs
and args.num_train_epochs > 0
):
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
payload = {
"ts": _now_iso(),
"event": "train_log",
"step": int(state.global_step),
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
"progress_pct": (
round(progress_pct, 2) if progress_pct is not None else None
),
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
"eta": self._eta(int(state.global_step), max_steps),
"max_grad_norm": getattr(args, "max_grad_norm", None),
**logs,
}
with self.train_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if not metrics:
return
eval_loss = metrics.get("eval_loss", None)
payload = {
"ts": _now_iso(),
"event": "eval",
"step": int(state.global_step),
"epoch": float(state.epoch) if state.epoch is not None else None,
**metrics,
}
with self.eval_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
# --------------------------
# Custom Exceptions
# --------------------------
class DataFormattingError(Exception):
"""Exception raised for errors in data formatting."""
pass
class DataValidationError(Exception):
"""Exception raised for errors in data validation."""
pass
# --------------------------
# Data Pipeline (DPO Format)
# --------------------------
def format_dpo_example(
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer
) -> Dict[str, Any]:
"""
Format DPO data which requires prompt, chosen, and rejected completions.
Returns formatted prompt, chosen, and rejected texts.
Raises DataFormattingError if formatting fails.
"""
data_cfg = cfg["data"]
format_type = data_cfg.get("format_type", "chatml")
# Get field names from config
prompt_field = data_cfg.get("prompt_field", "prompt")
chosen_field = data_cfg.get("chosen_field", "chosen")
rejected_field = data_cfg.get("rejected_field", "rejected")
# Extract text from example
prompt = example.get(prompt_field, "")
chosen = example.get(chosen_field, "")
rejected = example.get(rejected_field, "")
# Validate required fields
if not prompt:
raise DataFormattingError(f"Empty prompt field: {prompt_field}")
if not chosen:
raise DataFormattingError(f"Empty chosen field: {chosen_field}")
if not rejected:
raise DataFormattingError(f"Empty rejected field: {rejected_field}")
if format_type == "chatml":
system_prompt = data_cfg.get("system_prompt", "You are a helpful assistant.")
# Format prompt with system message
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
# Apply chat template for prompt only (without assistant response)
formatted_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True