Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import time | |
| import pickle | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Sequence, Set | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| DataCollatorForSeq2Seq, | |
| Seq2SeqTrainer, | |
| TrainerCallback, | |
| Seq2SeqTrainingArguments, | |
| ) | |
| from transformers.trainer_utils import get_last_checkpoint | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DEFAULT_SPLIT_DIR = REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits" | |
| DEFAULT_BASE_MODEL = REPO_ROOT / "models" / "t5-small" | |
| DEFAULT_OUT_DIR = REPO_ROOT / "models" / "finetune" / "t5-rewrite" | |
| class TokenizedListDataset(Dataset): | |
| def __init__(self, records: List[Dict[str, List[int]]]): | |
| self.records = records | |
| def __len__(self) -> int: | |
| return len(self.records) | |
| def __getitem__(self, idx: int) -> Dict[str, List[int]]: | |
| return self.records[idx] | |
| def _canon_tag(tag: str) -> str: | |
| t = " ".join(str(tag or "").strip().split()).lower() | |
| return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") | |
| def _parse_tag_set(text: str) -> Set[str]: | |
| out: Set[str] = set() | |
| for raw in (text or "").split(","): | |
| t = _canon_tag(raw) | |
| if t: | |
| out.add(t) | |
| return out | |
| def _set_metrics(pred_texts: List[str], gold_texts: List[str]) -> Dict[str, float]: | |
| if not pred_texts: | |
| return { | |
| "set_precision": 0.0, | |
| "set_recall": 0.0, | |
| "set_f1": 0.0, | |
| "exact_set_match": 0.0, | |
| "avg_pred_tags": 0.0, | |
| "avg_gold_tags": 0.0, | |
| } | |
| p_vals: List[float] = [] | |
| r_vals: List[float] = [] | |
| f_vals: List[float] = [] | |
| exact = 0 | |
| pred_sizes: List[int] = [] | |
| gold_sizes: List[int] = [] | |
| for ptxt, gtxt in zip(pred_texts, gold_texts): | |
| pset = _parse_tag_set(ptxt) | |
| gset = _parse_tag_set(gtxt) | |
| pred_sizes.append(len(pset)) | |
| gold_sizes.append(len(gset)) | |
| if pset == gset: | |
| exact += 1 | |
| if not pset and not gset: | |
| p_vals.append(1.0) | |
| r_vals.append(1.0) | |
| f_vals.append(1.0) | |
| continue | |
| if not pset or not gset: | |
| p_vals.append(0.0) | |
| r_vals.append(0.0) | |
| f_vals.append(0.0) | |
| continue | |
| tp = len(pset & gset) | |
| p = tp / len(pset) | |
| r = tp / len(gset) | |
| f = (2 * p * r / (p + r)) if (p + r) > 0 else 0.0 | |
| p_vals.append(p) | |
| r_vals.append(r) | |
| f_vals.append(f) | |
| n = len(pred_texts) | |
| return { | |
| "set_precision": float(np.mean(p_vals)), | |
| "set_recall": float(np.mean(r_vals)), | |
| "set_f1": float(np.mean(f_vals)), | |
| "exact_set_match": exact / n, | |
| "avg_pred_tags": float(np.mean(pred_sizes)), | |
| "avg_gold_tags": float(np.mean(gold_sizes)), | |
| } | |
| class ProgressFileCallback(TrainerCallback): | |
| def __init__(self, progress_path: Path, history_path: Optional[Path] = None): | |
| self.progress_path = progress_path | |
| self.history_path = history_path | |
| self.start_time: Optional[float] = None | |
| def _write(self, payload: Dict[str, object]) -> None: | |
| self.progress_path.parent.mkdir(parents=True, exist_ok=True) | |
| with self.progress_path.open("w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2) | |
| def _append_history(self, payload: Dict[str, object]) -> None: | |
| if self.history_path is None: | |
| return | |
| self.history_path.parent.mkdir(parents=True, exist_ok=True) | |
| with self.history_path.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(payload, ensure_ascii=False) + "\n") | |
| def _pct_eta(self, global_step: int, max_steps: int) -> Dict[str, Optional[float]]: | |
| max_steps = max(0, int(max_steps)) | |
| global_step = max(0, int(global_step)) | |
| pct = (100.0 * global_step / max_steps) if max_steps > 0 else None | |
| eta = None | |
| elapsed = None | |
| if self.start_time is not None: | |
| elapsed = time.time() - self.start_time | |
| if max_steps > 0 and global_step > 0: | |
| eta = (elapsed / global_step) * (max_steps - global_step) | |
| return {"pct": pct, "eta_sec": eta, "elapsed_sec": elapsed} | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| self.start_time = time.time() | |
| info = self._pct_eta(state.global_step, state.max_steps) | |
| self._write( | |
| { | |
| "status": "running", | |
| "global_step": int(state.global_step), | |
| "max_steps": int(state.max_steps), | |
| "pct_complete": info["pct"], | |
| "elapsed_sec": info["elapsed_sec"], | |
| "eta_sec": info["eta_sec"], | |
| "last_log": {}, | |
| "updated_at_epoch_sec": time.time(), | |
| } | |
| ) | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| info = self._pct_eta(state.global_step, state.max_steps) | |
| payload = { | |
| "status": "running", | |
| "global_step": int(state.global_step), | |
| "max_steps": int(state.max_steps), | |
| "pct_complete": info["pct"], | |
| "elapsed_sec": info["elapsed_sec"], | |
| "eta_sec": info["eta_sec"], | |
| "last_log": logs or {}, | |
| "updated_at_epoch_sec": time.time(), | |
| } | |
| self._write(payload) | |
| pct_text = f"{info['pct']:.1f}%" if info["pct"] is not None else "n/a" | |
| eta_text = f"{info['eta_sec']:.0f}s" if info["eta_sec"] is not None else "n/a" | |
| print(f"[train] step {state.global_step}/{state.max_steps} ({pct_text}) eta={eta_text} logs={logs or {}}") | |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): | |
| row = { | |
| "event": "evaluate", | |
| "global_step": int(state.global_step), | |
| "max_steps": int(state.max_steps), | |
| "metrics": metrics or {}, | |
| "updated_at_epoch_sec": time.time(), | |
| } | |
| self._append_history(row) | |
| def on_train_end(self, args, state, control, **kwargs): | |
| info = self._pct_eta(state.global_step, state.max_steps) | |
| self._write( | |
| { | |
| "status": "completed", | |
| "global_step": int(state.global_step), | |
| "max_steps": int(state.max_steps), | |
| "pct_complete": info["pct"], | |
| "elapsed_sec": info["elapsed_sec"], | |
| "eta_sec": 0.0, | |
| "last_log": {}, | |
| "updated_at_epoch_sec": time.time(), | |
| } | |
| ) | |
| class PeriodicTestEvalCallback(TrainerCallback): | |
| def __init__(self, test_dataset: Dataset, every_steps: int): | |
| self.test_dataset = test_dataset | |
| self.every_steps = max(0, int(every_steps)) | |
| self._trainer: Optional[Seq2SeqTrainer] = None | |
| self._in_test_eval = False | |
| def bind_trainer(self, trainer: Seq2SeqTrainer) -> None: | |
| self._trainer = trainer | |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): | |
| if self.every_steps <= 0 or self._trainer is None or self._in_test_eval: | |
| return | |
| if state.global_step <= 0 or (state.global_step % self.every_steps) != 0: | |
| return | |
| if isinstance(metrics, dict) and any(k.startswith("test_") for k in metrics.keys()): | |
| return | |
| self._in_test_eval = True | |
| try: | |
| m = self._trainer.evaluate(eval_dataset=self.test_dataset, metric_key_prefix="test") | |
| test_recall = m.get("test_set_recall") | |
| if test_recall is None: | |
| print(f"[test-eval] step={state.global_step} ran periodic held-out test evaluation") | |
| else: | |
| print(f"[test-eval] step={state.global_step} test_set_recall={float(test_recall):.4f}") | |
| finally: | |
| self._in_test_eval = False | |
| class RecallWeightedSeq2SeqTrainer(Seq2SeqTrainer): | |
| def __init__( | |
| self, | |
| *args, | |
| eos_token_id: int, | |
| comma_token_ids: Sequence[int], | |
| eos_loss_weight: float, | |
| comma_loss_weight: float, | |
| **kwargs, | |
| ): | |
| super().__init__(*args, **kwargs) | |
| self.eos_token_id = int(eos_token_id) | |
| self.comma_token_ids = [int(x) for x in comma_token_ids] | |
| self.eos_loss_weight = float(eos_loss_weight) | |
| self.comma_loss_weight = float(comma_loss_weight) | |
| self.use_weighted_loss = (self.eos_loss_weight != 1.0) or (self.comma_loss_weight != 1.0) | |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): | |
| if not self.use_weighted_loss: | |
| return super().compute_loss( | |
| model, | |
| inputs, | |
| return_outputs=return_outputs, | |
| num_items_in_batch=num_items_in_batch, | |
| ) | |
| labels = inputs.get("labels") | |
| if labels is None: | |
| return super().compute_loss( | |
| model, | |
| inputs, | |
| return_outputs=return_outputs, | |
| num_items_in_batch=num_items_in_batch, | |
| ) | |
| outputs = model(**inputs) | |
| logits = outputs.get("logits") if isinstance(outputs, dict) else outputs.logits | |
| vocab = logits.size(-1) | |
| token_loss = F.cross_entropy( | |
| logits.view(-1, vocab), | |
| labels.view(-1), | |
| ignore_index=-100, | |
| reduction="none", | |
| ).view_as(labels) | |
| valid = (labels != -100).to(logits.dtype) | |
| weights = torch.ones_like(labels, dtype=logits.dtype) | |
| if self.eos_loss_weight != 1.0: | |
| weights = torch.where(labels == self.eos_token_id, torch.full_like(weights, self.eos_loss_weight), weights) | |
| if self.comma_loss_weight != 1.0 and self.comma_token_ids: | |
| for cid in self.comma_token_ids: | |
| weights = torch.where(labels == cid, torch.full_like(weights, self.comma_loss_weight), weights) | |
| denom = (weights * valid).sum().clamp(min=1.0) | |
| loss = (token_loss * weights * valid).sum() / denom | |
| return (loss, outputs) if return_outputs else loss | |
| def _read_jsonl(path: Path) -> List[Dict[str, str]]: | |
| rows: List[Dict[str, str]] = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| obj = json.loads(line) | |
| src = str(obj.get("source_text", "")).strip() | |
| tgt = str(obj.get("target_text", "")).strip() | |
| if not src or not tgt: | |
| continue | |
| rows.append({"source_text": src, "target_text": tgt}) | |
| return rows | |
| def _cap(rows: List[Dict[str, str]], n: int) -> List[Dict[str, str]]: | |
| if n <= 0: | |
| return rows | |
| return rows[: min(n, len(rows))] | |
| def _tokenize_rows( | |
| rows: Sequence[Dict[str, str]], | |
| tokenizer, | |
| source_max_len: int, | |
| target_max_len: int, | |
| ) -> TokenizedListDataset: | |
| srcs = [r["source_text"] for r in rows] | |
| tgts = [r["target_text"] for r in rows] | |
| src_tok = tokenizer(srcs, truncation=True, max_length=source_max_len) | |
| tgt_tok = tokenizer(text_target=tgts, truncation=True, max_length=target_max_len) | |
| recs: List[Dict[str, List[int]]] = [] | |
| for i in range(len(rows)): | |
| recs.append( | |
| { | |
| "input_ids": src_tok["input_ids"][i], | |
| "attention_mask": src_tok["attention_mask"][i], | |
| "labels": tgt_tok["input_ids"][i], | |
| } | |
| ) | |
| return TokenizedListDataset(recs) | |
| def _maybe_disable_bad_rng_state(checkpoint_dir: Optional[str]) -> Optional[str]: | |
| if not checkpoint_dir: | |
| return checkpoint_dir | |
| rng_path = Path(checkpoint_dir) / "rng_state.pth" | |
| if not rng_path.is_file(): | |
| return checkpoint_dir | |
| try: | |
| torch.load(str(rng_path)) | |
| return checkpoint_dir | |
| except pickle.UnpicklingError: | |
| # Newer torch defaults to weights_only=True and can reject older rng blobs. | |
| # If RNG state is unreadable, training can still resume from weights/optimizer; | |
| # we just skip restoring exact RNG stream. | |
| bad_path = rng_path.with_suffix(".pth.unusable") | |
| try: | |
| if bad_path.exists(): | |
| bad_path.unlink() | |
| rng_path.replace(bad_path) | |
| print(f"Disabled unreadable RNG state file: {rng_path} -> {bad_path}") | |
| except Exception as ex: | |
| print(f"Warning: could not move unreadable RNG state file {rng_path}: {ex}") | |
| return checkpoint_dir | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Fine-tune local T5 for caption -> comma-separated canonical tags") | |
| ap.add_argument("--split-dir", type=Path, default=DEFAULT_SPLIT_DIR) | |
| ap.add_argument("--base-model-dir", type=Path, default=DEFAULT_BASE_MODEL) | |
| ap.add_argument("--output-dir", type=Path, default=DEFAULT_OUT_DIR) | |
| ap.add_argument("--source-max-len", type=int, default=160) | |
| ap.add_argument("--target-max-len", type=int, default=256) | |
| ap.add_argument("--num-beams", type=int, default=4) | |
| ap.add_argument("--generation-length-penalty", type=float, default=0.8, | |
| help="<1 encourages longer outputs (recall-leaning), >1 encourages shorter outputs") | |
| ap.add_argument("--lr", type=float, default=2e-4) | |
| ap.add_argument("--weight-decay", type=float, default=0.01) | |
| ap.add_argument("--warmup-ratio", type=float, default=0.05) | |
| ap.add_argument("--label-smoothing", type=float, default=0.1) | |
| ap.add_argument("--eos-loss-weight", type=float, default=1.0, | |
| help="When <1, penalize EOS-token mismatch less to reduce over-short outputs") | |
| ap.add_argument("--comma-loss-weight", type=float, default=1.0, | |
| help="When <1, penalize comma-token mismatch less to focus loss on tag tokens") | |
| ap.add_argument("--epochs", type=float, default=1.0) | |
| ap.add_argument("--max-steps", type=int, default=3000, help="<=0 uses full epoch schedule") | |
| ap.add_argument("--train-batch-size", type=int, default=2) | |
| ap.add_argument("--eval-batch-size", type=int, default=2) | |
| ap.add_argument("--grad-accum", type=int, default=8) | |
| ap.add_argument("--logging-steps", type=int, default=25) | |
| ap.add_argument("--eval-steps", type=int, default=500, help="<=0 evaluates each epoch") | |
| ap.add_argument("--save-steps", type=int, default=250, help="<=0 saves each epoch") | |
| ap.add_argument("--max-train-samples", type=int, default=0, help="Cap train samples after loading (0 disables)") | |
| ap.add_argument("--max-val-samples", type=int, default=300, help="Cap validation samples for eval (0 disables)") | |
| ap.add_argument("--max-test-samples", type=int, default=300, help="Cap test samples for eval (0 disables)") | |
| ap.add_argument("--eval-during-train", action="store_true", default=False, | |
| help="Enable periodic evaluation/checkpoint selection during training") | |
| ap.add_argument("--periodic-test-eval", action="store_true", default=False, | |
| help="When eval-during-train is enabled, evaluate both val and test each eval pass") | |
| ap.add_argument("--test-eval-every-steps", type=int, default=0, | |
| help="If >0, run held-out test eval every N global steps (after val eval)") | |
| ap.add_argument("--save-total-limit", type=int, default=3, | |
| help="Max number of checkpoints kept on disk") | |
| ap.add_argument("--best-model-metric", type=str, default="recall", | |
| choices=["recall", "f1", "precision", "loss"], | |
| help="Metric used to select best checkpoint when load_best_model_at_end is active") | |
| ap.add_argument("--require-cuda", action="store_true", default=False, | |
| help="Fail immediately if CUDA is not available") | |
| ap.add_argument("--progress-file", type=Path, | |
| default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_progress.json", | |
| help="JSON file updated at each logging step with percent/ETA/progress") | |
| ap.add_argument("--progress-history-file", type=Path, | |
| default=REPO_ROOT / "data" / "runtime_metrics" / "t5_rewrite_train_progress_history.jsonl", | |
| help="JSONL history file for periodic evaluation events") | |
| ap.add_argument("--fp16", action="store_true", default=False, help="Enable fp16 mixed precision training") | |
| ap.add_argument("--bf16", action="store_true", default=False, help="Enable bf16 mixed precision training") | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--report-to", type=str, default="none", help="none or tensorboard/wandb") | |
| ap.add_argument( | |
| "--resume-if-available", | |
| action="store_true", | |
| default=False, | |
| help="Resume from latest checkpoint in output-dir when present", | |
| ) | |
| args = ap.parse_args() | |
| split_dir = args.split_dir if args.split_dir.is_absolute() else (REPO_ROOT / args.split_dir).resolve() | |
| model_dir = args.base_model_dir if args.base_model_dir.is_absolute() else (REPO_ROOT / args.base_model_dir).resolve() | |
| out_dir = args.output_dir if args.output_dir.is_absolute() else (REPO_ROOT / args.output_dir).resolve() | |
| progress_path = args.progress_file if args.progress_file.is_absolute() else (REPO_ROOT / args.progress_file).resolve() | |
| progress_history_path = ( | |
| args.progress_history_file | |
| if args.progress_history_file.is_absolute() | |
| else (REPO_ROOT / args.progress_history_file).resolve() | |
| ) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| cuda_available = torch.cuda.is_available() | |
| cuda_name = torch.cuda.get_device_name(0) if cuda_available else "" | |
| if args.require_cuda and not cuda_available: | |
| raise RuntimeError( | |
| "CUDA is required for this run but not available. " | |
| "Use a CUDA-enabled PyTorch environment (GPU wheel) and retry." | |
| ) | |
| print( | |
| f"torch={torch.__version__} cuda_available={cuda_available}" | |
| + (f" device='{cuda_name}'" if cuda_name else "") | |
| ) | |
| if args.eos_loss_weight != 1.0 or args.comma_loss_weight != 1.0: | |
| if args.label_smoothing != 0.0: | |
| raise ValueError( | |
| "Weighted loss is enabled via eos/comma loss weights, but label smoothing is non-zero. " | |
| "Set --label-smoothing 0 when using weighted loss." | |
| ) | |
| def _write_stage_status(status: str, extra: Optional[Dict[str, object]] = None) -> None: | |
| payload: Dict[str, object] = { | |
| "status": status, | |
| "updated_at_epoch_sec": time.time(), | |
| } | |
| if extra: | |
| payload.update(extra) | |
| progress_path.parent.mkdir(parents=True, exist_ok=True) | |
| with progress_path.open("w", encoding="utf-8") as f: | |
| json.dump(payload, f, ensure_ascii=False, indent=2) | |
| train_path = split_dir / "train.jsonl" | |
| val_path = split_dir / "val.jsonl" | |
| test_path = split_dir / "test.jsonl" | |
| for p in (train_path, val_path, test_path): | |
| if not p.is_file(): | |
| raise FileNotFoundError(f"Missing split file: {p}") | |
| if not model_dir.is_dir(): | |
| raise FileNotFoundError(f"Missing base model dir: {model_dir}") | |
| _write_stage_status("loading_dataset") | |
| train_rows = _cap(_read_jsonl(train_path), args.max_train_samples) | |
| val_rows = _cap(_read_jsonl(val_path), args.max_val_samples) | |
| test_rows = _cap(_read_jsonl(test_path), args.max_test_samples) | |
| print("dataset_rows:", {"train": len(train_rows), "validation": len(val_rows), "test": len(test_rows)}) | |
| _write_stage_status( | |
| "loading_model", | |
| { | |
| "dataset_rows": { | |
| "train": len(train_rows), | |
| "validation": len(val_rows), | |
| "test": len(test_rows), | |
| } | |
| }, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(str(model_dir), local_files_only=True, use_fast=False) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(str(model_dir), local_files_only=True) | |
| model.generation_config.length_penalty = float(args.generation_length_penalty) | |
| model.generation_config.num_beams = max(1, int(args.num_beams)) | |
| _write_stage_status("tokenizing") | |
| train_ds = _tokenize_rows(train_rows, tokenizer, args.source_max_len, args.target_max_len) | |
| val_ds = _tokenize_rows(val_rows, tokenizer, args.source_max_len, args.target_max_len) | |
| test_ds = _tokenize_rows(test_rows, tokenizer, args.source_max_len, args.target_max_len) | |
| collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) | |
| report_to = [] if args.report_to == "none" else [args.report_to] | |
| if args.eval_during_train: | |
| eval_strategy = "steps" if args.eval_steps > 0 else "epoch" | |
| save_strategy = "steps" if args.save_steps > 0 else "epoch" | |
| load_best = not (args.periodic_test_eval and args.test_eval_every_steps <= 0) | |
| else: | |
| eval_strategy = "no" | |
| save_strategy = "no" | |
| load_best = False | |
| metric_for_best_map = { | |
| "recall": "eval_set_recall", | |
| "f1": "eval_set_f1", | |
| "precision": "eval_set_precision", | |
| "loss": "eval_loss", | |
| } | |
| metric_for_best_model = metric_for_best_map[args.best_model_metric] | |
| greater_is_better = args.best_model_metric != "loss" | |
| targs = Seq2SeqTrainingArguments( | |
| output_dir=str(out_dir), | |
| learning_rate=args.lr, | |
| weight_decay=args.weight_decay, | |
| warmup_ratio=args.warmup_ratio, | |
| label_smoothing_factor=args.label_smoothing, | |
| per_device_train_batch_size=args.train_batch_size, | |
| per_device_eval_batch_size=args.eval_batch_size, | |
| gradient_accumulation_steps=args.grad_accum, | |
| predict_with_generate=True, | |
| generation_num_beams=args.num_beams, | |
| generation_max_length=args.target_max_len, | |
| num_train_epochs=args.epochs, | |
| max_steps=args.max_steps if args.max_steps > 0 else -1, | |
| evaluation_strategy=eval_strategy, | |
| eval_steps=args.eval_steps if (args.eval_during_train and args.eval_steps > 0) else None, | |
| save_strategy=save_strategy, | |
| save_steps=args.save_steps if (args.eval_during_train and args.save_steps > 0) else None, | |
| logging_steps=args.logging_steps, | |
| logging_strategy="steps", | |
| save_total_limit=args.save_total_limit, | |
| load_best_model_at_end=load_best, | |
| metric_for_best_model=metric_for_best_model if load_best else None, | |
| greater_is_better=greater_is_better if load_best else None, | |
| seed=args.seed, | |
| dataloader_num_workers=0, | |
| report_to=report_to, | |
| fp16=args.fp16, | |
| bf16=args.bf16, | |
| ) | |
| def _compute_metrics(eval_pred): | |
| preds, labels = eval_pred | |
| if isinstance(preds, tuple): | |
| preds = preds[0] | |
| preds = np.asarray(preds) | |
| if preds.ndim == 3: | |
| preds = np.argmax(preds, axis=-1) | |
| preds = np.where( | |
| (preds < 0) | (preds >= tokenizer.vocab_size), | |
| tokenizer.pad_token_id, | |
| preds, | |
| ) | |
| labels = np.asarray(labels) | |
| labels = np.where( | |
| (labels < 0) | (labels >= tokenizer.vocab_size), | |
| tokenizer.pad_token_id, | |
| labels, | |
| ) | |
| pred_texts = tokenizer.batch_decode(preds.tolist(), skip_special_tokens=True) | |
| gold_texts = tokenizer.batch_decode(labels.tolist(), skip_special_tokens=True) | |
| return _set_metrics(pred_texts, gold_texts) | |
| eval_dataset_obj = ( | |
| {"val": val_ds, "test": test_ds} | |
| if (args.eval_during_train and args.periodic_test_eval and args.test_eval_every_steps <= 0) | |
| else val_ds | |
| ) | |
| comma_token_ids = tokenizer.encode(",", add_special_tokens=False) | |
| callbacks: List[TrainerCallback] = [ProgressFileCallback(progress_path, history_path=progress_history_path)] | |
| periodic_test_cb: Optional[PeriodicTestEvalCallback] = None | |
| if args.eval_during_train and args.test_eval_every_steps > 0: | |
| periodic_test_cb = PeriodicTestEvalCallback(test_dataset=test_ds, every_steps=args.test_eval_every_steps) | |
| callbacks.append(periodic_test_cb) | |
| trainer = RecallWeightedSeq2SeqTrainer( | |
| model=model, | |
| args=targs, | |
| train_dataset=train_ds, | |
| eval_dataset=eval_dataset_obj, | |
| tokenizer=tokenizer, | |
| data_collator=collator, | |
| compute_metrics=_compute_metrics, | |
| callbacks=callbacks, | |
| eos_token_id=tokenizer.eos_token_id, | |
| comma_token_ids=comma_token_ids, | |
| eos_loss_weight=args.eos_loss_weight, | |
| comma_loss_weight=args.comma_loss_weight, | |
| ) | |
| if periodic_test_cb is not None: | |
| periodic_test_cb.bind_trainer(trainer) | |
| resume_checkpoint = None | |
| if args.resume_if_available: | |
| resume_checkpoint = get_last_checkpoint(str(out_dir)) | |
| if resume_checkpoint: | |
| print(f"Resuming from checkpoint: {resume_checkpoint}") | |
| resume_checkpoint = _maybe_disable_bad_rng_state(resume_checkpoint) | |
| _write_stage_status( | |
| "training_starting", | |
| { | |
| "dataset_rows_tokenized": { | |
| "train": len(train_ds), | |
| "validation": len(val_ds), | |
| "test": len(test_ds), | |
| }, | |
| "max_steps": int(targs.max_steps), | |
| "num_train_epochs": float(targs.num_train_epochs), | |
| "resume_from_checkpoint": resume_checkpoint, | |
| }, | |
| ) | |
| train_result = trainer.train(resume_from_checkpoint=resume_checkpoint) | |
| trainer.save_model(str(out_dir)) | |
| tokenizer.save_pretrained(str(out_dir)) | |
| val_metrics = trainer.evaluate(eval_dataset=val_ds) | |
| test_metrics = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="test") | |
| metrics = { | |
| "train": train_result.metrics, | |
| "val": val_metrics, | |
| "test": test_metrics, | |
| "config": { | |
| "split_dir": str(split_dir), | |
| "base_model_dir": str(model_dir), | |
| "output_dir": str(out_dir), | |
| "source_max_len": args.source_max_len, | |
| "target_max_len": args.target_max_len, | |
| "num_beams": args.num_beams, | |
| "generation_length_penalty": args.generation_length_penalty, | |
| "learning_rate": args.lr, | |
| "weight_decay": args.weight_decay, | |
| "warmup_ratio": args.warmup_ratio, | |
| "label_smoothing": args.label_smoothing, | |
| "eos_loss_weight": args.eos_loss_weight, | |
| "comma_loss_weight": args.comma_loss_weight, | |
| "epochs": args.epochs, | |
| "max_steps": args.max_steps, | |
| "train_batch_size": args.train_batch_size, | |
| "eval_batch_size": args.eval_batch_size, | |
| "grad_accum": args.grad_accum, | |
| "seed": args.seed, | |
| "max_train_samples": args.max_train_samples, | |
| "max_val_samples": args.max_val_samples, | |
| "max_test_samples": args.max_test_samples, | |
| "eval_during_train": args.eval_during_train, | |
| "periodic_test_eval": args.periodic_test_eval, | |
| "test_eval_every_steps": args.test_eval_every_steps, | |
| "save_total_limit": args.save_total_limit, | |
| "best_model_metric": args.best_model_metric, | |
| "require_cuda": args.require_cuda, | |
| "progress_file": str(progress_path), | |
| "progress_history_file": str(progress_history_path), | |
| "cuda_available": cuda_available, | |
| "cuda_device": cuda_name, | |
| "fp16": args.fp16, | |
| "bf16": args.bf16, | |
| "resume_if_available": args.resume_if_available, | |
| }, | |
| } | |
| _write_stage_status( | |
| "completed", | |
| { | |
| "global_step": int(train_result.metrics.get("global_step", targs.max_steps)), | |
| "max_steps": int(targs.max_steps), | |
| "pct_complete": 100.0, | |
| "elapsed_sec": train_result.metrics.get("train_runtime"), | |
| "eta_sec": 0.0, | |
| "last_log": { | |
| "train_loss": train_result.metrics.get("train_loss"), | |
| "eval_set_f1": val_metrics.get("eval_set_f1"), | |
| "test_set_f1": test_metrics.get("test_set_f1"), | |
| }, | |
| }, | |
| ) | |
| with (out_dir / "train_metrics.json").open("w", encoding="utf-8") as f: | |
| json.dump(metrics, f, ensure_ascii=False, indent=2) | |
| print(json.dumps(metrics, ensure_ascii=False, indent=2)) | |
| return 0 | |
| if __name__ == "__main__": | |
| os.chdir(REPO_ROOT) | |
| raise SystemExit(main()) | |