#!/usr/bin/env python # -*- coding: utf-8 -*- """ Domain-Adaptive Pretraining (DAPT) of xlm-roberta-large for Medieval Latin. Pipeline -------- 1. Load MedBerta (baseline) + CanonBerta (already upsampled) from the HF Hub. 2. Carve a *group-aware* held-out set: whole documents are held out (by ``document_id``) so no document leaks across train / validation / test. This matters for an honest perplexity and for downstream Loci-Similes / text-reuse evaluation. 3. DOC-SENTENCES packing: paragraphs are grouped by ``document_id`` (kept in document order), tokenized, and greedily packed into fixed 512-token sequences that never cross a document boundary. 4. MLM continued pretraining with the HF Trainer, bf16, SDPA attention, in-training evaluation (loss + perplexity + masked-token accuracy), best-model selection, and rich logging for later write-up. The script is built for a single NVIDIA RTX PRO 6000 Blackwell (96 GB) and is deliberately defensive about OOM (auto batch-size search, expandable CUDA segments, optional gradient checkpointing, eval logit reduction). Authentication -------------- The datasets are private. Log in once before running: huggingface-cli login # or: export HF_TOKEN=hf_... Quick check before spending GPU time (builds + caches datasets, prints stats, no training): python dapt_xlmr_pretrain.py --dry-run Typical run: python dapt_xlmr_pretrain.py \ --output-dir runs/dapt_xlmr_medlatin_v1 \ --num-train-epochs 3 \ --per-device-train-batch-size 32 \ --gradient-accumulation-steps 8 """ import argparse import hashlib import inspect import json import logging import math import os import platform import sys import traceback import unicodedata from collections import Counter from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List, Optional # Must be set BEFORE torch is imported to take effect: reduces fragmentation, # which is the most common cause of "phantom" OOM on long runs. os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") # avoid fork warnings import numpy as np import torch from datasets import Dataset, concatenate_datasets, load_dataset from transformers import ( AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling, EarlyStoppingCallback, Trainer, TrainerCallback, TrainingArguments, set_seed, ) from transformers.trainer_utils import get_last_checkpoint try: _USE_PROCESSING_CLASS = "processing_class" in inspect.signature(Trainer.__init__).parameters except Exception: # extremely defensive; fall back to legacy name _USE_PROCESSING_CLASS = False logging.basicConfig( format="%(asctime)s | %(levelname)-7s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger("dapt") # --------------------------------------------------------------------------- # # Config # --------------------------------------------------------------------------- # @dataclass class Config: # Data med_dataset: str = "mschonhardt/MedBerta" canon_dataset: str = "mschonhardt/CanonBerta" dataset_split: str = "train" text_column: Optional[str] = None # auto-detected if None doc_id_column: Optional[str] = None # auto-detected if None order_column: Optional[str] = None # optional within-doc ordering key # Model / tokenizer model_name: str = "xlm-roberta-large" max_seq_length: int = 512 mlm_probability: float = 0.15 attn_implementation: str = "sdpa" # robust on Blackwell; "flash_attention_2" optional min_chunk_tokens: int = 64 # drop doc-final fragments shorter than this # Held-out construction (group-aware, by document_id) val_doc_fraction: float = 0.01 test_doc_fraction: float = 0.0 # set >0 to also hold out a test set max_eval_docs_per_source: int = 400 # cap eval size to keep eval fast stratify_column: Optional[str] = "category" # e.g. "genre"/"subcorpus": stratified held-out split_seed: int = 13 # Which corpus to train ON. "combined" = ONE model on MedBerta + upsampled # CanonBerta (the methodology you described). "med"/"canon" = a single-corpus # model, e.g. for an ablation or two separate deliverables. Held-out eval is # still reported per-source where both are present. train_corpus: str = "combined" # {"combined", "med", "canon"} dedup_train: bool = False # collapse upsampling duplicates in TRAIN # (recommended for a standalone canon model) # Pre-flight diagnostics & text handling diagnose_tokenizer: bool = True # run UNK / fertility report before training tokenizer_sample_size: int = 20000 # paragraphs sampled for the diagnostic normalize_nfc: bool = False # apply Unicode NFC before tokenizing diagnose_only: bool = False # run diagnostics + reports, then exit # Optimisation (RoBERTa-style defaults, tuned for *continued* pretraining) learning_rate: float = 1e-4 # peak LR for the combined DAPT corpus weight_decay: float = 0.01 adam_beta1: float = 0.9 adam_beta2: float = 0.98 adam_epsilon: float = 1e-6 max_grad_norm: float = 1.0 warmup_ratio: float = 0.06 lr_scheduler_type: str = "linear" num_train_epochs: float = 10.0 # ceiling; early stopping ends training max_steps: int = -1 optim: str = "adamw_torch_fused" # fused AdamW: faster on Blackwell early_stopping_patience: int = 5 # eval rounds w/o improvement before stop early_stopping_threshold: float = 1e-4 # min eval_loss delta to count as improvement # Throughput / memory (96 GB GDDR7 -> large effective batch) per_device_train_batch_size: int = 64 per_device_eval_batch_size: int = 64 gradient_accumulation_steps: int = 16 # effective batch ~1024 sequences gradient_checkpointing: bool = False auto_find_batch_size: bool = True torch_compile: bool = False # enable for the final fixed-batch run dataloader_num_workers: int = 8 eval_accumulation_steps: int = 50 # offload eval tensors to CPU periodically # Schedule of eval / logging / checkpointing eval_steps: int = 500 logging_steps: int = 50 save_steps: int = 500 save_total_limit: int = 3 # Bookkeeping output_dir: str = "runs/dapt_xlmr_medlatin" cache_dir: str = ".cache_packed" seed: int = 42 report_to_wandb: bool = False preprocess_num_proc: int = 16 resume: bool = True dry_run: bool = False # --------------------------------------------------------------------------- # # Column auto-detection # --------------------------------------------------------------------------- # TEXT_CANDIDATES = ["text", "paragraph", "content", "sentence", "passage"] DOC_CANDIDATES = ["document_id", "doc_id", "docid", "document", "work_id", "work"] ORDER_CANDIDATES = ["order", "paragraph_id", "par_id", "index", "idx", "n", "position", "seq"] def _pick(columns: List[str], candidates: List[str], explicit: Optional[str], kind: str) -> Optional[str]: if explicit is not None: if explicit not in columns: raise ValueError(f"Requested {kind} column '{explicit}' not in dataset columns {columns}") return explicit for c in candidates: if c in columns: return c return None # --------------------------------------------------------------------------- # # Dataset loading & group-aware split # --------------------------------------------------------------------------- # def load_source(repo: str, split: str, source_tag: str, cfg: Config): logger.info("Loading %s (split=%s) ...", repo, split) ds = load_dataset(repo, split=split, token=True) cols = ds.column_names text_col = _pick(cols, TEXT_CANDIDATES, cfg.text_column, "text") doc_col = _pick(cols, DOC_CANDIDATES, cfg.doc_id_column, "document_id") order_col = _pick(cols, ORDER_CANDIDATES, cfg.order_column, "order") # may be None if text_col is None or doc_col is None: raise ValueError( f"Could not resolve text/doc_id columns for {repo}. " f"Available columns: {cols}. " f"Pass --text-column / --doc-id-column explicitly." ) logger.info(" %s -> text='%s', doc_id='%s', order='%s', rows=%d", source_tag, text_col, doc_col, order_col, len(ds)) # Normalise to a small canonical schema and stamp the source. keep = {text_col: "text", doc_col: "document_id"} if order_col: keep[order_col] = "order" strat_col = cfg.stratify_column if strat_col is not None: if strat_col not in cols: raise ValueError(f"--stratify-column '{strat_col}' not in {repo} columns {cols}") keep[strat_col] = "stratum" ds = ds.rename_columns(keep) # The corpora already ship a `source` column ("med"/"canon"). We deliberately # drop it here and re-stamp an authoritative per-dataset tag below, so # provenance is guaranteed consistent regardless of the incoming values. canonical = {"text", "document_id", "order", "stratum"} drop = [c for c in ds.column_names if c not in canonical] if drop: ds = ds.remove_columns(drop) if "order" not in ds.column_names: # Fallback only (your schema always has `order`): stable within-doc order. ds = ds.add_column("order", list(range(len(ds)))) if "source" in ds.column_names: # safety: never collide with add_column ds = ds.remove_columns(["source"]) ds = ds.add_column("source", [source_tag] * len(ds)) # document_id is namespaced per source so identical ids in both corpora # never collide during splitting/packing. ds = ds.map(lambda b: {"document_id": [f"{source_tag}:{d}" for d in b["document_id"]]}, batched=True, desc=f"namespace doc ids ({source_tag})") return ds def choose_heldout_docs(ds, cfg: Config): """Pick whole documents to hold out, capped per source, deterministic. If a ``stratum`` column is present (via --stratify-column), documents are sampled proportionally per stratum so rare genres are not lost from the held-out set; otherwise sampling is uniform-random over document ids. """ rng = np.random.default_rng(cfg.split_seed) n_total_docs = len(set(ds["document_id"])) n_val = min(int(round(n_total_docs * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source) n_test = min(int(round(n_total_docs * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source) n_pick = n_val + n_test if "stratum" in ds.column_names: # Map each document to its (majority) stratum. doc_stratum: Dict[str, str] = {} for d, s in zip(ds["document_id"], ds["stratum"]): doc_stratum.setdefault(d, str(s)) by_stratum: Dict[str, List[str]] = {} for d, s in doc_stratum.items(): by_stratum.setdefault(s, []).append(d) picked: List[str] = [] for s, docs in sorted(by_stratum.items()): docs = sorted(docs) rng.shuffle(docs) # proportional allocation, at least 1 doc per non-empty stratum k = max(1, int(round(n_pick * len(docs) / len(doc_stratum)))) picked.extend(docs[:min(k, len(docs))]) rng.shuffle(picked) picked = picked[:n_pick] else: doc_ids = sorted(set(ds["document_id"])) rng.shuffle(doc_ids) picked = doc_ids[:n_pick] val_docs = set(picked[:n_val]) test_docs = set(picked[n_val:n_val + n_test]) return val_docs, test_docs def representativeness_report(train_ds, val_ds, cfg: Config, out_dir: str): """Compare the held-out validation set against train on cheap, honest signals (paragraph char-length distribution, and stratum proportions if available) so a skewed eval set is caught before training.""" def char_lengths(ds): out: List[int] = [] for batch in ds.iter(batch_size=10000): out.extend(len(t) for t in batch["text"]) return np.asarray(out, dtype=np.int64) if out else np.array([0], dtype=np.int64) def pct(a): if len(a) == 0: return {} return {f"p{p}": float(np.percentile(a, p)) for p in (5, 25, 50, 75, 95)} tr_len, va_len = char_lengths(train_ds), char_lengths(val_ds) report = { "train_paragraphs": len(train_ds), "val_paragraphs": len(val_ds), "char_len_train": {"mean": float(tr_len.mean()), **pct(tr_len)}, "char_len_val": {"mean": float(va_len.mean()), **pct(va_len)}, } if "stratum" in train_ds.column_names and "stratum" in val_ds.column_names: def props(ds): c = Counter(str(s) for s in ds["stratum"]) tot = sum(c.values()) or 1 return {k: v / tot for k, v in c.items()} tr_p, va_p = props(train_ds), props(val_ds) keys = sorted(set(tr_p) | set(va_p)) # total variation distance between the two stratum distributions tvd = 0.5 * sum(abs(tr_p.get(k, 0.0) - va_p.get(k, 0.0)) for k in keys) report["stratum_proportions_train"] = {k: round(tr_p.get(k, 0.0), 4) for k in keys} report["stratum_proportions_val"] = {k: round(va_p.get(k, 0.0), 4) for k in keys} report["stratum_total_variation_distance"] = round(tvd, 4) missing = [k for k in tr_p if k not in va_p] if missing: report["strata_absent_from_val"] = missing os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "eval_representativeness.json"), "w") as fh: json.dump(report, fh, indent=2, ensure_ascii=False) logger.info("Eval representativeness:\n%s", json.dumps(report, indent=2, ensure_ascii=False)) if report.get("strata_absent_from_val"): logger.warning("Strata present in train but MISSING from validation: %s", report["strata_absent_from_val"]) if report.get("stratum_total_variation_distance", 0.0) > 0.15: logger.warning("Validation stratum distribution differs notably from train " "(TVD=%.3f); consider --stratify-column or a larger eval cap.", report["stratum_total_variation_distance"]) return report def dedup_rows(ds): """Eval honesty: collapse exact (document_id, order, text) duplicates that upsampling introduced, so every held-out paragraph is counted once.""" seen = set() keep_idx = [] for i, (d, o, t) in enumerate(zip(ds["document_id"], ds["order"], ds["text"])): key = hashlib.blake2b(f"{d}\x1f{o}\x1f{t}".encode("utf-8"), digest_size=16).digest() if key not in seen: seen.add(key) keep_idx.append(i) return ds.select(keep_idx) # --------------------------------------------------------------------------- # # Tokenization + DOC-SENTENCES packing # --------------------------------------------------------------------------- # def maybe_nfc(text: str, cfg: Config) -> str: return unicodedata.normalize("NFC", text) if cfg.normalize_nfc else text def diagnose_tokenizer(texts: List[str], tokenizer, cfg: Config, out_dir: str): """Pre-flight check for SentencePiece behaviour on historical text.""" unk = tokenizer.unk_token unk_id = tokenizer.unk_token_id rng = np.random.default_rng(cfg.split_seed) if len(texts) > cfg.tokenizer_sample_size: idx = rng.choice(len(texts), size=cfg.tokenizer_sample_size, replace=False) texts = [texts[i] for i in idx] total_tokens = 0 total_unk = 0 total_chars = 0 total_words = 0 tok_lengths: List[int] = [] unk_char_counter: Counter = Counter() pathological: List[dict] = [] for t in texts: t = maybe_nfc(t, cfg) toks = tokenizer.tokenize(t) n = len(toks) n_unk = sum(1 for x in toks if x == unk) total_tokens += n total_unk += n_unk total_chars += len(t) total_words += max(1, len(t.split())) tok_lengths.append(n) if len(t) >= 20 and n / max(1, len(t)) > 1.5: if len(pathological) < 25: pathological.append({"chars": len(t), "tokens": n, "preview": t[:80]}) uniq_chars = set() for t in texts[:5000]: uniq_chars.update(maybe_nfc(t, cfg)) risky_chars = {} for ch in uniq_chars: if ch.isspace(): continue ids = tokenizer(ch, add_special_tokens=False)["input_ids"] if unk_id is not None and unk_id in ids: name = unicodedata.name(ch, "UNNAMED") risky_chars[ch] = {"codepoint": f"U+{ord(ch):04X}", "name": name} tok_lengths_arr = np.asarray(tok_lengths) if tok_lengths else np.array([0]) report = { "sampled_paragraphs": len(texts), "normalize_nfc": cfg.normalize_nfc, "unk_token": unk, "unk_rate": round(total_unk / max(1, total_tokens), 6), "total_unk_tokens": total_unk, "fertility_tokens_per_word": round(total_tokens / max(1, total_words), 3), "fertility_tokens_per_char": round(total_tokens / max(1, total_chars), 3), "tokens_per_paragraph": { "mean": float(tok_lengths_arr.mean()), "p50": float(np.percentile(tok_lengths_arr, 50)), "p95": float(np.percentile(tok_lengths_arr, 95)), "max": int(tok_lengths_arr.max()), "share_over_max_seq_len": round( float((tok_lengths_arr > cfg.max_seq_length).mean()), 4), }, "num_risky_unk_characters": len(risky_chars), "risky_unk_characters": dict(sorted(risky_chars.items())[:100]), "pathological_examples": pathological, } os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "tokenizer_diagnostics.json"), "w") as fh: json.dump(report, fh, indent=2, ensure_ascii=False) logger.info("Tokenizer diagnostics: unk_rate=%.4f%% fertility=%.2f tok/word " "risky_chars=%d", report["unk_rate"] * 100, report["fertility_tokens_per_word"], report["num_risky_unk_characters"]) if report["unk_rate"] > 0.005: logger.warning("High rate (%.3f%%). Inspect risky_unk_characters in " "tokenizer_diagnostics.json; consider --normalize-nfc or a " "transliteration/cleanup pass for medieval glyphs.", report["unk_rate"] * 100) if risky_chars: sample = ", ".join(f"{c} ({m['codepoint']})" for c, m in list(risky_chars.items())[:15]) logger.warning("Characters mapping to (sample): %s", sample) return report def tokenize_paragraphs(ds, tokenizer, cfg: Config): def _tok(batch): texts = [maybe_nfc(t, cfg) for t in batch["text"]] if cfg.normalize_nfc else batch["text"] enc = tokenizer(texts, add_special_tokens=False, truncation=False, return_attention_mask=False) return {"input_ids": enc["input_ids"]} return ds.map( _tok, batched=True, num_proc=cfg.preprocess_num_proc, remove_columns=["text"], desc="tokenize paragraphs", ) def pack_doc_sentences(ds, tokenizer, cfg: Config, desc: str) -> Dataset: bos = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id eos = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id if bos is None or eos is None: raise ValueError("Tokenizer has no bos/cls or eos/sep token; cannot pack.") max_content = cfg.max_seq_length - 2 sort_keys = ["document_id", "order"] if "order" in ds.column_names else ["document_id"] ds = ds.sort(sort_keys) def generator(): buffer: List[int] = [] cur_doc = None cur_src = None def flush_full(): nonlocal buffer while len(buffer) >= max_content: chunk = buffer[:max_content] buffer = buffer[max_content:] yield {"input_ids": [bos] + chunk + [eos], "document_id": cur_doc, "source": cur_src} for batch in ds.iter(batch_size=2000): ids_col = batch["input_ids"] doc_col = batch["document_id"] src_col = batch["source"] for ids, doc, src in zip(ids_col, doc_col, src_col): if doc != cur_doc: if buffer and len(buffer) >= cfg.min_chunk_tokens: yield {"input_ids": [bos] + buffer + [eos], "document_id": cur_doc, "source": cur_src} buffer = [] cur_doc, cur_src = doc, src buffer.extend(ids) yield from flush_full() if buffer and len(buffer) >= cfg.min_chunk_tokens: yield {"input_ids": [bos] + buffer + [eos], "document_id": cur_doc, "source": cur_src} packed = Dataset.from_generator(generator, cache_dir=cfg.cache_dir) logger.info(" packed %s -> %d sequences (max_len=%d)", desc, len(packed), cfg.max_seq_length) return packed def diagnose_pretokenized(ds, tokenizer, cfg: Config, out_dir: str): unk_id = tokenizer.unk_token_id bos_id = tokenizer.bos_token_id if tokenizer.bos_token_id is not None else tokenizer.cls_token_id eos_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id n = len(ds) m = min(cfg.tokenizer_sample_size, n) idx = np.random.default_rng(cfg.split_seed).choice(n, size=m, replace=False) sample = ds.select(idx) total_tokens = 0 total_unk = 0 lengths: List[int] = [] bos_ok = 0 eos_ok = 0 over_len = 0 for ids in sample["input_ids"]: L = len(ids) lengths.append(L) total_tokens += L if unk_id is not None: total_unk += ids.count(unk_id) if L and ids[0] == bos_id: bos_ok += 1 if L and ids[-1] == eos_id: eos_ok += 1 if L > cfg.max_seq_length: over_len += 1 arr = np.asarray(lengths) if lengths else np.array([0]) report = { "mode": "pretokenized", "sampled_sequences": int(m), "total_sequences": int(n), "unk_rate": round(total_unk / max(1, total_tokens), 6), "total_unk_tokens": int(total_unk), "seq_len": {"mean": float(arr.mean()), "p50": float(np.percentile(arr, 50)), "p95": float(np.percentile(arr, 95)), "max": int(arr.max())}, "share_over_max_seq_len": round(float(over_len) / max(1, m), 4), "starts_with_bos_rate": round(bos_ok / max(1, m), 4), "ends_with_eos_rate": round(eos_ok / max(1, m), 4), } os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "tokenizer_diagnostics.json"), "w") as fh: json.dump(report, fh, indent=2, ensure_ascii=False) logger.info("Pretokenized diagnostics: unk_rate=%.4f%% mean_len=%.1f " "bos=%.0f%% eos=%.0f%%", report["unk_rate"] * 100, report["seq_len"]["mean"], report["starts_with_bos_rate"] * 100, report["ends_with_eos_rate"] * 100) if report["unk_rate"] > 0.005: logger.warning("High rate (%.3f%%) in the pre-tokenized data — verify it was " "tokenized with the SAME tokenizer (%s).", report["unk_rate"] * 100, cfg.model_name) if report["starts_with_bos_rate"] < 0.5 or report["ends_with_eos_rate"] < 0.5: logger.warning("Many sequences lack bos/eos boundaries; if your packing omitted special " "tokens, MLM quality may suffer.") if report["share_over_max_seq_len"] > 0: logger.warning("%.2f%% of sequences exceed max_seq_length=%d and will be truncated.", report["share_over_max_seq_len"] * 100, cfg.max_seq_length) return report def _load_dataset_any(name: str, split: str, token: bool): if os.path.isdir(name): from datasets import load_from_disk d = load_from_disk(name) if hasattr(d, "keys") and not hasattr(d, "column_names"): # DatasetDict d = d[split] return d return load_dataset(name, split=split, token=token) def _ids_hash(ids) -> bytes: return hashlib.blake2b(np.asarray(ids, dtype=np.int32).tobytes(), digest_size=12).digest() def build_pretokenized_bundle(cfg: Config, tokenizer, selected, cache_root: str): max_len = cfg.max_seq_length DOCID_CANDS = ("doc_id", "document_id", "docid") parts = [] for repo, tag in selected: ds = _load_dataset_any(repo, cfg.dataset_split, True) docid_col = next((c for c in DOCID_CANDS if c in ds.column_names), None) # ----- WISSENSCHAFTLICHE KORREKTUR: Metadaten beibehalten ----- keep = ["input_ids"] if "attention_mask" in ds.column_names: keep.append("attention_mask") if docid_col: keep.append(docid_col) strat_col = cfg.stratify_column if strat_col and strat_col in ds.column_names: keep.append(strat_col) ds = ds.remove_columns([c for c in ds.column_names if c not in keep]) if docid_col and docid_col != "document_id": ds = ds.rename_column(docid_col, "document_id") if strat_col and strat_col in ds.column_names and strat_col != "stratum": ds = ds.rename_column(strat_col, "stratum") # -------------------------------------------------------------- ds = ds.add_column("source", [tag] * len(ds)) parts.append(ds) logger.info(" loaded pretokenized %s -> %d sequences (doc_id=%s)", tag, len(ds), docid_col) full = concatenate_datasets(parts) if len(parts) > 1 else parts[0] has_doc = "document_id" in full.column_names if has_doc: full = full.map( lambda b: {"document_id": [f"{s}:{d}" for s, d in zip(b["source"], b["document_id"])]}, batched=True, desc="namespace doc ids") has_mask = "attention_mask" in full.column_names def _fix(batch): out = [] masks = batch["attention_mask"] if has_mask else [None] * len(batch["input_ids"]) for ids, am in zip(batch["input_ids"], masks): if am is not None: ids = [t for t, a in zip(ids, am) if a == 1] if len(ids) > max_len: ids = ids[:max_len] out.append(ids) return {"input_ids": out} full = full.map(_fix, batched=True, num_proc=cfg.preprocess_num_proc, remove_columns=(["attention_mask"] if has_mask else []), desc="normalize pretokenized") out_dir = cfg.output_dir if cfg.diagnose_tokenizer: diagnose_pretokenized(full, tokenizer, cfg, out_dir) rng = np.random.default_rng(cfg.split_seed) train_idx: List[int] = [] val_idx: List[int] = [] test_idx: List[int] = [] if has_doc: docs = full["document_id"] unique_docs = sorted(set(docs)) n_val = min(int(round(len(unique_docs) * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source) n_test = min(int(round(len(unique_docs) * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source) n_pick = n_val + n_test # ----- WISSENSCHAFTLICHE KORREKTUR: Stratified Split ----- if "stratum" in full.column_names: logger.info(f"Applying STRATIFIED document split based on column '{cfg.stratify_column}'.") strata = full["stratum"] doc_stratum: Dict[str, str] = {} for d, s in zip(docs, strata): doc_stratum.setdefault(d, str(s)) by_stratum: Dict[str, List[str]] = {} for d, s in doc_stratum.items(): by_stratum.setdefault(s, []).append(d) picked: List[str] = [] for s, d_list in sorted(by_stratum.items()): d_list = sorted(d_list) rng.shuffle(d_list) k = max(1, int(round(n_pick * len(d_list) / len(doc_stratum)))) picked.extend(d_list[:min(k, len(d_list))]) rng.shuffle(picked) val_docs = set(picked[:n_val]) test_docs = set(picked[n_val:n_val + n_test]) split_kind = "document-level STRATIFIED group-aware" else: rng.shuffle(unique_docs) val_docs = set(unique_docs[:n_val]) test_docs = set(unique_docs[n_val:n_val + n_test]) split_kind = "document-level group-aware" # --------------------------------------------------------- val_cand, test_cand = [], [] for i, d in enumerate(docs): if d in val_docs: val_cand.append((i, d)) elif d in test_docs: test_cand.append((i, d)) else: train_idx.append(i) for cand, dst in ((val_cand, val_idx), (test_cand, test_idx)): seen = set() for i, d in cand: key = (d, _ids_hash(full[i]["input_ids"])) if key not in seen: seen.add(key) dst.append(i) else: logger.warning("No doc_id/document_id column found in the pre-tokenized data; " "falling back to a weaker sequence-level holdout.") hashes = [_ids_hash(ids) for ids in full["input_ids"]] uniq = sorted(set(hashes)) rng.shuffle(uniq) n_val = min(int(round(len(uniq) * cfg.val_doc_fraction)), cfg.max_eval_docs_per_source * 4) n_test = min(int(round(len(uniq) * cfg.test_doc_fraction)), cfg.max_eval_docs_per_source * 4) val_h = set(uniq[:n_val]) test_h = set(uniq[n_val:n_val + n_test]) seen_v, seen_t = set(), set() for i, h in enumerate(hashes): if h in val_h: if h not in seen_v: seen_v.add(h); val_idx.append(i) elif h in test_h: if h not in seen_t: seen_t.add(h); test_idx.append(i) else: train_idx.append(i) split_kind = "sequence-level dedup" bundle: Dict[str, Dataset] = {"train": full.select(train_idx)} if val_idx: bundle["validation"] = full.select(val_idx) if len(set(bundle["validation"]["source"])) > 1: for tag in ("med", "canon"): sub = bundle["validation"].filter(lambda b: [s == tag for s in b["source"]], batched=True) if len(sub): bundle[f"validation_{tag}"] = sub if test_idx: bundle["test"] = full.select(test_idx) logger.info("Pretokenized split (%s) -> train=%d val=%d test=%d", split_kind, len(bundle["train"]), len(val_idx), len(test_idx)) def len_stats(d): ls = np.asarray([len(x) for x in d["input_ids"]]) if len(d) else np.array([0]) return {"mean": float(ls.mean()), "p50": float(np.percentile(ls, 50)), "p95": float(np.percentile(ls, 95))} rep = {"split_kind": split_kind, "train_sequences": len(bundle["train"]), "val_sequences": len(val_idx), "len_train": len_stats(bundle["train"])} if val_idx: rep["len_val"] = len_stats(bundle["validation"]) with open(os.path.join(out_dir, "eval_representativeness.json"), "w") as fh: json.dump(rep, fh, indent=2) if cfg.diagnose_only: logger.info("--diagnose-only set: wrote diagnostics to %s, exiting before training.", out_dir) sys.exit(0) os.makedirs(cache_root, exist_ok=True) for name, d in bundle.items(): d.save_to_disk(os.path.join(cache_root, name)) logger.info("Saved pretokenized datasets to %s", cache_root) return bundle def build_or_load_packed(cfg: Config, tokenizer): sig = json.dumps({ "med": cfg.med_dataset, "canon": cfg.canon_dataset, "split": cfg.dataset_split, "model": cfg.model_name, "max_len": cfg.max_seq_length, "min_chunk": cfg.min_chunk_tokens, "val_frac": cfg.val_doc_fraction, "test_frac": cfg.test_doc_fraction, "max_eval": cfg.max_eval_docs_per_source, "split_seed": cfg.split_seed, "text_col": cfg.text_column, "doc_col": cfg.doc_id_column, "order_col": cfg.order_column, "stratify": cfg.stratify_column, "nfc": cfg.normalize_nfc, "train_corpus": cfg.train_corpus, "dedup_train": cfg.dedup_train, }, sort_keys=True) key = hashlib.blake2b(sig.encode("utf-8"), digest_size=12).hexdigest() cache_root = os.path.join(cfg.cache_dir, f"packed_{key}") if os.path.isdir(cache_root) and not cfg.diagnose_only: from datasets import load_from_disk logger.info("Loading cached packed datasets from %s", cache_root) bundle = {name: load_from_disk(os.path.join(cache_root, name)) for name in os.listdir(cache_root) if os.path.isdir(os.path.join(cache_root, name))} return bundle if cfg.train_corpus not in {"combined", "med", "canon"}: raise ValueError(f"--train-corpus must be combined|med|canon, got {cfg.train_corpus!r}") selected = [] if cfg.train_corpus in {"combined", "med"}: selected.append((cfg.med_dataset, "med")) if cfg.train_corpus in {"combined", "canon"}: selected.append((cfg.canon_dataset, "canon")) probe = _load_dataset_any(selected[0][0], cfg.dataset_split, True) if "input_ids" in probe.column_names: logger.warning( "Dataset %s is PRE-TOKENIZED (columns=%s). Skipping raw-text tokenization " "and DOC-SENTENCES packing.", selected[0][0], probe.column_names) del probe return build_pretokenized_bundle(cfg, tokenizer, selected, cache_root) del probe sources = [] if cfg.train_corpus in {"combined", "med"}: sources.append(load_source(cfg.med_dataset, cfg.dataset_split, "med", cfg)) if cfg.train_corpus in {"combined", "canon"}: sources.append(load_source(cfg.canon_dataset, cfg.dataset_split, "canon", cfg)) logger.info("Training corpus = %s (%d source dataset(s))", cfg.train_corpus, len(sources)) train_parts, val_parts, test_parts = [], [], [] for src in sources: tag = src["source"][0] val_docs, test_docs = choose_heldout_docs(src, cfg) in_val = src.filter(lambda b: [d in val_docs for d in b["document_id"]], batched=True, desc=f"select val ({tag})") in_test = src.filter(lambda b: [d in test_docs for d in b["document_id"]], batched=True, desc=f"select test ({tag})") held = val_docs | test_docs in_train = src.filter(lambda b: [d not in held for d in b["document_id"]], batched=True, desc=f"select train ({tag})") train_parts.append(in_train) if len(in_val): val_parts.append(dedup_rows(in_val)) if len(in_test): test_parts.append(dedup_rows(in_test)) train_raw = concatenate_datasets(train_parts) if cfg.dedup_train: before = len(train_raw) train_raw = dedup_rows(train_raw) logger.info("dedup_train: collapsed %d -> %d training paragraphs " "(upsampling duplicates removed).", before, len(train_raw)) val_raw = concatenate_datasets(val_parts) if val_parts else None logger.info("Raw paragraph counts -> train=%d val=%d test=%d", len(train_raw), sum(len(p) for p in val_parts), sum(len(p) for p in test_parts)) out_dir = cfg.output_dir if cfg.diagnose_tokenizer: n = len(train_raw) m = min(cfg.tokenizer_sample_size, n) idx = np.random.default_rng(cfg.split_seed).choice(n, size=m, replace=False) sample_texts = train_raw.select(idx)["text"] diagnose_tokenizer(sample_texts, tokenizer, cfg, out_dir) if val_raw is not None: representativeness_report(train_raw, val_raw, cfg, out_dir) if cfg.diagnose_only: logger.info("--diagnose-only set: wrote diagnostics to %s, exiting before " "packing/training.", out_dir) sys.exit(0) bundle: Dict[str, Dataset] = {} bundle["train"] = pack_doc_sentences(tokenize_paragraphs(train_raw, tokenizer, cfg), tokenizer, cfg, "train") if val_raw is not None: bundle["validation"] = pack_doc_sentences(tokenize_paragraphs(val_raw, tokenizer, cfg), tokenizer, cfg, "validation") if len(sources) > 1: for tag in ("med", "canon"): sub = val_raw.filter(lambda b: [s == tag for s in b["source"]], batched=True) if len(sub): bundle[f"validation_{tag}"] = pack_doc_sentences( tokenize_paragraphs(sub, tokenizer, cfg), tokenizer, cfg, f"validation_{tag}") if test_parts: test_raw = concatenate_datasets(test_parts) bundle["test"] = pack_doc_sentences(tokenize_paragraphs(test_raw, tokenizer, cfg), tokenizer, cfg, "test") os.makedirs(cache_root, exist_ok=True) for name, d in bundle.items(): d.save_to_disk(os.path.join(cache_root, name)) logger.info("Saved packed datasets to %s", cache_root) return bundle # --------------------------------------------------------------------------- # # Metrics # --------------------------------------------------------------------------- # def preprocess_logits_for_metrics(logits, labels): if isinstance(logits, tuple): logits = logits[0] return logits.argmax(dim=-1) def compute_metrics(eval_pred): preds, labels = eval_pred labels = labels.reshape(-1) preds = preds.reshape(-1) mask = labels != -100 if mask.sum() == 0: return {"masked_accuracy": 0.0} correct = (preds[mask] == labels[mask]).sum() return {"masked_accuracy": float(correct) / float(mask.sum())} class PerplexityCallback(TrainerCallback): def on_evaluate(self, args, state, control, metrics=None, **kwargs): if not metrics: return for k in list(metrics.keys()): if k.endswith("loss") and ("eval" in k): try: metrics[k.replace("loss", "perplexity")] = math.exp(min(metrics[k], 20)) except OverflowError: metrics[k.replace("loss", "perplexity")] = float("inf") class JsonlLoggingCallback(TrainerCallback): def __init__(self, path: str): self.path = path os.makedirs(os.path.dirname(path), exist_ok=True) def on_log(self, args, state, control, logs=None, **kwargs): if not logs: return record = dict(logs) record["step"] = state.global_step record["epoch"] = state.epoch record["wall_time"] = datetime.now().isoformat(timespec="seconds") with open(self.path, "a", encoding="utf-8") as fh: fh.write(json.dumps(record, ensure_ascii=False) + "\n") fh.flush() os.fsync(fh.fileno()) # --------------------------------------------------------------------------- # # Environment / run metadata # --------------------------------------------------------------------------- # def capture_environment(cfg: Config) -> dict: import transformers import datasets as ds_lib info = { "timestamp": datetime.now().isoformat(timespec="seconds"), "python": platform.python_version(), "platform": platform.platform(), "torch": torch.__version__, "transformers": transformers.__version__, "datasets": ds_lib.__version__, "cuda_available": torch.cuda.is_available(), } if torch.cuda.is_available(): info["cuda"] = torch.version.cuda info["gpu_name"] = torch.cuda.get_device_name(0) props = torch.cuda.get_device_properties(0) info["gpu_total_memory_gb"] = round(props.total_memory / 1024**3, 1) cap = f"{props.major}.{props.minor}" info["gpu_capability"] = cap info["bf16_supported"] = torch.cuda.is_bf16_supported() try: arch_list = torch.cuda.get_arch_list() except Exception: arch_list = [] info["torch_arch_list"] = arch_list sm_tag = f"sm_{props.major}{props.minor}" info["gpu_arch_supported_by_torch"] = any(sm_tag == a for a in arch_list) return info # --------------------------------------------------------------------------- # # Argparse # --------------------------------------------------------------------------- # def parse_args() -> Config: cfg = Config() p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) for f in cfg.__dataclass_fields__.values(): name = "--" + f.name.replace("_", "-") default = getattr(cfg, f.name) if isinstance(default, bool): if default: p.add_argument("--no-" + f.name.replace("_", "-"), dest=f.name, action="store_false") else: p.add_argument(name, dest=f.name, action="store_true") elif default is None: p.add_argument(name, dest=f.name, default=None, type=str) else: p.add_argument(name, dest=f.name, default=default, type=type(default)) args = p.parse_args() return Config(**vars(args)) # --------------------------------------------------------------------------- # # Main # --------------------------------------------------------------------------- # def main(): cfg = parse_args() set_seed(cfg.seed) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.set_float32_matmul_precision("high") os.makedirs(cfg.output_dir, exist_ok=True) file_handler = logging.FileHandler(os.path.join(cfg.output_dir, "train.log")) file_handler.setFormatter(logging.Formatter( "%(asctime)s | %(levelname)-7s | %(name)s | %(message)s", "%Y-%m-%d %H:%M:%S")) logging.getLogger().addHandler(file_handler) env = capture_environment(cfg) logger.info("Environment:\n%s", json.dumps(env, indent=2)) with open(os.path.join(cfg.output_dir, "run_metadata.json"), "w") as fh: json.dump({"config": cfg.__dict__, "environment": env}, fh, indent=2) if not env["cuda_available"]: logger.warning("CUDA not available — this script is intended to run on a GPU server.") else: if not env.get("bf16_supported", False): logger.warning("bf16 not reported as supported on this GPU; consider fp16.") if env.get("torch_arch_list") and not env.get("gpu_arch_supported_by_torch", True): logger.error( "This torch build (%s, arch_list=%s) has NO kernels for your GPU " "(capability %s). On Blackwell (RTX PRO 6000) install a cu128+ build, e.g.:\n" " pip install --upgrade torch --index-url https://download.pytorch.org/whl/cu128\n" "CUDA ops will fail until this matches.", env["torch"], env.get("torch_arch_list"), env.get("gpu_capability")) if cfg.save_steps % cfg.eval_steps != 0: new_save = max(cfg.eval_steps, (cfg.save_steps // cfg.eval_steps) * cfg.eval_steps) logger.warning("save_steps (%d) not a multiple of eval_steps (%d); adjusting to %d.", cfg.save_steps, cfg.eval_steps, new_save) cfg.save_steps = new_save tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True) tokenizer.model_max_length = cfg.max_seq_length bundle = build_or_load_packed(cfg, tokenizer) train_ds = bundle["train"] eval_ds = bundle.get("validation") def tok_count(d): lengths = d.map(lambda b: {"len": [len(x) for x in b["input_ids"]]}, batched=True, remove_columns=d.column_names, desc="count tokens")["len"] return int(sum(lengths)), float(np.mean(lengths)) stats = {} for name, d in bundle.items(): n_tok, mean_len = tok_count(d) stats[name] = {"sequences": len(d), "tokens": n_tok, "mean_seq_len": round(mean_len, 1)} logger.info("Packed dataset statistics:\n%s", json.dumps(stats, indent=2)) with open(os.path.join(cfg.output_dir, "dataset_stats.json"), "w") as fh: json.dump(stats, fh, indent=2) if cfg.dry_run: logger.info("--dry-run set: datasets built and cached, exiting before training.") return train_ds = train_ds.remove_columns([c for c in train_ds.column_names if c != "input_ids"]) if eval_ds is not None: eval_ds = eval_ds.remove_columns([c for c in eval_ds.column_names if c != "input_ids"]) model = AutoModelForMaskedLM.from_pretrained( cfg.model_name, attn_implementation=cfg.attn_implementation, ) if cfg.gradient_checkpointing: model.config.use_cache = False logger.info("Model parameters: %.1fM", sum(p.numel() for p in model.parameters()) / 1e6) collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=cfg.mlm_probability, pad_to_multiple_of=8, ) report_to = ["tensorboard"] if cfg.report_to_wandb: report_to.append("wandb") do_eval = eval_ds is not None args = TrainingArguments( output_dir=cfg.output_dir, seed=cfg.seed, data_seed=cfg.seed, num_train_epochs=cfg.num_train_epochs, max_steps=cfg.max_steps, learning_rate=cfg.learning_rate, weight_decay=cfg.weight_decay, adam_beta1=cfg.adam_beta1, adam_beta2=cfg.adam_beta2, adam_epsilon=cfg.adam_epsilon, max_grad_norm=cfg.max_grad_norm, warmup_ratio=cfg.warmup_ratio, lr_scheduler_type=cfg.lr_scheduler_type, optim=cfg.optim, bf16=bool(env.get("bf16_supported")), per_device_train_batch_size=cfg.per_device_train_batch_size, per_device_eval_batch_size=cfg.per_device_eval_batch_size, gradient_accumulation_steps=cfg.gradient_accumulation_steps, gradient_checkpointing=cfg.gradient_checkpointing, auto_find_batch_size=cfg.auto_find_batch_size, torch_compile=cfg.torch_compile, dataloader_num_workers=cfg.dataloader_num_workers, dataloader_pin_memory=True, eval_accumulation_steps=cfg.eval_accumulation_steps, eval_strategy="steps" if do_eval else "no", eval_steps=cfg.eval_steps if do_eval else None, logging_strategy="steps", logging_steps=cfg.logging_steps, logging_first_step=True, save_strategy="steps", save_steps=cfg.save_steps, save_total_limit=cfg.save_total_limit, load_best_model_at_end=do_eval, metric_for_best_model="eval_loss" if do_eval else None, greater_is_better=False, report_to=report_to, run_name=os.path.basename(cfg.output_dir.rstrip("/")), include_num_input_tokens_seen=True, logging_dir=os.path.join(cfg.output_dir, "tb"), ) callbacks = [PerplexityCallback(), JsonlLoggingCallback(os.path.join(cfg.output_dir, "training_log.jsonl"))] if do_eval and cfg.early_stopping_patience > 0: callbacks.append(EarlyStoppingCallback( early_stopping_patience=cfg.early_stopping_patience, early_stopping_threshold=cfg.early_stopping_threshold)) tok_kwarg = {"processing_class": tokenizer} if _USE_PROCESSING_CLASS else {"tokenizer": tokenizer} trainer = Trainer( model=model, args=args, train_dataset=train_ds, eval_dataset=eval_ds, data_collator=collator, compute_metrics=compute_metrics if do_eval else None, preprocess_logits_for_metrics=preprocess_logits_for_metrics if do_eval else None, callbacks=callbacks, **tok_kwarg, ) resume_from = None if cfg.resume: last = get_last_checkpoint(cfg.output_dir) if last: logger.info("Resuming from checkpoint %s", last) resume_from = last try: train_result = trainer.train(resume_from_checkpoint=resume_from) except BaseException as exc: tb_path = os.path.join(cfg.output_dir, "crash_traceback.txt") with open(tb_path, "w", encoding="utf-8") as fh: fh.write(f"Crashed at {datetime.now().isoformat()}\n\n") traceback.print_exc(file=fh) logger.error("Training crashed (%s: %s). Traceback written to %s", type(exc).__name__, exc, tb_path) try: emergency = os.path.join(cfg.output_dir, "emergency_checkpoint") trainer.save_model(emergency) trainer.save_state() logger.error("Emergency model state saved to %s", emergency) except Exception as save_exc: logger.error("Emergency save also failed: %s", save_exc) raise trainer.save_model() tokenizer.save_pretrained(cfg.output_dir) trainer.save_state() metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) final = {} if eval_ds is not None: m = trainer.evaluate(eval_dataset=eval_ds, metric_key_prefix="final_val") m["final_val_perplexity"] = math.exp(min(m["final_val_loss"], 20)) final.update(m) for tag in ("med", "canon"): sub = bundle.get(f"validation_{tag}") if sub is not None: sub = sub.remove_columns([c for c in sub.column_names if c != "input_ids"]) m = trainer.evaluate(eval_dataset=sub, metric_key_prefix=f"final_val_{tag}") m[f"final_val_{tag}_perplexity"] = math.exp(min(m[f"final_val_{tag}_loss"], 20)) final.update(m) if "test" in bundle: test_ds = bundle["test"].remove_columns( [c for c in bundle["test"].column_names if c != "input_ids"]) m = trainer.evaluate(eval_dataset=test_ds, metric_key_prefix="final_test") m["final_test_perplexity"] = math.exp(min(m["final_test_loss"], 20)) final.update(m) if final: logger.info("Final evaluation:\n%s", json.dumps(final, indent=2)) trainer.save_metrics("final_eval", final) logger.info("Done. Best model + tokenizer saved to %s", cfg.output_dir) if __name__ == "__main__": main()