MediBerta / train.py
mschonhardt's picture
Add files using upload-large-folder tool
6b4288c verified
Raw
History Blame Contribute Delete
52.2 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Domain-Adaptive Pretraining (DAPT) of xlm-roberta-large for Medieval Latin.
Pipeline
--------
1. Load MedBerta (baseline) + CanonBerta (already upsampled) from the HF Hub.
2. Carve a *group-aware* held-out set: whole documents are held out (by
``document_id``) so no document leaks across train / validation / test.
This matters for an honest perplexity and for downstream Loci-Similes /
text-reuse evaluation.
3. DOC-SENTENCES packing: paragraphs are grouped by ``document_id`` (kept in
document order), tokenized, and greedily packed into fixed 512-token
sequences that never cross a document boundary.
4. MLM continued pretraining with the HF Trainer, bf16, SDPA attention,
in-training evaluation (loss + perplexity + masked-token accuracy),
best-model selection, and rich logging for later write-up.
The script is built for a single NVIDIA RTX PRO 6000 Blackwell (96 GB) and is
deliberately defensive about OOM (auto batch-size search, expandable CUDA
segments, optional gradient checkpointing, eval logit reduction).
Authentication
--------------
The datasets are private. Log in once before running:
huggingface-cli login # or: export HF_TOKEN=hf_...
Quick check before spending GPU time (builds + caches datasets, prints stats,
no training):
python dapt_xlmr_pretrain.py --dry-run
Typical run:
python dapt_xlmr_pretrain.py \
--output-dir runs/dapt_xlmr_medlatin_v1 \
--num-train-epochs 3 \
--per-device-train-batch-size 32 \
--gradient-accumulation-steps 8
"""
import argparse
import hashlib
import inspect
import json
import logging
import math
import os
import platform
import sys
import traceback
import unicodedata
from collections import Counter
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, List, Optional
# Must be set BEFORE torch is imported to take effect: reduces fragmentation,
# which is the most common cause of "phantom" OOM on long runs.
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # avoid fork warnings
import numpy as np
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
try:
_USE_PROCESSING_CLASS = "processing_class" in inspect.signature(Trainer.__init__).parameters
except Exception: # extremely defensive; fall back to legacy name
_USE_PROCESSING_CLASS = False
logging.basicConfig(
format="%(asctime)s | %(levelname)-7s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger("dapt")
# --------------------------------------------------------------------------- #
# Config
# --------------------------------------------------------------------------- #
@dataclass
class Config:
# Data
med_dataset: str = "mschonhardt/MedBerta"
canon_dataset: str = "mschonhardt/CanonBerta"
dataset_split: str = "train"
text_column: Optional[str] = None # auto-detected if None
doc_id_column: Optional[str] = None # auto-detected if None
order_column: Optional[str] = None # optional within-doc ordering key
# Model / tokenizer
model_name: str = "xlm-roberta-large"
max_seq_length: int = 512
mlm_probability: float = 0.15
attn_implementation: str = "sdpa" # robust on Blackwell; "flash_attention_2" optional
min_chunk_tokens: int = 64 # drop doc-final fragments shorter than this
# Held-out construction (group-aware, by document_id)
val_doc_fraction: float = 0.01
test_doc_fraction: float = 0.0 # set >0 to also hold out a test set
max_eval_docs_per_source: int = 400 # cap eval size to keep eval fast
stratify_column: Optional[str] = "category" # e.g. "genre"/"subcorpus": stratified held-out
split_seed: int = 13
# Which corpus to train ON. "combined" = ONE model on MedBerta + upsampled
# CanonBerta (the methodology you described). "med"/"canon" = a single-corpus
# model, e.g. for an ablation or two separate deliverables. Held-out eval is
# still reported per-source where both are present.
train_corpus: str = "combined" # {"combined", "med", "canon"}
dedup_train: bool = False # collapse upsampling duplicates in TRAIN
# (recommended for a standalone canon model)
# Pre-flight diagnostics & text handling
diagnose_tokenizer: bool = True # run UNK / fertility report before training
tokenizer_sample_size: int = 20000 # paragraphs sampled for the diagnostic
normalize_nfc: bool = False # apply Unicode NFC before tokenizing
diagnose_only: bool = False # run diagnostics + reports, then exit
# Optimisation (RoBERTa-style defaults, tuned for *continued* pretraining)
learning_rate: float = 1e-4 # peak LR for the combined DAPT corpus
weight_decay: float = 0.01
adam_beta1: float = 0.9
adam_beta2: float = 0.98
adam_epsilon: float = 1e-6
max_grad_norm: float = 1.0
warmup_ratio: float = 0.06
lr_scheduler_type: str = "linear"
num_train_epochs: float = 10.0 # ceiling; early stopping ends training
max_steps: int = -1
optim: str = "adamw_torch_fused" # fused AdamW: faster on Blackwell
early_stopping_patience: int = 5 # eval rounds w/o improvement before stop
early_stopping_threshold: float = 1e-4 # min eval_loss delta to count as improvement
# Throughput / memory (96 GB GDDR7 -> large effective batch)
per_device_train_batch_size: int = 64
per_device_eval_batch_size: int = 64
gradient_accumulation_steps: int = 16 # effective batch ~1024 sequences
gradient_checkpointing: bool = False
auto_find_batch_size: bool = True
torch_compile: bool = False # enable for the final fixed-batch run
dataloader_num_workers: int = 8
eval_accumulation_steps: int = 50 # offload eval tensors to CPU periodically
# Schedule of eval / logging / checkpointing
eval_steps: int = 500
logging_steps: int = 50
save_steps: int = 500
save_total_limit: int = 3
# Bookkeeping
output_dir: str = "runs/dapt_xlmr_medlatin"
cache_dir: str = ".cache_packed"
seed: int = 42
report_to_wandb: bool = False
preprocess_num_proc: int = 16
resume: bool = True
dry_run: bool = False
# --------------------------------------------------------------------------- #
# Column auto-detection
# --------------------------------------------------------------------------- #
TEXT_CANDIDATES = ["text", "paragraph", "content", "sentence", "passage"]
DOC_CANDIDATES = ["document_id", "doc_id", "docid", "document", "work_id", "work"]
ORDER_CANDIDATES = ["order", "paragraph_id", "par_id", "index", "idx", "n", "position", "seq"]
def _pick(columns: List[str], candidates: List[str], explicit: Optional[str], kind: str) -> Optional[str]:
if explicit is not None:
if explicit not in columns:
raise ValueError(f"Requested {kind} column '{explicit}' not in dataset columns {columns}")
return explicit
for c in candidates:
if c in columns:
return c
return None
# --------------------------------------------------------------------------- #
# Dataset loading & group-aware split
# --------------------------------------------------------------------------- #
def load_source(repo: str, split: str, source_tag: str, cfg: Config):
logger.info("Loading %s (split=%s) ...", repo, split)
ds = load_dataset(repo, split=split, token=True)
cols = ds.column_names
text_col = _pick(cols, TEXT_CANDIDATES, cfg.text_column, "text")
doc_col = _pick(cols, DOC_CANDIDATES, cfg.doc_id_column, "document_id")
order_col = _pick(cols, ORDER_CANDIDATES, cfg.order_column, "order") # may be None
if text_col is None or doc_col is None:
raise ValueError(
f"Could not resolve text/doc_id columns for {repo}. "
f"Available columns: {cols}. "
f"Pass --text-column / --doc-id-column explicitly."
)
logger.info(" %s -> text='%s', doc_id='%s', order='%s', rows=%d",
source_tag, text_col, doc_col, order_col, len(ds))
# Normalise to a small canonical schema and stamp the source.
keep = {text_col: "text", doc_col: "document_id"}
if order_col:
keep[order_col] = "order"
strat_col = cfg.stratify_column
if strat_col is not None:
if strat_col not in cols:
raise ValueError(f"--stratify-column '{strat_col}' not in {repo} columns {cols}")
keep[strat_col] = "stratum"
ds = ds.rename_columns(keep)
# The corpora already ship a `source` column ("med"/"canon"). We deliberately
# drop it here and re-stamp an authoritative per-dataset tag below, so
# provenance is guaranteed consistent regardless of the incoming values.
canonical = {"text", "document_id", "order", "stratum"}
drop = [c for c in ds.column_names if c not in canonical]
if drop:
ds = ds.remove_columns(drop)
if "order" not in ds.column_names:
# Fallback only (your schema always has `order`): stable within-doc order.
ds = ds.add_column("order", list(range(len(ds))))
if "source" in ds.column_names: # safety: never collide with add_column
ds = ds.remove_columns(["source"])
ds = ds.add_column("source", [source_tag] * len(ds))
# document_id is namespaced per source so identical ids in both corpora
# never collide during splitting/packing.
ds = ds.map(lambda b: {"document_id": [f"{source_tag}:{d}" for d in b["document_id"]]},
batched=True, desc=f"namespace doc ids ({source_tag})")
return ds
def choose_heldout_docs(ds, cfg: Config):
"""Pick whole documents to hold out, capped per source, deterministic.
If a ``stratum`` column is present (via --stratify-column), documents are
sampled proportionally per stratum so rare genres are not lost from the
held-out set; otherwise sampling is uniform-random over document ids.
"""
rng = np.random.default_rng(cfg.split_seed)
n_total_docs = len(set(ds["document_id"]))
n_val = min(int(round(n_total_docs * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source)
n_test = min(int(round(n_total_docs * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source)
n_pick = n_val + n_test
if "stratum" in ds.column_names:
# Map each document to its (majority) stratum.
doc_stratum: Dict[str, str] = {}
for d, s in zip(ds["document_id"], ds["stratum"]):
doc_stratum.setdefault(d, str(s))
by_stratum: Dict[str, List[str]] = {}
for d, s in doc_stratum.items():
by_stratum.setdefault(s, []).append(d)
picked: List[str] = []
for s, docs in sorted(by_stratum.items()):
docs = sorted(docs)
rng.shuffle(docs)
# proportional allocation, at least 1 doc per non-empty stratum
k = max(1, int(round(n_pick * len(docs) / len(doc_stratum))))
picked.extend(docs[:min(k, len(docs))])
rng.shuffle(picked)
picked = picked[:n_pick]
else:
doc_ids = sorted(set(ds["document_id"]))
rng.shuffle(doc_ids)
picked = doc_ids[:n_pick]
val_docs = set(picked[:n_val])
test_docs = set(picked[n_val:n_val + n_test])
return val_docs, test_docs
def representativeness_report(train_ds, val_ds, cfg: Config, out_dir: str):
"""Compare the held-out validation set against train on cheap, honest
signals (paragraph char-length distribution, and stratum proportions if
available) so a skewed eval set is caught before training."""
def char_lengths(ds):
out: List[int] = []
for batch in ds.iter(batch_size=10000):
out.extend(len(t) for t in batch["text"])
return np.asarray(out, dtype=np.int64) if out else np.array([0], dtype=np.int64)
def pct(a):
if len(a) == 0:
return {}
return {f"p{p}": float(np.percentile(a, p)) for p in (5, 25, 50, 75, 95)}
tr_len, va_len = char_lengths(train_ds), char_lengths(val_ds)
report = {
"train_paragraphs": len(train_ds),
"val_paragraphs": len(val_ds),
"char_len_train": {"mean": float(tr_len.mean()), **pct(tr_len)},
"char_len_val": {"mean": float(va_len.mean()), **pct(va_len)},
}
if "stratum" in train_ds.column_names and "stratum" in val_ds.column_names:
def props(ds):
c = Counter(str(s) for s in ds["stratum"])
tot = sum(c.values()) or 1
return {k: v / tot for k, v in c.items()}
tr_p, va_p = props(train_ds), props(val_ds)
keys = sorted(set(tr_p) | set(va_p))
# total variation distance between the two stratum distributions
tvd = 0.5 * sum(abs(tr_p.get(k, 0.0) - va_p.get(k, 0.0)) for k in keys)
report["stratum_proportions_train"] = {k: round(tr_p.get(k, 0.0), 4) for k in keys}
report["stratum_proportions_val"] = {k: round(va_p.get(k, 0.0), 4) for k in keys}
report["stratum_total_variation_distance"] = round(tvd, 4)
missing = [k for k in tr_p if k not in va_p]
if missing:
report["strata_absent_from_val"] = missing
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "eval_representativeness.json"), "w") as fh:
json.dump(report, fh, indent=2, ensure_ascii=False)
logger.info("Eval representativeness:\n%s", json.dumps(report, indent=2, ensure_ascii=False))
if report.get("strata_absent_from_val"):
logger.warning("Strata present in train but MISSING from validation: %s",
report["strata_absent_from_val"])
if report.get("stratum_total_variation_distance", 0.0) > 0.15:
logger.warning("Validation stratum distribution differs notably from train "
"(TVD=%.3f); consider --stratify-column or a larger eval cap.",
report["stratum_total_variation_distance"])
return report
def dedup_rows(ds):
"""Eval honesty: collapse exact (document_id, order, text) duplicates that
upsampling introduced, so every held-out paragraph is counted once."""
seen = set()
keep_idx = []
for i, (d, o, t) in enumerate(zip(ds["document_id"], ds["order"], ds["text"])):
key = hashlib.blake2b(f"{d}\x1f{o}\x1f{t}".encode("utf-8"), digest_size=16).digest()
if key not in seen:
seen.add(key)
keep_idx.append(i)
return ds.select(keep_idx)
# --------------------------------------------------------------------------- #
# Tokenization + DOC-SENTENCES packing
# --------------------------------------------------------------------------- #
def maybe_nfc(text: str, cfg: Config) -> str:
return unicodedata.normalize("NFC", text) if cfg.normalize_nfc else text
def diagnose_tokenizer(texts: List[str], tokenizer, cfg: Config, out_dir: str):
"""Pre-flight check for SentencePiece behaviour on historical text."""
unk = tokenizer.unk_token
unk_id = tokenizer.unk_token_id
rng = np.random.default_rng(cfg.split_seed)
if len(texts) > cfg.tokenizer_sample_size:
idx = rng.choice(len(texts), size=cfg.tokenizer_sample_size, replace=False)
texts = [texts[i] for i in idx]
total_tokens = 0
total_unk = 0
total_chars = 0
total_words = 0
tok_lengths: List[int] = []
unk_char_counter: Counter = Counter()
pathological: List[dict] = []
for t in texts:
t = maybe_nfc(t, cfg)
toks = tokenizer.tokenize(t)
n = len(toks)
n_unk = sum(1 for x in toks if x == unk)
total_tokens += n
total_unk += n_unk
total_chars += len(t)
total_words += max(1, len(t.split()))
tok_lengths.append(n)
if len(t) >= 20 and n / max(1, len(t)) > 1.5:
if len(pathological) < 25:
pathological.append({"chars": len(t), "tokens": n, "preview": t[:80]})
uniq_chars = set()
for t in texts[:5000]:
uniq_chars.update(maybe_nfc(t, cfg))
risky_chars = {}
for ch in uniq_chars:
if ch.isspace():
continue
ids = tokenizer(ch, add_special_tokens=False)["input_ids"]
if unk_id is not None and unk_id in ids:
name = unicodedata.name(ch, "UNNAMED")
risky_chars[ch] = {"codepoint": f"U+{ord(ch):04X}", "name": name}
tok_lengths_arr = np.asarray(tok_lengths) if tok_lengths else np.array([0])
report = {
"sampled_paragraphs": len(texts),
"normalize_nfc": cfg.normalize_nfc,
"unk_token": unk,
"unk_rate": round(total_unk / max(1, total_tokens), 6),
"total_unk_tokens": total_unk,
"fertility_tokens_per_word": round(total_tokens / max(1, total_words), 3),
"fertility_tokens_per_char": round(total_tokens / max(1, total_chars), 3),
"tokens_per_paragraph": {
"mean": float(tok_lengths_arr.mean()),
"p50": float(np.percentile(tok_lengths_arr, 50)),
"p95": float(np.percentile(tok_lengths_arr, 95)),
"max": int(tok_lengths_arr.max()),
"share_over_max_seq_len": round(
float((tok_lengths_arr > cfg.max_seq_length).mean()), 4),
},
"num_risky_unk_characters": len(risky_chars),
"risky_unk_characters": dict(sorted(risky_chars.items())[:100]),
"pathological_examples": pathological,
}
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "tokenizer_diagnostics.json"), "w") as fh:
json.dump(report, fh, indent=2, ensure_ascii=False)
logger.info("Tokenizer diagnostics: unk_rate=%.4f%% fertility=%.2f tok/word "
"risky_chars=%d", report["unk_rate"] * 100,
report["fertility_tokens_per_word"], report["num_risky_unk_characters"])
if report["unk_rate"] > 0.005:
logger.warning("High <unk> rate (%.3f%%). Inspect risky_unk_characters in "
"tokenizer_diagnostics.json; consider --normalize-nfc or a "
"transliteration/cleanup pass for medieval glyphs.",
report["unk_rate"] * 100)
if risky_chars:
sample = ", ".join(f"{c} ({m['codepoint']})" for c, m in list(risky_chars.items())[:15])
logger.warning("Characters mapping to <unk> (sample): %s", sample)
return report
def tokenize_paragraphs(ds, tokenizer, cfg: Config):
def _tok(batch):
texts = [maybe_nfc(t, cfg) for t in batch["text"]] if cfg.normalize_nfc else batch["text"]
enc = tokenizer(texts, add_special_tokens=False,
truncation=False, return_attention_mask=False)
return {"input_ids": enc["input_ids"]}
return ds.map(
_tok,
batched=True,
num_proc=cfg.preprocess_num_proc,
remove_columns=["text"],
desc="tokenize paragraphs",
)
def pack_doc_sentences(ds, tokenizer, cfg: Config, desc: str) -> Dataset:
bos = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
eos = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
if bos is None or eos is None:
raise ValueError("Tokenizer has no bos/cls or eos/sep token; cannot pack.")
max_content = cfg.max_seq_length - 2
sort_keys = ["document_id", "order"] if "order" in ds.column_names else ["document_id"]
ds = ds.sort(sort_keys)
def generator():
buffer: List[int] = []
cur_doc = None
cur_src = None
def flush_full():
nonlocal buffer
while len(buffer) >= max_content:
chunk = buffer[:max_content]
buffer = buffer[max_content:]
yield {"input_ids": [bos] + chunk + [eos],
"document_id": cur_doc, "source": cur_src}
for batch in ds.iter(batch_size=2000):
ids_col = batch["input_ids"]
doc_col = batch["document_id"]
src_col = batch["source"]
for ids, doc, src in zip(ids_col, doc_col, src_col):
if doc != cur_doc:
if buffer and len(buffer) >= cfg.min_chunk_tokens:
yield {"input_ids": [bos] + buffer + [eos],
"document_id": cur_doc, "source": cur_src}
buffer = []
cur_doc, cur_src = doc, src
buffer.extend(ids)
yield from flush_full()
if buffer and len(buffer) >= cfg.min_chunk_tokens:
yield {"input_ids": [bos] + buffer + [eos],
"document_id": cur_doc, "source": cur_src}
packed = Dataset.from_generator(generator, cache_dir=cfg.cache_dir)
logger.info(" packed %s -> %d sequences (max_len=%d)", desc, len(packed), cfg.max_seq_length)
return packed
def diagnose_pretokenized(ds, tokenizer, cfg: Config, out_dir: str):
unk_id = tokenizer.unk_token_id
bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id
eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
n = len(ds)
m = min(cfg.tokenizer_sample_size, n)
idx = np.random.default_rng(cfg.split_seed).choice(n, size=m, replace=False)
sample = ds.select(idx)
total_tokens = 0
total_unk = 0
lengths: List[int] = []
bos_ok = 0
eos_ok = 0
over_len = 0
for ids in sample["input_ids"]:
L = len(ids)
lengths.append(L)
total_tokens += L
if unk_id is not None:
total_unk += ids.count(unk_id)
if L and ids[0] == bos_id:
bos_ok += 1
if L and ids[-1] == eos_id:
eos_ok += 1
if L > cfg.max_seq_length:
over_len += 1
arr = np.asarray(lengths) if lengths else np.array([0])
report = {
"mode": "pretokenized",
"sampled_sequences": int(m),
"total_sequences": int(n),
"unk_rate": round(total_unk / max(1, total_tokens), 6),
"total_unk_tokens": int(total_unk),
"seq_len": {"mean": float(arr.mean()), "p50": float(np.percentile(arr, 50)),
"p95": float(np.percentile(arr, 95)), "max": int(arr.max())},
"share_over_max_seq_len": round(float(over_len) / max(1, m), 4),
"starts_with_bos_rate": round(bos_ok / max(1, m), 4),
"ends_with_eos_rate": round(eos_ok / max(1, m), 4),
}
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "tokenizer_diagnostics.json"), "w") as fh:
json.dump(report, fh, indent=2, ensure_ascii=False)
logger.info("Pretokenized diagnostics: unk_rate=%.4f%% mean_len=%.1f "
"bos=%.0f%% eos=%.0f%%", report["unk_rate"] * 100, report["seq_len"]["mean"],
report["starts_with_bos_rate"] * 100, report["ends_with_eos_rate"] * 100)
if report["unk_rate"] > 0.005:
logger.warning("High <unk> rate (%.3f%%) in the pre-tokenized data — verify it was "
"tokenized with the SAME tokenizer (%s).", report["unk_rate"] * 100, cfg.model_name)
if report["starts_with_bos_rate"] < 0.5 or report["ends_with_eos_rate"] < 0.5:
logger.warning("Many sequences lack bos/eos boundaries; if your packing omitted special "
"tokens, MLM quality may suffer.")
if report["share_over_max_seq_len"] > 0:
logger.warning("%.2f%% of sequences exceed max_seq_length=%d and will be truncated.",
report["share_over_max_seq_len"] * 100, cfg.max_seq_length)
return report
def _load_dataset_any(name: str, split: str, token: bool):
if os.path.isdir(name):
from datasets import load_from_disk
d = load_from_disk(name)
if hasattr(d, "keys") and not hasattr(d, "column_names"): # DatasetDict
d = d[split]
return d
return load_dataset(name, split=split, token=token)
def _ids_hash(ids) -> bytes:
return hashlib.blake2b(np.asarray(ids, dtype=np.int32).tobytes(), digest_size=12).digest()
def build_pretokenized_bundle(cfg: Config, tokenizer, selected, cache_root: str):
max_len = cfg.max_seq_length
DOCID_CANDS = ("doc_id", "document_id", "docid")
parts = []
for repo, tag in selected:
ds = _load_dataset_any(repo, cfg.dataset_split, True)
docid_col = next((c for c in DOCID_CANDS if c in ds.column_names), None)
# ----- WISSENSCHAFTLICHE KORREKTUR: Metadaten beibehalten -----
keep = ["input_ids"]
if "attention_mask" in ds.column_names:
keep.append("attention_mask")
if docid_col:
keep.append(docid_col)
strat_col = cfg.stratify_column
if strat_col and strat_col in ds.column_names:
keep.append(strat_col)
ds = ds.remove_columns([c for c in ds.column_names if c not in keep])
if docid_col and docid_col != "document_id":
ds = ds.rename_column(docid_col, "document_id")
if strat_col and strat_col in ds.column_names and strat_col != "stratum":
ds = ds.rename_column(strat_col, "stratum")
# --------------------------------------------------------------
ds = ds.add_column("source", [tag] * len(ds))
parts.append(ds)
logger.info(" loaded pretokenized %s -> %d sequences (doc_id=%s)",
tag, len(ds), docid_col)
full = concatenate_datasets(parts) if len(parts) > 1 else parts[0]
has_doc = "document_id" in full.column_names
if has_doc:
full = full.map(
lambda b: {"document_id": [f"{s}:{d}" for s, d in zip(b["source"], b["document_id"])]},
batched=True, desc="namespace doc ids")
has_mask = "attention_mask" in full.column_names
def _fix(batch):
out = []
masks = batch["attention_mask"] if has_mask else [None] * len(batch["input_ids"])
for ids, am in zip(batch["input_ids"], masks):
if am is not None:
ids = [t for t, a in zip(ids, am) if a == 1]
if len(ids) > max_len:
ids = ids[:max_len]
out.append(ids)
return {"input_ids": out}
full = full.map(_fix, batched=True, num_proc=cfg.preprocess_num_proc,
remove_columns=(["attention_mask"] if has_mask else []),
desc="normalize pretokenized")
out_dir = cfg.output_dir
if cfg.diagnose_tokenizer:
diagnose_pretokenized(full, tokenizer, cfg, out_dir)
rng = np.random.default_rng(cfg.split_seed)
train_idx: List[int] = []
val_idx: List[int] = []
test_idx: List[int] = []
if has_doc:
docs = full["document_id"]
unique_docs = sorted(set(docs))
n_val = min(int(round(len(unique_docs) * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source)
n_test = min(int(round(len(unique_docs) * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source)
n_pick = n_val + n_test
# ----- WISSENSCHAFTLICHE KORREKTUR: Stratified Split -----
if "stratum" in full.column_names:
logger.info(f"Applying STRATIFIED document split based on column '{cfg.stratify_column}'.")
strata = full["stratum"]
doc_stratum: Dict[str, str] = {}
for d, s in zip(docs, strata):
doc_stratum.setdefault(d, str(s))
by_stratum: Dict[str, List[str]] = {}
for d, s in doc_stratum.items():
by_stratum.setdefault(s, []).append(d)
picked: List[str] = []
for s, d_list in sorted(by_stratum.items()):
d_list = sorted(d_list)
rng.shuffle(d_list)
k = max(1, int(round(n_pick * len(d_list) / len(doc_stratum))))
picked.extend(d_list[:min(k, len(d_list))])
rng.shuffle(picked)
val_docs = set(picked[:n_val])
test_docs = set(picked[n_val:n_val + n_test])
split_kind = "document-level STRATIFIED group-aware"
else:
rng.shuffle(unique_docs)
val_docs = set(unique_docs[:n_val])
test_docs = set(unique_docs[n_val:n_val + n_test])
split_kind = "document-level group-aware"
# ---------------------------------------------------------
val_cand, test_cand = [], []
for i, d in enumerate(docs):
if d in val_docs:
val_cand.append((i, d))
elif d in test_docs:
test_cand.append((i, d))
else:
train_idx.append(i)
for cand, dst in ((val_cand, val_idx), (test_cand, test_idx)):
seen = set()
for i, d in cand:
key = (d, _ids_hash(full[i]["input_ids"]))
if key not in seen:
seen.add(key)
dst.append(i)
else:
logger.warning("No doc_id/document_id column found in the pre-tokenized data; "
"falling back to a weaker sequence-level holdout.")
hashes = [_ids_hash(ids) for ids in full["input_ids"]]
uniq = sorted(set(hashes))
rng.shuffle(uniq)
n_val = min(int(round(len(uniq) * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source * 4)
n_test = min(int(round(len(uniq) * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source * 4)
val_h = set(uniq[:n_val])
test_h = set(uniq[n_val:n_val + n_test])
seen_v, seen_t = set(), set()
for i, h in enumerate(hashes):
if h in val_h:
if h not in seen_v:
seen_v.add(h); val_idx.append(i)
elif h in test_h:
if h not in seen_t:
seen_t.add(h); test_idx.append(i)
else:
train_idx.append(i)
split_kind = "sequence-level dedup"
bundle: Dict[str, Dataset] = {"train": full.select(train_idx)}
if val_idx:
bundle["validation"] = full.select(val_idx)
if len(set(bundle["validation"]["source"])) > 1:
for tag in ("med", "canon"):
sub = bundle["validation"].filter(lambda b: [s == tag for s in b["source"]], batched=True)
if len(sub):
bundle[f"validation_{tag}"] = sub
if test_idx:
bundle["test"] = full.select(test_idx)
logger.info("Pretokenized split (%s) -> train=%d val=%d test=%d",
split_kind, len(bundle["train"]), len(val_idx), len(test_idx))
def len_stats(d):
ls = np.asarray([len(x) for x in d["input_ids"]]) if len(d) else np.array([0])
return {"mean": float(ls.mean()), "p50": float(np.percentile(ls, 50)),
"p95": float(np.percentile(ls, 95))}
rep = {"split_kind": split_kind,
"train_sequences": len(bundle["train"]), "val_sequences": len(val_idx),
"len_train": len_stats(bundle["train"])}
if val_idx:
rep["len_val"] = len_stats(bundle["validation"])
with open(os.path.join(out_dir, "eval_representativeness.json"), "w") as fh:
json.dump(rep, fh, indent=2)
if cfg.diagnose_only:
logger.info("--diagnose-only set: wrote diagnostics to %s, exiting before training.", out_dir)
sys.exit(0)
os.makedirs(cache_root, exist_ok=True)
for name, d in bundle.items():
d.save_to_disk(os.path.join(cache_root, name))
logger.info("Saved pretokenized datasets to %s", cache_root)
return bundle
def build_or_load_packed(cfg: Config, tokenizer):
sig = json.dumps({
"med": cfg.med_dataset, "canon": cfg.canon_dataset, "split": cfg.dataset_split,
"model": cfg.model_name, "max_len": cfg.max_seq_length, "min_chunk": cfg.min_chunk_tokens,
"val_frac": cfg.val_doc_fraction, "test_frac": cfg.test_doc_fraction,
"max_eval": cfg.max_eval_docs_per_source, "split_seed": cfg.split_seed,
"text_col": cfg.text_column, "doc_col": cfg.doc_id_column, "order_col": cfg.order_column,
"stratify": cfg.stratify_column, "nfc": cfg.normalize_nfc,
"train_corpus": cfg.train_corpus, "dedup_train": cfg.dedup_train,
}, sort_keys=True)
key = hashlib.blake2b(sig.encode("utf-8"), digest_size=12).hexdigest()
cache_root = os.path.join(cfg.cache_dir, f"packed_{key}")
if os.path.isdir(cache_root) and not cfg.diagnose_only:
from datasets import load_from_disk
logger.info("Loading cached packed datasets from %s", cache_root)
bundle = {name: load_from_disk(os.path.join(cache_root, name))
for name in os.listdir(cache_root)
if os.path.isdir(os.path.join(cache_root, name))}
return bundle
if cfg.train_corpus not in {"combined", "med", "canon"}:
raise ValueError(f"--train-corpus must be combined|med|canon, got {cfg.train_corpus!r}")
selected = []
if cfg.train_corpus in {"combined", "med"}:
selected.append((cfg.med_dataset, "med"))
if cfg.train_corpus in {"combined", "canon"}:
selected.append((cfg.canon_dataset, "canon"))
probe = _load_dataset_any(selected[0][0], cfg.dataset_split, True)
if "input_ids" in probe.column_names:
logger.warning(
"Dataset %s is PRE-TOKENIZED (columns=%s). Skipping raw-text tokenization "
"and DOC-SENTENCES packing.", selected[0][0], probe.column_names)
del probe
return build_pretokenized_bundle(cfg, tokenizer, selected, cache_root)
del probe
sources = []
if cfg.train_corpus in {"combined", "med"}:
sources.append(load_source(cfg.med_dataset, cfg.dataset_split, "med", cfg))
if cfg.train_corpus in {"combined", "canon"}:
sources.append(load_source(cfg.canon_dataset, cfg.dataset_split, "canon", cfg))
logger.info("Training corpus = %s (%d source dataset(s))", cfg.train_corpus, len(sources))
train_parts, val_parts, test_parts = [], [], []
for src in sources:
tag = src["source"][0]
val_docs, test_docs = choose_heldout_docs(src, cfg)
in_val = src.filter(lambda b: [d in val_docs for d in b["document_id"]],
batched=True, desc=f"select val ({tag})")
in_test = src.filter(lambda b: [d in test_docs for d in b["document_id"]],
batched=True, desc=f"select test ({tag})")
held = val_docs | test_docs
in_train = src.filter(lambda b: [d not in held for d in b["document_id"]],
batched=True, desc=f"select train ({tag})")
train_parts.append(in_train)
if len(in_val):
val_parts.append(dedup_rows(in_val))
if len(in_test):
test_parts.append(dedup_rows(in_test))
train_raw = concatenate_datasets(train_parts)
if cfg.dedup_train:
before = len(train_raw)
train_raw = dedup_rows(train_raw)
logger.info("dedup_train: collapsed %d -> %d training paragraphs "
"(upsampling duplicates removed).", before, len(train_raw))
val_raw = concatenate_datasets(val_parts) if val_parts else None
logger.info("Raw paragraph counts -> train=%d val=%d test=%d",
len(train_raw),
sum(len(p) for p in val_parts),
sum(len(p) for p in test_parts))
out_dir = cfg.output_dir
if cfg.diagnose_tokenizer:
n = len(train_raw)
m = min(cfg.tokenizer_sample_size, n)
idx = np.random.default_rng(cfg.split_seed).choice(n, size=m, replace=False)
sample_texts = train_raw.select(idx)["text"]
diagnose_tokenizer(sample_texts, tokenizer, cfg, out_dir)
if val_raw is not None:
representativeness_report(train_raw, val_raw, cfg, out_dir)
if cfg.diagnose_only:
logger.info("--diagnose-only set: wrote diagnostics to %s, exiting before "
"packing/training.", out_dir)
sys.exit(0)
bundle: Dict[str, Dataset] = {}
bundle["train"] = pack_doc_sentences(tokenize_paragraphs(train_raw, tokenizer, cfg),
tokenizer, cfg, "train")
if val_raw is not None:
bundle["validation"] = pack_doc_sentences(tokenize_paragraphs(val_raw, tokenizer, cfg),
tokenizer, cfg, "validation")
if len(sources) > 1:
for tag in ("med", "canon"):
sub = val_raw.filter(lambda b: [s == tag for s in b["source"]], batched=True)
if len(sub):
bundle[f"validation_{tag}"] = pack_doc_sentences(
tokenize_paragraphs(sub, tokenizer, cfg), tokenizer, cfg, f"validation_{tag}")
if test_parts:
test_raw = concatenate_datasets(test_parts)
bundle["test"] = pack_doc_sentences(tokenize_paragraphs(test_raw, tokenizer, cfg),
tokenizer, cfg, "test")
os.makedirs(cache_root, exist_ok=True)
for name, d in bundle.items():
d.save_to_disk(os.path.join(cache_root, name))
logger.info("Saved packed datasets to %s", cache_root)
return bundle
# --------------------------------------------------------------------------- #
# Metrics
# --------------------------------------------------------------------------- #
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
logits = logits[0]
return logits.argmax(dim=-1)
def compute_metrics(eval_pred):
preds, labels = eval_pred
labels = labels.reshape(-1)
preds = preds.reshape(-1)
mask = labels != -100
if mask.sum() == 0:
return {"masked_accuracy": 0.0}
correct = (preds[mask] == labels[mask]).sum()
return {"masked_accuracy": float(correct) / float(mask.sum())}
class PerplexityCallback(TrainerCallback):
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if not metrics:
return
for k in list(metrics.keys()):
if k.endswith("loss") and ("eval" in k):
try:
metrics[k.replace("loss", "perplexity")] = math.exp(min(metrics[k], 20))
except OverflowError:
metrics[k.replace("loss", "perplexity")] = float("inf")
class JsonlLoggingCallback(TrainerCallback):
def __init__(self, path: str):
self.path = path
os.makedirs(os.path.dirname(path), exist_ok=True)
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs:
return
record = dict(logs)
record["step"] = state.global_step
record["epoch"] = state.epoch
record["wall_time"] = datetime.now().isoformat(timespec="seconds")
with open(self.path, "a", encoding="utf-8") as fh:
fh.write(json.dumps(record, ensure_ascii=False) + "\n")
fh.flush()
os.fsync(fh.fileno())
# --------------------------------------------------------------------------- #
# Environment / run metadata
# --------------------------------------------------------------------------- #
def capture_environment(cfg: Config) -> dict:
import transformers
import datasets as ds_lib
info = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"python": platform.python_version(),
"platform": platform.platform(),
"torch": torch.__version__,
"transformers": transformers.__version__,
"datasets": ds_lib.__version__,
"cuda_available": torch.cuda.is_available(),
}
if torch.cuda.is_available():
info["cuda"] = torch.version.cuda
info["gpu_name"] = torch.cuda.get_device_name(0)
props = torch.cuda.get_device_properties(0)
info["gpu_total_memory_gb"] = round(props.total_memory / 1024**3, 1)
cap = f"{props.major}.{props.minor}"
info["gpu_capability"] = cap
info["bf16_supported"] = torch.cuda.is_bf16_supported()
try:
arch_list = torch.cuda.get_arch_list()
except Exception:
arch_list = []
info["torch_arch_list"] = arch_list
sm_tag = f"sm_{props.major}{props.minor}"
info["gpu_arch_supported_by_torch"] = any(sm_tag == a for a in arch_list)
return info
# --------------------------------------------------------------------------- #
# Argparse
# --------------------------------------------------------------------------- #
def parse_args() -> Config:
cfg = Config()
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
for f in cfg.__dataclass_fields__.values():
name = "--" + f.name.replace("_", "-")
default = getattr(cfg, f.name)
if isinstance(default, bool):
if default:
p.add_argument("--no-" + f.name.replace("_", "-"), dest=f.name, action="store_false")
else:
p.add_argument(name, dest=f.name, action="store_true")
elif default is None:
p.add_argument(name, dest=f.name, default=None, type=str)
else:
p.add_argument(name, dest=f.name, default=default, type=type(default))
args = p.parse_args()
return Config(**vars(args))
# --------------------------------------------------------------------------- #
# Main
# --------------------------------------------------------------------------- #
def main():
cfg = parse_args()
set_seed(cfg.seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
os.makedirs(cfg.output_dir, exist_ok=True)
file_handler = logging.FileHandler(os.path.join(cfg.output_dir, "train.log"))
file_handler.setFormatter(logging.Formatter(
"%(asctime)s | %(levelname)-7s | %(name)s | %(message)s", "%Y-%m-%d %H:%M:%S"))
logging.getLogger().addHandler(file_handler)
env = capture_environment(cfg)
logger.info("Environment:\n%s", json.dumps(env, indent=2))
with open(os.path.join(cfg.output_dir, "run_metadata.json"), "w") as fh:
json.dump({"config": cfg.__dict__, "environment": env}, fh, indent=2)
if not env["cuda_available"]:
logger.warning("CUDA not available — this script is intended to run on a GPU server.")
else:
if not env.get("bf16_supported", False):
logger.warning("bf16 not reported as supported on this GPU; consider fp16.")
if env.get("torch_arch_list") and not env.get("gpu_arch_supported_by_torch", True):
logger.error(
"This torch build (%s, arch_list=%s) has NO kernels for your GPU "
"(capability %s). On Blackwell (RTX PRO 6000) install a cu128+ build, e.g.:\n"
" pip install --upgrade torch --index-url https://download.pytorch.org/whl/cu128\n"
"CUDA ops will fail until this matches.",
env["torch"], env.get("torch_arch_list"), env.get("gpu_capability"))
if cfg.save_steps % cfg.eval_steps != 0:
new_save = max(cfg.eval_steps, (cfg.save_steps // cfg.eval_steps) * cfg.eval_steps)
logger.warning("save_steps (%d) not a multiple of eval_steps (%d); adjusting to %d.",
cfg.save_steps, cfg.eval_steps, new_save)
cfg.save_steps = new_save
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
tokenizer.model_max_length = cfg.max_seq_length
bundle = build_or_load_packed(cfg, tokenizer)
train_ds = bundle["train"]
eval_ds = bundle.get("validation")
def tok_count(d):
lengths = d.map(lambda b: {"len": [len(x) for x in b["input_ids"]]},
batched=True, remove_columns=d.column_names, desc="count tokens")["len"]
return int(sum(lengths)), float(np.mean(lengths))
stats = {}
for name, d in bundle.items():
n_tok, mean_len = tok_count(d)
stats[name] = {"sequences": len(d), "tokens": n_tok, "mean_seq_len": round(mean_len, 1)}
logger.info("Packed dataset statistics:\n%s", json.dumps(stats, indent=2))
with open(os.path.join(cfg.output_dir, "dataset_stats.json"), "w") as fh:
json.dump(stats, fh, indent=2)
if cfg.dry_run:
logger.info("--dry-run set: datasets built and cached, exiting before training.")
return
train_ds = train_ds.remove_columns([c for c in train_ds.column_names if c != "input_ids"])
if eval_ds is not None:
eval_ds = eval_ds.remove_columns([c for c in eval_ds.column_names if c != "input_ids"])
model = AutoModelForMaskedLM.from_pretrained(
cfg.model_name,
attn_implementation=cfg.attn_implementation,
)
if cfg.gradient_checkpointing:
model.config.use_cache = False
logger.info("Model parameters: %.1fM", sum(p.numel() for p in model.parameters()) / 1e6)
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=cfg.mlm_probability,
pad_to_multiple_of=8,
)
report_to = ["tensorboard"]
if cfg.report_to_wandb:
report_to.append("wandb")
do_eval = eval_ds is not None
args = TrainingArguments(
output_dir=cfg.output_dir,
seed=cfg.seed,
data_seed=cfg.seed,
num_train_epochs=cfg.num_train_epochs,
max_steps=cfg.max_steps,
learning_rate=cfg.learning_rate,
weight_decay=cfg.weight_decay,
adam_beta1=cfg.adam_beta1,
adam_beta2=cfg.adam_beta2,
adam_epsilon=cfg.adam_epsilon,
max_grad_norm=cfg.max_grad_norm,
warmup_ratio=cfg.warmup_ratio,
lr_scheduler_type=cfg.lr_scheduler_type,
optim=cfg.optim,
bf16=bool(env.get("bf16_supported")),
per_device_train_batch_size=cfg.per_device_train_batch_size,
per_device_eval_batch_size=cfg.per_device_eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
gradient_checkpointing=cfg.gradient_checkpointing,
auto_find_batch_size=cfg.auto_find_batch_size,
torch_compile=cfg.torch_compile,
dataloader_num_workers=cfg.dataloader_num_workers,
dataloader_pin_memory=True,
eval_accumulation_steps=cfg.eval_accumulation_steps,
eval_strategy="steps" if do_eval else "no",
eval_steps=cfg.eval_steps if do_eval else None,
logging_strategy="steps",
logging_steps=cfg.logging_steps,
logging_first_step=True,
save_strategy="steps",
save_steps=cfg.save_steps,
save_total_limit=cfg.save_total_limit,
load_best_model_at_end=do_eval,
metric_for_best_model="eval_loss" if do_eval else None,
greater_is_better=False,
report_to=report_to,
run_name=os.path.basename(cfg.output_dir.rstrip("/")),
include_num_input_tokens_seen=True,
logging_dir=os.path.join(cfg.output_dir, "tb"),
)
callbacks = [PerplexityCallback(),
JsonlLoggingCallback(os.path.join(cfg.output_dir, "training_log.jsonl"))]
if do_eval and cfg.early_stopping_patience > 0:
callbacks.append(EarlyStoppingCallback(
early_stopping_patience=cfg.early_stopping_patience,
early_stopping_threshold=cfg.early_stopping_threshold))
tok_kwarg = {"processing_class": tokenizer} if _USE_PROCESSING_CLASS else {"tokenizer": tokenizer}
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=collator,
compute_metrics=compute_metrics if do_eval else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics if do_eval else None,
callbacks=callbacks,
**tok_kwarg,
)
resume_from = None
if cfg.resume:
last = get_last_checkpoint(cfg.output_dir)
if last:
logger.info("Resuming from checkpoint %s", last)
resume_from = last
try:
train_result = trainer.train(resume_from_checkpoint=resume_from)
except BaseException as exc:
tb_path = os.path.join(cfg.output_dir, "crash_traceback.txt")
with open(tb_path, "w", encoding="utf-8") as fh:
fh.write(f"Crashed at {datetime.now().isoformat()}\n\n")
traceback.print_exc(file=fh)
logger.error("Training crashed (%s: %s). Traceback written to %s",
type(exc).__name__, exc, tb_path)
try:
emergency = os.path.join(cfg.output_dir, "emergency_checkpoint")
trainer.save_model(emergency)
trainer.save_state()
logger.error("Emergency model state saved to %s", emergency)
except Exception as save_exc:
logger.error("Emergency save also failed: %s", save_exc)
raise
trainer.save_model()
tokenizer.save_pretrained(cfg.output_dir)
trainer.save_state()
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
final = {}
if eval_ds is not None:
m = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix="final_val")
m["final_val_perplexity"] = math.exp(min(m["final_val_loss"], 20))
final.update(m)
for tag in ("med", "canon"):
sub = bundle.get(f"validation_{tag}")
if sub is not None:
sub = sub.remove_columns([c for c in sub.column_names if c != "input_ids"])
m = trainer.evaluate(eval_dataset=sub, metric_key_prefix=f"final_val_{tag}")
m[f"final_val_{tag}_perplexity"] = math.exp(min(m[f"final_val_{tag}_loss"], 20))
final.update(m)
if "test" in bundle:
test_ds = bundle["test"].remove_columns(
[c for c in bundle["test"].column_names if c != "input_ids"])
m = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="final_test")
m["final_test_perplexity"] = math.exp(min(m["final_test_loss"], 20))
final.update(m)
if final:
logger.info("Final evaluation:\n%s", json.dumps(final, indent=2))
trainer.save_metrics("final_eval", final)
logger.info("Done. Best model + tokenizer saved to %s", cfg.output_dir)
if __name__ == "__main__":
main()