|
|
|
|
|
|
|
|
""" |
|
|
train.py — Binary text classification with 🤗 Transformers |
|
|
--------------------------------------------------------- |
|
|
- Reads a CSV with columns: text, label (0/1) |
|
|
- Supports chunking long docs into 512-token windows with overlap |
|
|
- Uses `evaluate` for metrics (accuracy, f1, roc_auc) |
|
|
- MPS-safe (no accidental fp16) |
|
|
- Optional push to Hugging Face Hub |
|
|
|
|
|
Usage (local, no Hub): |
|
|
python train.py --csv_path data.csv --model_name bert-base-uncased --output_dir bert-binclass |
|
|
|
|
|
Usage (push to Hub at end): |
|
|
python train.py --csv_path data.csv --push_to_hub --hub_model_id your-username/bert-binclass |
|
|
|
|
|
Notes: |
|
|
- Login once with: `huggingface-cli login` or set env HF_TOKEN. |
|
|
- If your texts are long (300-800 words), consider chunking (default: enabled). |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
DataCollatorWithPadding, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
EarlyStoppingCallback |
|
|
) |
|
|
import evaluate |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
p = argparse.ArgumentParser(description="Train a binary text classifier with Hugging Face Transformers.") |
|
|
|
|
|
p.add_argument("--csv_path", type=str, required=True, help="Path to CSV with 'text' and 'label' columns.") |
|
|
p.add_argument("--test_size", type=float, default=0.2, help="Validation split size.") |
|
|
p.add_argument("--seed", type=int, default=14, help="Random seed.") |
|
|
|
|
|
p.add_argument("--model_name", type=str, default="bert-base-uncased", help="Base model checkpoint.") |
|
|
p.add_argument("--num_labels", type=int, default=2, help="Number of labels (binary=2).") |
|
|
|
|
|
p.add_argument("--use_chunking", action="store_true", default=True, help="Enable chunking of long docs (default True).") |
|
|
p.add_argument("--no_chunking", dest="use_chunking", action="store_false", help="Disable chunking.") |
|
|
p.add_argument("--max_length", type=int, default=512, help="Max tokens per chunk/sequence.") |
|
|
p.add_argument("--stride", type=int, default=128, help="Overlap between chunks when chunking.") |
|
|
|
|
|
p.add_argument("--output_dir", type=str, default="bert-binclass", help="Output directory.") |
|
|
p.add_argument("--epochs", type=int, default=3, help="Number of training epochs.") |
|
|
p.add_argument("--train_bs", type=int, default=16, help="Per-device train batch size.") |
|
|
p.add_argument("--eval_bs", type=int, default=32, help="Per-device eval batch size.") |
|
|
p.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate.") |
|
|
p.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.") |
|
|
p.add_argument("--lr_scheduler_type", type=str, default="cosine", |
|
|
choices=["linear","cosine","cosine_with_restarts","polynomial","constant","constant_with_warmup","inverse_sqrt","reduce_lr_on_plateau"], |
|
|
help="LR scheduler type.") |
|
|
p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio.") |
|
|
p.add_argument("--early_stopping_patience", type=int, default=3, |
|
|
help="Stop after this many evals without improvement.") |
|
|
p.add_argument("--logging_steps", type=int, default=50, help="Logging steps.") |
|
|
p.add_argument("--save_total_limit", type=int, default=2, help="Keep only last N checkpoints.") |
|
|
|
|
|
p.add_argument("--push_to_hub", action="store_true", help="Push final model to Hugging Face Hub.") |
|
|
p.add_argument("--hub_model_id", type=str, default=None, help="Repository name on Hub (e.g., user/model).") |
|
|
p.add_argument("--hub_private_repo", action="store_true", help="Create a private repo on the Hub.") |
|
|
|
|
|
p.add_argument("--report_to", type=str, default="tensorboard", help="Logging backend: 'none', 'tensorboard', 'wandb', 'mlflow'.") |
|
|
return p.parse_args() |
|
|
|
|
|
|
|
|
def set_env_for_mps(): |
|
|
|
|
|
os.environ["ACCELERATE_MIXED_PRECISION"] = "no" |
|
|
os.environ["ACCELERATE_USE_MPS_DEVICE"] = "true" |
|
|
|
|
|
|
|
|
def load_and_split(csv_path: str, seed: int, test_size: float): |
|
|
raw = load_dataset("csv", data_files={"full": csv_path}) |
|
|
full = raw["full"] |
|
|
|
|
|
|
|
|
def _coerce_label(ex): |
|
|
ex["label"] = int(ex["label"]) |
|
|
return ex |
|
|
full = full.map(_coerce_label) |
|
|
|
|
|
|
|
|
full = full.class_encode_column("label") |
|
|
splits = full.train_test_split(test_size=test_size, seed=seed, stratify_by_column="label") |
|
|
return splits["train"], splits["test"] |
|
|
|
|
|
|
|
|
def build_tokenized_datasets(ds_train, ds_val, tokenizer, use_chunking=True, max_length=512, stride=128): |
|
|
if use_chunking: |
|
|
def tokenize_with_overflow(batch): |
|
|
enc = tokenizer( |
|
|
batch["text"], |
|
|
truncation=True, |
|
|
padding=False, |
|
|
max_length=max_length, |
|
|
return_overflowing_tokens=True, |
|
|
stride=stride, |
|
|
) |
|
|
mapping = enc.pop("overflow_to_sample_mapping") |
|
|
enc["label"] = [int(batch["label"][i]) for i in mapping] |
|
|
enc["doc_id"] = [int(i) for i in mapping] |
|
|
return enc |
|
|
|
|
|
ds_train = ds_train.map(tokenize_with_overflow, batched=True, remove_columns=ds_train.column_names) |
|
|
ds_val = ds_val.map(tokenize_with_overflow, batched=True, remove_columns=ds_val.column_names) |
|
|
else: |
|
|
def tokenize_simple(batch): |
|
|
return tokenizer( |
|
|
batch["text"], |
|
|
truncation=True, |
|
|
padding=False, |
|
|
max_length=max_length, |
|
|
) |
|
|
ds_train = ds_train.map(tokenize_simple, batched=True) |
|
|
ds_val = ds_val.map(tokenize_simple, batched=True) |
|
|
|
|
|
|
|
|
keep_cols = [c for c in ["input_ids","attention_mask","token_type_ids","label","doc_id"] |
|
|
if c in ds_train.column_names] |
|
|
ds_train = ds_train.remove_columns([c for c in ds_train.column_names if c not in keep_cols]) |
|
|
ds_val = ds_val.remove_columns([c for c in ds_val.column_names if c not in keep_cols]) |
|
|
return ds_train, ds_val |
|
|
|
|
|
|
|
|
def build_metrics(): |
|
|
acc = evaluate.load("accuracy") |
|
|
f1 = evaluate.load("f1") |
|
|
auc = evaluate.load("roc_auc") |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
logits, labels = eval_pred |
|
|
preds = np.argmax(logits, axis=-1) |
|
|
probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True) |
|
|
pos = probs[:, 1] if probs.shape[1] > 1 else probs[:, 0] |
|
|
|
|
|
res = {} |
|
|
res.update(acc.compute(predictions=preds, references=labels)) |
|
|
res.update(f1.compute(predictions=preds, references=labels, average="binary")) |
|
|
try: |
|
|
res.update(auc.compute(prediction_scores=pos, references=labels)) |
|
|
except ValueError: |
|
|
res["roc_auc"] = float("nan") |
|
|
return res |
|
|
|
|
|
return compute_metrics |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
set_env_for_mps() |
|
|
|
|
|
|
|
|
np.random.seed(args.seed) |
|
|
torch.manual_seed(args.seed) |
|
|
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True) |
|
|
ds_train, ds_val = load_and_split(args.csv_path, args.seed, args.test_size) |
|
|
ds_train, ds_val = build_tokenized_datasets( |
|
|
ds_train, ds_val, tok, |
|
|
use_chunking=args.use_chunking, |
|
|
max_length=args.max_length, |
|
|
stride=args.stride, |
|
|
) |
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
args.model_name, |
|
|
num_labels=args.num_labels, |
|
|
torch_dtype=torch.float32, |
|
|
) |
|
|
|
|
|
|
|
|
hf_hub_kwargs = {} |
|
|
if args.push_to_hub: |
|
|
hf_hub_kwargs.update(dict( |
|
|
push_to_hub=True, |
|
|
hub_model_id=args.hub_model_id, |
|
|
hub_private_repo=args.hub_private_repo, |
|
|
)) |
|
|
else: |
|
|
hf_hub_kwargs.update(dict(push_to_hub=False)) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
num_train_epochs=args.epochs, |
|
|
per_device_train_batch_size=args.train_bs, |
|
|
per_device_eval_batch_size=args.eval_bs, |
|
|
learning_rate=args.learning_rate, |
|
|
weight_decay=args.weight_decay, |
|
|
lr_scheduler_type=args.lr_scheduler_type, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
eval_strategy="steps", |
|
|
save_strategy="steps", |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="f1", |
|
|
greater_is_better=True, |
|
|
logging_strategy="steps", |
|
|
logging_steps=args.logging_steps, |
|
|
report_to=args.report_to, |
|
|
fp16=False, bf16=False, fp16_full_eval=False, |
|
|
seed=args.seed, |
|
|
save_total_limit=args.save_total_limit, |
|
|
**hf_hub_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=ds_train, |
|
|
eval_dataset=ds_val, |
|
|
processing_class=tok, |
|
|
data_collator=DataCollatorWithPadding(tokenizer=tok), |
|
|
compute_metrics=build_metrics(), |
|
|
callbacks=[EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience)], |
|
|
) |
|
|
|
|
|
|
|
|
train_out = trainer.train() |
|
|
print(train_out) |
|
|
|
|
|
|
|
|
eval_out = trainer.evaluate() |
|
|
print("Eval:", eval_out) |
|
|
|
|
|
|
|
|
trainer.save_model(args.output_dir) |
|
|
tok.save_pretrained(args.output_dir) |
|
|
print(f"Saved model & tokenizer to: {args.output_dir}") |
|
|
|
|
|
|
|
|
if args.push_to_hub: |
|
|
print("Pushing to the Hugging Face Hub...") |
|
|
|
|
|
|
|
|
trainer.push_to_hub(commit_message="Add trained binary classifier") |
|
|
print("Pushed to Hub:", args.hub_model_id or "(auto repo)") |
|
|
else: |
|
|
print("Hub push disabled. To enable, pass --push_to_hub and --hub_model_id.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|