paper-classifier / train.py
gr8monk3ys's picture
Upload folder using huggingface_hub
2be4558 verified
#!/usr/bin/env python3
"""
Fine-tune DistilBERT for academic paper abstract classification.
This script downloads arxiv paper abstracts, preprocesses them, and fine-tunes
a DistilBERT model for multi-class sequence classification. Supports pushing
the trained model to the HuggingFace Hub.
Author: Lorenzo Scaturchio (gr8monk3ys)
License: MIT
"""
import argparse
import logging
import os
import sys
from pathlib import Path
import evaluate
import numpy as np
import torch
from datasets import ClassLabel, DatasetDict, load_dataset
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
EarlyStoppingCallback,
Trainer,
TrainingArguments,
set_seed,
)
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)],
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
MODEL_NAME = "distilbert-base-uncased"
DEFAULT_DATASET = "ccdv/arxiv-classification"
DEFAULT_OUTPUT_DIR = "./results"
DEFAULT_MODEL_DIR = "./model"
# Canonical label order so the id<->label mapping is deterministic.
LABEL_NAMES = [
"cs.AI",
"cs.CL",
"cs.CV",
"cs.LG",
"cs.NE",
"cs.RO",
"math.ST",
"stat.ML",
]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments for training hyperparameters."""
parser = argparse.ArgumentParser(
description="Fine-tune DistilBERT on arxiv paper classification."
)
# Data
parser.add_argument(
"--dataset_name",
type=str,
default=DEFAULT_DATASET,
help="HuggingFace dataset identifier (default: %(default)s).",
)
parser.add_argument(
"--max_length",
type=int,
default=512,
help="Maximum token length for the tokenizer (default: %(default)s).",
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help="Cap the number of training samples (useful for debugging).",
)
parser.add_argument(
"--max_eval_samples",
type=int,
default=None,
help="Cap the number of evaluation samples (useful for debugging).",
)
# Training
parser.add_argument(
"--output_dir",
type=str,
default=DEFAULT_OUTPUT_DIR,
help="Directory for training checkpoints (default: %(default)s).",
)
parser.add_argument(
"--model_dir",
type=str,
default=DEFAULT_MODEL_DIR,
help="Directory where the final model is saved (default: %(default)s).",
)
parser.add_argument(
"--num_train_epochs",
type=int,
default=5,
help="Total number of training epochs (default: %(default)s).",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=16,
help="Batch size per device during training (default: %(default)s).",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=32,
help="Batch size per device during evaluation (default: %(default)s).",
)
parser.add_argument(
"--learning_rate",
type=float,
default=2e-5,
help="Peak learning rate (default: %(default)s).",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.01,
help="Weight decay coefficient (default: %(default)s).",
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.1,
help="Fraction of total steps used for linear warmup (default: %(default)s).",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility (default: %(default)s).",
)
parser.add_argument(
"--early_stopping_patience",
type=int,
default=3,
help="Number of evaluations with no improvement before stopping (default: %(default)s).",
)
parser.add_argument(
"--fp16",
action="store_true",
default=False,
help="Use mixed-precision (FP16) training.",
)
# Hub
parser.add_argument(
"--push_to_hub",
action="store_true",
default=False,
help="Push the trained model to the HuggingFace Hub.",
)
parser.add_argument(
"--hub_model_id",
type=str,
default="gr8monk3ys/paper-classifier-model",
help="Repository id on the HuggingFace Hub (default: %(default)s).",
)
return parser.parse_args()
def build_label_mappings(label_names: list[str]) -> tuple[dict, dict]:
"""Return (label2id, id2label) dicts for the given label names."""
label2id = {label: idx for idx, label in enumerate(label_names)}
id2label = {idx: label for idx, label in enumerate(label_names)}
return label2id, id2label
def load_and_prepare_dataset(
dataset_name: str,
label2id: dict[str, int],
max_train_samples: int | None = None,
max_eval_samples: int | None = None,
) -> DatasetDict:
"""Load the dataset and normalise the label column.
The function handles two common dataset layouts:
1. The dataset already has train / validation / test splits and a
numeric ``label`` column whose values match our ``label2id``.
2. The dataset has a string ``label`` column that needs mapping.
Returns a ``DatasetDict`` with ``train`` and ``validation`` splits.
"""
logger.info("Loading dataset: %s", dataset_name)
raw = load_dataset(dataset_name, trust_remote_code=True)
# Determine the text and label column names --------------------------
sample_columns = list(next(iter(raw.values())).column_names)
text_col = None
for candidate in ("text", "abstract", "input", "sentence"):
if candidate in sample_columns:
text_col = candidate
break
if text_col is None:
# Fall back to the first string-typed column
text_col = sample_columns[0]
logger.info("Using text column: '%s'", text_col)
label_col = None
for candidate in ("label", "labels", "category", "class"):
if candidate in sample_columns:
label_col = candidate
break
if label_col is None:
label_col = sample_columns[-1]
logger.info("Using label column: '%s'", label_col)
# Rename columns so downstream code can rely on 'text' and 'label' ---
def _rename(example):
return {"text": str(example[text_col]), "label": example[label_col]}
raw = raw.map(_rename, remove_columns=sample_columns)
# If labels are strings, map them to ints using label2id -------------
sample_label = raw[list(raw.keys())[0]][0]["label"]
if isinstance(sample_label, str):
logger.info("Mapping string labels to integer ids.")
def _map_label(example):
lbl = example["label"]
if lbl in label2id:
example["label"] = label2id[lbl]
else:
example["label"] = -1 # will be filtered out
return example
raw = raw.map(_map_label)
raw = raw.filter(lambda ex: ex["label"] != -1)
# Ensure we have a ClassLabel feature --------------------------------
label_feature = ClassLabel(
num_classes=len(label2id), names=list(label2id.keys())
)
raw = raw.cast_column("label", label_feature)
# Build train / validation splits ------------------------------------
if "validation" not in raw and "test" in raw:
raw["validation"] = raw.pop("test")
elif "validation" not in raw:
split = raw["train"].train_test_split(test_size=0.1, seed=42, stratify_by_column="label")
raw = DatasetDict({"train": split["train"], "validation": split["test"]})
# Subsample if requested ---------------------------------------------
if max_train_samples is not None:
raw["train"] = raw["train"].select(range(min(max_train_samples, len(raw["train"]))))
if max_eval_samples is not None:
raw["validation"] = raw["validation"].select(
range(min(max_eval_samples, len(raw["validation"])))
)
logger.info(
"Dataset sizes -> train: %d, validation: %d",
len(raw["train"]),
len(raw["validation"]),
)
return raw
def tokenize_dataset(
dataset: DatasetDict,
tokenizer: AutoTokenizer,
max_length: int,
) -> DatasetDict:
"""Tokenize the ``text`` column using the supplied tokenizer."""
def _tokenize(batch):
return tokenizer(
batch["text"],
padding="max_length",
truncation=True,
max_length=max_length,
)
logger.info("Tokenizing dataset (max_length=%d) ...", max_length)
tokenized = dataset.map(_tokenize, batched=True, desc="Tokenizing")
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
return tokenized
def build_compute_metrics_fn():
"""Return a ``compute_metrics`` callable for the HF Trainer.
Loads the ``accuracy``, ``f1``, ``precision`` and ``recall`` evaluate
metrics once at creation time to avoid repeated disk access.
"""
acc_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")
prec_metric = evaluate.load("precision")
rec_metric = evaluate.load("recall")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
results = {}
results.update(acc_metric.compute(predictions=predictions, references=labels))
results.update(
f1_metric.compute(
predictions=predictions, references=labels, average="weighted"
)
)
results.update(
prec_metric.compute(
predictions=predictions, references=labels, average="weighted"
)
)
results.update(
rec_metric.compute(
predictions=predictions, references=labels, average="weighted"
)
)
return results
return compute_metrics
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
args = parse_args()
# Reproducibility
set_seed(args.seed)
logger.info("Seed set to %d", args.seed)
# Device info
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
logger.info("Using device: %s", device)
# Label mappings
label2id, id2label = build_label_mappings(LABEL_NAMES)
num_labels = len(LABEL_NAMES)
logger.info("Number of labels: %d", num_labels)
# Dataset
dataset = load_and_prepare_dataset(
dataset_name=args.dataset_name,
label2id=label2id,
max_train_samples=args.max_train_samples,
max_eval_samples=args.max_eval_samples,
)
# Tokenizer
logger.info("Loading tokenizer: %s", MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenized_dataset = tokenize_dataset(dataset, tokenizer, args.max_length)
# Model
logger.info("Loading model: %s", MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_NAME,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
)
# Training arguments
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
warmup_ratio=args.warmup_ratio,
lr_scheduler_type="linear",
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
fp16=args.fp16 and torch.cuda.is_available(),
report_to="none",
seed=args.seed,
push_to_hub=False, # we push manually after training
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
compute_metrics=build_compute_metrics_fn(),
callbacks=[
EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience),
],
)
# Train
logger.info("Starting training ...")
train_result = trainer.train()
logger.info("Training complete.")
# Log final training metrics
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
# Evaluate
logger.info("Running final evaluation ...")
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
trainer.save_metrics("eval", eval_metrics)
# Save model + tokenizer
model_dir = Path(args.model_dir)
model_dir.mkdir(parents=True, exist_ok=True)
logger.info("Saving model to %s", model_dir)
trainer.save_model(str(model_dir))
tokenizer.save_pretrained(str(model_dir))
# Push to Hub
if args.push_to_hub:
logger.info("Pushing model to HuggingFace Hub: %s", args.hub_model_id)
try:
model.push_to_hub(args.hub_model_id)
tokenizer.push_to_hub(args.hub_model_id)
logger.info("Model pushed successfully.")
except Exception:
logger.exception("Failed to push model to Hub.")
sys.exit(1)
logger.info("All done.")
if __name__ == "__main__":
main()