|
|
|
|
|
""" |
|
|
Fine-tune DistilBERT for academic paper abstract classification. |
|
|
|
|
|
This script downloads arxiv paper abstracts, preprocesses them, and fine-tunes |
|
|
a DistilBERT model for multi-class sequence classification. Supports pushing |
|
|
the trained model to the HuggingFace Hub. |
|
|
|
|
|
Author: Lorenzo Scaturchio (gr8monk3ys) |
|
|
License: MIT |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import evaluate |
|
|
import numpy as np |
|
|
import torch |
|
|
from datasets import ClassLabel, DatasetDict, load_dataset |
|
|
from transformers import ( |
|
|
AutoModelForSequenceClassification, |
|
|
AutoTokenizer, |
|
|
EarlyStoppingCallback, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
set_seed, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s", |
|
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "distilbert-base-uncased" |
|
|
DEFAULT_DATASET = "ccdv/arxiv-classification" |
|
|
DEFAULT_OUTPUT_DIR = "./results" |
|
|
DEFAULT_MODEL_DIR = "./model" |
|
|
|
|
|
|
|
|
LABEL_NAMES = [ |
|
|
"cs.AI", |
|
|
"cs.CL", |
|
|
"cs.CV", |
|
|
"cs.LG", |
|
|
"cs.NE", |
|
|
"cs.RO", |
|
|
"math.ST", |
|
|
"stat.ML", |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
"""Parse command-line arguments for training hyperparameters.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Fine-tune DistilBERT on arxiv paper classification." |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--dataset_name", |
|
|
type=str, |
|
|
default=DEFAULT_DATASET, |
|
|
help="HuggingFace dataset identifier (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_length", |
|
|
type=int, |
|
|
default=512, |
|
|
help="Maximum token length for the tokenizer (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_train_samples", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Cap the number of training samples (useful for debugging).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_eval_samples", |
|
|
type=int, |
|
|
default=None, |
|
|
help="Cap the number of evaluation samples (useful for debugging).", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default=DEFAULT_OUTPUT_DIR, |
|
|
help="Directory for training checkpoints (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--model_dir", |
|
|
type=str, |
|
|
default=DEFAULT_MODEL_DIR, |
|
|
help="Directory where the final model is saved (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_train_epochs", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Total number of training epochs (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--per_device_train_batch_size", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Batch size per device during training (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--per_device_eval_batch_size", |
|
|
type=int, |
|
|
default=32, |
|
|
help="Batch size per device during evaluation (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--learning_rate", |
|
|
type=float, |
|
|
default=2e-5, |
|
|
help="Peak learning rate (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--weight_decay", |
|
|
type=float, |
|
|
default=0.01, |
|
|
help="Weight decay coefficient (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--warmup_ratio", |
|
|
type=float, |
|
|
default=0.1, |
|
|
help="Fraction of total steps used for linear warmup (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=42, |
|
|
help="Random seed for reproducibility (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--early_stopping_patience", |
|
|
type=int, |
|
|
default=3, |
|
|
help="Number of evaluations with no improvement before stopping (default: %(default)s).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp16", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Use mixed-precision (FP16) training.", |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument( |
|
|
"--push_to_hub", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Push the trained model to the HuggingFace Hub.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--hub_model_id", |
|
|
type=str, |
|
|
default="gr8monk3ys/paper-classifier-model", |
|
|
help="Repository id on the HuggingFace Hub (default: %(default)s).", |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def build_label_mappings(label_names: list[str]) -> tuple[dict, dict]: |
|
|
"""Return (label2id, id2label) dicts for the given label names.""" |
|
|
label2id = {label: idx for idx, label in enumerate(label_names)} |
|
|
id2label = {idx: label for idx, label in enumerate(label_names)} |
|
|
return label2id, id2label |
|
|
|
|
|
|
|
|
def load_and_prepare_dataset( |
|
|
dataset_name: str, |
|
|
label2id: dict[str, int], |
|
|
max_train_samples: int | None = None, |
|
|
max_eval_samples: int | None = None, |
|
|
) -> DatasetDict: |
|
|
"""Load the dataset and normalise the label column. |
|
|
|
|
|
The function handles two common dataset layouts: |
|
|
1. The dataset already has train / validation / test splits and a |
|
|
numeric ``label`` column whose values match our ``label2id``. |
|
|
2. The dataset has a string ``label`` column that needs mapping. |
|
|
|
|
|
Returns a ``DatasetDict`` with ``train`` and ``validation`` splits. |
|
|
""" |
|
|
logger.info("Loading dataset: %s", dataset_name) |
|
|
raw = load_dataset(dataset_name, trust_remote_code=True) |
|
|
|
|
|
|
|
|
sample_columns = list(next(iter(raw.values())).column_names) |
|
|
text_col = None |
|
|
for candidate in ("text", "abstract", "input", "sentence"): |
|
|
if candidate in sample_columns: |
|
|
text_col = candidate |
|
|
break |
|
|
if text_col is None: |
|
|
|
|
|
text_col = sample_columns[0] |
|
|
logger.info("Using text column: '%s'", text_col) |
|
|
|
|
|
label_col = None |
|
|
for candidate in ("label", "labels", "category", "class"): |
|
|
if candidate in sample_columns: |
|
|
label_col = candidate |
|
|
break |
|
|
if label_col is None: |
|
|
label_col = sample_columns[-1] |
|
|
logger.info("Using label column: '%s'", label_col) |
|
|
|
|
|
|
|
|
def _rename(example): |
|
|
return {"text": str(example[text_col]), "label": example[label_col]} |
|
|
|
|
|
raw = raw.map(_rename, remove_columns=sample_columns) |
|
|
|
|
|
|
|
|
sample_label = raw[list(raw.keys())[0]][0]["label"] |
|
|
if isinstance(sample_label, str): |
|
|
logger.info("Mapping string labels to integer ids.") |
|
|
|
|
|
def _map_label(example): |
|
|
lbl = example["label"] |
|
|
if lbl in label2id: |
|
|
example["label"] = label2id[lbl] |
|
|
else: |
|
|
example["label"] = -1 |
|
|
return example |
|
|
|
|
|
raw = raw.map(_map_label) |
|
|
raw = raw.filter(lambda ex: ex["label"] != -1) |
|
|
|
|
|
|
|
|
label_feature = ClassLabel( |
|
|
num_classes=len(label2id), names=list(label2id.keys()) |
|
|
) |
|
|
raw = raw.cast_column("label", label_feature) |
|
|
|
|
|
|
|
|
if "validation" not in raw and "test" in raw: |
|
|
raw["validation"] = raw.pop("test") |
|
|
elif "validation" not in raw: |
|
|
split = raw["train"].train_test_split(test_size=0.1, seed=42, stratify_by_column="label") |
|
|
raw = DatasetDict({"train": split["train"], "validation": split["test"]}) |
|
|
|
|
|
|
|
|
if max_train_samples is not None: |
|
|
raw["train"] = raw["train"].select(range(min(max_train_samples, len(raw["train"])))) |
|
|
if max_eval_samples is not None: |
|
|
raw["validation"] = raw["validation"].select( |
|
|
range(min(max_eval_samples, len(raw["validation"]))) |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
"Dataset sizes -> train: %d, validation: %d", |
|
|
len(raw["train"]), |
|
|
len(raw["validation"]), |
|
|
) |
|
|
return raw |
|
|
|
|
|
|
|
|
def tokenize_dataset( |
|
|
dataset: DatasetDict, |
|
|
tokenizer: AutoTokenizer, |
|
|
max_length: int, |
|
|
) -> DatasetDict: |
|
|
"""Tokenize the ``text`` column using the supplied tokenizer.""" |
|
|
|
|
|
def _tokenize(batch): |
|
|
return tokenizer( |
|
|
batch["text"], |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
) |
|
|
|
|
|
logger.info("Tokenizing dataset (max_length=%d) ...", max_length) |
|
|
tokenized = dataset.map(_tokenize, batched=True, desc="Tokenizing") |
|
|
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"]) |
|
|
return tokenized |
|
|
|
|
|
|
|
|
def build_compute_metrics_fn(): |
|
|
"""Return a ``compute_metrics`` callable for the HF Trainer. |
|
|
|
|
|
Loads the ``accuracy``, ``f1``, ``precision`` and ``recall`` evaluate |
|
|
metrics once at creation time to avoid repeated disk access. |
|
|
""" |
|
|
acc_metric = evaluate.load("accuracy") |
|
|
f1_metric = evaluate.load("f1") |
|
|
prec_metric = evaluate.load("precision") |
|
|
rec_metric = evaluate.load("recall") |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
logits, labels = eval_pred |
|
|
predictions = np.argmax(logits, axis=-1) |
|
|
results = {} |
|
|
results.update(acc_metric.compute(predictions=predictions, references=labels)) |
|
|
results.update( |
|
|
f1_metric.compute( |
|
|
predictions=predictions, references=labels, average="weighted" |
|
|
) |
|
|
) |
|
|
results.update( |
|
|
prec_metric.compute( |
|
|
predictions=predictions, references=labels, average="weighted" |
|
|
) |
|
|
) |
|
|
results.update( |
|
|
rec_metric.compute( |
|
|
predictions=predictions, references=labels, average="weighted" |
|
|
) |
|
|
) |
|
|
return results |
|
|
|
|
|
return compute_metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
|
|
|
set_seed(args.seed) |
|
|
logger.info("Seed set to %d", args.seed) |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") |
|
|
logger.info("Using device: %s", device) |
|
|
|
|
|
|
|
|
label2id, id2label = build_label_mappings(LABEL_NAMES) |
|
|
num_labels = len(LABEL_NAMES) |
|
|
logger.info("Number of labels: %d", num_labels) |
|
|
|
|
|
|
|
|
dataset = load_and_prepare_dataset( |
|
|
dataset_name=args.dataset_name, |
|
|
label2id=label2id, |
|
|
max_train_samples=args.max_train_samples, |
|
|
max_eval_samples=args.max_eval_samples, |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Loading tokenizer: %s", MODEL_NAME) |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
tokenized_dataset = tokenize_dataset(dataset, tokenizer, args.max_length) |
|
|
|
|
|
|
|
|
logger.info("Loading model: %s", MODEL_NAME) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_NAME, |
|
|
num_labels=num_labels, |
|
|
id2label=id2label, |
|
|
label2id=label2id, |
|
|
) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
num_train_epochs=args.num_train_epochs, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
learning_rate=args.learning_rate, |
|
|
weight_decay=args.weight_decay, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
lr_scheduler_type="linear", |
|
|
eval_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
logging_strategy="steps", |
|
|
logging_steps=50, |
|
|
save_total_limit=2, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="f1", |
|
|
greater_is_better=True, |
|
|
fp16=args.fp16 and torch.cuda.is_available(), |
|
|
report_to="none", |
|
|
seed=args.seed, |
|
|
push_to_hub=False, |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_dataset["train"], |
|
|
eval_dataset=tokenized_dataset["validation"], |
|
|
tokenizer=tokenizer, |
|
|
compute_metrics=build_compute_metrics_fn(), |
|
|
callbacks=[ |
|
|
EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience), |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Starting training ...") |
|
|
train_result = trainer.train() |
|
|
logger.info("Training complete.") |
|
|
|
|
|
|
|
|
metrics = train_result.metrics |
|
|
trainer.log_metrics("train", metrics) |
|
|
trainer.save_metrics("train", metrics) |
|
|
|
|
|
|
|
|
logger.info("Running final evaluation ...") |
|
|
eval_metrics = trainer.evaluate() |
|
|
trainer.log_metrics("eval", eval_metrics) |
|
|
trainer.save_metrics("eval", eval_metrics) |
|
|
|
|
|
|
|
|
model_dir = Path(args.model_dir) |
|
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
|
logger.info("Saving model to %s", model_dir) |
|
|
trainer.save_model(str(model_dir)) |
|
|
tokenizer.save_pretrained(str(model_dir)) |
|
|
|
|
|
|
|
|
if args.push_to_hub: |
|
|
logger.info("Pushing model to HuggingFace Hub: %s", args.hub_model_id) |
|
|
try: |
|
|
model.push_to_hub(args.hub_model_id) |
|
|
tokenizer.push_to_hub(args.hub_model_id) |
|
|
logger.info("Model pushed successfully.") |
|
|
except Exception: |
|
|
logger.exception("Failed to push model to Hub.") |
|
|
sys.exit(1) |
|
|
|
|
|
logger.info("All done.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|