|
|
|
|
| """
|
| Domain-Adaptive Pretraining (DAPT) of xlm-roberta-large for Medieval Latin.
|
|
|
| Pipeline
|
| --------
|
| 1. Load MedBerta (baseline) + CanonBerta (already upsampled) from the HF Hub.
|
| 2. Carve a *group-aware* held-out set: whole documents are held out (by
|
| ``document_id``) so no document leaks across train / validation / test.
|
| This matters for an honest perplexity and for downstream Loci-Similes /
|
| text-reuse evaluation.
|
| 3. DOC-SENTENCES packing: paragraphs are grouped by ``document_id`` (kept in
|
| document order), tokenized, and greedily packed into fixed 512-token
|
| sequences that never cross a document boundary.
|
| 4. MLM continued pretraining with the HF Trainer, bf16, SDPA attention,
|
| in-training evaluation (loss + perplexity + masked-token accuracy),
|
| best-model selection, and rich logging for later write-up.
|
|
|
| The script is built for a single NVIDIA RTX PRO 6000 Blackwell (96 GB) and is
|
| deliberately defensive about OOM (auto batch-size search, expandable CUDA
|
| segments, optional gradient checkpointing, eval logit reduction).
|
|
|
| Authentication
|
| --------------
|
| The datasets are private. Log in once before running:
|
| huggingface-cli login # or: export HF_TOKEN=hf_...
|
|
|
| Quick check before spending GPU time (builds + caches datasets, prints stats,
|
| no training):
|
| python dapt_xlmr_pretrain.py --dry-run
|
|
|
| Typical run:
|
| python dapt_xlmr_pretrain.py \
|
| --output-dir runs/dapt_xlmr_medlatin_v1 \
|
| --num-train-epochs 3 \
|
| --per-device-train-batch-size 32 \
|
| --gradient-accumulation-steps 8
|
| """
|
|
|
| import argparse
|
| import hashlib
|
| import inspect
|
| import json
|
| import logging
|
| import math
|
| import os
|
| import platform
|
| import sys
|
| import traceback
|
| import unicodedata
|
| from collections import Counter
|
| from dataclasses import dataclass, field
|
| from datetime import datetime
|
| from typing import Dict, List, Optional
|
|
|
|
|
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
|
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
|
|
| import numpy as np
|
| import torch
|
| from datasets import Dataset, concatenate_datasets, load_dataset
|
| from transformers import (
|
| AutoModelForMaskedLM,
|
| AutoTokenizer,
|
| DataCollatorForLanguageModeling,
|
| EarlyStoppingCallback,
|
| Trainer,
|
| TrainerCallback,
|
| TrainingArguments,
|
| set_seed,
|
| )
|
| from transformers.trainer_utils import get_last_checkpoint
|
|
|
| try:
|
| _USE_PROCESSING_CLASS = "processing_class" in inspect.signature(Trainer.__init__).parameters
|
| except Exception:
|
| _USE_PROCESSING_CLASS = False
|
|
|
| logging.basicConfig(
|
| format="%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
|
| datefmt="%Y-%m-%d %H:%M:%S",
|
| level=logging.INFO,
|
| handlers=[logging.StreamHandler(sys.stdout)],
|
| )
|
| logger = logging.getLogger("dapt")
|
|
|
|
|
|
|
|
|
|
|
| @dataclass
|
| class Config:
|
|
|
| med_dataset: str = "mschonhardt/MedBerta"
|
| canon_dataset: str = "mschonhardt/CanonBerta"
|
| dataset_split: str = "train"
|
| text_column: Optional[str] = None
|
| doc_id_column: Optional[str] = None
|
| order_column: Optional[str] = None
|
|
|
|
|
| model_name: str = "xlm-roberta-large"
|
| max_seq_length: int = 512
|
| mlm_probability: float = 0.15
|
| attn_implementation: str = "sdpa"
|
| min_chunk_tokens: int = 64
|
|
|
|
|
| val_doc_fraction: float = 0.01
|
| test_doc_fraction: float = 0.0
|
| max_eval_docs_per_source: int = 400
|
| stratify_column: Optional[str] = "category"
|
| split_seed: int = 13
|
|
|
|
|
|
|
|
|
|
|
| train_corpus: str = "combined"
|
| dedup_train: bool = False
|
|
|
|
|
|
|
| diagnose_tokenizer: bool = True
|
| tokenizer_sample_size: int = 20000
|
| normalize_nfc: bool = False
|
| diagnose_only: bool = False
|
|
|
|
|
| learning_rate: float = 1e-4
|
| weight_decay: float = 0.01
|
| adam_beta1: float = 0.9
|
| adam_beta2: float = 0.98
|
| adam_epsilon: float = 1e-6
|
| max_grad_norm: float = 1.0
|
| warmup_ratio: float = 0.06
|
| lr_scheduler_type: str = "linear"
|
| num_train_epochs: float = 10.0
|
| max_steps: int = -1
|
| optim: str = "adamw_torch_fused"
|
| early_stopping_patience: int = 5
|
| early_stopping_threshold: float = 1e-4
|
|
|
|
|
| per_device_train_batch_size: int = 64
|
| per_device_eval_batch_size: int = 64
|
| gradient_accumulation_steps: int = 16
|
| gradient_checkpointing: bool = False
|
| auto_find_batch_size: bool = True
|
| torch_compile: bool = False
|
| dataloader_num_workers: int = 8
|
| eval_accumulation_steps: int = 50
|
|
|
|
|
| eval_steps: int = 500
|
| logging_steps: int = 50
|
| save_steps: int = 500
|
| save_total_limit: int = 3
|
|
|
|
|
| output_dir: str = "runs/dapt_xlmr_medlatin"
|
| cache_dir: str = ".cache_packed"
|
| seed: int = 42
|
| report_to_wandb: bool = False
|
| preprocess_num_proc: int = 16
|
| resume: bool = True
|
| dry_run: bool = False
|
|
|
|
|
|
|
|
|
|
|
| TEXT_CANDIDATES = ["text", "paragraph", "content", "sentence", "passage"]
|
| DOC_CANDIDATES = ["document_id", "doc_id", "docid", "document", "work_id", "work"]
|
| ORDER_CANDIDATES = ["order", "paragraph_id", "par_id", "index", "idx", "n", "position", "seq"]
|
|
|
|
|
| def _pick(columns: List[str], candidates: List[str], explicit: Optional[str], kind: str) -> Optional[str]:
|
| if explicit is not None:
|
| if explicit not in columns:
|
| raise ValueError(f"Requested {kind} column '{explicit}' not in dataset columns {columns}")
|
| return explicit
|
| for c in candidates:
|
| if c in columns:
|
| return c
|
| return None
|
|
|
|
|
|
|
|
|
|
|
| def load_source(repo: str, split: str, source_tag: str, cfg: Config):
|
| logger.info("Loading %s (split=%s) ...", repo, split)
|
| ds = load_dataset(repo, split=split, token=True)
|
| cols = ds.column_names
|
|
|
| text_col = _pick(cols, TEXT_CANDIDATES, cfg.text_column, "text")
|
| doc_col = _pick(cols, DOC_CANDIDATES, cfg.doc_id_column, "document_id")
|
| order_col = _pick(cols, ORDER_CANDIDATES, cfg.order_column, "order")
|
|
|
| if text_col is None or doc_col is None:
|
| raise ValueError(
|
| f"Could not resolve text/doc_id columns for {repo}. "
|
| f"Available columns: {cols}. "
|
| f"Pass --text-column / --doc-id-column explicitly."
|
| )
|
| logger.info(" %s -> text='%s', doc_id='%s', order='%s', rows=%d",
|
| source_tag, text_col, doc_col, order_col, len(ds))
|
|
|
|
|
| keep = {text_col: "text", doc_col: "document_id"}
|
| if order_col:
|
| keep[order_col] = "order"
|
| strat_col = cfg.stratify_column
|
| if strat_col is not None:
|
| if strat_col not in cols:
|
| raise ValueError(f"--stratify-column '{strat_col}' not in {repo} columns {cols}")
|
| keep[strat_col] = "stratum"
|
| ds = ds.rename_columns(keep)
|
|
|
|
|
|
|
| canonical = {"text", "document_id", "order", "stratum"}
|
| drop = [c for c in ds.column_names if c not in canonical]
|
| if drop:
|
| ds = ds.remove_columns(drop)
|
| if "order" not in ds.column_names:
|
|
|
| ds = ds.add_column("order", list(range(len(ds))))
|
| if "source" in ds.column_names:
|
| ds = ds.remove_columns(["source"])
|
| ds = ds.add_column("source", [source_tag] * len(ds))
|
|
|
|
|
| ds = ds.map(lambda b: {"document_id": [f"{source_tag}:{d}" for d in b["document_id"]]},
|
| batched=True, desc=f"namespace doc ids ({source_tag})")
|
| return ds
|
|
|
|
|
| def choose_heldout_docs(ds, cfg: Config):
|
| """Pick whole documents to hold out, capped per source, deterministic.
|
|
|
| If a ``stratum`` column is present (via --stratify-column), documents are
|
| sampled proportionally per stratum so rare genres are not lost from the
|
| held-out set; otherwise sampling is uniform-random over document ids.
|
| """
|
| rng = np.random.default_rng(cfg.split_seed)
|
| n_total_docs = len(set(ds["document_id"]))
|
| n_val = min(int(round(n_total_docs * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source)
|
| n_test = min(int(round(n_total_docs * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source)
|
| n_pick = n_val + n_test
|
|
|
| if "stratum" in ds.column_names:
|
|
|
| doc_stratum: Dict[str, str] = {}
|
| for d, s in zip(ds["document_id"], ds["stratum"]):
|
| doc_stratum.setdefault(d, str(s))
|
| by_stratum: Dict[str, List[str]] = {}
|
| for d, s in doc_stratum.items():
|
| by_stratum.setdefault(s, []).append(d)
|
| picked: List[str] = []
|
| for s, docs in sorted(by_stratum.items()):
|
| docs = sorted(docs)
|
| rng.shuffle(docs)
|
|
|
| k = max(1, int(round(n_pick * len(docs) / len(doc_stratum))))
|
| picked.extend(docs[:min(k, len(docs))])
|
| rng.shuffle(picked)
|
| picked = picked[:n_pick]
|
| else:
|
| doc_ids = sorted(set(ds["document_id"]))
|
| rng.shuffle(doc_ids)
|
| picked = doc_ids[:n_pick]
|
|
|
| val_docs = set(picked[:n_val])
|
| test_docs = set(picked[n_val:n_val + n_test])
|
| return val_docs, test_docs
|
|
|
|
|
| def representativeness_report(train_ds, val_ds, cfg: Config, out_dir: str):
|
| """Compare the held-out validation set against train on cheap, honest
|
| signals (paragraph char-length distribution, and stratum proportions if
|
| available) so a skewed eval set is caught before training."""
|
| def char_lengths(ds):
|
| out: List[int] = []
|
| for batch in ds.iter(batch_size=10000):
|
| out.extend(len(t) for t in batch["text"])
|
| return np.asarray(out, dtype=np.int64) if out else np.array([0], dtype=np.int64)
|
|
|
| def pct(a):
|
| if len(a) == 0:
|
| return {}
|
| return {f"p{p}": float(np.percentile(a, p)) for p in (5, 25, 50, 75, 95)}
|
|
|
| tr_len, va_len = char_lengths(train_ds), char_lengths(val_ds)
|
| report = {
|
| "train_paragraphs": len(train_ds),
|
| "val_paragraphs": len(val_ds),
|
| "char_len_train": {"mean": float(tr_len.mean()), **pct(tr_len)},
|
| "char_len_val": {"mean": float(va_len.mean()), **pct(va_len)},
|
| }
|
| if "stratum" in train_ds.column_names and "stratum" in val_ds.column_names:
|
| def props(ds):
|
| c = Counter(str(s) for s in ds["stratum"])
|
| tot = sum(c.values()) or 1
|
| return {k: v / tot for k, v in c.items()}
|
| tr_p, va_p = props(train_ds), props(val_ds)
|
| keys = sorted(set(tr_p) | set(va_p))
|
|
|
| tvd = 0.5 * sum(abs(tr_p.get(k, 0.0) - va_p.get(k, 0.0)) for k in keys)
|
| report["stratum_proportions_train"] = {k: round(tr_p.get(k, 0.0), 4) for k in keys}
|
| report["stratum_proportions_val"] = {k: round(va_p.get(k, 0.0), 4) for k in keys}
|
| report["stratum_total_variation_distance"] = round(tvd, 4)
|
| missing = [k for k in tr_p if k not in va_p]
|
| if missing:
|
| report["strata_absent_from_val"] = missing
|
|
|
| os.makedirs(out_dir, exist_ok=True)
|
| with open(os.path.join(out_dir, "eval_representativeness.json"), "w") as fh:
|
| json.dump(report, fh, indent=2, ensure_ascii=False)
|
| logger.info("Eval representativeness:\n%s", json.dumps(report, indent=2, ensure_ascii=False))
|
| if report.get("strata_absent_from_val"):
|
| logger.warning("Strata present in train but MISSING from validation: %s",
|
| report["strata_absent_from_val"])
|
| if report.get("stratum_total_variation_distance", 0.0) > 0.15:
|
| logger.warning("Validation stratum distribution differs notably from train "
|
| "(TVD=%.3f); consider --stratify-column or a larger eval cap.",
|
| report["stratum_total_variation_distance"])
|
| return report
|
|
|
|
|
| def dedup_rows(ds):
|
| """Eval honesty: collapse exact (document_id, order, text) duplicates that
|
| upsampling introduced, so every held-out paragraph is counted once."""
|
| seen = set()
|
| keep_idx = []
|
| for i, (d, o, t) in enumerate(zip(ds["document_id"], ds["order"], ds["text"])):
|
| key = hashlib.blake2b(f"{d}\x1f{o}\x1f{t}".encode("utf-8"), digest_size=16).digest()
|
| if key not in seen:
|
| seen.add(key)
|
| keep_idx.append(i)
|
| return ds.select(keep_idx)
|
|
|
|
|
|
|
|
|
|
|
| def maybe_nfc(text: str, cfg: Config) -> str:
|
| return unicodedata.normalize("NFC", text) if cfg.normalize_nfc else text
|
|
|
|
|
| def diagnose_tokenizer(texts: List[str], tokenizer, cfg: Config, out_dir: str):
|
| """Pre-flight check for SentencePiece behaviour on historical text."""
|
| unk = tokenizer.unk_token
|
| unk_id = tokenizer.unk_token_id
|
| rng = np.random.default_rng(cfg.split_seed)
|
| if len(texts) > cfg.tokenizer_sample_size:
|
| idx = rng.choice(len(texts), size=cfg.tokenizer_sample_size, replace=False)
|
| texts = [texts[i] for i in idx]
|
|
|
| total_tokens = 0
|
| total_unk = 0
|
| total_chars = 0
|
| total_words = 0
|
| tok_lengths: List[int] = []
|
| unk_char_counter: Counter = Counter()
|
| pathological: List[dict] = []
|
|
|
| for t in texts:
|
| t = maybe_nfc(t, cfg)
|
| toks = tokenizer.tokenize(t)
|
| n = len(toks)
|
| n_unk = sum(1 for x in toks if x == unk)
|
| total_tokens += n
|
| total_unk += n_unk
|
| total_chars += len(t)
|
| total_words += max(1, len(t.split()))
|
| tok_lengths.append(n)
|
| if len(t) >= 20 and n / max(1, len(t)) > 1.5:
|
| if len(pathological) < 25:
|
| pathological.append({"chars": len(t), "tokens": n, "preview": t[:80]})
|
|
|
| uniq_chars = set()
|
| for t in texts[:5000]:
|
| uniq_chars.update(maybe_nfc(t, cfg))
|
| risky_chars = {}
|
| for ch in uniq_chars:
|
| if ch.isspace():
|
| continue
|
| ids = tokenizer(ch, add_special_tokens=False)["input_ids"]
|
| if unk_id is not None and unk_id in ids:
|
| name = unicodedata.name(ch, "UNNAMED")
|
| risky_chars[ch] = {"codepoint": f"U+{ord(ch):04X}", "name": name}
|
|
|
| tok_lengths_arr = np.asarray(tok_lengths) if tok_lengths else np.array([0])
|
| report = {
|
| "sampled_paragraphs": len(texts),
|
| "normalize_nfc": cfg.normalize_nfc,
|
| "unk_token": unk,
|
| "unk_rate": round(total_unk / max(1, total_tokens), 6),
|
| "total_unk_tokens": total_unk,
|
| "fertility_tokens_per_word": round(total_tokens / max(1, total_words), 3),
|
| "fertility_tokens_per_char": round(total_tokens / max(1, total_chars), 3),
|
| "tokens_per_paragraph": {
|
| "mean": float(tok_lengths_arr.mean()),
|
| "p50": float(np.percentile(tok_lengths_arr, 50)),
|
| "p95": float(np.percentile(tok_lengths_arr, 95)),
|
| "max": int(tok_lengths_arr.max()),
|
| "share_over_max_seq_len": round(
|
| float((tok_lengths_arr > cfg.max_seq_length).mean()), 4),
|
| },
|
| "num_risky_unk_characters": len(risky_chars),
|
| "risky_unk_characters": dict(sorted(risky_chars.items())[:100]),
|
| "pathological_examples": pathological,
|
| }
|
| os.makedirs(out_dir, exist_ok=True)
|
| with open(os.path.join(out_dir, "tokenizer_diagnostics.json"), "w") as fh:
|
| json.dump(report, fh, indent=2, ensure_ascii=False)
|
|
|
| logger.info("Tokenizer diagnostics: unk_rate=%.4f%% fertility=%.2f tok/word "
|
| "risky_chars=%d", report["unk_rate"] * 100,
|
| report["fertility_tokens_per_word"], report["num_risky_unk_characters"])
|
| if report["unk_rate"] > 0.005:
|
| logger.warning("High <unk> rate (%.3f%%). Inspect risky_unk_characters in "
|
| "tokenizer_diagnostics.json; consider --normalize-nfc or a "
|
| "transliteration/cleanup pass for medieval glyphs.",
|
| report["unk_rate"] * 100)
|
| if risky_chars:
|
| sample = ", ".join(f"{c} ({m['codepoint']})" for c, m in list(risky_chars.items())[:15])
|
| logger.warning("Characters mapping to <unk> (sample): %s", sample)
|
| return report
|
|
|
|
|
| def tokenize_paragraphs(ds, tokenizer, cfg: Config):
|
| def _tok(batch):
|
| texts = [maybe_nfc(t, cfg) for t in batch["text"]] if cfg.normalize_nfc else batch["text"]
|
| enc = tokenizer(texts, add_special_tokens=False,
|
| truncation=False, return_attention_mask=False)
|
| return {"input_ids": enc["input_ids"]}
|
|
|
| return ds.map(
|
| _tok,
|
| batched=True,
|
| num_proc=cfg.preprocess_num_proc,
|
| remove_columns=["text"],
|
| desc="tokenize paragraphs",
|
| )
|
|
|
|
|
| def pack_doc_sentences(ds, tokenizer, cfg: Config, desc: str) -> Dataset:
|
| bos = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
|
| eos = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
|
| if bos is None or eos is None:
|
| raise ValueError("Tokenizer has no bos/cls or eos/sep token; cannot pack.")
|
| max_content = cfg.max_seq_length - 2
|
|
|
| sort_keys = ["document_id", "order"] if "order" in ds.column_names else ["document_id"]
|
| ds = ds.sort(sort_keys)
|
|
|
| def generator():
|
| buffer: List[int] = []
|
| cur_doc = None
|
| cur_src = None
|
|
|
| def flush_full():
|
| nonlocal buffer
|
| while len(buffer) >= max_content:
|
| chunk = buffer[:max_content]
|
| buffer = buffer[max_content:]
|
| yield {"input_ids": [bos] + chunk + [eos],
|
| "document_id": cur_doc, "source": cur_src}
|
|
|
| for batch in ds.iter(batch_size=2000):
|
| ids_col = batch["input_ids"]
|
| doc_col = batch["document_id"]
|
| src_col = batch["source"]
|
| for ids, doc, src in zip(ids_col, doc_col, src_col):
|
| if doc != cur_doc:
|
| if buffer and len(buffer) >= cfg.min_chunk_tokens:
|
| yield {"input_ids": [bos] + buffer + [eos],
|
| "document_id": cur_doc, "source": cur_src}
|
| buffer = []
|
| cur_doc, cur_src = doc, src
|
| buffer.extend(ids)
|
| yield from flush_full()
|
| if buffer and len(buffer) >= cfg.min_chunk_tokens:
|
| yield {"input_ids": [bos] + buffer + [eos],
|
| "document_id": cur_doc, "source": cur_src}
|
|
|
| packed = Dataset.from_generator(generator, cache_dir=cfg.cache_dir)
|
| logger.info(" packed %s -> %d sequences (max_len=%d)", desc, len(packed), cfg.max_seq_length)
|
| return packed
|
|
|
|
|
| def diagnose_pretokenized(ds, tokenizer, cfg: Config, out_dir: str):
|
| unk_id = tokenizer.unk_token_id
|
| bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
|
| eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
|
| n = len(ds)
|
| m = min(cfg.tokenizer_sample_size, n)
|
| idx = np.random.default_rng(cfg.split_seed).choice(n, size=m, replace=False)
|
| sample = ds.select(idx)
|
|
|
| total_tokens = 0
|
| total_unk = 0
|
| lengths: List[int] = []
|
| bos_ok = 0
|
| eos_ok = 0
|
| over_len = 0
|
| for ids in sample["input_ids"]:
|
| L = len(ids)
|
| lengths.append(L)
|
| total_tokens += L
|
| if unk_id is not None:
|
| total_unk += ids.count(unk_id)
|
| if L and ids[0] == bos_id:
|
| bos_ok += 1
|
| if L and ids[-1] == eos_id:
|
| eos_ok += 1
|
| if L > cfg.max_seq_length:
|
| over_len += 1
|
|
|
| arr = np.asarray(lengths) if lengths else np.array([0])
|
| report = {
|
| "mode": "pretokenized",
|
| "sampled_sequences": int(m),
|
| "total_sequences": int(n),
|
| "unk_rate": round(total_unk / max(1, total_tokens), 6),
|
| "total_unk_tokens": int(total_unk),
|
| "seq_len": {"mean": float(arr.mean()), "p50": float(np.percentile(arr, 50)),
|
| "p95": float(np.percentile(arr, 95)), "max": int(arr.max())},
|
| "share_over_max_seq_len": round(float(over_len) / max(1, m), 4),
|
| "starts_with_bos_rate": round(bos_ok / max(1, m), 4),
|
| "ends_with_eos_rate": round(eos_ok / max(1, m), 4),
|
| }
|
| os.makedirs(out_dir, exist_ok=True)
|
| with open(os.path.join(out_dir, "tokenizer_diagnostics.json"), "w") as fh:
|
| json.dump(report, fh, indent=2, ensure_ascii=False)
|
| logger.info("Pretokenized diagnostics: unk_rate=%.4f%% mean_len=%.1f "
|
| "bos=%.0f%% eos=%.0f%%", report["unk_rate"] * 100, report["seq_len"]["mean"],
|
| report["starts_with_bos_rate"] * 100, report["ends_with_eos_rate"] * 100)
|
| if report["unk_rate"] > 0.005:
|
| logger.warning("High <unk> rate (%.3f%%) in the pre-tokenized data — verify it was "
|
| "tokenized with the SAME tokenizer (%s).", report["unk_rate"] * 100, cfg.model_name)
|
| if report["starts_with_bos_rate"] < 0.5 or report["ends_with_eos_rate"] < 0.5:
|
| logger.warning("Many sequences lack bos/eos boundaries; if your packing omitted special "
|
| "tokens, MLM quality may suffer.")
|
| if report["share_over_max_seq_len"] > 0:
|
| logger.warning("%.2f%% of sequences exceed max_seq_length=%d and will be truncated.",
|
| report["share_over_max_seq_len"] * 100, cfg.max_seq_length)
|
| return report
|
|
|
|
|
| def _load_dataset_any(name: str, split: str, token: bool):
|
| if os.path.isdir(name):
|
| from datasets import load_from_disk
|
| d = load_from_disk(name)
|
| if hasattr(d, "keys") and not hasattr(d, "column_names"):
|
| d = d[split]
|
| return d
|
| return load_dataset(name, split=split, token=token)
|
|
|
|
|
| def _ids_hash(ids) -> bytes:
|
| return hashlib.blake2b(np.asarray(ids, dtype=np.int32).tobytes(), digest_size=12).digest()
|
|
|
|
|
| def build_pretokenized_bundle(cfg: Config, tokenizer, selected, cache_root: str):
|
| max_len = cfg.max_seq_length
|
| DOCID_CANDS = ("doc_id", "document_id", "docid")
|
| parts = []
|
|
|
| for repo, tag in selected:
|
| ds = _load_dataset_any(repo, cfg.dataset_split, True)
|
| docid_col = next((c for c in DOCID_CANDS if c in ds.column_names), None)
|
|
|
|
|
| keep = ["input_ids"]
|
| if "attention_mask" in ds.column_names:
|
| keep.append("attention_mask")
|
| if docid_col:
|
| keep.append(docid_col)
|
|
|
| strat_col = cfg.stratify_column
|
| if strat_col and strat_col in ds.column_names:
|
| keep.append(strat_col)
|
|
|
| ds = ds.remove_columns([c for c in ds.column_names if c not in keep])
|
|
|
| if docid_col and docid_col != "document_id":
|
| ds = ds.rename_column(docid_col, "document_id")
|
| if strat_col and strat_col in ds.column_names and strat_col != "stratum":
|
| ds = ds.rename_column(strat_col, "stratum")
|
|
|
|
|
| ds = ds.add_column("source", [tag] * len(ds))
|
| parts.append(ds)
|
| logger.info(" loaded pretokenized %s -> %d sequences (doc_id=%s)",
|
| tag, len(ds), docid_col)
|
|
|
| full = concatenate_datasets(parts) if len(parts) > 1 else parts[0]
|
| has_doc = "document_id" in full.column_names
|
|
|
| if has_doc:
|
| full = full.map(
|
| lambda b: {"document_id": [f"{s}:{d}" for s, d in zip(b["source"], b["document_id"])]},
|
| batched=True, desc="namespace doc ids")
|
|
|
| has_mask = "attention_mask" in full.column_names
|
|
|
| def _fix(batch):
|
| out = []
|
| masks = batch["attention_mask"] if has_mask else [None] * len(batch["input_ids"])
|
| for ids, am in zip(batch["input_ids"], masks):
|
| if am is not None:
|
| ids = [t for t, a in zip(ids, am) if a == 1]
|
| if len(ids) > max_len:
|
| ids = ids[:max_len]
|
| out.append(ids)
|
| return {"input_ids": out}
|
|
|
| full = full.map(_fix, batched=True, num_proc=cfg.preprocess_num_proc,
|
| remove_columns=(["attention_mask"] if has_mask else []),
|
| desc="normalize pretokenized")
|
|
|
| out_dir = cfg.output_dir
|
| if cfg.diagnose_tokenizer:
|
| diagnose_pretokenized(full, tokenizer, cfg, out_dir)
|
|
|
| rng = np.random.default_rng(cfg.split_seed)
|
| train_idx: List[int] = []
|
| val_idx: List[int] = []
|
| test_idx: List[int] = []
|
|
|
| if has_doc:
|
| docs = full["document_id"]
|
| unique_docs = sorted(set(docs))
|
| n_val = min(int(round(len(unique_docs) * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source)
|
| n_test = min(int(round(len(unique_docs) * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source)
|
| n_pick = n_val + n_test
|
|
|
|
|
| if "stratum" in full.column_names:
|
| logger.info(f"Applying STRATIFIED document split based on column '{cfg.stratify_column}'.")
|
| strata = full["stratum"]
|
| doc_stratum: Dict[str, str] = {}
|
| for d, s in zip(docs, strata):
|
| doc_stratum.setdefault(d, str(s))
|
| by_stratum: Dict[str, List[str]] = {}
|
| for d, s in doc_stratum.items():
|
| by_stratum.setdefault(s, []).append(d)
|
|
|
| picked: List[str] = []
|
| for s, d_list in sorted(by_stratum.items()):
|
| d_list = sorted(d_list)
|
| rng.shuffle(d_list)
|
| k = max(1, int(round(n_pick * len(d_list) / len(doc_stratum))))
|
| picked.extend(d_list[:min(k, len(d_list))])
|
| rng.shuffle(picked)
|
| val_docs = set(picked[:n_val])
|
| test_docs = set(picked[n_val:n_val + n_test])
|
| split_kind = "document-level STRATIFIED group-aware"
|
| else:
|
| rng.shuffle(unique_docs)
|
| val_docs = set(unique_docs[:n_val])
|
| test_docs = set(unique_docs[n_val:n_val + n_test])
|
| split_kind = "document-level group-aware"
|
|
|
|
|
| val_cand, test_cand = [], []
|
| for i, d in enumerate(docs):
|
| if d in val_docs:
|
| val_cand.append((i, d))
|
| elif d in test_docs:
|
| test_cand.append((i, d))
|
| else:
|
| train_idx.append(i)
|
|
|
| for cand, dst in ((val_cand, val_idx), (test_cand, test_idx)):
|
| seen = set()
|
| for i, d in cand:
|
| key = (d, _ids_hash(full[i]["input_ids"]))
|
| if key not in seen:
|
| seen.add(key)
|
| dst.append(i)
|
| else:
|
| logger.warning("No doc_id/document_id column found in the pre-tokenized data; "
|
| "falling back to a weaker sequence-level holdout.")
|
| hashes = [_ids_hash(ids) for ids in full["input_ids"]]
|
| uniq = sorted(set(hashes))
|
| rng.shuffle(uniq)
|
| n_val = min(int(round(len(uniq) * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source * 4)
|
| n_test = min(int(round(len(uniq) * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source * 4)
|
| val_h = set(uniq[:n_val])
|
| test_h = set(uniq[n_val:n_val + n_test])
|
| seen_v, seen_t = set(), set()
|
| for i, h in enumerate(hashes):
|
| if h in val_h:
|
| if h not in seen_v:
|
| seen_v.add(h); val_idx.append(i)
|
| elif h in test_h:
|
| if h not in seen_t:
|
| seen_t.add(h); test_idx.append(i)
|
| else:
|
| train_idx.append(i)
|
| split_kind = "sequence-level dedup"
|
|
|
| bundle: Dict[str, Dataset] = {"train": full.select(train_idx)}
|
| if val_idx:
|
| bundle["validation"] = full.select(val_idx)
|
| if len(set(bundle["validation"]["source"])) > 1:
|
| for tag in ("med", "canon"):
|
| sub = bundle["validation"].filter(lambda b: [s == tag for s in b["source"]], batched=True)
|
| if len(sub):
|
| bundle[f"validation_{tag}"] = sub
|
| if test_idx:
|
| bundle["test"] = full.select(test_idx)
|
|
|
| logger.info("Pretokenized split (%s) -> train=%d val=%d test=%d",
|
| split_kind, len(bundle["train"]), len(val_idx), len(test_idx))
|
|
|
| def len_stats(d):
|
| ls = np.asarray([len(x) for x in d["input_ids"]]) if len(d) else np.array([0])
|
| return {"mean": float(ls.mean()), "p50": float(np.percentile(ls, 50)),
|
| "p95": float(np.percentile(ls, 95))}
|
| rep = {"split_kind": split_kind,
|
| "train_sequences": len(bundle["train"]), "val_sequences": len(val_idx),
|
| "len_train": len_stats(bundle["train"])}
|
| if val_idx:
|
| rep["len_val"] = len_stats(bundle["validation"])
|
| with open(os.path.join(out_dir, "eval_representativeness.json"), "w") as fh:
|
| json.dump(rep, fh, indent=2)
|
|
|
| if cfg.diagnose_only:
|
| logger.info("--diagnose-only set: wrote diagnostics to %s, exiting before training.", out_dir)
|
| sys.exit(0)
|
|
|
| os.makedirs(cache_root, exist_ok=True)
|
| for name, d in bundle.items():
|
| d.save_to_disk(os.path.join(cache_root, name))
|
| logger.info("Saved pretokenized datasets to %s", cache_root)
|
| return bundle
|
|
|
|
|
| def build_or_load_packed(cfg: Config, tokenizer):
|
| sig = json.dumps({
|
| "med": cfg.med_dataset, "canon": cfg.canon_dataset, "split": cfg.dataset_split,
|
| "model": cfg.model_name, "max_len": cfg.max_seq_length, "min_chunk": cfg.min_chunk_tokens,
|
| "val_frac": cfg.val_doc_fraction, "test_frac": cfg.test_doc_fraction,
|
| "max_eval": cfg.max_eval_docs_per_source, "split_seed": cfg.split_seed,
|
| "text_col": cfg.text_column, "doc_col": cfg.doc_id_column, "order_col": cfg.order_column,
|
| "stratify": cfg.stratify_column, "nfc": cfg.normalize_nfc,
|
| "train_corpus": cfg.train_corpus, "dedup_train": cfg.dedup_train,
|
| }, sort_keys=True)
|
| key = hashlib.blake2b(sig.encode("utf-8"), digest_size=12).hexdigest()
|
| cache_root = os.path.join(cfg.cache_dir, f"packed_{key}")
|
|
|
| if os.path.isdir(cache_root) and not cfg.diagnose_only:
|
| from datasets import load_from_disk
|
| logger.info("Loading cached packed datasets from %s", cache_root)
|
| bundle = {name: load_from_disk(os.path.join(cache_root, name))
|
| for name in os.listdir(cache_root)
|
| if os.path.isdir(os.path.join(cache_root, name))}
|
| return bundle
|
|
|
| if cfg.train_corpus not in {"combined", "med", "canon"}:
|
| raise ValueError(f"--train-corpus must be combined|med|canon, got {cfg.train_corpus!r}")
|
|
|
| selected = []
|
| if cfg.train_corpus in {"combined", "med"}:
|
| selected.append((cfg.med_dataset, "med"))
|
| if cfg.train_corpus in {"combined", "canon"}:
|
| selected.append((cfg.canon_dataset, "canon"))
|
|
|
| probe = _load_dataset_any(selected[0][0], cfg.dataset_split, True)
|
| if "input_ids" in probe.column_names:
|
| logger.warning(
|
| "Dataset %s is PRE-TOKENIZED (columns=%s). Skipping raw-text tokenization "
|
| "and DOC-SENTENCES packing.", selected[0][0], probe.column_names)
|
| del probe
|
| return build_pretokenized_bundle(cfg, tokenizer, selected, cache_root)
|
| del probe
|
|
|
| sources = []
|
| if cfg.train_corpus in {"combined", "med"}:
|
| sources.append(load_source(cfg.med_dataset, cfg.dataset_split, "med", cfg))
|
| if cfg.train_corpus in {"combined", "canon"}:
|
| sources.append(load_source(cfg.canon_dataset, cfg.dataset_split, "canon", cfg))
|
| logger.info("Training corpus = %s (%d source dataset(s))", cfg.train_corpus, len(sources))
|
|
|
| train_parts, val_parts, test_parts = [], [], []
|
| for src in sources:
|
| tag = src["source"][0]
|
| val_docs, test_docs = choose_heldout_docs(src, cfg)
|
| in_val = src.filter(lambda b: [d in val_docs for d in b["document_id"]],
|
| batched=True, desc=f"select val ({tag})")
|
| in_test = src.filter(lambda b: [d in test_docs for d in b["document_id"]],
|
| batched=True, desc=f"select test ({tag})")
|
| held = val_docs | test_docs
|
| in_train = src.filter(lambda b: [d not in held for d in b["document_id"]],
|
| batched=True, desc=f"select train ({tag})")
|
| train_parts.append(in_train)
|
| if len(in_val):
|
| val_parts.append(dedup_rows(in_val))
|
| if len(in_test):
|
| test_parts.append(dedup_rows(in_test))
|
|
|
| train_raw = concatenate_datasets(train_parts)
|
| if cfg.dedup_train:
|
| before = len(train_raw)
|
| train_raw = dedup_rows(train_raw)
|
| logger.info("dedup_train: collapsed %d -> %d training paragraphs "
|
| "(upsampling duplicates removed).", before, len(train_raw))
|
| val_raw = concatenate_datasets(val_parts) if val_parts else None
|
| logger.info("Raw paragraph counts -> train=%d val=%d test=%d",
|
| len(train_raw),
|
| sum(len(p) for p in val_parts),
|
| sum(len(p) for p in test_parts))
|
|
|
| out_dir = cfg.output_dir
|
| if cfg.diagnose_tokenizer:
|
| n = len(train_raw)
|
| m = min(cfg.tokenizer_sample_size, n)
|
| idx = np.random.default_rng(cfg.split_seed).choice(n, size=m, replace=False)
|
| sample_texts = train_raw.select(idx)["text"]
|
| diagnose_tokenizer(sample_texts, tokenizer, cfg, out_dir)
|
| if val_raw is not None:
|
| representativeness_report(train_raw, val_raw, cfg, out_dir)
|
|
|
| if cfg.diagnose_only:
|
| logger.info("--diagnose-only set: wrote diagnostics to %s, exiting before "
|
| "packing/training.", out_dir)
|
| sys.exit(0)
|
|
|
| bundle: Dict[str, Dataset] = {}
|
| bundle["train"] = pack_doc_sentences(tokenize_paragraphs(train_raw, tokenizer, cfg),
|
| tokenizer, cfg, "train")
|
| if val_raw is not None:
|
| bundle["validation"] = pack_doc_sentences(tokenize_paragraphs(val_raw, tokenizer, cfg),
|
| tokenizer, cfg, "validation")
|
| if len(sources) > 1:
|
| for tag in ("med", "canon"):
|
| sub = val_raw.filter(lambda b: [s == tag for s in b["source"]], batched=True)
|
| if len(sub):
|
| bundle[f"validation_{tag}"] = pack_doc_sentences(
|
| tokenize_paragraphs(sub, tokenizer, cfg), tokenizer, cfg, f"validation_{tag}")
|
| if test_parts:
|
| test_raw = concatenate_datasets(test_parts)
|
| bundle["test"] = pack_doc_sentences(tokenize_paragraphs(test_raw, tokenizer, cfg),
|
| tokenizer, cfg, "test")
|
|
|
| os.makedirs(cache_root, exist_ok=True)
|
| for name, d in bundle.items():
|
| d.save_to_disk(os.path.join(cache_root, name))
|
| logger.info("Saved packed datasets to %s", cache_root)
|
| return bundle
|
|
|
|
|
|
|
|
|
|
|
| def preprocess_logits_for_metrics(logits, labels):
|
| if isinstance(logits, tuple):
|
| logits = logits[0]
|
| return logits.argmax(dim=-1)
|
|
|
|
|
| def compute_metrics(eval_pred):
|
| preds, labels = eval_pred
|
| labels = labels.reshape(-1)
|
| preds = preds.reshape(-1)
|
| mask = labels != -100
|
| if mask.sum() == 0:
|
| return {"masked_accuracy": 0.0}
|
| correct = (preds[mask] == labels[mask]).sum()
|
| return {"masked_accuracy": float(correct) / float(mask.sum())}
|
|
|
|
|
| class PerplexityCallback(TrainerCallback):
|
| def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
| if not metrics:
|
| return
|
| for k in list(metrics.keys()):
|
| if k.endswith("loss") and ("eval" in k):
|
| try:
|
| metrics[k.replace("loss", "perplexity")] = math.exp(min(metrics[k], 20))
|
| except OverflowError:
|
| metrics[k.replace("loss", "perplexity")] = float("inf")
|
|
|
|
|
| class JsonlLoggingCallback(TrainerCallback):
|
| def __init__(self, path: str):
|
| self.path = path
|
| os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
| def on_log(self, args, state, control, logs=None, **kwargs):
|
| if not logs:
|
| return
|
| record = dict(logs)
|
| record["step"] = state.global_step
|
| record["epoch"] = state.epoch
|
| record["wall_time"] = datetime.now().isoformat(timespec="seconds")
|
| with open(self.path, "a", encoding="utf-8") as fh:
|
| fh.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| fh.flush()
|
| os.fsync(fh.fileno())
|
|
|
|
|
|
|
|
|
|
|
| def capture_environment(cfg: Config) -> dict:
|
| import transformers
|
| import datasets as ds_lib
|
| info = {
|
| "timestamp": datetime.now().isoformat(timespec="seconds"),
|
| "python": platform.python_version(),
|
| "platform": platform.platform(),
|
| "torch": torch.__version__,
|
| "transformers": transformers.__version__,
|
| "datasets": ds_lib.__version__,
|
| "cuda_available": torch.cuda.is_available(),
|
| }
|
| if torch.cuda.is_available():
|
| info["cuda"] = torch.version.cuda
|
| info["gpu_name"] = torch.cuda.get_device_name(0)
|
| props = torch.cuda.get_device_properties(0)
|
| info["gpu_total_memory_gb"] = round(props.total_memory / 1024**3, 1)
|
| cap = f"{props.major}.{props.minor}"
|
| info["gpu_capability"] = cap
|
| info["bf16_supported"] = torch.cuda.is_bf16_supported()
|
| try:
|
| arch_list = torch.cuda.get_arch_list()
|
| except Exception:
|
| arch_list = []
|
| info["torch_arch_list"] = arch_list
|
| sm_tag = f"sm_{props.major}{props.minor}"
|
| info["gpu_arch_supported_by_torch"] = any(sm_tag == a for a in arch_list)
|
| return info
|
|
|
|
|
|
|
|
|
|
|
| def parse_args() -> Config:
|
| cfg = Config()
|
| p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| for f in cfg.__dataclass_fields__.values():
|
| name = "--" + f.name.replace("_", "-")
|
| default = getattr(cfg, f.name)
|
| if isinstance(default, bool):
|
| if default:
|
| p.add_argument("--no-" + f.name.replace("_", "-"), dest=f.name, action="store_false")
|
| else:
|
| p.add_argument(name, dest=f.name, action="store_true")
|
| elif default is None:
|
| p.add_argument(name, dest=f.name, default=None, type=str)
|
| else:
|
| p.add_argument(name, dest=f.name, default=default, type=type(default))
|
| args = p.parse_args()
|
| return Config(**vars(args))
|
|
|
|
|
|
|
|
|
|
|
| def main():
|
| cfg = parse_args()
|
| set_seed(cfg.seed)
|
|
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
| torch.set_float32_matmul_precision("high")
|
|
|
| os.makedirs(cfg.output_dir, exist_ok=True)
|
|
|
| file_handler = logging.FileHandler(os.path.join(cfg.output_dir, "train.log"))
|
| file_handler.setFormatter(logging.Formatter(
|
| "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s", "%Y-%m-%d %H:%M:%S"))
|
| logging.getLogger().addHandler(file_handler)
|
|
|
| env = capture_environment(cfg)
|
| logger.info("Environment:\n%s", json.dumps(env, indent=2))
|
| with open(os.path.join(cfg.output_dir, "run_metadata.json"), "w") as fh:
|
| json.dump({"config": cfg.__dict__, "environment": env}, fh, indent=2)
|
|
|
| if not env["cuda_available"]:
|
| logger.warning("CUDA not available — this script is intended to run on a GPU server.")
|
| else:
|
| if not env.get("bf16_supported", False):
|
| logger.warning("bf16 not reported as supported on this GPU; consider fp16.")
|
| if env.get("torch_arch_list") and not env.get("gpu_arch_supported_by_torch", True):
|
| logger.error(
|
| "This torch build (%s, arch_list=%s) has NO kernels for your GPU "
|
| "(capability %s). On Blackwell (RTX PRO 6000) install a cu128+ build, e.g.:\n"
|
| " pip install --upgrade torch --index-url https://download.pytorch.org/whl/cu128\n"
|
| "CUDA ops will fail until this matches.",
|
| env["torch"], env.get("torch_arch_list"), env.get("gpu_capability"))
|
|
|
| if cfg.save_steps % cfg.eval_steps != 0:
|
| new_save = max(cfg.eval_steps, (cfg.save_steps // cfg.eval_steps) * cfg.eval_steps)
|
| logger.warning("save_steps (%d) not a multiple of eval_steps (%d); adjusting to %d.",
|
| cfg.save_steps, cfg.eval_steps, new_save)
|
| cfg.save_steps = new_save
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
|
| tokenizer.model_max_length = cfg.max_seq_length
|
|
|
| bundle = build_or_load_packed(cfg, tokenizer)
|
| train_ds = bundle["train"]
|
| eval_ds = bundle.get("validation")
|
|
|
| def tok_count(d):
|
| lengths = d.map(lambda b: {"len": [len(x) for x in b["input_ids"]]},
|
| batched=True, remove_columns=d.column_names, desc="count tokens")["len"]
|
| return int(sum(lengths)), float(np.mean(lengths))
|
|
|
| stats = {}
|
| for name, d in bundle.items():
|
| n_tok, mean_len = tok_count(d)
|
| stats[name] = {"sequences": len(d), "tokens": n_tok, "mean_seq_len": round(mean_len, 1)}
|
| logger.info("Packed dataset statistics:\n%s", json.dumps(stats, indent=2))
|
| with open(os.path.join(cfg.output_dir, "dataset_stats.json"), "w") as fh:
|
| json.dump(stats, fh, indent=2)
|
|
|
| if cfg.dry_run:
|
| logger.info("--dry-run set: datasets built and cached, exiting before training.")
|
| return
|
|
|
| train_ds = train_ds.remove_columns([c for c in train_ds.column_names if c != "input_ids"])
|
| if eval_ds is not None:
|
| eval_ds = eval_ds.remove_columns([c for c in eval_ds.column_names if c != "input_ids"])
|
|
|
| model = AutoModelForMaskedLM.from_pretrained(
|
| cfg.model_name,
|
| attn_implementation=cfg.attn_implementation,
|
| )
|
| if cfg.gradient_checkpointing:
|
| model.config.use_cache = False
|
| logger.info("Model parameters: %.1fM", sum(p.numel() for p in model.parameters()) / 1e6)
|
|
|
| collator = DataCollatorForLanguageModeling(
|
| tokenizer=tokenizer,
|
| mlm=True,
|
| mlm_probability=cfg.mlm_probability,
|
| pad_to_multiple_of=8,
|
| )
|
|
|
| report_to = ["tensorboard"]
|
| if cfg.report_to_wandb:
|
| report_to.append("wandb")
|
|
|
| do_eval = eval_ds is not None
|
| args = TrainingArguments(
|
| output_dir=cfg.output_dir,
|
| seed=cfg.seed,
|
| data_seed=cfg.seed,
|
| num_train_epochs=cfg.num_train_epochs,
|
| max_steps=cfg.max_steps,
|
| learning_rate=cfg.learning_rate,
|
| weight_decay=cfg.weight_decay,
|
| adam_beta1=cfg.adam_beta1,
|
| adam_beta2=cfg.adam_beta2,
|
| adam_epsilon=cfg.adam_epsilon,
|
| max_grad_norm=cfg.max_grad_norm,
|
| warmup_ratio=cfg.warmup_ratio,
|
| lr_scheduler_type=cfg.lr_scheduler_type,
|
| optim=cfg.optim,
|
| bf16=bool(env.get("bf16_supported")),
|
| per_device_train_batch_size=cfg.per_device_train_batch_size,
|
| per_device_eval_batch_size=cfg.per_device_eval_batch_size,
|
| gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
| gradient_checkpointing=cfg.gradient_checkpointing,
|
| auto_find_batch_size=cfg.auto_find_batch_size,
|
| torch_compile=cfg.torch_compile,
|
| dataloader_num_workers=cfg.dataloader_num_workers,
|
| dataloader_pin_memory=True,
|
| eval_accumulation_steps=cfg.eval_accumulation_steps,
|
| eval_strategy="steps" if do_eval else "no",
|
| eval_steps=cfg.eval_steps if do_eval else None,
|
| logging_strategy="steps",
|
| logging_steps=cfg.logging_steps,
|
| logging_first_step=True,
|
| save_strategy="steps",
|
| save_steps=cfg.save_steps,
|
| save_total_limit=cfg.save_total_limit,
|
| load_best_model_at_end=do_eval,
|
| metric_for_best_model="eval_loss" if do_eval else None,
|
| greater_is_better=False,
|
| report_to=report_to,
|
| run_name=os.path.basename(cfg.output_dir.rstrip("/")),
|
| include_num_input_tokens_seen=True,
|
| logging_dir=os.path.join(cfg.output_dir, "tb"),
|
| )
|
|
|
| callbacks = [PerplexityCallback(),
|
| JsonlLoggingCallback(os.path.join(cfg.output_dir, "training_log.jsonl"))]
|
| if do_eval and cfg.early_stopping_patience > 0:
|
| callbacks.append(EarlyStoppingCallback(
|
| early_stopping_patience=cfg.early_stopping_patience,
|
| early_stopping_threshold=cfg.early_stopping_threshold))
|
|
|
| tok_kwarg = {"processing_class": tokenizer} if _USE_PROCESSING_CLASS else {"tokenizer": tokenizer}
|
| trainer = Trainer(
|
| model=model,
|
| args=args,
|
| train_dataset=train_ds,
|
| eval_dataset=eval_ds,
|
| data_collator=collator,
|
| compute_metrics=compute_metrics if do_eval else None,
|
| preprocess_logits_for_metrics=preprocess_logits_for_metrics if do_eval else None,
|
| callbacks=callbacks,
|
| **tok_kwarg,
|
| )
|
|
|
| resume_from = None
|
| if cfg.resume:
|
| last = get_last_checkpoint(cfg.output_dir)
|
| if last:
|
| logger.info("Resuming from checkpoint %s", last)
|
| resume_from = last
|
|
|
| try:
|
| train_result = trainer.train(resume_from_checkpoint=resume_from)
|
| except BaseException as exc:
|
| tb_path = os.path.join(cfg.output_dir, "crash_traceback.txt")
|
| with open(tb_path, "w", encoding="utf-8") as fh:
|
| fh.write(f"Crashed at {datetime.now().isoformat()}\n\n")
|
| traceback.print_exc(file=fh)
|
| logger.error("Training crashed (%s: %s). Traceback written to %s",
|
| type(exc).__name__, exc, tb_path)
|
| try:
|
| emergency = os.path.join(cfg.output_dir, "emergency_checkpoint")
|
| trainer.save_model(emergency)
|
| trainer.save_state()
|
| logger.error("Emergency model state saved to %s", emergency)
|
| except Exception as save_exc:
|
| logger.error("Emergency save also failed: %s", save_exc)
|
| raise
|
|
|
| trainer.save_model()
|
| tokenizer.save_pretrained(cfg.output_dir)
|
| trainer.save_state()
|
|
|
| metrics = train_result.metrics
|
| trainer.log_metrics("train", metrics)
|
| trainer.save_metrics("train", metrics)
|
|
|
| final = {}
|
| if eval_ds is not None:
|
| m = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix="final_val")
|
| m["final_val_perplexity"] = math.exp(min(m["final_val_loss"], 20))
|
| final.update(m)
|
| for tag in ("med", "canon"):
|
| sub = bundle.get(f"validation_{tag}")
|
| if sub is not None:
|
| sub = sub.remove_columns([c for c in sub.column_names if c != "input_ids"])
|
| m = trainer.evaluate(eval_dataset=sub, metric_key_prefix=f"final_val_{tag}")
|
| m[f"final_val_{tag}_perplexity"] = math.exp(min(m[f"final_val_{tag}_loss"], 20))
|
| final.update(m)
|
| if "test" in bundle:
|
| test_ds = bundle["test"].remove_columns(
|
| [c for c in bundle["test"].column_names if c != "input_ids"])
|
| m = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="final_test")
|
| m["final_test_perplexity"] = math.exp(min(m["final_test_loss"], 20))
|
| final.update(m)
|
|
|
| if final:
|
| logger.info("Final evaluation:\n%s", json.dumps(final, indent=2))
|
| trainer.save_metrics("final_eval", final)
|
| logger.info("Done. Best model + tokenizer saved to %s", cfg.output_dir)
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |