mumble-cleanup / src /cleanup /train /trainer.py
adikuma's picture
initial upload: cleanup code and 688-pair seed dataset
fd0b01f verified
Raw
History Blame Contribute Delete
4.4 kB
# wire up trl sftrainer with completion-only loss masking and the autodetected
# precision. saves the lora adapter, tokenizer, and training history under
# runs/<run-id>/.
import json
from pathlib import Path
from typing import Optional
import torch
from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer
from cleanup.config import TrainConfig
from cleanup.data.download import load_pairs
from cleanup.data.tokenize import RESPONSE_TEMPLATE, to_dataset
from cleanup.train.model import load_base_and_tokenizer, wrap_with_lora
def _detect_precision_flags(cfg: TrainConfig) -> dict:
if not torch.cuda.is_available():
return {"bf16": False, "fp16": False, "tf32": False}
if cfg.bf16 and torch.cuda.is_bf16_supported():
return {"bf16": True, "fp16": False, "tf32": cfg.tf32}
if cfg.fp16:
return {"bf16": False, "fp16": True, "tf32": cfg.tf32}
return {"bf16": False, "fp16": False, "tf32": cfg.tf32}
def train(
cfg: TrainConfig,
run_dir: Path,
smoke: bool = False,
epochs_override: Optional[int] = None,
lr_override: Optional[float] = None,
) -> dict:
run_dir = Path(run_dir)
run_dir.mkdir(parents=True, exist_ok=True)
model, tokenizer = load_base_and_tokenizer(cfg)
model = wrap_with_lora(model, cfg)
train_rows = load_pairs(cfg.data_dir, "train", cfg.train_rows if not smoke else 200)
val_rows = load_pairs(cfg.data_dir, "val", cfg.val_rows if not smoke else 40)
train_ds = to_dataset(train_rows)
val_ds = to_dataset(val_rows)
# the collator scans each tokenized example for response_template and masks
# everything up to and including it with -100 in the labels. cross entropy
# then only fires on assistant tokens.
collator = DataCollatorForCompletionOnlyLM(
response_template=RESPONSE_TEMPLATE,
tokenizer=tokenizer,
)
epochs = epochs_override if epochs_override is not None else cfg.num_epochs
if smoke:
epochs = 1
lr = lr_override if lr_override is not None else cfg.learning_rate
precision = _detect_precision_flags(cfg)
args = SFTConfig(
output_dir=str(run_dir),
num_train_epochs=epochs,
per_device_train_batch_size=cfg.train_batch_size if not smoke else 2,
per_device_eval_batch_size=cfg.eval_batch_size if not smoke else 2,
gradient_accumulation_steps=cfg.gradient_accumulation_steps if not smoke else 1,
learning_rate=lr,
weight_decay=cfg.weight_decay,
warmup_ratio=cfg.warmup_ratio,
max_grad_norm=cfg.max_grad_norm,
adam_beta1=cfg.adam_beta1,
adam_beta2=cfg.adam_beta2,
adam_epsilon=cfg.adam_epsilon,
lr_scheduler_type=cfg.lr_scheduler_type,
eval_strategy="steps",
eval_steps=cfg.eval_steps if not smoke else 25,
save_strategy="steps",
save_steps=cfg.save_steps if not smoke else 25,
save_total_limit=cfg.save_total_limit,
logging_steps=cfg.logging_steps if not smoke else 5,
report_to=[],
metric_for_best_model=cfg.metric_for_best_model,
greater_is_better=cfg.greater_is_better,
load_best_model_at_end=True,
seed=cfg.seed,
dataloader_num_workers=cfg.dataloader_num_workers,
max_seq_length=cfg.max_seq_length,
packing=False,
**precision,
gradient_checkpointing=not smoke,
)
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=tokenizer,
data_collator=collator,
)
trainer.train()
# save lora adapter + tokenizer under runs/<id>/model
model_dir = run_dir / "model"
model_dir.mkdir(exist_ok=True)
trainer.model.save_pretrained(model_dir)
tokenizer.save_pretrained(model_dir)
# dump the training log history so the report builder can chart it
history_path = run_dir / "training_history.json"
history_path.write_text(json.dumps(trainer.state.log_history, indent=2))
summary = {
"epochs": epochs,
"learning_rate": lr,
"train_rows": len(train_rows),
"val_rows": len(val_rows),
"best_metric": trainer.state.best_metric,
"best_model_checkpoint": trainer.state.best_model_checkpoint,
}
(run_dir / "train_summary.json").write_text(json.dumps(summary, indent=2))
return summary