SirajRLX's picture
Add Training Scripts
e527a65 verified
import argparse
import json
import inspect # Added for Transformers version compatibility
import math
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, List
import torch
import yaml
from datasets import load_dataset, DatasetDict
from huggingface_hub import snapshot_download
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizerFast,
TrainingArguments,
Trainer,
TrainerCallback,
default_data_collator,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training,
PeftModel,
)
try:
from transformers import BitsAndBytesConfig
except ImportError: # older transformers
BitsAndBytesConfig = None
# --------------------------
# Helpers
# --------------------------
def _dtype_from_str(s: str) -> torch.dtype:
s = (s or "").lower()
if s in ("float16", "fp16"):
return torch.float16
if s in ("bfloat16", "bf16"):
return torch.bfloat16
if s in ("float32", "fp32"):
return torch.float32
raise ValueError(f"Unknown torch_dtype: {s}")
def _now_iso() -> str:
return time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
def _safe_exp(x: float) -> float:
x = min(float(x), 50.0)
return float(math.exp(x))
def _ensure_dir(p: Path) -> Path:
p.mkdir(parents=True, exist_ok=True)
return p
def _looks_like_model_dir(p: Path) -> bool:
if not p.exists() or not p.is_dir():
return False
if (p / "config.json").exists():
return True
if any(p.glob("*.safetensors")) or any(p.glob("pytorch_model*.bin")):
return True
return False
def _detect_text_field(example: Dict[str, Any]) -> Optional[str]:
for k, v in example.items():
if isinstance(v, str) and v.strip():
return k
return None
def _load_tokenizer(base_dir: Path, use_fast: bool, trust_remote_code: bool):
try:
return AutoTokenizer.from_pretrained(
str(base_dir),
use_fast=use_fast,
trust_remote_code=trust_remote_code,
)
except ValueError as e:
if "TokenizersBackend" not in str(e):
raise
tok_file = base_dir / "tokenizer.json"
tok_cfg_path = base_dir / "tokenizer_config.json"
if not tok_file.exists():
raise
tok_kwargs: Dict[str, Any] = {}
if tok_cfg_path.exists():
with tok_cfg_path.open("r", encoding="utf-8") as f:
tok_cfg = json.load(f)
for key in ("bos_token", "eos_token", "pad_token", "unk_token", "model_max_length"):
if tok_cfg.get(key) is not None:
tok_kwargs[key] = tok_cfg[key]
extra = tok_cfg.get("additional_special_tokens") or tok_cfg.get("extra_special_tokens")
if extra:
tok_kwargs["additional_special_tokens"] = extra
return PreTrainedTokenizerFast(tokenizer_file=str(tok_file), **tok_kwargs)
def _infer_target_modules(model) -> List[str]:
names = set()
for n, _ in model.named_modules():
names.add(n.split(".")[-1])
for group in [
["q_proj", "k_proj", "v_proj", "o_proj"],
["Wqkv", "out_proj"],
["query_key_value", "dense"],
["c_attn", "c_proj"],
]:
if all(x in names for x in group):
return group
fallback = [x for x in ["q_proj", "k_proj", "v_proj", "o_proj", "c_attn", "c_proj", "out_proj", "dense"] if x in names]
if fallback:
return fallback
raise ValueError("Could not auto-infer target_modules. Set peft.target_modules explicitly.")
def _choose_attn_impl(cfg: Dict[str, Any]) -> Optional[str]:
return cfg.get("model", {}).get("attn_implementation", None)
# --------------------------
# JSONL Logger Callback
# --------------------------
class JsonlLoggerCallback(TrainerCallback):
def __init__(self, run_dir: Path):
self.run_dir = run_dir
self.train_log_path = _ensure_dir(run_dir / "logs") / "train.jsonl"
self.eval_log_path = _ensure_dir(run_dir / "logs") / "eval.jsonl"
self.start_time = None
def _eta(self, global_step: int, max_steps: int) -> Optional[str]:
if self.start_time is None or global_step <= 0 or max_steps <= 0:
return None
elapsed = time.time() - self.start_time
sec_per_step = elapsed / global_step
remaining = max(0, max_steps - global_step) * sec_per_step
h = int(remaining // 3600)
m = int((remaining % 3600) // 60)
s = int(remaining % 60)
return f"{h:02d}:{m:02d}:{s:02d}"
def on_train_begin(self, args, state, control, **kwargs):
self.start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
max_steps = int(state.max_steps) if getattr(state, "max_steps", None) else 0
progress_pct = (100.0 * state.global_step / max_steps) if max_steps > 0 else None
epoch_pct = None
if state.epoch is not None and args.num_train_epochs and args.num_train_epochs > 0:
epoch_pct = 100.0 * (float(state.epoch) / float(args.num_train_epochs))
payload = {
"ts": _now_iso(),
"event": "train_log",
"step": int(state.global_step),
"epoch": round(float(state.epoch), 4) if state.epoch is not None else None,
"progress_pct": round(progress_pct, 2) if progress_pct is not None else None,
"epoch_pct": round(epoch_pct, 2) if epoch_pct is not None else None,
"eta": self._eta(int(state.global_step), max_steps),
"max_grad_norm": getattr(args, "max_grad_norm", None),
**logs,
}
with self.train_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if not metrics:
return
eval_loss = metrics.get("eval_loss", None)
ppl = _safe_exp(eval_loss) if eval_loss is not None else None
payload = {
"ts": _now_iso(),
"event": "eval",
"step": int(state.global_step),
"epoch": float(state.epoch) if state.epoch is not None else None,
**metrics,
"perplexity": ppl,
}
with self.eval_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")
# --------------------------
# Data Pipeline (EOS + Packing)
# --------------------------
def build_datasets(cfg: Dict[str, Any], tokenizer) -> Tuple[Any, Any]:
data_cfg = cfg["data"]
train_path = data_cfg["train_jsonl"]
eval_path = data_cfg.get("eval_jsonl", None)
split_ratio = float(data_cfg.get("eval_split_ratio", 0.0))
text_field = data_cfg.get("text_field", "text")
block_size = int(data_cfg.get("block_size", 2048))
shuffle = bool(data_cfg.get("shuffle", True))
num_proc = int(data_cfg.get("num_proc", 4))
pack_mode = str(data_cfg.get("pack_mode", "drop")).lower().strip()
if pack_mode not in ("drop", "pad"):
raise ValueError(f"data.pack_mode must be 'drop' or 'pad', got: {pack_mode}")
eos_id = tokenizer.eos_token_id
if eos_id is None:
raise ValueError("Tokenizer has no eos_token_id; CPT packing needs an EOS delimiter.")
if tokenizer.pad_token_id is None:
# safe default for many causal LMs
tokenizer.pad_token = tokenizer.eos_token
pad_id = tokenizer.pad_token_id
ds = load_dataset("json", data_files={"train": train_path})
if eval_path:
ds_eval = load_dataset("json", data_files={"eval": eval_path})
dsd = DatasetDict({"train": ds["train"], "eval": ds_eval["eval"]})
else:
if 0.0 < split_ratio < 1.0:
split = ds["train"].train_test_split(test_size=split_ratio, seed=int(cfg["run"].get("seed", 42)))
dsd = DatasetDict({"train": split["train"], "eval": split["test"]})
else:
dsd = DatasetDict({"train": ds["train"], "eval": None})
if text_field not in dsd["train"].column_names:
auto_field = _detect_text_field(dsd["train"][0])
if not auto_field:
raise ValueError(f"Could not find text field. Columns: {dsd['train'].column_names}")
text_field = auto_field
def tokenize_fn(examples):
out = tokenizer(
examples[text_field],
add_special_tokens=False,
truncation=False,
padding=False,
)
if "token_type_ids" in out:
del out["token_type_ids"]
# Add EOS between docs
out["input_ids"] = [ids + [eos_id] for ids in out["input_ids"]]
out["attention_mask"] = [m + [1] for m in out["attention_mask"]]
return out
tokenized_train = dsd["train"].map(
tokenize_fn,
batched=True,
num_proc=num_proc,
remove_columns=dsd["train"].column_names,
desc="Tokenizing train",
)
tokenized_eval = None
if dsd["eval"] is not None:
tokenized_eval = dsd["eval"].map(
tokenize_fn,
batched=True,
num_proc=num_proc,
remove_columns=dsd["eval"].column_names,
desc="Tokenizing eval",
)
def group_texts(examples):
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated["input_ids"])
if total_length == 0:
return {"input_ids": [], "attention_mask": [], "labels": []}
full_len = (total_length // block_size) * block_size
blocks_input, blocks_attn, blocks_labels = [], [], []
# full blocks
for i in range(0, full_len, block_size):
chunk = concatenated["input_ids"][i:i + block_size]
attn = concatenated["attention_mask"][i:i + block_size]
blocks_input.append(chunk)
blocks_attn.append(attn)
blocks_labels.append(chunk.copy())
# remainder
remainder = total_length - full_len
if remainder > 0 and pack_mode == "pad":
chunk = concatenated["input_ids"][full_len:full_len + remainder]
attn = concatenated["attention_mask"][full_len:full_len + remainder]
pad_len = block_size - remainder
chunk_padded = chunk + [pad_id] * pad_len
attn_padded = attn + [0] * pad_len
labels = chunk_padded.copy()
labels[-pad_len:] = [-100] * pad_len # loss mask
blocks_input.append(chunk_padded)
blocks_attn.append(attn_padded)
blocks_labels.append(labels)
return {
"input_ids": blocks_input,
"attention_mask": blocks_attn,
"labels": blocks_labels,
}
tokenized_train = tokenized_train.map(
group_texts,
batched=True,
num_proc=num_proc,
desc=f"Packing train blocks (mode={pack_mode})",
)
if tokenized_eval is not None:
tokenized_eval = tokenized_eval.map(
group_texts,
batched=True,
num_proc=num_proc,
desc=f"Packing eval blocks (mode={pack_mode})",
)
if len(tokenized_train) == 0:
raise ValueError(
"Train dataset is empty after packing. "
"Either increase data, reduce block_size, or set data.pack_mode='pad'."
)
if shuffle:
tokenized_train = tokenized_train.shuffle(seed=int(cfg["run"].get("seed", 42)))
return tokenized_train, tokenized_eval
# --------------------------
# Model Loading + PEFT
# --------------------------
def _select_model_loader(base_dir: Path):
cfg_path = base_dir / "config.json"
if not cfg_path.exists():
return {"kind": "causal", "arch": None}
with cfg_path.open("r", encoding="utf-8") as f:
cfg = json.load(f)
arch = cfg.get("architectures") or []
arch_name = arch[0] if arch else None
if any("ForConditionalGeneration" in a for a in arch):
return {"kind": "conditional", "arch": arch_name}
return {"kind": "causal", "arch": arch_name}
def _resolve_model_class(arch_name: str):
import transformers
cls = getattr(transformers, arch_name, None)
if cls is None:
raise ValueError(f"Model class '{arch_name}' is not available in installed transformers.")
return cls
def load_base_model_and_tokenizer(cfg: Dict[str, Any], base_dir: Path):
model_cfg = cfg["model"]
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
device_map = model_cfg.get("device_map", "auto")
tokenizer = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
torch_dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
use_4bit = bool(model_cfg.get("use_4bit", False))
quant_cfg = None
if use_4bit:
if BitsAndBytesConfig is None:
raise ImportError("BitsAndBytesConfig is not available in this transformers version.")
quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=str(model_cfg.get("bnb_4bit_quant_type", "nf4")),
bnb_4bit_use_double_quant=bool(model_cfg.get("bnb_4bit_use_double_quant", True)),
bnb_4bit_compute_dtype=_dtype_from_str(model_cfg.get("bnb_4bit_compute_dtype", "bfloat16")),
)
attn_impl = _choose_attn_impl(cfg)
model_meta = _select_model_loader(base_dir)
try:
if model_meta["kind"] == "conditional":
model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
if model_cls is None:
raise ValueError("Conditional model architecture not specified in config.json.")
model = model_cls.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
attn_implementation=attn_impl,
)
else:
model = AutoModelForCausalLM.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
attn_implementation=attn_impl,
)
except Exception as e:
if attn_impl is not None:
print(f"[warn] attn_implementation='{attn_impl}' failed: {e}")
print("[warn] Falling back to default attention implementation.")
if model_meta["kind"] == "conditional":
model_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
if model_cls is None:
raise ValueError("Conditional model architecture not specified in config.json.")
model = model_cls.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
)
else:
model = AutoModelForCausalLM.from_pretrained(
str(base_dir),
device_map=device_map,
trust_remote_code=trust_remote_code,
low_cpu_mem_usage=True,
torch_dtype=(torch_dtype if not use_4bit else None),
quantization_config=quant_cfg,
)
return model, tokenizer
def apply_peft(cfg: Dict[str, Any], model):
peft_cfg = cfg["peft"]
model_cfg = cfg["model"]
tr_cfg = cfg["train"]
if not bool(peft_cfg.get("enabled", True)):
return model, None
use_4bit = bool(model_cfg.get("use_4bit", False))
gradient_checkpointing = bool(tr_cfg.get("gradient_checkpointing", True))
if gradient_checkpointing and hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
if hasattr(model, "config"):
model.config.use_cache = False
if use_4bit:
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=gradient_checkpointing,
)
target_modules = peft_cfg.get("target_modules", "auto")
if target_modules == "auto":
target_modules = _infer_target_modules(model)
lora_config = LoraConfig(
r=int(peft_cfg.get("r", 16)),
lora_alpha=int(peft_cfg.get("lora_alpha", 32)),
lora_dropout=float(peft_cfg.get("lora_dropout", 0.05)),
bias=str(peft_cfg.get("bias", "none")),
task_type="CAUSAL_LM",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
return model, lora_config
# --------------------------
# Merge Logic
# --------------------------
def merge_adapter(cfg: Dict[str, Any], base_dir: Path, adapter_dir: Path, final_dir: Path):
print(f"--- Merge: {adapter_dir} + {base_dir} -> {final_dir} ---")
model_cfg = cfg["model"]
merge_cfg = cfg.get("merge", {})
trust_remote_code = bool(model_cfg.get("trust_remote_code", True))
use_fast = bool(model_cfg.get("tokenizer_use_fast", True))
merged_dtype = _dtype_from_str(merge_cfg.get("merged_dtype", "float16"))
max_shard_size = str(merge_cfg.get("max_shard_size", "2GB"))
model_meta = _select_model_loader(base_dir)
if model_meta["kind"] == "conditional":
base_cls = _resolve_model_class(model_meta["arch"]) if model_meta["arch"] else None
if base_cls is None:
raise ValueError("Conditional model architecture not specified in config.json.")
base = base_cls.from_pretrained(
str(base_dir),
torch_dtype=merged_dtype,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
else:
base = AutoModelForCausalLM.from_pretrained(
str(base_dir),
torch_dtype=merged_dtype,
device_map="cpu",
low_cpu_mem_usage=True,
trust_remote_code=trust_remote_code,
)
merged = PeftModel.from_pretrained(base, str(adapter_dir))
merged = merged.merge_and_unload()
_ensure_dir(final_dir)
# Fix for transformers weight conversion bug with quantized models
# Clear weight conversions to avoid NotImplementedError in reverse_transform
if hasattr(merged, '_weight_conversions'):
merged._weight_conversions = []
merged.save_pretrained(
str(final_dir),
safe_serialization=True,
max_shard_size=max_shard_size
)
tok = _load_tokenizer(base_dir, use_fast=use_fast, trust_remote_code=trust_remote_code)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.save_pretrained(str(final_dir))
print("--- Merge complete ---")
# --------------------------
# Main
# --------------------------
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--config", required=True, help="Path to YAML config")
ap.add_argument("--merge-only", action="store_true", help="Skip training, just merge adapter")
args = ap.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
run_dir = _ensure_dir(Path(cfg["run"]["run_dir"]))
_ensure_dir(run_dir / "logs")
with (run_dir / "config_resolved.yaml").open("w", encoding="utf-8") as f:
yaml.safe_dump(cfg, f, sort_keys=False)
model_cfg = cfg["model"]
repo_id = str(model_cfg["repo_id"]).strip()
repo_path = Path(repo_id)
# ✅ Local model path -> load directly; no download
if repo_path.exists() and repo_path.is_dir():
base_dir = repo_path
if not _looks_like_model_dir(base_dir):
raise ValueError(f"model.repo_id points to a directory, but it doesn't look like a HF model dir: {base_dir}")
else:
# HF repo_id -> download into run_dir/base_local_dir
base_dir = _ensure_dir(run_dir / model_cfg.get("base_local_dir", "base_model"))
if not _looks_like_model_dir(base_dir):
print(f"Base model not found at {base_dir}, downloading from {repo_id} ...")
snapshot_download(
repo_id=repo_id,
revision=model_cfg.get("revision", None),
local_dir=str(base_dir),
local_dir_use_symlinks=False,
)
ckpt_dir = _ensure_dir(run_dir / "checkpoints")
best_adapter_dir = _ensure_dir(run_dir / "best_adapter")
merge_cfg = cfg.get("merge", {}) or {}
if merge_cfg.get("output_dir"):
od = Path(str(merge_cfg["output_dir"]))
final_dir = od if od.is_absolute() else (run_dir / od)
else:
final_dir = run_dir / "final_model"
# Merge-only
if args.merge_only:
if not _looks_like_model_dir(best_adapter_dir):
raise FileNotFoundError(f"Adapter not found at {best_adapter_dir}")
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
return
# Training
set_seed(int(cfg["run"].get("seed", 42)))
model, tokenizer = load_base_model_and_tokenizer(cfg, base_dir)
model, _ = apply_peft(cfg, model)
train_ds, eval_ds = build_datasets(cfg, tokenizer)
tr_cfg = cfg["train"]
dtype = _dtype_from_str(model_cfg.get("torch_dtype", "bfloat16"))
use_fp16 = (dtype == torch.float16)
use_bf16 = (dtype == torch.bfloat16)
max_steps = int(tr_cfg.get("max_steps", 0))
num_train_epochs = float(tr_cfg.get("num_train_epochs", 1))
# --- Dynamic evaluation strategy parameter handling ---
ta_params = inspect.signature(TrainingArguments.__init__).parameters
eval_key = "eval_strategy" if "eval_strategy" in ta_params else "evaluation_strategy"
desired_ta_kwargs = dict(
output_dir=str(ckpt_dir),
max_steps=max_steps if max_steps > 0 else -1,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=int(tr_cfg.get("per_device_train_batch_size", 1)),
per_device_eval_batch_size=int(tr_cfg.get("per_device_eval_batch_size", tr_cfg.get("per_device_train_batch_size", 1))),
gradient_accumulation_steps=int(tr_cfg.get("gradient_accumulation_steps", 1)),
learning_rate=float(tr_cfg.get("learning_rate", 2e-5)),
weight_decay=float(tr_cfg.get("weight_decay", 0.0)),
warmup_ratio=float(tr_cfg.get("warmup_ratio", 0.0)),
lr_scheduler_type=str(tr_cfg.get("lr_scheduler_type", "cosine")),
optim=str(tr_cfg.get("optim", "paged_adamw_8bit" if bool(model_cfg.get("use_4bit", False)) else "adamw_torch")),
max_grad_norm=float(tr_cfg.get("max_grad_norm", 1.0)),
logging_steps=int(tr_cfg.get("logging_steps", 10)),
save_strategy=str(tr_cfg.get("save_strategy", "steps")),
save_steps=int(tr_cfg.get("save_steps", 200)),
save_total_limit=int(tr_cfg.get("save_total_limit", 3)),
eval_steps=int(tr_cfg.get("eval_steps", 200)),
load_best_model_at_end=bool(tr_cfg.get("load_best_model_at_end", True)) if eval_ds is not None else False,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=use_fp16,
bf16=use_bf16,
report_to=[],
remove_unused_columns=False,
save_safetensors=True,
overwrite_output_dir=False,
)
# Set the correct argument name for this transformers version
desired_ta_kwargs[eval_key] = str(tr_cfg.get("evaluation_strategy", "steps" if eval_ds is not None else "no"))
ta_kwargs = {k: v for k, v in desired_ta_kwargs.items() if k in ta_params}
training_args = TrainingArguments(**ta_kwargs)
trainer_params = inspect.signature(Trainer.__init__).parameters
desired_trainer_kwargs = dict(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=eval_ds,
tokenizer=tokenizer,
processing_class=tokenizer,
data_collator=default_data_collator,
callbacks=[JsonlLoggerCallback(run_dir)],
)
trainer_kwargs = {k: v for k, v in desired_trainer_kwargs.items() if k in trainer_params}
trainer = Trainer(**trainer_kwargs)
# Resume
resume_from = tr_cfg.get("resume_from_checkpoint", None)
if resume_from == "auto":
last = get_last_checkpoint(str(ckpt_dir))
resume_from = last if last else None
if resume_from:
print(f"Resuming from {resume_from}")
print("Starting training...")
trainer.train(resume_from_checkpoint=resume_from)
trainer.save_model(str(best_adapter_dir))
print(f"Saved best adapter -> {best_adapter_dir}")
if eval_ds is not None:
metrics = trainer.evaluate()
eval_loss = metrics.get("eval_loss", None)
metrics["perplexity"] = _safe_exp(eval_loss) if eval_loss is not None else None
with (run_dir / "eval_final.json").open("w", encoding="utf-8") as f:
json.dump(metrics, f, indent=2)
print(f"Final eval_loss={eval_loss}, ppl={metrics['perplexity']}")
if bool(cfg.get("merge", {}).get("enabled", False)):
del trainer, model
torch.cuda.empty_cache()
merge_adapter(cfg, base_dir, best_adapter_dir, final_dir)
else:
print("Merge disabled. Run with --merge-only later if needed.")
if __name__ == "__main__":
main()