SignalMod / src /models /transformer_trainer.py
Mirae Kang
feat: implement new models and improve UI, #23
46cc63a
raw
history blame
35.1 kB
"""
Transformer fine-tuning (DistilBERT, Toxic-BERT, etc.) with partial or head-only
freezing, label smoothing, gap-aware early stopping, and val threshold tuning.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
from src.evaluation.threshold_tuning import predict_with_threshold, search_best_threshold
from src.utils.logger import get_logger
logger = get_logger(__name__)
def _bert_encoder_layers(model) -> list[nn.Module]:
"""Return transformer blocks for BERT / Toxic-BERT."""
if hasattr(model, "bert") and hasattr(model.bert, "encoder"):
return list(model.bert.encoder.layer)
if hasattr(model, "distilbert"):
return list(model.distilbert.transformer.layer)
raise AttributeError("Unsupported architecture for layer freeze")
def _distilbert_layers(model) -> list[nn.Module]:
"""Return transformer blocks for DistilBERT (6 layers)."""
return list(model.distilbert.transformer.layer)
def unfreeze_full_encoder(model) -> int:
"""Train all encoder blocks plus classification head (Final Squeeze / full BERT)."""
for param in model.parameters():
param.requires_grad = True
layers = _bert_encoder_layers(model)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
logger.info(
f"Full unfreeze — all {len(layers)} encoder blocks + head — "
f"trainable {trainable:,}/{total:,} ({100 * trainable / total:.1f}%)"
)
return len(layers)
def freeze_all_inference(model) -> int:
"""Freeze every parameter — pretrained inference only (Golden Baseline)."""
for param in model.parameters():
param.requires_grad = False
layers = _bert_encoder_layers(model)
logger.info(
f"Inference-only — all {len(layers)} encoder blocks + head frozen (zero fine-tuning)"
)
return len(layers)
def freeze_head_only(model) -> None:
"""Freeze entire backbone; train classification head only (Expert / Toxic-BERT)."""
for param in model.parameters():
param.requires_grad = False
for name, param in model.named_parameters():
if any(k in name for k in ("classifier", "pre_classifier", "pooler")):
param.requires_grad = True
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
logger.info(
f"Head-only freeze — trainable {trainable:,}/{total:,} "
f"({100 * trainable / total:.2f}%)"
)
def freeze_encoder_partial(model, freeze_first_n: int = 4) -> int:
"""Freeze first N encoder blocks; train remaining blocks + classification head."""
for param in model.parameters():
param.requires_grad = False
layers = _bert_encoder_layers(model)
for layer in layers[freeze_first_n:]:
for param in layer.parameters():
param.requires_grad = True
for name, param in model.named_parameters():
if any(k in name for k in ("classifier", "pre_classifier", "pooler")):
param.requires_grad = True
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
n_train = len(layers) - freeze_first_n
logger.info(
f"Partial freeze: {freeze_first_n}/{len(layers)} blocks frozen — "
f"training last {n_train} + head — "
f"trainable {trainable:,}/{total:,} ({100 * trainable / total:.1f}%)"
)
return len(layers)
def freeze_distilbert_partial(model, freeze_first_n: int = 4) -> None:
"""Backward-compatible alias for DistilBERT partial freeze."""
freeze_encoder_partial(model, freeze_first_n=freeze_first_n)
def build_head_only_optimizer(
model,
*,
learning_rate: float = 2e-5,
weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
params = [p for p in model.parameters() if p.requires_grad]
return torch.optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay)
def build_full_optimizer(
model,
*,
learning_rate: float = 5e-6,
weight_decay: float = 0.01,
) -> torch.optim.Optimizer:
params = [p for p in model.parameters() if p.requires_grad]
return torch.optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay)
def build_partial_optimizer(
model,
*,
learning_rate: float = 1e-5,
encoder_learning_rate: float | None = None,
head_learning_rate: float | None = None,
weight_decay: float = 0.01,
freeze_first_n: int = 4,
) -> torch.optim.Optimizer:
"""
Parameter groups: frozen layers excluded; top encoder blocks + head (optional split LRs).
"""
enc_lr = encoder_learning_rate if encoder_learning_rate is not None else learning_rate
head_lr = head_learning_rate if head_learning_rate is not None else learning_rate
layers = _bert_encoder_layers(model)
top_layer_ids = {
id(p) for layer in layers[freeze_first_n:] for p in layer.parameters()
}
head_params = [
p
for n, p in model.named_parameters()
if p.requires_grad and ("classifier" in n or "pre_classifier" in n)
]
head_ids = {id(p) for p in head_params}
top_params = [
p
for p in model.parameters()
if p.requires_grad and id(p) in top_layer_ids and id(p) not in head_ids
]
groups = []
if top_params:
groups.append(
{"params": top_params, "lr": enc_lr, "weight_decay": weight_decay}
)
if head_params:
groups.append(
{"params": head_params, "lr": head_lr, "weight_decay": weight_decay}
)
if not groups:
groups = [{"params": [p for p in model.parameters() if p.requires_grad]}]
return torch.optim.AdamW(groups)
def _average_state_dicts(state_dicts: list[dict]) -> dict:
"""Element-wise mean of compatible state dicts (Stochastic Weight Averaging)."""
if not state_dicts:
raise ValueError("state_dicts must not be empty")
avg = {k: v.clone().float() for k, v in state_dicts[0].items()}
for sd in state_dicts[1:]:
for k in avg:
avg[k] += sd[k].float()
n = float(len(state_dicts))
return {k: (v / n).to(state_dicts[0][k].dtype) for k, v in avg.items()}
def build_distilbert_optimizer(
model,
*,
learning_rate: float = 1e-5,
weight_decay: float = 0.01,
freeze_first_n: int = 4,
) -> torch.optim.Optimizer:
return build_partial_optimizer(
model,
learning_rate=learning_rate,
weight_decay=weight_decay,
freeze_first_n=freeze_first_n,
)
def logits_to_toxic_prob(logits: np.ndarray | torch.Tensor) -> np.ndarray:
"""Map model logits to P(toxic): sigmoid on 'toxic' for 6-label BERT, else softmax."""
t = torch.as_tensor(logits, dtype=torch.float32)
if t.ndim == 1:
t = t.unsqueeze(0)
if t.shape[-1] >= 6:
return torch.sigmoid(t)[:, 0].numpy()
return torch.softmax(t, dim=-1)[:, 1].numpy()
def compute_hf_metrics(eval_pred) -> dict[str, float]:
logits, labels = eval_pred
probs = logits_to_toxic_prob(logits)
preds = np.argmax(logits, axis=1)
return {
"f1_toxic": float(f1_score(labels, preds, pos_label=1, zero_division=0)),
"f1_weighted": float(f1_score(labels, preds, average="weighted", zero_division=0)),
"precision": float(precision_score(labels, preds, pos_label=1, zero_division=0)),
"recall": float(recall_score(labels, preds, pos_label=1, zero_division=0)),
"roc_auc": float(roc_auc_score(labels, probs)),
}
def _symmetric_kl(logits_a: torch.Tensor, logits_b: torch.Tensor) -> torch.Tensor:
"""Symmetric KL between two logit vectors (R-Drop regularization)."""
log_p = nn.functional.log_softmax(logits_a, dim=-1)
log_q = nn.functional.log_softmax(logits_b, dim=-1)
p = log_p.exp()
q = log_q.exp()
kl_pq = nn.functional.kl_div(log_p, q, reduction="batchmean", log_target=False)
kl_qp = nn.functional.kl_div(log_q, p, reduction="batchmean", log_target=False)
return (kl_pq + kl_qp) / 2.0
class LabelSmoothingTrainer(Trainer):
"""Cross-entropy with label smoothing for the classification head."""
def __init__(self, *args, label_smoothing: float = 0.1, **kwargs):
super().__init__(*args, **kwargs)
self.label_smoothing = label_smoothing
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
loss = nn.functional.cross_entropy(
outputs.logits,
labels,
label_smoothing=self.label_smoothing,
)
return (loss, outputs) if return_outputs else loss
class RDropTrainer(LabelSmoothingTrainer):
"""R-Drop: dual forward passes + symmetric KL to limit overfitting (Performance Squeeze)."""
def __init__(self, *args, rdrop_alpha: float = 0.5, **kwargs):
super().__init__(*args, **kwargs)
self.rdrop_alpha = rdrop_alpha
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs1 = model(**inputs)
outputs2 = model(**inputs)
ce = (
nn.functional.cross_entropy(
outputs1.logits, labels, label_smoothing=self.label_smoothing
)
+ nn.functional.cross_entropy(
outputs2.logits, labels, label_smoothing=self.label_smoothing
)
) / 2.0
kl = _symmetric_kl(outputs1.logits, outputs2.logits)
loss = ce + self.rdrop_alpha * kl
return (loss, outputs1) if return_outputs else loss
def create_optimizer(self):
if self.optimizer is None:
cfg = self.args
freeze_mode = getattr(self, "_freeze_mode", "partial")
if freeze_mode == "head_only":
self.optimizer = build_head_only_optimizer(
self.model,
learning_rate=cfg.learning_rate,
weight_decay=cfg.weight_decay,
)
elif freeze_mode == "full_unfreeze":
self.optimizer = build_full_optimizer(
self.model,
learning_rate=cfg.learning_rate,
weight_decay=cfg.weight_decay,
)
else:
freeze_n = getattr(self, "_freeze_first_n", 4)
enc_lr = getattr(self, "_encoder_lr", None)
head_lr = getattr(self, "_head_lr", None)
self.optimizer = build_partial_optimizer(
self.model,
learning_rate=cfg.learning_rate,
encoder_learning_rate=enc_lr,
head_learning_rate=head_lr,
weight_decay=cfg.weight_decay,
freeze_first_n=freeze_n,
)
return self.optimizer
class SWACallback(TrainerCallback):
"""
Average model weights over the last N completed epochs (Stochastic Weight Averaging).
Applied in ``on_train_end`` after ``load_best_model_at_end`` so inference uses SWA weights.
"""
def __init__(self, last_n_epochs: int = 5):
self.last_n_epochs = last_n_epochs
self._snapshots: list[dict[str, torch.Tensor]] = []
def on_epoch_end(self, args, state, control, model=None, **kwargs):
if model is None:
return control
self._snapshots.append(
{k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
)
return control
def on_train_end(self, args, state, control, model=None, **kwargs):
if model is None or not self._snapshots:
return control
n_use = min(self.last_n_epochs, len(self._snapshots))
averaged = _average_state_dicts(self._snapshots[-n_use:])
model.load_state_dict(averaged, strict=True)
logger.info(
f"SWA applied — averaged last {n_use}/{len(self._snapshots)} epoch checkpoints"
)
return control
class GapEarlyStoppingCallback(TrainerCallback):
"""
Stop when validation F1 plateaus OR |train_f1 - val_f1| exceeds max_gap.
Train F1 is computed on ``train_eval_dataset`` each evaluation step.
Attach ``trainer`` after constructing the Trainer (see train_distilbert_stable).
"""
def __init__(
self,
train_eval_dataset,
*,
patience: int = 3,
max_gap: float = 0.045,
metric: str = "f1_toxic",
gap_check_min_epoch: int = 2,
gap_stop_enabled: bool = False,
):
self.train_eval_dataset = train_eval_dataset
self.patience = patience
self.max_gap = max_gap
self.metric = metric
self.gap_check_min_epoch = gap_check_min_epoch
self.gap_stop_enabled = gap_stop_enabled
self.best_metric = -1.0
self.bad_epochs = 0
self.trainer: Trainer | None = None
self._checking = False
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics is None or self.trainer is None or self._checking:
return control
val_f1 = metrics.get(f"eval_{self.metric}", metrics.get(self.metric, 0.0))
self._checking = True
try:
train_metrics = self.trainer.evaluate(
eval_dataset=self.train_eval_dataset,
metric_key_prefix="train",
)
finally:
self._checking = False
train_f1 = train_metrics.get(f"train_{self.metric}", 0.0)
gap = abs(train_f1 - val_f1)
logger.info(
f"Gap monitor — train_f1={train_f1:.4f} val_f1={val_f1:.4f} gap={gap:.4f}"
)
epoch = int(state.epoch or 0)
if (
self.gap_stop_enabled
and epoch >= self.gap_check_min_epoch
and gap > self.max_gap
):
logger.warning(
f"Gap defense — train-val gap {gap:.4f} > {self.max_gap}; "
"stopping and reverting to best checkpoint"
)
control.should_training_stop = True
return control
if val_f1 > self.best_metric:
self.best_metric = val_f1
self.bad_epochs = 0
else:
self.bad_epochs += 1
if self.bad_epochs >= self.patience:
logger.info(
f"Early stop: no {self.metric} improvement for {self.patience} epochs"
)
control.should_training_stop = True
return control
def _predict_probs_from_dataset(
trainer: Trainer,
tokenized_dataset,
) -> np.ndarray:
ds = tokenized_dataset
if "label" in ds.column_names:
ds = ds.remove_columns(["label"])
out = trainer.predict(ds)
return logits_to_toxic_prob(out.predictions)
def predict_with_tta(
trainer: Trainer,
tokenizer,
texts: list[str],
labels_placeholder: list[int] | None,
*,
max_length: int,
aug_cfg: dict,
) -> np.ndarray:
"""
Test-time augmentation: average P(toxic) from original and back-translated texts.
"""
from datasets import Dataset
from src.features.augmentation import back_translate_texts
labels = labels_placeholder if labels_placeholder is not None else [0] * len(texts)
def _tokenize(ds):
return tokenizer(ds["text"], truncation=True, max_length=max_length)
def _prep(raw_texts: list[str]):
ds = Dataset.from_dict({"text": raw_texts, "label": labels[: len(raw_texts)]})
tok = ds.map(_tokenize, batched=True)
drop_cols = [
c for c in tok.column_names if c not in ("input_ids", "attention_mask", "label")
]
if drop_cols:
tok = tok.remove_columns(drop_cols)
tok.set_format("torch")
return tok
original_probs = _predict_probs_from_dataset(trainer, _prep(texts))
if not aug_cfg.get("enabled", False):
return original_probs
logger.info(f"TTA — back-translating {len(texts)} test comments")
bt_texts = back_translate_texts(
texts,
source_lang=aug_cfg.get("source_lang", "en"),
pivot_lang=aug_cfg.get("pivot_lang", "de"),
max_words=int(aug_cfg.get("max_words", 60)),
rate_limit_every=int(aug_cfg.get("rate_limit_every", 50)),
rate_limit_sleep_sec=float(aug_cfg.get("rate_limit_sleep_sec", 1.0)),
)
bt_probs = _predict_probs_from_dataset(trainer, _prep(bt_texts))
averaged = (original_probs + bt_probs) / 2.0
logger.info("TTA — averaged original and back-translated probabilities")
return averaged
def _transformer_cfg(cfg: dict) -> dict:
"""Resolve transformer section (expert) or legacy distilbert key."""
if "transformer" in cfg:
return cfg["transformer"]
return cfg["distilbert"]
def _apply_model_freeze(model, bert_cfg: dict) -> tuple[str, int]:
freeze_mode = bert_cfg.get("freeze_mode", "partial")
if freeze_mode == "head_only":
freeze_head_only(model)
return "head_only", 0
if freeze_mode in ("full", "full_unfreeze", "all_layers"):
n_layers = unfreeze_full_encoder(model)
return "full_unfreeze", 0
if freeze_mode in ("inference_only", "frozen", "golden_baseline"):
n_layers = freeze_all_inference(model)
return "inference_only", n_layers
layers = _bert_encoder_layers(model)
if freeze_mode in ("last_n_layers", "train_last_n"):
train_last_n = int(bert_cfg.get("train_last_n_layers", 4))
freeze_first_n = max(0, len(layers) - train_last_n)
else:
freeze_first_n = int(bert_cfg.get("freeze_first_n_layers", 4))
freeze_encoder_partial(model, freeze_first_n=freeze_first_n)
return f"partial_last_{len(layers) - freeze_first_n}", freeze_first_n
def _bert_metrics_from_probs(
*,
model_label: str,
model_id: str,
freeze_mode: str,
y_train: np.ndarray,
y_test: np.ndarray,
train_probs: np.ndarray,
test_probs: np.ndarray,
threshold: float,
val_f1_at_threshold: float | None = None,
extra: dict | None = None,
) -> dict[str, Any]:
train_preds = predict_with_threshold(train_probs, threshold)
test_preds = predict_with_threshold(test_probs, threshold)
f1_train = float(f1_score(y_train, train_preds, average="weighted", zero_division=0))
f1_test = float(f1_score(y_test, test_preds, average="weighted", zero_division=0))
f1_toxic_test = float(f1_score(y_test, test_preds, pos_label=1, zero_division=0))
f1_toxic_train = float(f1_score(y_train, train_preds, pos_label=1, zero_division=0))
gap_weighted = abs(f1_train - f1_test)
gap_toxic = abs(f1_toxic_train - f1_toxic_test)
metrics = {
"model": model_label,
"model_id": model_id,
"freeze_mode": freeze_mode,
"threshold": round(threshold, 4),
"val_f1_at_threshold": round(val_f1_at_threshold, 4) if val_f1_at_threshold else None,
"f1_weighted": round(f1_test, 4),
"f1_toxic": round(f1_toxic_test, 4),
"f1_toxic_train": round(f1_toxic_train, 4),
"train_test_gap_toxic": round(gap_toxic, 4),
"train_test_gap_toxic_pp": round(gap_toxic * 100, 2),
"f1_train": round(f1_train, 4),
"train_test_gap": round(gap_weighted, 4),
"train_test_gap_pp": round(gap_weighted * 100, 2),
"gap_ok": gap_weighted < 0.05,
"roc_auc": round(float(roc_auc_score(y_test, test_probs)), 4),
"fp": int(((y_test == 0) & (test_preds == 1)).sum()),
"fn": int(((y_test == 1) & (test_preds == 0)).sum()),
}
if extra:
metrics.update(extra)
return metrics
def evaluate_pretrained_bert_baseline(
hf_train,
hf_val,
hf_test,
y_train: np.ndarray,
y_test: np.ndarray,
y_val: np.ndarray,
cfg: dict,
*,
seed: int = 42,
model_label: str = "Golden-Baseline-Toxic-BERT",
) -> dict[str, Any]:
"""
Step 1 — pretrained Toxic-BERT with all layers frozen (no fine-tuning on project data).
"""
bert_cfg = cfg.get("baseline", _transformer_cfg(cfg))
model_id = bert_cfg["model_id"]
set_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)
max_len = int(bert_cfg.get("max_length", 128))
def _tokenize(ds):
return tokenizer(ds["text"], truncation=True, max_length=max_len)
def _prep(ds):
tok = ds.map(_tokenize, batched=True)
drop_cols = [
c for c in tok.column_names if c not in ("input_ids", "attention_mask", "label")
]
if drop_cols:
tok = tok.remove_columns(drop_cols)
tok.set_format("torch")
return tok
tok_train = _prep(hf_train)
tok_val = _prep(hf_val)
tok_test = _prep(hf_test)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
model.to(device)
_apply_model_freeze(model, {**bert_cfg, "freeze_mode": "inference_only"})
model.eval()
args = TrainingArguments(
output_dir="/tmp/golden_baseline_eval",
per_device_eval_batch_size=int(bert_cfg.get("batch_size", 8)),
report_to="none",
seed=seed,
)
trainer = Trainer(
model=model,
args=args,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_hf_metrics,
)
logger.info(f"Golden Baseline — {model_id} (inference only, no training)")
tok_val_pred = tok_val.remove_columns(["label"]) if "label" in tok_val.column_names else tok_val
tok_train_pred = (
tok_train.remove_columns(["label"]) if "label" in tok_train.column_names else tok_train
)
tok_test_pred = (
tok_test.remove_columns(["label"]) if "label" in tok_test.column_names else tok_test
)
val_probs = logits_to_toxic_prob(trainer.predict(tok_val_pred).predictions)
train_probs = logits_to_toxic_prob(trainer.predict(tok_train_pred).predictions)
test_probs = logits_to_toxic_prob(trainer.predict(tok_test_pred).predictions)
y_val_arr = np.asarray(y_val).astype(int)
y_train_arr = np.asarray(y_train).astype(int)
y_test_arr = np.asarray(y_test).astype(int)
th_cfg = bert_cfg.get("threshold_tuning", {})
threshold = 0.5
val_f1_at_threshold = None
if th_cfg.get("enabled", True):
threshold, val_f1_at_threshold = search_best_threshold(
y_val_arr,
val_probs,
metric=th_cfg.get("metric", "f1_weighted"),
min_threshold=float(th_cfg.get("min_threshold", 0.05)),
max_threshold=float(th_cfg.get("max_threshold", 0.95)),
step=float(th_cfg.get("step", 0.01)),
)
metrics = _bert_metrics_from_probs(
model_label=model_label,
model_id=model_id,
freeze_mode="inference_only",
y_train=y_train_arr,
y_test=y_test_arr,
train_probs=train_probs,
test_probs=test_probs,
threshold=threshold,
val_f1_at_threshold=val_f1_at_threshold,
extra={
"trained": False,
"num_labels": int(model.config.num_labels),
"prob_mode": "sigmoid_toxic" if model.config.num_labels >= 6 else "softmax",
"esencial_compliant_gap": True,
},
)
metrics["gap_ok"] = metrics["train_test_gap"] < 0.01
return {
"metrics": metrics,
"trainer": trainer,
"tokenizer": tokenizer,
"test_probs": test_probs,
"val_probs": val_probs,
"train_probs": train_probs,
"threshold": threshold,
}
def train_transformer_stable(
hf_train,
hf_val,
hf_test,
y_test: np.ndarray,
y_val: np.ndarray,
cfg: dict,
output_dir: Path,
*,
seed: int = 42,
model_label: str = "Transformer-stable",
) -> dict[str, Any]:
"""
Fine-tune a sequence classifier with stability-focused regularization.
Returns metrics, trainer, tokenizer, test probabilities, and optimal threshold.
"""
bert_cfg = _transformer_cfg(cfg)
model_id = bert_cfg["model_id"]
set_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_id)
max_len = int(bert_cfg.get("max_length", 128))
def _tokenize(ds):
return tokenizer(ds["text"], truncation=True, max_length=max_len)
def _prep(ds):
tok = ds.map(_tokenize, batched=True)
drop_cols = [
c for c in tok.column_names if c not in ("input_ids", "attention_mask", "label")
]
if drop_cols:
tok = tok.remove_columns(drop_cols)
tok.set_format("torch")
return tok
tok_train = _prep(hf_train)
tok_val = _prep(hf_val)
tok_test = _prep(hf_test)
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=2,
ignore_mismatched_sizes=True,
)
model.config.problem_type = "single_label_classification"
model.config.num_labels = 2
dropout_p = float(bert_cfg.get("head_dropout", 0.5))
if hasattr(model.config, "hidden_dropout_prob"):
model.config.hidden_dropout_prob = dropout_p
if hasattr(model.config, "seq_classif_dropout"):
model.config.seq_classif_dropout = dropout_p
if hasattr(model, "dropout"):
model.dropout = nn.Dropout(dropout_p)
model.to(device)
freeze_mode, freeze_first_n = _apply_model_freeze(model, bert_cfg)
es_cfg = bert_cfg.get("early_stopping", {})
es_metric = es_cfg.get("metric", bert_cfg.get("metric_for_best", "f1_toxic"))
gap_cb = GapEarlyStoppingCallback(
tok_train,
patience=int(es_cfg.get("patience", 3)),
max_gap=float(es_cfg.get("max_train_val_gap", 0.045)),
metric=es_metric,
gap_check_min_epoch=int(es_cfg.get("gap_check_min_epoch", 2)),
gap_stop_enabled=bool(es_cfg.get("gap_stop_enabled", False)),
)
callbacks: list[TrainerCallback] = [gap_cb]
swa_cfg = bert_cfg.get("swa", {})
swa_cb: SWACallback | None = None
if swa_cfg.get("enabled", False):
swa_cb = SWACallback(last_n_epochs=int(swa_cfg.get("last_n_epochs", 5)))
callbacks.append(swa_cb)
args = TrainingArguments(
output_dir=str(output_dir),
learning_rate=float(bert_cfg.get("learning_rate", 1e-5)),
num_train_epochs=int(bert_cfg.get("max_epochs", 8)),
per_device_train_batch_size=int(bert_cfg.get("batch_size", 8)),
per_device_eval_batch_size=int(bert_cfg.get("batch_size", 8)),
weight_decay=float(bert_cfg.get("weight_decay", 0.01)),
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model=bert_cfg.get("metric_for_best", "f1_toxic"),
greater_is_better=True,
warmup_ratio=float(bert_cfg.get("warmup_ratio", 0.1)),
logging_steps=20,
save_total_limit=2,
fp16=torch.cuda.is_available(),
report_to="none",
seed=seed,
)
rdrop_cfg = bert_cfg.get("rdrop", {})
label_smooth = float(bert_cfg.get("label_smoothing", 0.1))
if rdrop_cfg.get("enabled", False):
trainer = RDropTrainer(
model=model,
args=args,
train_dataset=tok_train,
eval_dataset=tok_val,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_hf_metrics,
callbacks=callbacks,
label_smoothing=label_smooth,
rdrop_alpha=float(rdrop_cfg.get("alpha", 0.5)),
)
rdrop_note = f", R-Drop α={rdrop_cfg.get('alpha', 0.5)}"
else:
trainer = LabelSmoothingTrainer(
model=model,
args=args,
train_dataset=tok_train,
eval_dataset=tok_val,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_hf_metrics,
callbacks=callbacks,
label_smoothing=label_smooth,
)
rdrop_note = ""
gap_cb.trainer = trainer
trainer._freeze_mode = freeze_mode # noqa: SLF001
trainer._freeze_first_n = freeze_first_n # noqa: SLF001
if bert_cfg.get("encoder_learning_rate") is not None:
trainer._encoder_lr = float(bert_cfg["encoder_learning_rate"]) # noqa: SLF001
if bert_cfg.get("head_learning_rate") is not None:
trainer._head_lr = float(bert_cfg["head_learning_rate"]) # noqa: SLF001
enc_lr = bert_cfg.get("encoder_learning_rate", bert_cfg.get("learning_rate"))
head_lr = bert_cfg.get("head_learning_rate", bert_cfg.get("learning_rate"))
logger.info(
f"Training {model_id} ({freeze_mode} freeze, enc_lr={enc_lr}, head_lr={head_lr}"
f"{rdrop_note}"
f"{', SWA last ' + str(swa_cfg.get('last_n_epochs', 5)) + ' epochs' if swa_cb else ''})..."
)
trainer.train()
val_out = trainer.predict(tok_val)
val_probs = logits_to_toxic_prob(val_out.predictions)
y_val_arr = np.asarray(y_val).astype(int)
th_cfg = bert_cfg.get("threshold_tuning", {})
threshold = 0.5
val_f1_at_threshold = None
if th_cfg.get("enabled", False):
metric = th_cfg.get("metric", "f1_toxic")
threshold, val_f1_at_threshold = search_best_threshold(
y_val_arr,
val_probs,
metric=metric,
min_threshold=float(th_cfg.get("min_threshold", 0.05)),
max_threshold=float(th_cfg.get("max_threshold", 0.95)),
step=float(th_cfg.get("step", 0.01)),
)
th_step = float(th_cfg.get("step", 0.01))
logger.info(
f"Val threshold tuning — best_t={threshold:.3f} "
f"val_{metric}={val_f1_at_threshold:.4f} (step={th_step})"
)
tta_cfg = bert_cfg.get("test_time_augmentation", {})
test_texts = list(hf_test["text"])
if tta_cfg.get("enabled", False):
probs = predict_with_tta(
trainer,
tokenizer,
test_texts,
list(hf_test["label"]),
max_length=max_len,
aug_cfg=tta_cfg,
)
preds_default = (probs >= 0.5).astype(int)
else:
output = trainer.predict(tok_test)
probs = logits_to_toxic_prob(output.predictions)
preds_default = np.argmax(output.predictions, axis=1)
preds = predict_with_threshold(probs, threshold)
train_out = trainer.predict(tok_train)
train_probs = logits_to_toxic_prob(train_out.predictions)
train_preds = predict_with_threshold(train_probs, threshold)
train_labels = np.asarray(hf_train["label"]).astype(int)
y_test_arr = np.asarray(y_test).astype(int)
f1_train = float(f1_score(train_labels, train_preds, average="weighted", zero_division=0))
f1_test = float(f1_score(y_test_arr, preds, average="weighted", zero_division=0))
f1_toxic_test = float(f1_score(y_test_arr, preds, pos_label=1, zero_division=0))
f1_toxic_train = float(
f1_score(train_labels, train_preds, pos_label=1, zero_division=0)
)
gap_weighted = abs(f1_train - f1_test)
gap_toxic = abs(f1_toxic_train - f1_toxic_test)
metrics = {
"model": model_label,
"model_id": model_id,
"freeze_mode": freeze_mode,
"rdrop_enabled": bool(rdrop_cfg.get("enabled", False)),
"tta_enabled": bool(tta_cfg.get("enabled", False)),
"swa_enabled": bool(swa_cfg.get("enabled", False)),
"swa_epochs_averaged": (
min(int(swa_cfg.get("last_n_epochs", 5)), len(swa_cb._snapshots))
if swa_cb and swa_cb._snapshots
else 0
),
"threshold": round(threshold, 4),
"threshold_step": float(th_cfg.get("step", 0.01)) if th_cfg.get("enabled") else None,
"val_f1_at_threshold": round(val_f1_at_threshold, 4) if val_f1_at_threshold else None,
"f1_weighted": round(f1_test, 4),
"f1_toxic": round(f1_toxic_test, 4),
"f1_toxic_train": round(f1_toxic_train, 4),
"train_test_gap_toxic": round(gap_toxic, 4),
"train_test_gap_toxic_pp": round(gap_toxic * 100, 2),
"gap_toxic_ok": gap_toxic < 0.05,
"f1_train": round(f1_train, 4),
"train_test_gap": round(gap_weighted, 4),
"train_test_gap_pp": round(gap_weighted * 100, 2),
"f1_weighted_default_thresh": round(
float(f1_score(y_test_arr, preds_default, average="weighted", zero_division=0)), 4
),
"roc_auc": round(float(roc_auc_score(y_test_arr, probs)), 4),
"fp": int(((y_test_arr == 0) & (preds == 1)).sum()),
"fn": int(((y_test_arr == 1) & (preds == 0)).sum()),
}
trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
return {
"metrics": metrics,
"trainer": trainer,
"tokenizer": tokenizer,
"test_probs": probs,
"test_preds": preds,
"threshold": threshold,
"val_probs": val_probs,
}
def infer_transformer_probs(
model_dir: Path,
texts,
*,
max_length: int = 128,
batch_size: int = 16,
) -> np.ndarray:
"""Load a saved classifier and return P(toxic) for each text."""
model_dir = Path(model_dir)
text_list = list(texts)
tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
probs: list[float] = []
with torch.no_grad():
for i in range(0, len(text_list), batch_size):
batch = text_list[i : i + batch_size]
enc = tokenizer(
batch,
truncation=True,
max_length=max_length,
padding=True,
return_tensors="pt",
)
enc = {k: v.to(device) for k, v in enc.items()}
logits = model(**enc).logits
batch_probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
probs.extend(batch_probs.tolist())
return np.array(probs)
def train_distilbert_stable(
hf_train,
hf_val,
hf_test,
y_test: np.ndarray,
cfg: dict,
output_dir: Path,
*,
seed: int = 42,
) -> dict[str, Any]:
"""Backward-compatible wrapper for stable DistilBERT pipeline."""
y_val = np.asarray(hf_val["label"]).astype(int)
return train_transformer_stable(
hf_train,
hf_val,
hf_test,
y_test,
y_val,
cfg,
output_dir,
seed=seed,
model_label="DistilBERT-stable",
)