cipher-detective-ai / scripts /train_transformer.py
Paul Clark
feat(training): add early stopping, cosine LR schedule, and push-to-hub
711098d
from __future__ import annotations
import argparse
import json
from pathlib import Path
import numpy as np
import torch
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
)
def load_jsonl(path):
rows = []
for line in Path(path).read_text(encoding="utf-8").splitlines():
if line.strip():
rows.append(json.loads(line))
return rows
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
precision, recall, f1, _ = precision_recall_fscore_support(
labels, preds, average="macro", zero_division=0
)
return {
"accuracy": accuracy_score(labels, preds),
"macro_precision": precision,
"macro_recall": recall,
"macro_f1": f1,
}
def make_weighted_trainer(class_weights_tensor):
"""Return a Trainer subclass that uses class-weighted cross-entropy loss."""
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
weights = class_weights_tensor.to(logits.device)
loss = torch.nn.functional.cross_entropy(logits, labels, weight=weights)
return (loss, outputs) if return_outputs else loss
return WeightedTrainer
def make_focal_trainer(class_weights_tensor, gamma: float = 2.0):
"""Focal loss trainer: down-weights easy examples, focuses on hard ones.
Combines class-weighting (for imbalance) with focal loss (for hard negatives).
Recommended when the dataset has both class-imbalance AND many confusable pairs.
"""
import torch.nn.functional as F
class FocalTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
weights = class_weights_tensor.to(logits.device)
# Standard weighted cross-entropy
ce = F.cross_entropy(logits, labels, weight=weights, reduction="none")
# Focal scaling: (1 - p_t)^gamma
probs = F.softmax(logits, dim=-1)
pt = probs.gather(1, labels.unsqueeze(1)).squeeze(1)
focal = ((1 - pt) ** gamma) * ce
loss = focal.mean()
return (loss, outputs) if return_outputs else loss
return FocalTrainer
def main():
ap = argparse.ArgumentParser(
description="Fine-tune a transformer for 81-class cipher identification."
)
ap.add_argument("--data", default="data/cipher_examples.jsonl")
ap.add_argument(
"--test-data", default=None,
help="Separate JSONL eval file (e.g. blind split). "
"If omitted, 15%% of --data is held out.",
)
ap.add_argument(
"--model", default="roberta-base",
help="Pre-trained model ID or local path. "
"Smaller: distilroberta-base. Larger: roberta-large.",
)
ap.add_argument("--out", default="cipher_model")
ap.add_argument("--epochs", type=float, default=10.0,
help="Training epochs. 10+ recommended for 81-class accuracy.")
ap.add_argument("--batch-size", type=int, default=16)
ap.add_argument("--max-length", type=int, default=256,
help="Token length. 256 covers most cipher texts; raise for long ones.")
ap.add_argument(
"--weighted-loss", action="store_true", default=True,
help="Use class-weighted cross-entropy (default: on). "
"Essential given the 75:1 class-imbalance in the dataset.",
)
ap.add_argument(
"--focal-loss", action="store_true",
help="Use focal loss instead of plain weighted cross-entropy. "
"Helps when many ciphers are statistically similar.",
)
ap.add_argument(
"--lr", type=float, default=2e-5,
help="Peak learning rate. 2e-5 works well for roberta-base; "
"try 3e-5 for distilroberta.",
)
ap.add_argument("--warmup-ratio", type=float, default=0.06,
help="Fraction of total steps used for linear warmup.")
ap.add_argument("--label-smoothing", type=float, default=0.05,
help="Label smoothing factor (0 = off). Helps with similar-class confusion.")
ap.add_argument("--grad-accum", type=int, default=2,
help="Gradient accumulation steps. Effective batch = batch-size × grad-accum.")
ap.add_argument(
"--early-stopping-patience", type=int, default=3,
help="Stop training if macro_f1 doesn't improve for this many eval epochs (0 = off).",
)
ap.add_argument(
"--push-to-hub", action="store_true",
help="Push the trained model to the Hugging Face Hub after training.",
)
ap.add_argument(
"--hub-model-id", default=None,
help="Hub repo id for --push-to-hub (e.g. username/cipher-model). "
"Required when --push-to-hub is set.",
)
args = ap.parse_args()
rows = load_jsonl(args.data)
# Drop labels with fewer than 2 examples (can't stratify-split them).
from collections import Counter
label_counts = Counter(r["label"] for r in rows)
dropped = {lbl for lbl, cnt in label_counts.items() if cnt < 2}
if dropped:
print(f"Dropping {len(dropped)} label(s) with <2 examples: {sorted(dropped)}")
rows = [r for r in rows if r["label"] not in dropped]
labels = sorted({r["label"] for r in rows})
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for label, i in label2id.items()}
print(f"Dataset: {len(rows):,} examples | {len(labels)} labels")
print(f"Model: {args.model} | epochs: {args.epochs} | lr: {args.lr}")
# Keep only the two columns needed for training.
rows = [{"text": r["text"], "label_id": label2id[r["label"]]} for r in rows]
if args.test_data:
test_rows_raw = load_jsonl(args.test_data)
test_rows = [
{"text": r["text"], "label_id": label2id[r["label"]]}
for r in test_rows_raw
if r.get("label") in label2id
]
train_rows = rows
print(f"Using separate test file: {len(test_rows)} eval examples")
else:
train_rows, test_rows = train_test_split(
rows,
test_size=0.15,
random_state=42,
stratify=[r["label_id"] for r in rows],
)
ds_train = Dataset.from_list(train_rows)
ds_test = Dataset.from_list(test_rows)
tok = AutoTokenizer.from_pretrained(args.model)
def tokenize(batch):
return tok(batch["text"], truncation=True, max_length=args.max_length)
ds_train = ds_train.map(tokenize, batched=True)
ds_test = ds_test.map(tokenize, batched=True)
ds_train = ds_train.rename_column("label_id", "labels")
ds_test = ds_test.rename_column("label_id", "labels")
model = AutoModelForSequenceClassification.from_pretrained(
args.model,
num_labels=len(labels),
id2label=id2label,
label2id=label2id,
)
# Compute class weights for the weighted / focal loss trainer.
train_label_ids = [r["label_id"] for r in train_rows]
class_weights = compute_class_weight(
class_weight="balanced",
classes=np.arange(len(labels)),
y=train_label_ids,
)
# Cap extreme weights to prevent instability on very rare classes.
class_weights = np.clip(class_weights, 0.1, 20.0)
weights_tensor = torch.tensor(class_weights, dtype=torch.float32)
print(f"Class weights — min: {weights_tensor.min():.2f} max: {weights_tensor.max():.2f}")
training_args = TrainingArguments(
output_dir=args.out,
eval_strategy="epoch",
save_strategy="epoch",
learning_rate=args.lr,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
weight_decay=0.01,
warmup_ratio=args.warmup_ratio,
label_smoothing_factor=args.label_smoothing,
gradient_accumulation_steps=args.grad_accum,
lr_scheduler_type="cosine",
logging_steps=100,
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
report_to="none",
save_total_limit=2,
# Mixed-precision: speeds up training on modern GPUs
fp16=torch.cuda.is_available(),
dataloader_num_workers=2,
# Hub push (only active when --push-to-hub is passed)
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id if args.push_to_hub else None,
)
if args.focal_loss:
print("Using focal loss (with class weighting)")
TrainerClass = make_focal_trainer(weights_tensor)
elif args.weighted_loss:
print("Using class-weighted cross-entropy loss")
TrainerClass = make_weighted_trainer(weights_tensor)
else:
print("Using standard cross-entropy loss (no class weighting)")
TrainerClass = Trainer
callbacks = []
if args.early_stopping_patience > 0:
callbacks.append(EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience))
print(f"Early stopping: patience={args.early_stopping_patience} epochs")
trainer = TrainerClass(
model=model,
args=training_args,
train_dataset=ds_train,
eval_dataset=ds_test,
processing_class=tok,
data_collator=DataCollatorWithPadding(tok),
compute_metrics=compute_metrics,
callbacks=callbacks or None,
)
trainer.train()
metrics = trainer.evaluate()
trainer.save_model(args.out)
tok.save_pretrained(args.out)
if args.push_to_hub:
print(f"Pushing model to Hub: {args.hub_model_id}")
trainer.push_to_hub()
out_path = Path(args.out)
(out_path / "training_metrics.json").write_text(
json.dumps(metrics, indent=2), encoding="utf-8"
)
(out_path / "label_mapping.json").write_text(
json.dumps({"label2id": label2id, "id2label": id2label}, indent=2),
encoding="utf-8",
)
print(json.dumps(metrics, indent=2))
print(f"\nSaved model to {args.out}")
print(f"Accuracy: {metrics.get('eval_accuracy', 0):.3f}")
print(f"Macro F1: {metrics.get('eval_macro_f1', 0):.3f}")
if __name__ == "__main__":
main()