Adive01 commited on
Commit
5efa78a
Β·
verified Β·
1 Parent(s): 8cf0774

Upload mlplo/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mlplo/train.py +243 -0
mlplo/train.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import logging
5
+ import shutil
6
+ import tempfile
7
+ from pathlib import Path
8
+
9
+ from datasets import load_from_disk
10
+ import torch
11
+ from transformers import (
12
+ AutoModelForSeq2SeqLM,
13
+ DataCollatorForSeq2Seq,
14
+ Seq2SeqTrainer,
15
+ Seq2SeqTrainingArguments,
16
+ set_seed,
17
+ )
18
+
19
+ from .common import (
20
+ CHECKPOINT_DIR,
21
+ DEFAULT_MODEL_NAME,
22
+ DEFAULT_TARGET_MAX_LENGTH,
23
+ build_compute_metrics,
24
+ ensure_project_dirs,
25
+ load_tokenizer,
26
+ maybe_limit_split,
27
+ resolve_mixed_precision,
28
+ write_json,
29
+ )
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ def parse_args() -> argparse.Namespace:
35
+ parser = argparse.ArgumentParser(
36
+ description="Fine-tune BART on a prepared summarization dataset."
37
+ )
38
+ parser.add_argument(
39
+ "--dataset-dir", required=True, help="Path produced by mlplo.data_cleaning."
40
+ )
41
+ parser.add_argument("--model-name", default=DEFAULT_MODEL_NAME)
42
+ parser.add_argument("--output-dir", default=str(CHECKPOINT_DIR / "bart-large-xsum"))
43
+ parser.add_argument("--per-device-train-batch-size", type=int, default=2)
44
+ parser.add_argument("--per-device-eval-batch-size", type=int, default=2)
45
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
46
+ parser.add_argument("--learning-rate", type=float, default=3e-5) # lower LR for large model
47
+ parser.add_argument("--weight-decay", type=float, default=0.01)
48
+ parser.add_argument("--num-train-epochs", type=float, default=5.0) # more epochs + early stopping
49
+ parser.add_argument("--warmup-ratio", type=float, default=0.06)
50
+ parser.add_argument("--label-smoothing", type=float, default=0.1) # regularisation
51
+ parser.add_argument("--logging-steps", type=int, default=25)
52
+ parser.add_argument("--save-total-limit", type=int, default=2)
53
+ parser.add_argument(
54
+ "--generation-max-length", type=int, default=DEFAULT_TARGET_MAX_LENGTH
55
+ )
56
+ parser.add_argument("--generation-num-beams", type=int, default=6)
57
+ parser.add_argument("--max-train-samples", type=int, default=None)
58
+ parser.add_argument("--max-eval-samples", type=int, default=None)
59
+ parser.add_argument("--max-test-samples", type=int, default=None)
60
+ parser.add_argument("--seed", type=int, default=42)
61
+ parser.add_argument("--gradient-checkpointing", action="store_true")
62
+ parser.add_argument("--overwrite-output-dir", action="store_true")
63
+ parser.add_argument(
64
+ "--resume-from-checkpoint",
65
+ default=None,
66
+ help="Path to a checkpoint directory to resume from.",
67
+ )
68
+ parser.add_argument(
69
+ "--run-test-eval",
70
+ action="store_true",
71
+ help="Run an additional evaluation pass on the held-out test split.",
72
+ )
73
+ return parser.parse_args()
74
+
75
+
76
+ def _prepare_output_dir(output_dir: Path, overwrite: bool) -> None:
77
+ """Handle output directory creation / overwriting safely."""
78
+ if not output_dir.exists() or not any(output_dir.iterdir()):
79
+ output_dir.mkdir(parents=True, exist_ok=True)
80
+ return
81
+
82
+ if not overwrite:
83
+ raise FileExistsError(
84
+ f"Output directory '{output_dir}' is not empty. "
85
+ "Pass --overwrite-output-dir to replace it."
86
+ )
87
+
88
+ # Atomic-ish overwrite: move to a temp name, then delete
89
+ tmp = output_dir.parent / (output_dir.name + ".__tmp_delete")
90
+ try:
91
+ output_dir.rename(tmp)
92
+ shutil.rmtree(tmp)
93
+ except Exception:
94
+ # If rename failed, try in-place rmtree as fallback
95
+ if tmp.exists():
96
+ shutil.rmtree(tmp)
97
+ else:
98
+ shutil.rmtree(output_dir)
99
+ output_dir.mkdir(parents=True, exist_ok=True)
100
+
101
+
102
+ def main() -> None:
103
+ logging.basicConfig(
104
+ level=logging.INFO,
105
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
106
+ )
107
+ args = parse_args()
108
+ ensure_project_dirs()
109
+ set_seed(args.seed)
110
+
111
+ # ── Validate dataset path ─────────────────────────────────────────────────
112
+ dataset_path = Path(args.dataset_dir)
113
+ if not dataset_path.exists():
114
+ raise FileNotFoundError(
115
+ f"Prepared dataset not found: {dataset_path}\n"
116
+ "Run mlplo.data_cleaning first."
117
+ )
118
+
119
+ # ── Load dataset splits ───────────────────────────────────────────────────
120
+ LOGGER.info("Loading prepared dataset from %s", dataset_path)
121
+ tokenized_dataset = load_from_disk(str(dataset_path))
122
+
123
+ required = {"train", "validation"}
124
+ missing = required - set(tokenized_dataset.keys())
125
+ if missing:
126
+ raise KeyError(
127
+ f"Dataset at '{dataset_path}' is missing required splits: {missing}. "
128
+ "Re-run mlplo.data_cleaning to regenerate the dataset."
129
+ )
130
+
131
+ train_dataset = maybe_limit_split(tokenized_dataset["train"], args.max_train_samples)
132
+ eval_dataset = maybe_limit_split(tokenized_dataset["validation"], args.max_eval_samples)
133
+ has_test = "test" in tokenized_dataset
134
+ test_dataset = (
135
+ maybe_limit_split(tokenized_dataset["test"], args.max_test_samples)
136
+ if has_test
137
+ else None
138
+ )
139
+
140
+ # ── Validate resume-from-checkpoint ──────────────────────────────────────
141
+ resume_path = args.resume_from_checkpoint
142
+ if resume_path is not None and not Path(resume_path).exists():
143
+ raise FileNotFoundError(
144
+ f"--resume-from-checkpoint path does not exist: {resume_path}"
145
+ )
146
+
147
+ # ── Output directory ──────────────────────────────────────────────────────
148
+ output_dir = Path(args.output_dir)
149
+ _prepare_output_dir(output_dir, overwrite=args.overwrite_output_dir)
150
+ metrics_dir = output_dir / "metrics"
151
+
152
+ # ── Model + tokenizer ─────────────────────────────────────────────────────
153
+ LOGGER.info("Loading tokenizer and model '%s'…", args.model_name)
154
+ tokenizer = load_tokenizer(args.model_name)
155
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
156
+
157
+ if args.gradient_checkpointing:
158
+ if hasattr(model, "gradient_checkpointing_enable"):
159
+ model.gradient_checkpointing_enable()
160
+ else:
161
+ LOGGER.warning(
162
+ "Model '%s' does not support gradient_checkpointing_enable(); skipping.",
163
+ args.model_name,
164
+ )
165
+
166
+ precision = resolve_mixed_precision()
167
+ data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
168
+
169
+ # BERTScore is intentionally excluded from training-time compute_metrics.
170
+ # It downloads a ~400 MB model and is 10-20Γ— slower than ROUGE.
171
+ # Use mlplo.eval with --include-bertscore for BERTScore evaluation.
172
+ compute_metrics = build_compute_metrics(tokenizer, include_bertscore=False)
173
+
174
+ training_args = Seq2SeqTrainingArguments(
175
+ output_dir=str(output_dir),
176
+ learning_rate=args.learning_rate,
177
+ per_device_train_batch_size=args.per_device_train_batch_size,
178
+ per_device_eval_batch_size=args.per_device_eval_batch_size,
179
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
180
+ weight_decay=args.weight_decay,
181
+ num_train_epochs=args.num_train_epochs,
182
+ warmup_ratio=args.warmup_ratio,
183
+ label_smoothing_factor=args.label_smoothing,
184
+ logging_steps=args.logging_steps,
185
+ eval_strategy="epoch",
186
+ save_strategy="epoch",
187
+ save_total_limit=args.save_total_limit,
188
+ predict_with_generate=True,
189
+ generation_max_length=args.generation_max_length,
190
+ generation_num_beams=args.generation_num_beams,
191
+ load_best_model_at_end=True,
192
+ metric_for_best_model="rougeL",
193
+ greater_is_better=True,
194
+ fp16=precision["fp16"],
195
+ bf16=precision["bf16"],
196
+ report_to="none",
197
+ optim="adamw_torch",
198
+ remove_unused_columns=True,
199
+ )
200
+
201
+ trainer = Seq2SeqTrainer(
202
+ model=model,
203
+ args=training_args,
204
+ train_dataset=train_dataset,
205
+ eval_dataset=eval_dataset,
206
+ processing_class=tokenizer,
207
+ data_collator=data_collator,
208
+ compute_metrics=compute_metrics,
209
+ )
210
+
211
+ LOGGER.info("Starting training…")
212
+ train_result = trainer.train(resume_from_checkpoint=resume_path)
213
+ trainer.save_model()
214
+ tokenizer.save_pretrained(output_dir)
215
+ write_json(metrics_dir / "train_metrics.json", train_result.metrics)
216
+
217
+ LOGGER.info("Running final validation…")
218
+ validation_metrics = trainer.evaluate(
219
+ eval_dataset=eval_dataset, metric_key_prefix="validation"
220
+ )
221
+ write_json(metrics_dir / "validation_metrics.json", validation_metrics)
222
+
223
+ if args.run_test_eval:
224
+ if test_dataset is None:
225
+ LOGGER.warning(
226
+ "--run-test-eval requested but dataset has no 'test' split; skipping."
227
+ )
228
+ else:
229
+ LOGGER.info("Running held-out test evaluation…")
230
+ test_metrics = trainer.evaluate(
231
+ eval_dataset=test_dataset, metric_key_prefix="test"
232
+ )
233
+ write_json(metrics_dir / "test_metrics.json", test_metrics)
234
+
235
+ # Free GPU memory before any downstream process reuses the device
236
+ if torch.cuda.is_available():
237
+ torch.cuda.empty_cache()
238
+
239
+ LOGGER.info("Training complete. Outputs saved to %s", output_dir)
240
+
241
+
242
+ if __name__ == "__main__":
243
+ main()