|
|
import argparse |
|
|
import json |
|
|
import inspect |
|
|
import math |
|
|
import gc |
|
|
import time |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Optional, Tuple, List |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
from datasets import load_dataset, DatasetDict |
|
|
from huggingface_hub import snapshot_download |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
AutoConfig, |
|
|
BitsAndBytesConfig, |
|
|
TrainingArguments, |
|
|
TrainerCallback, |
|
|
EarlyStoppingCallback, |
|
|
set_seed, |
|
|
) |
|
|
from transformers.trainer_utils import get_last_checkpoint |
|
|
from peft import ( |
|
|
LoraConfig, |
|
|
get_peft_model, |
|
|
prepare_model_for_kbit_training, |
|
|
PeftModel, |
|
|
) |
|
|
from trl import DPOTrainer, DPOConfig |
|
|
|
|
|
|
|
|
try: |
|
|
from packaging import version |
|
|
import trl |
|
|
if version.parse(trl.__version__) < version.parse("0.7.0"): |
|
|
logger.warning(f"TRL version {trl.__version__} detected. Version >= 0.7.0 recommended.") |
|
|
except ImportError: |
|
|
logger.warning("Could not verify TRL version") |
|
|
|
|
|
try: |
|
|
import wandb |
|
|
WANDB_AVAILABLE = True |
|
|
except ImportError: |
|
|
WANDB_AVAILABLE = False |
|
|
wandb = None |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataFormattingError(Exception): |
|
|
"""Exception raised for errors in data formatting.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class DataValidationError(Exception): |
|
|
"""Exception raised for errors in data validation.""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _dtype_from_str(s: str) -> torch.dtype: |
|
|
s = (s or "").lower() |
|
|
if s in ("float16", "fp16"): |
|
|
return torch.float16 |
|
|
if s in ("bfloat16", "bf16"): |
|
|
return torch.bfloat16 |
|
|
if s in ("float32", "fp32"): |
|
|
return torch.float32 |
|
|
raise ValueError(f"Unknown torch_dtype: {s}") |
|
|
|
|
|
|
|
|
def _now_iso() -> str: |
|
|
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime()) |
|
|
|
|
|
|
|
|
def _safe_exp(x: float) -> float: |
|
|
x = min(float(x), 50.0) |
|
|
return float(math.exp(x)) |
|
|
|
|
|
|
|
|
def _ensure_dir(p: Path) -> Path: |
|
|
p.mkdir(parents=True, exist_ok=True) |
|
|
return p |
|
|
|
|
|
|
|
|
def _looks_like_model_dir(p: Path) -> bool: |
|
|
if not p.exists() or not p.is_dir(): |
|
|
return False |
|
|
if (p / "config.json").exists(): |
|
|
return True |
|
|
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")): |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def _infer_target_modules(model) -> List[str]: |
|
|
names = set() |
|
|
for n, _ in model.named_modules(): |
|
|
names.add(n.split(".")[-1]) |
|
|
|
|
|
for group in [ |
|
|
["q_proj", "k_proj", "v_proj", "o_proj"], |
|
|
["Wqkv", "out_proj"], |
|
|
["query_key_value", "dense"], |
|
|
["c_attn", "c_proj"], |
|
|
]: |
|
|
if all(x in names for x in group): |
|
|
return group |
|
|
|
|
|
fallback = [ |
|
|
x |
|
|
for x in [ |
|
|
"q_proj", |
|
|
"k_proj", |
|
|
"v_proj", |
|
|
"o_proj", |
|
|
"c_attn", |
|
|
"c_proj", |
|
|
"out_proj", |
|
|
"dense", |
|
|
] |
|
|
if x in names |
|
|
] |
|
|
if fallback: |
|
|
return fallback |
|
|
|
|
|
raise ValueError( |
|
|
"Could not auto-infer target_modules. Set peft.target_modules explicitly." |
|
|
) |
|
|
|
|
|
|
|
|
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]: |
|
|
return cfg.get("model", {}).get("attn_implementation", None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_wandb(cfg: Dict[str, Any], run_dir: Path): |
|
|
"""Initialize Wandb if enabled in configuration.""" |
|
|
wandb_cfg = cfg.get("wandb", {}) |
|
|
|
|
|
if not wandb_cfg.get("enabled", False): |
|
|
print("Wandb logging disabled") |
|
|
return None |
|
|
|
|
|
if not WANDB_AVAILABLE: |
|
|
print("Wandb not available. Install with: pip install wandb") |
|
|
return None |
|
|
|
|
|
project = wandb_cfg.get("project", "dpo-training") |
|
|
entity = wandb_cfg.get("entity", None) |
|
|
name = wandb_cfg.get("name", None) |
|
|
tags = wandb_cfg.get("tags", []) |
|
|
notes = wandb_cfg.get("notes", None) |
|
|
|
|
|
try: |
|
|
wandb.init( |
|
|
project=project, |
|
|
entity=entity, |
|
|
name=name, |
|
|
tags=tags, |
|
|
notes=notes, |
|
|
dir=str(run_dir), |
|
|
config={ |
|
|
"model": cfg.get("model", {}), |
|
|
"data": cfg.get("data", {}), |
|
|
"peft": cfg.get("peft", {}), |
|
|
"dpo": cfg.get("dpo", {}), |
|
|
"train": cfg.get("train", {}), |
|
|
"run_dir": str(run_dir), |
|
|
} |
|
|
) |
|
|
print(f"Wandb initialized: project='{project}', name='{name or 'auto-generated'}'") |
|
|
return wandb |
|
|
except Exception as e: |
|
|
print(f"Failed to initialize Wandb: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def finish_wandb(): |
|
|
"""Finish Wandb run if active.""" |
|
|
if WANDB_AVAILABLE and wandb.run is not None: |
|
|
wandb.finish() |
|
|
print("Wandb run finished") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JsonlLoggerCallback(TrainerCallback): |
|
|
def __init__(self, run_dir: Path): |
|
|
self.run_dir = run_dir |
|
|
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl" |
|
|
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl" |
|
|
self.start_time = None |
|
|
|
|
|
def _eta(self, global_step: int, max_steps: int) -> Optional[str]: |
|
|
if self.start_time is None or global_step <= 0 or max_steps <= 0: |
|
|
return None |
|
|
elapsed = time.time() - self.start_time |
|
|
sec_per_step = elapsed / global_step |
|
|
remaining = max(0, max_steps - global_step) * sec_per_step |
|
|
h = int(remaining // 3600) |
|
|
m = int((remaining % 3600) // 60) |
|
|
s = int(remaining % 60) |
|
|
return f"{h:02d}:{m:02d}:{s:02d}" |
|
|
|
|
|
def on_train_begin(self, args, state, control, **kwargs): |
|
|
self.start_time = time.time() |
|
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
|
if not logs: |
|
|
return |
|
|
|
|
|
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0 |
|
|
progress_pct = ( |
|
|
(100.0 * state.global_step / max_steps) if max_steps > 0 else None |
|
|
) |
|
|
epoch_pct = None |
|
|
if ( |
|
|
state.epoch is not None |
|
|
and args.num_train_epochs |
|
|
and args.num_train_epochs > 0 |
|
|
): |
|
|
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs)) |
|
|
|
|
|
payload = { |
|
|
"ts": _now_iso(), |
|
|
"event": "train_log", |
|
|
"step": int(state.global_step), |
|
|
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None, |
|
|
"progress_pct": ( |
|
|
round(progress_pct, 2) if progress_pct is not None else None |
|
|
), |
|
|
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None, |
|
|
"eta": self._eta(int(state.global_step), max_steps), |
|
|
"max_grad_norm": getattr(args, "max_grad_norm", None), |
|
|
**logs, |
|
|
} |
|
|
|
|
|
with self.train_log_path.open("a", encoding="utf-8") as f: |
|
|
f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
|
|
|
|
|
def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
|
|
if not metrics: |
|
|
return |
|
|
eval_loss = metrics.get("eval_loss", None) |
|
|
|
|
|
payload = { |
|
|
"ts": _now_iso(), |
|
|
"event": "eval", |
|
|
"step": int(state.global_step), |
|
|
"epoch": float(state.epoch) if state.epoch is not None else None, |
|
|
**metrics, |
|
|
} |
|
|
with self.eval_log_path.open("a", encoding="utf-8") as f: |
|
|
f.write(json.dumps(payload, ensure_ascii=False) + "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataFormattingError(Exception): |
|
|
"""Exception raised for errors in data formatting.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class DataValidationError(Exception): |
|
|
"""Exception raised for errors in data validation.""" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_dpo_example( |
|
|
example: Dict[str, Any], cfg: Dict[str, Any], tokenizer |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Format DPO data which requires prompt, chosen, and rejected completions. |
|
|
Returns formatted prompt, chosen, and rejected texts. |
|
|
Raises DataFormattingError if formatting fails. |
|
|
""" |
|
|
data_cfg = cfg["data"] |
|
|
format_type = data_cfg.get("format_type", "chatml") |
|
|
|
|
|
|
|
|
prompt_field = data_cfg.get("prompt_field", "prompt") |
|
|
chosen_field = data_cfg.get("chosen_field", "chosen") |
|
|
rejected_field = data_cfg.get("rejected_field", "rejected") |
|
|
|
|
|
|
|
|
prompt = example.get(prompt_field, "") |
|
|
chosen = example.get(chosen_field, "") |
|
|
rejected = example.get(rejected_field, "") |
|
|
|
|
|
|
|
|
if not prompt: |
|
|
raise DataFormattingError(f"Empty prompt field: {prompt_field}") |
|
|
if not chosen: |
|
|
raise DataFormattingError(f"Empty chosen field: {chosen_field}") |
|
|
if not rejected: |
|
|
raise DataFormattingError(f"Empty rejected field: {rejected_field}") |
|
|
|
|
|
if format_type == "chatml": |
|
|
|
|
|
formatted_prompt = prompt |
|
|
formatted_chosen = chosen |
|
|
formatted_rejected = rejected |
|
|
|
|
|
elif format_type == "alpaca": |
|
|
|
|
|
formatted_prompt = prompt |
|
|
formatted_chosen = chosen |
|
|
formatted_rejected = rejected |
|
|
|
|
|
elif format_type == "custom": |
|
|
|
|
|
template = data_cfg.get("custom_template", "{prompt}") |
|
|
formatted_prompt = template.format(prompt=prompt) |
|
|
formatted_chosen = chosen |
|
|
formatted_rejected = rejected |
|
|
else: |
|
|
raise ValueError(f"Unsupported format_type: {format_type}") |
|
|
|
|
|
return { |
|
|
"prompt": formatted_prompt, |
|
|
"chosen": formatted_chosen, |
|
|
"rejected": formatted_rejected, |
|
|
} |
|
|
|
|
|
|
|
|
def validate_dpo_data(dataset, stage: str = "train") -> None: |
|
|
""" |
|
|
Validate DPO dataset has all required fields and proper structure. |
|
|
|
|
|
Args: |
|
|
dataset: Dataset to validate |
|
|
stage: Training stage ("train" or "eval") |
|
|
|
|
|
Raises: |
|
|
DataValidationError if validation fails |
|
|
""" |
|
|
required_fields = ["prompt", "chosen", "rejected"] |
|
|
|
|
|
|
|
|
for field in required_fields: |
|
|
if field not in dataset.column_names: |
|
|
raise DataValidationError( |
|
|
f"{stage} dataset missing required field: {field}. " |
|
|
f"Available fields: {dataset.column_names}" |
|
|
) |
|
|
|
|
|
|
|
|
if len(dataset) > 0: |
|
|
sample = dataset[0] |
|
|
for field in required_fields: |
|
|
if not sample[field] or len(sample[field].strip()) == 0: |
|
|
logger.warning(f"{stage} dataset has empty {field} in first example") |
|
|
|
|
|
logger.info(f"{stage} dataset validation passed: {len(dataset)} examples") |
|
|
|
|
|
|
|
|
def build_dpo_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]: |
|
|
""" |
|
|
Build datasets for DPO training. |
|
|
Expected JSONL format: {"prompt": "...", "chosen": "...", "rejected": "..."} |
|
|
Or with custom field names specified in config. |
|
|
""" |
|
|
data_cfg = cfg["data"] |
|
|
train_path = data_cfg["train_jsonl"] |
|
|
eval_path = data_cfg.get("eval_jsonl", None) |
|
|
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0)) |
|
|
shuffle = bool(data_cfg.get("shuffle", True)) |
|
|
num_proc = int(data_cfg.get("num_proc", 4)) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
ds = load_dataset("json", data_files={"train": train_path}) |
|
|
|
|
|
if eval_path: |
|
|
ds_eval = load_dataset("json", data_files={"eval": eval_path}) |
|
|
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]}) |
|
|
else: |
|
|
if 0.0 < split_ratio < 1.0: |
|
|
split = ds["train"].train_test_split( |
|
|
test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)) |
|
|
) |
|
|
dsd = DatasetDict({"train": split["train"], "eval": split["test"]}) |
|
|
else: |
|
|
dsd = DatasetDict({"train": ds["train"], "eval": None}) |
|
|
|
|
|
|
|
|
def format_fn(examples): |
|
|
prompts = [] |
|
|
chosen_list = [] |
|
|
rejected_list = [] |
|
|
errors = 0 |
|
|
|
|
|
for i in range(len(examples[list(examples.keys())[0]])): |
|
|
example = {k: examples[k][i] for k in examples.keys()} |
|
|
try: |
|
|
formatted = format_dpo_example(example, cfg, tokenizer) |
|
|
prompts.append(formatted["prompt"]) |
|
|
chosen_list.append(formatted["chosen"]) |
|
|
rejected_list.append(formatted["rejected"]) |
|
|
except (DataFormattingError, Exception) as e: |
|
|
errors += 1 |
|
|
if errors <= 5: |
|
|
logger.warning(f"Failed to format example {i}: {e}") |
|
|
|
|
|
prompts.append("") |
|
|
chosen_list.append("") |
|
|
rejected_list.append("") |
|
|
|
|
|
if errors > 0: |
|
|
logger.warning(f"Total formatting errors in batch: {errors}") |
|
|
|
|
|
return { |
|
|
"prompt": prompts, |
|
|
"chosen": chosen_list, |
|
|
"rejected": rejected_list, |
|
|
} |
|
|
|
|
|
logger.info("Formatting train DPO data...") |
|
|
formatted_train = dsd["train"].map( |
|
|
format_fn, |
|
|
batched=True, |
|
|
num_proc=num_proc, |
|
|
remove_columns=dsd["train"].column_names, |
|
|
desc="Formatting train DPO data", |
|
|
) |
|
|
|
|
|
|
|
|
formatted_train = formatted_train.filter(lambda x: len(x["prompt"]) > 0) |
|
|
logger.info(f"Train dataset after filtering: {len(formatted_train)} examples") |
|
|
|
|
|
|
|
|
validate_dpo_data(formatted_train, "train") |
|
|
|
|
|
formatted_eval = None |
|
|
if dsd["eval"] is not None: |
|
|
logger.info("Formatting eval DPO data...") |
|
|
formatted_eval = dsd["eval"].map( |
|
|
format_fn, |
|
|
batched=True, |
|
|
num_proc=num_proc, |
|
|
remove_columns=dsd["eval"].column_names, |
|
|
desc="Formatting eval DPO data", |
|
|
) |
|
|
formatted_eval = formatted_eval.filter(lambda x: len(x["prompt"]) > 0) |
|
|
logger.info(f"Eval dataset after filtering: {len(formatted_eval)} examples") |
|
|
validate_dpo_data(formatted_eval, "eval") |
|
|
|
|
|
if shuffle: |
|
|
formatted_train = formatted_train.shuffle(seed=int(cfg["run"].get("seed", 42))) |
|
|
|
|
|
return formatted_train, formatted_eval |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path): |
|
|
model_cfg = cfg["model"] |
|
|
trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) |
|
|
use_fast = bool(model_cfg.get("tokenizer_use_fast", True)) |
|
|
device_map = model_cfg.get("device_map", "auto") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
str(base_dir), |
|
|
use_fast=use_fast, |
|
|
trust_remote_code=trust_remote_code, |
|
|
) |
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) |
|
|
use_4bit = bool(model_cfg.get("use_4bit", False)) |
|
|
|
|
|
quant_cfg = None |
|
|
if use_4bit: |
|
|
quant_cfg = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")), |
|
|
bnb_4bit_use_double_quant=bool( |
|
|
model_cfg.get("bnb_4bit_use_double_quant", True) |
|
|
), |
|
|
bnb_4bit_compute_dtype=_dtype_from_str( |
|
|
model_cfg.get("bnb_4bit_compute_dtype", "bfloat16") |
|
|
), |
|
|
) |
|
|
|
|
|
attn_impl = _choose_attn_impl(cfg) |
|
|
|
|
|
|
|
|
try: |
|
|
config = AutoConfig.from_pretrained(str(base_dir), trust_remote_code=True) |
|
|
model_type = config.model_type |
|
|
architectures = getattr(config, 'architectures', []) |
|
|
|
|
|
|
|
|
if model_type == "mistral3" or (architectures and "Mistral3" in architectures[0]): |
|
|
logger.info(f"Detected Mistral3 model architecture, loading with specific class") |
|
|
from transformers.models.mistral3.modeling_mistral3 import Mistral3ForConditionalGeneration |
|
|
|
|
|
try: |
|
|
model = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
str(base_dir), |
|
|
config=config, |
|
|
device_map=device_map, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=(torch_dtype if not use_4bit else None), |
|
|
quantization_config=quant_cfg, |
|
|
attn_implementation=attn_impl, |
|
|
) |
|
|
except Exception as e: |
|
|
if attn_impl is not None: |
|
|
logger.warning(f"attn_implementation='{attn_impl}' failed: {e}") |
|
|
logger.warning("Falling back to default attention implementation.") |
|
|
model = Mistral3ForConditionalGeneration.from_pretrained( |
|
|
str(base_dir), |
|
|
config=config, |
|
|
device_map=device_map, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=(torch_dtype if not use_4bit else None), |
|
|
quantization_config=quant_cfg, |
|
|
) |
|
|
else: |
|
|
raise e |
|
|
else: |
|
|
|
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
str(base_dir), |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=(torch_dtype if not use_4bit else None), |
|
|
quantization_config=quant_cfg, |
|
|
attn_implementation=attn_impl, |
|
|
) |
|
|
except Exception as e: |
|
|
if attn_impl is not None: |
|
|
logger.warning(f"attn_implementation='{attn_impl}' failed: {e}") |
|
|
logger.warning("Falling back to default attention implementation.") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
str(base_dir), |
|
|
device_map=device_map, |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=(torch_dtype if not use_4bit else None), |
|
|
quantization_config=quant_cfg, |
|
|
) |
|
|
else: |
|
|
raise e |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load model: {e}") |
|
|
raise e |
|
|
|
|
|
|
|
|
logger.info("Ensuring all parameters are materialized...") |
|
|
meta_params = [] |
|
|
for name, param in model.named_parameters(): |
|
|
if param.device.type == 'meta': |
|
|
meta_params.append(name) |
|
|
|
|
|
if meta_params: |
|
|
logger.warning(f"Found {len(meta_params)} parameters on meta device") |
|
|
|
|
|
if hasattr(model, 'vision_tower'): |
|
|
logger.info("Freezing vision tower for text-only training") |
|
|
for param in model.vision_tower.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def apply_peft(cfg: Dict[str, Any], model): |
|
|
peft_cfg = cfg["peft"] |
|
|
model_cfg = cfg["model"] |
|
|
tr_cfg = cfg["train"] |
|
|
|
|
|
if not bool(peft_cfg.get("enabled", True)): |
|
|
return model, None |
|
|
|
|
|
use_4bit = bool(model_cfg.get("use_4bit", False)) |
|
|
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True)) |
|
|
|
|
|
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"): |
|
|
model.gradient_checkpointing_enable() |
|
|
if hasattr(model, "config"): |
|
|
model.config.use_cache = False |
|
|
|
|
|
if use_4bit: |
|
|
model = prepare_model_for_kbit_training( |
|
|
model, |
|
|
use_gradient_checkpointing=gradient_checkpointing, |
|
|
) |
|
|
|
|
|
target_modules = peft_cfg.get("target_modules", "auto") |
|
|
if target_modules == "auto": |
|
|
target_modules = _infer_target_modules(model) |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=int(peft_cfg.get("r", 16)), |
|
|
lora_alpha=int(peft_cfg.get("lora_alpha", 32)), |
|
|
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)), |
|
|
bias=str(peft_cfg.get("bias", "none")), |
|
|
task_type="CAUSAL_LM", |
|
|
target_modules=target_modules, |
|
|
) |
|
|
model = get_peft_model(model, lora_config) |
|
|
return model, lora_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def merge_adapter( |
|
|
cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path |
|
|
): |
|
|
logger.info(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---") |
|
|
|
|
|
model_cfg = cfg["model"] |
|
|
merge_cfg = cfg.get("merge", {}) |
|
|
trust_remote_code = bool(model_cfg.get("trust_remote_code", True)) |
|
|
|
|
|
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16")) |
|
|
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB")) |
|
|
|
|
|
try: |
|
|
base = AutoModelForCausalLM.from_pretrained( |
|
|
str(base_dir), |
|
|
torch_dtype=merged_dtype, |
|
|
device_map="cpu", |
|
|
low_cpu_mem_usage=True, |
|
|
trust_remote_code=trust_remote_code, |
|
|
) |
|
|
|
|
|
merged = PeftModel.from_pretrained(base, str(adapter_dir)) |
|
|
merged = merged.merge_and_unload() |
|
|
|
|
|
|
|
|
del base |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
_ensure_dir(final_dir) |
|
|
merged.save_pretrained( |
|
|
str(final_dir), safe_serialization=True, max_shard_size=max_shard_size |
|
|
) |
|
|
|
|
|
|
|
|
del merged |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained( |
|
|
str(base_dir), trust_remote_code=trust_remote_code |
|
|
) |
|
|
if tok.pad_token is None: |
|
|
tok.pad_token = tok.eos_token |
|
|
tok.save_pretrained(str(final_dir)) |
|
|
|
|
|
logger.info("--- Merge complete ---") |
|
|
except Exception as e: |
|
|
logger.error(f"Merge failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
ap = argparse.ArgumentParser() |
|
|
ap.add_argument("--config", required=True, help="Path to YAML config") |
|
|
ap.add_argument( |
|
|
"--merge-only", action="store_true", help="Skip training, just merge adapter" |
|
|
) |
|
|
args = ap.parse_args() |
|
|
|
|
|
with open(args.config, "r", encoding="utf-8") as f: |
|
|
cfg = yaml.safe_load(f) |
|
|
|
|
|
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"])) |
|
|
_ensure_dir(run_dir / "logs") |
|
|
|
|
|
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f: |
|
|
yaml.safe_dump(cfg, f, sort_keys=False) |
|
|
|
|
|
model_cfg = cfg["model"] |
|
|
repo_id = str(model_cfg["repo_id"]).strip() |
|
|
repo_path = Path(repo_id) |
|
|
|
|
|
|
|
|
if repo_path.exists() and repo_path.is_dir() and _looks_like_model_dir(repo_path): |
|
|
base_dir = repo_path |
|
|
logger.info(f"Using local model at: {base_dir}") |
|
|
elif repo_path.exists() and repo_path.is_dir(): |
|
|
raise ValueError( |
|
|
f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}" |
|
|
) |
|
|
else: |
|
|
|
|
|
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model")) |
|
|
if not _looks_like_model_dir(base_dir): |
|
|
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...") |
|
|
snapshot_download( |
|
|
repo_id=repo_id, |
|
|
revision=model_cfg.get("revision", None), |
|
|
local_dir=str(base_dir), |
|
|
local_dir_use_symlinks=False, |
|
|
) |
|
|
|
|
|
ckpt_dir = _ensure_dir(run_dir / "checkpoints") |
|
|
best_adapter_dir = _ensure_dir(run_dir / "best_adapter") |
|
|
|
|
|
merge_cfg = cfg.get("merge", {}) or {} |
|
|
if merge_cfg.get("output_dir"): |
|
|
od = Path(str(merge_cfg["output_dir"])) |
|
|
final_dir = od if od.is_absolute() else (run_dir / od) |
|
|
else: |
|
|
final_dir = run_dir / "final_model" |
|
|
|
|
|
|
|
|
if args.merge_only: |
|
|
if not _looks_like_model_dir(best_adapter_dir): |
|
|
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}") |
|
|
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) |
|
|
return |
|
|
|
|
|
|
|
|
wandb_run = setup_wandb(cfg, run_dir) |
|
|
|
|
|
|
|
|
set_seed(int(cfg["run"].get("seed", 42))) |
|
|
|
|
|
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir) |
|
|
model, _ = apply_peft(cfg, model) |
|
|
|
|
|
|
|
|
dpo_cfg = cfg.get("dpo", {}) |
|
|
use_reference_model = bool(dpo_cfg.get("use_reference_model", True)) |
|
|
reference_free = bool(dpo_cfg.get("reference_free", False)) |
|
|
|
|
|
ref_model = None |
|
|
if use_reference_model and not reference_free: |
|
|
print("Loading reference model (frozen copy)...") |
|
|
ref_model, _ = load_base_model_and_tokenizer(cfg, base_dir) |
|
|
ref_model, _ = apply_peft(cfg, ref_model) |
|
|
|
|
|
for param in ref_model.parameters(): |
|
|
param.requires_grad = False |
|
|
ref_model.eval() |
|
|
print("Reference model loaded and frozen") |
|
|
|
|
|
train_ds, eval_ds = build_dpo_datasets(cfg, tokenizer) |
|
|
|
|
|
tr_cfg = cfg["train"] |
|
|
|
|
|
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16")) |
|
|
use_fp16 = dtype == torch.float16 |
|
|
use_bf16 = dtype == torch.bfloat16 |
|
|
|
|
|
max_steps = int(tr_cfg.get("max_steps", 0)) |
|
|
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1)) |
|
|
|
|
|
|
|
|
ta_params = inspect.signature(TrainingArguments.__init__).parameters |
|
|
eval_key = ( |
|
|
"eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy" |
|
|
) |
|
|
|
|
|
|
|
|
report_to = [] |
|
|
if wandb_run is not None: |
|
|
report_to.append("wandb") |
|
|
|
|
|
|
|
|
max_grad_norm = float(tr_cfg.get("max_grad_norm", 1.0)) |
|
|
if max_grad_norm <= 0: |
|
|
logger.warning(f"Invalid max_grad_norm={max_grad_norm}, using 1.0") |
|
|
max_grad_norm = 1.0 |
|
|
|
|
|
ta_kwargs = dict( |
|
|
output_dir=str(ckpt_dir), |
|
|
max_steps=max_steps if max_steps > 0 else -1, |
|
|
num_train_epochs=num_train_epochs, |
|
|
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)), |
|
|
per_device_eval_batch_size=int( |
|
|
tr_cfg.get( |
|
|
"per_device_eval_batch_size", |
|
|
tr_cfg.get("per_device_train_batch_size", 1), |
|
|
) |
|
|
), |
|
|
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)), |
|
|
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), |
|
|
weight_decay=float(tr_cfg.get("weight_decay", 0.0)), |
|
|
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), |
|
|
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")), |
|
|
optim=str( |
|
|
tr_cfg.get( |
|
|
"optim", |
|
|
( |
|
|
"paged_adamw_8bit" |
|
|
if bool(model_cfg.get("use_4bit", False)) |
|
|
else "adamw_torch" |
|
|
), |
|
|
) |
|
|
), |
|
|
max_grad_norm=max_grad_norm, |
|
|
logging_steps=int(tr_cfg.get("logging_steps", 10)), |
|
|
save_strategy=str(tr_cfg.get("save_strategy", "steps")), |
|
|
save_steps=int(tr_cfg.get("save_steps", 200)), |
|
|
save_total_limit=int(tr_cfg.get("save_total_limit", 3)), |
|
|
eval_steps=int(tr_cfg.get("eval_steps", 50)), |
|
|
load_best_model_at_end=( |
|
|
bool(tr_cfg.get("load_best_model_at_end", True)) |
|
|
if eval_ds is not None |
|
|
else False |
|
|
), |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
fp16=use_fp16, |
|
|
bf16=use_bf16, |
|
|
report_to=report_to, |
|
|
remove_unused_columns=False, |
|
|
) |
|
|
|
|
|
|
|
|
ta_kwargs[eval_key] = str( |
|
|
tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no") |
|
|
) |
|
|
|
|
|
training_args = TrainingArguments(**ta_kwargs) |
|
|
|
|
|
|
|
|
callbacks = [JsonlLoggerCallback(run_dir)] |
|
|
|
|
|
|
|
|
early_stopping_cfg = tr_cfg.get("early_stopping", {}) |
|
|
if early_stopping_cfg.get("enabled", False) and eval_ds is not None: |
|
|
early_stopping_callback = EarlyStoppingCallback( |
|
|
early_stopping_patience=int(early_stopping_cfg.get("patience", 3)), |
|
|
early_stopping_threshold=float(early_stopping_cfg.get("min_delta", 0.001)), |
|
|
) |
|
|
callbacks.append(early_stopping_callback) |
|
|
print(f"Early stopping enabled: patience={early_stopping_cfg.get('patience', 3)}, " |
|
|
f"min_delta={early_stopping_cfg.get('min_delta', 0.001)}") |
|
|
|
|
|
|
|
|
beta = float(dpo_cfg.get("beta", 0.1)) |
|
|
label_smoothing = float(dpo_cfg.get("label_smoothing", 0.0)) |
|
|
loss_type = str(dpo_cfg.get("loss_type", "sigmoid")) |
|
|
max_length = int(cfg["data"].get("max_length", 2048)) |
|
|
max_prompt_length = int(cfg["data"].get("max_prompt_length", max_length // 2)) |
|
|
|
|
|
logger.info(f"DPO Training with beta={beta}, loss_type={loss_type}") |
|
|
|
|
|
|
|
|
eval_strategy_val = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no")) |
|
|
|
|
|
|
|
|
dpo_config = DPOConfig( |
|
|
output_dir=str(run_dir), |
|
|
model_init_kwargs={"trust_remote_code": True}, |
|
|
num_train_epochs=int(tr_cfg.get("num_train_epochs", 3)), |
|
|
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 2)), |
|
|
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", 4)), |
|
|
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 4)), |
|
|
learning_rate=float(tr_cfg.get("learning_rate", 5e-5)), |
|
|
weight_decay=float(tr_cfg.get("weight_decay", 0.01)), |
|
|
adam_beta1=float(tr_cfg.get("adam_beta1", 0.9)), |
|
|
adam_beta2=float(tr_cfg.get("adam_beta2", 0.999)), |
|
|
adam_epsilon=float(tr_cfg.get("adam_epsilon", 1e-8)), |
|
|
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)), |
|
|
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "linear")), |
|
|
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)), |
|
|
logging_steps=int(tr_cfg.get("logging_steps", 10)), |
|
|
save_steps=int(tr_cfg.get("save_steps", 100)), |
|
|
save_total_limit=int(tr_cfg.get("save_total_limit", 3)), |
|
|
eval_steps=int(tr_cfg.get("eval_steps", 100)) if eval_ds is not None else None, |
|
|
eval_strategy=eval_strategy_val, |
|
|
save_strategy=str(tr_cfg.get("save_strategy", "steps")), |
|
|
load_best_model_at_end=( |
|
|
bool(tr_cfg.get("load_best_model_at_end", False)) |
|
|
if eval_ds is not None |
|
|
else False |
|
|
), |
|
|
metric_for_best_model=str(tr_cfg.get("metric_for_best_model", "eval_loss")), |
|
|
greater_is_better=bool(tr_cfg.get("greater_is_better", False)), |
|
|
fp16=use_fp16, |
|
|
bf16=use_bf16, |
|
|
report_to=report_to, |
|
|
remove_unused_columns=False, |
|
|
|
|
|
beta=beta, |
|
|
label_smoothing=label_smoothing, |
|
|
loss_type=loss_type, |
|
|
max_length=max_length, |
|
|
max_prompt_length=max_prompt_length, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
trainer = DPOTrainer( |
|
|
model=model, |
|
|
ref_model=ref_model, |
|
|
args=dpo_config, |
|
|
train_dataset=train_ds, |
|
|
eval_dataset=eval_ds, |
|
|
callbacks=callbacks, |
|
|
) |
|
|
|
|
|
|
|
|
resume_from = tr_cfg.get("resume_from_checkpoint", None) |
|
|
if resume_from == "auto": |
|
|
last = get_last_checkpoint(str(ckpt_dir)) |
|
|
resume_from = last if last else None |
|
|
if resume_from: |
|
|
logger.info(f"Resuming from {resume_from}") |
|
|
|
|
|
logger.info("Starting DPO training...") |
|
|
trainer.train(resume_from_checkpoint=resume_from) |
|
|
|
|
|
trainer.save_model(str(best_adapter_dir)) |
|
|
logger.info(f"Saved best adapter -> {best_adapter_dir}") |
|
|
|
|
|
if eval_ds is not None: |
|
|
metrics = trainer.evaluate() |
|
|
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f: |
|
|
json.dump(metrics, f, indent=2) |
|
|
print(f"Final metrics: {metrics}") |
|
|
|
|
|
if bool(cfg.get("merge", {}).get("enabled", False)): |
|
|
del trainer, model |
|
|
if ref_model is not None: |
|
|
del ref_model |
|
|
torch.cuda.empty_cache() |
|
|
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir) |
|
|
else: |
|
|
print("Merge disabled. Run with --merge-only later if needed.") |
|
|
|
|
|
|
|
|
finish_wandb() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|