Spaces:
Sleeping
Sleeping
Upload mlplo/train.py with huggingface_hub
Browse files- mlplo/train.py +243 -0
mlplo/train.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import logging
|
| 5 |
+
import shutil
|
| 6 |
+
import tempfile
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from datasets import load_from_disk
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import (
|
| 12 |
+
AutoModelForSeq2SeqLM,
|
| 13 |
+
DataCollatorForSeq2Seq,
|
| 14 |
+
Seq2SeqTrainer,
|
| 15 |
+
Seq2SeqTrainingArguments,
|
| 16 |
+
set_seed,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
from .common import (
|
| 20 |
+
CHECKPOINT_DIR,
|
| 21 |
+
DEFAULT_MODEL_NAME,
|
| 22 |
+
DEFAULT_TARGET_MAX_LENGTH,
|
| 23 |
+
build_compute_metrics,
|
| 24 |
+
ensure_project_dirs,
|
| 25 |
+
load_tokenizer,
|
| 26 |
+
maybe_limit_split,
|
| 27 |
+
resolve_mixed_precision,
|
| 28 |
+
write_json,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
LOGGER = logging.getLogger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def parse_args() -> argparse.Namespace:
|
| 35 |
+
parser = argparse.ArgumentParser(
|
| 36 |
+
description="Fine-tune BART on a prepared summarization dataset."
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--dataset-dir", required=True, help="Path produced by mlplo.data_cleaning."
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
|
| 42 |
+
parser.add_argument("--output-dir", default=str(CHECKPOINT_DIR / "bart-large-xsum"))
|
| 43 |
+
parser.add_argument("--per-device-train-batch-size", type=int, default=2)
|
| 44 |
+
parser.add_argument("--per-device-eval-batch-size", type=int, default=2)
|
| 45 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 46 |
+
parser.add_argument("--learning-rate", type=float, default=3e-5) # lower LR for large model
|
| 47 |
+
parser.add_argument("--weight-decay", type=float, default=0.01)
|
| 48 |
+
parser.add_argument("--num-train-epochs", type=float, default=5.0) # more epochs + early stopping
|
| 49 |
+
parser.add_argument("--warmup-ratio", type=float, default=0.06)
|
| 50 |
+
parser.add_argument("--label-smoothing", type=float, default=0.1) # regularisation
|
| 51 |
+
parser.add_argument("--logging-steps", type=int, default=25)
|
| 52 |
+
parser.add_argument("--save-total-limit", type=int, default=2)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--generation-max-length", type=int, default=DEFAULT_TARGET_MAX_LENGTH
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument("--generation-num-beams", type=int, default=6)
|
| 57 |
+
parser.add_argument("--max-train-samples", type=int, default=None)
|
| 58 |
+
parser.add_argument("--max-eval-samples", type=int, default=None)
|
| 59 |
+
parser.add_argument("--max-test-samples", type=int, default=None)
|
| 60 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 61 |
+
parser.add_argument("--gradient-checkpointing", action="store_true")
|
| 62 |
+
parser.add_argument("--overwrite-output-dir", action="store_true")
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--resume-from-checkpoint",
|
| 65 |
+
default=None,
|
| 66 |
+
help="Path to a checkpoint directory to resume from.",
|
| 67 |
+
)
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--run-test-eval",
|
| 70 |
+
action="store_true",
|
| 71 |
+
help="Run an additional evaluation pass on the held-out test split.",
|
| 72 |
+
)
|
| 73 |
+
return parser.parse_args()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _prepare_output_dir(output_dir: Path, overwrite: bool) -> None:
|
| 77 |
+
"""Handle output directory creation / overwriting safely."""
|
| 78 |
+
if not output_dir.exists() or not any(output_dir.iterdir()):
|
| 79 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
if not overwrite:
|
| 83 |
+
raise FileExistsError(
|
| 84 |
+
f"Output directory '{output_dir}' is not empty. "
|
| 85 |
+
"Pass --overwrite-output-dir to replace it."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Atomic-ish overwrite: move to a temp name, then delete
|
| 89 |
+
tmp = output_dir.parent / (output_dir.name + ".__tmp_delete")
|
| 90 |
+
try:
|
| 91 |
+
output_dir.rename(tmp)
|
| 92 |
+
shutil.rmtree(tmp)
|
| 93 |
+
except Exception:
|
| 94 |
+
# If rename failed, try in-place rmtree as fallback
|
| 95 |
+
if tmp.exists():
|
| 96 |
+
shutil.rmtree(tmp)
|
| 97 |
+
else:
|
| 98 |
+
shutil.rmtree(output_dir)
|
| 99 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main() -> None:
|
| 103 |
+
logging.basicConfig(
|
| 104 |
+
level=logging.INFO,
|
| 105 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 106 |
+
)
|
| 107 |
+
args = parse_args()
|
| 108 |
+
ensure_project_dirs()
|
| 109 |
+
set_seed(args.seed)
|
| 110 |
+
|
| 111 |
+
# ββ Validate dataset path βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
dataset_path = Path(args.dataset_dir)
|
| 113 |
+
if not dataset_path.exists():
|
| 114 |
+
raise FileNotFoundError(
|
| 115 |
+
f"Prepared dataset not found: {dataset_path}\n"
|
| 116 |
+
"Run mlplo.data_cleaning first."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# ββ Load dataset splits βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 120 |
+
LOGGER.info("Loading prepared dataset from %s", dataset_path)
|
| 121 |
+
tokenized_dataset = load_from_disk(str(dataset_path))
|
| 122 |
+
|
| 123 |
+
required = {"train", "validation"}
|
| 124 |
+
missing = required - set(tokenized_dataset.keys())
|
| 125 |
+
if missing:
|
| 126 |
+
raise KeyError(
|
| 127 |
+
f"Dataset at '{dataset_path}' is missing required splits: {missing}. "
|
| 128 |
+
"Re-run mlplo.data_cleaning to regenerate the dataset."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
train_dataset = maybe_limit_split(tokenized_dataset["train"], args.max_train_samples)
|
| 132 |
+
eval_dataset = maybe_limit_split(tokenized_dataset["validation"], args.max_eval_samples)
|
| 133 |
+
has_test = "test" in tokenized_dataset
|
| 134 |
+
test_dataset = (
|
| 135 |
+
maybe_limit_split(tokenized_dataset["test"], args.max_test_samples)
|
| 136 |
+
if has_test
|
| 137 |
+
else None
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# ββ Validate resume-from-checkpoint ββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
resume_path = args.resume_from_checkpoint
|
| 142 |
+
if resume_path is not None and not Path(resume_path).exists():
|
| 143 |
+
raise FileNotFoundError(
|
| 144 |
+
f"--resume-from-checkpoint path does not exist: {resume_path}"
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# ββ Output directory ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
output_dir = Path(args.output_dir)
|
| 149 |
+
_prepare_output_dir(output_dir, overwrite=args.overwrite_output_dir)
|
| 150 |
+
metrics_dir = output_dir / "metrics"
|
| 151 |
+
|
| 152 |
+
# ββ Model + tokenizer βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 153 |
+
LOGGER.info("Loading tokenizer and model '%s'β¦", args.model_name)
|
| 154 |
+
tokenizer = load_tokenizer(args.model_name)
|
| 155 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
| 156 |
+
|
| 157 |
+
if args.gradient_checkpointing:
|
| 158 |
+
if hasattr(model, "gradient_checkpointing_enable"):
|
| 159 |
+
model.gradient_checkpointing_enable()
|
| 160 |
+
else:
|
| 161 |
+
LOGGER.warning(
|
| 162 |
+
"Model '%s' does not support gradient_checkpointing_enable(); skipping.",
|
| 163 |
+
args.model_name,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
precision = resolve_mixed_precision()
|
| 167 |
+
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
|
| 168 |
+
|
| 169 |
+
# BERTScore is intentionally excluded from training-time compute_metrics.
|
| 170 |
+
# It downloads a ~400 MB model and is 10-20Γ slower than ROUGE.
|
| 171 |
+
# Use mlplo.eval with --include-bertscore for BERTScore evaluation.
|
| 172 |
+
compute_metrics = build_compute_metrics(tokenizer, include_bertscore=False)
|
| 173 |
+
|
| 174 |
+
training_args = Seq2SeqTrainingArguments(
|
| 175 |
+
output_dir=str(output_dir),
|
| 176 |
+
learning_rate=args.learning_rate,
|
| 177 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
| 178 |
+
per_device_eval_batch_size=args.per_device_eval_batch_size,
|
| 179 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 180 |
+
weight_decay=args.weight_decay,
|
| 181 |
+
num_train_epochs=args.num_train_epochs,
|
| 182 |
+
warmup_ratio=args.warmup_ratio,
|
| 183 |
+
label_smoothing_factor=args.label_smoothing,
|
| 184 |
+
logging_steps=args.logging_steps,
|
| 185 |
+
eval_strategy="epoch",
|
| 186 |
+
save_strategy="epoch",
|
| 187 |
+
save_total_limit=args.save_total_limit,
|
| 188 |
+
predict_with_generate=True,
|
| 189 |
+
generation_max_length=args.generation_max_length,
|
| 190 |
+
generation_num_beams=args.generation_num_beams,
|
| 191 |
+
load_best_model_at_end=True,
|
| 192 |
+
metric_for_best_model="rougeL",
|
| 193 |
+
greater_is_better=True,
|
| 194 |
+
fp16=precision["fp16"],
|
| 195 |
+
bf16=precision["bf16"],
|
| 196 |
+
report_to="none",
|
| 197 |
+
optim="adamw_torch",
|
| 198 |
+
remove_unused_columns=True,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
trainer = Seq2SeqTrainer(
|
| 202 |
+
model=model,
|
| 203 |
+
args=training_args,
|
| 204 |
+
train_dataset=train_dataset,
|
| 205 |
+
eval_dataset=eval_dataset,
|
| 206 |
+
processing_class=tokenizer,
|
| 207 |
+
data_collator=data_collator,
|
| 208 |
+
compute_metrics=compute_metrics,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
LOGGER.info("Starting trainingβ¦")
|
| 212 |
+
train_result = trainer.train(resume_from_checkpoint=resume_path)
|
| 213 |
+
trainer.save_model()
|
| 214 |
+
tokenizer.save_pretrained(output_dir)
|
| 215 |
+
write_json(metrics_dir / "train_metrics.json", train_result.metrics)
|
| 216 |
+
|
| 217 |
+
LOGGER.info("Running final validationβ¦")
|
| 218 |
+
validation_metrics = trainer.evaluate(
|
| 219 |
+
eval_dataset=eval_dataset, metric_key_prefix="validation"
|
| 220 |
+
)
|
| 221 |
+
write_json(metrics_dir / "validation_metrics.json", validation_metrics)
|
| 222 |
+
|
| 223 |
+
if args.run_test_eval:
|
| 224 |
+
if test_dataset is None:
|
| 225 |
+
LOGGER.warning(
|
| 226 |
+
"--run-test-eval requested but dataset has no 'test' split; skipping."
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
LOGGER.info("Running held-out test evaluationβ¦")
|
| 230 |
+
test_metrics = trainer.evaluate(
|
| 231 |
+
eval_dataset=test_dataset, metric_key_prefix="test"
|
| 232 |
+
)
|
| 233 |
+
write_json(metrics_dir / "test_metrics.json", test_metrics)
|
| 234 |
+
|
| 235 |
+
# Free GPU memory before any downstream process reuses the device
|
| 236 |
+
if torch.cuda.is_available():
|
| 237 |
+
torch.cuda.empty_cache()
|
| 238 |
+
|
| 239 |
+
LOGGER.info("Training complete. Outputs saved to %s", output_dir)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|