tomerz14's picture
Upload train.py
226b20d verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
train.py — Binary text classification with 🤗 Transformers
---------------------------------------------------------
- Reads a CSV with columns: text, label (0/1)
- Supports chunking long docs into 512-token windows with overlap
- Uses `evaluate` for metrics (accuracy, f1, roc_auc)
- MPS-safe (no accidental fp16)
- Optional push to Hugging Face Hub
Usage (local, no Hub):
python train.py --csv_path data.csv --model_name bert-base-uncased --output_dir bert-binclass
Usage (push to Hub at end):
python train.py --csv_path data.csv --push_to_hub --hub_model_id your-username/bert-binclass
Notes:
- Login once with: `huggingface-cli login` or set env HF_TOKEN.
- If your texts are long (300-800 words), consider chunking (default: enabled).
"""
import os
import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
import evaluate
def parse_args():
p = argparse.ArgumentParser(description="Train a binary text classifier with Hugging Face Transformers.")
# Data
p.add_argument("--csv_path", type=str, required=True, help="Path to CSV with 'text' and 'label' columns.")
p.add_argument("--test_size", type=float, default=0.2, help="Validation split size.")
p.add_argument("--seed", type=int, default=14, help="Random seed.")
# Model
p.add_argument("--model_name", type=str, default="bert-base-uncased", help="Base model checkpoint.")
p.add_argument("--num_labels", type=int, default=2, help="Number of labels (binary=2).")
# Tokenization / Chunking
p.add_argument("--use_chunking", action="store_true", default=True, help="Enable chunking of long docs (default True).")
p.add_argument("--no_chunking", dest="use_chunking", action="store_false", help="Disable chunking.")
p.add_argument("--max_length", type=int, default=512, help="Max tokens per chunk/sequence.")
p.add_argument("--stride", type=int, default=128, help="Overlap between chunks when chunking.")
# Training
p.add_argument("--output_dir", type=str, default="bert-binclass", help="Output directory.")
p.add_argument("--epochs", type=int, default=3, help="Number of training epochs.")
p.add_argument("--train_bs", type=int, default=16, help="Per-device train batch size.")
p.add_argument("--eval_bs", type=int, default=32, help="Per-device eval batch size.")
p.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate.")
p.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
p.add_argument("--lr_scheduler_type", type=str, default="cosine",
choices=["linear","cosine","cosine_with_restarts","polynomial","constant","constant_with_warmup","inverse_sqrt","reduce_lr_on_plateau"],
help="LR scheduler type.")
p.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio.")
p.add_argument("--early_stopping_patience", type=int, default=3,
help="Stop after this many evals without improvement.")
p.add_argument("--logging_steps", type=int, default=50, help="Logging steps.")
p.add_argument("--save_total_limit", type=int, default=2, help="Keep only last N checkpoints.")
# Hub
p.add_argument("--push_to_hub", action="store_true", help="Push final model to Hugging Face Hub.")
p.add_argument("--hub_model_id", type=str, default=None, help="Repository name on Hub (e.g., user/model).")
p.add_argument("--hub_private_repo", action="store_true", help="Create a private repo on the Hub.")
# Misc
p.add_argument("--report_to", type=str, default="tensorboard", help="Logging backend: 'none', 'tensorboard', 'wandb', 'mlflow'.")
return p.parse_args()
def set_env_for_mps():
# Avoid accidental mixed precision on Apple MPS
os.environ["ACCELERATE_MIXED_PRECISION"] = "no"
os.environ["ACCELERATE_USE_MPS_DEVICE"] = "true"
def load_and_split(csv_path: str, seed: int, test_size: float):
raw = load_dataset("csv", data_files={"full": csv_path})
full = raw["full"]
# Ensure labels are ints
def _coerce_label(ex):
ex["label"] = int(ex["label"])
return ex
full = full.map(_coerce_label)
# Class-encode to enable stratified split
full = full.class_encode_column("label")
splits = full.train_test_split(test_size=test_size, seed=seed, stratify_by_column="label")
return splits["train"], splits["test"]
def build_tokenized_datasets(ds_train, ds_val, tokenizer, use_chunking=True, max_length=512, stride=128):
if use_chunking:
def tokenize_with_overflow(batch):
enc = tokenizer(
batch["text"],
truncation=True,
padding=False,
max_length=max_length,
return_overflowing_tokens=True,
stride=stride,
)
mapping = enc.pop("overflow_to_sample_mapping")
enc["label"] = [int(batch["label"][i]) for i in mapping]
enc["doc_id"] = [int(i) for i in mapping]
return enc
ds_train = ds_train.map(tokenize_with_overflow, batched=True, remove_columns=ds_train.column_names)
ds_val = ds_val.map(tokenize_with_overflow, batched=True, remove_columns=ds_val.column_names)
else:
def tokenize_simple(batch):
return tokenizer(
batch["text"],
truncation=True,
padding=False,
max_length=max_length,
)
ds_train = ds_train.map(tokenize_simple, batched=True)
ds_val = ds_val.map(tokenize_simple, batched=True)
# Keep only inputs expected by the model (+ optional doc_id)
keep_cols = [c for c in ["input_ids","attention_mask","token_type_ids","label","doc_id"]
if c in ds_train.column_names]
ds_train = ds_train.remove_columns([c for c in ds_train.column_names if c not in keep_cols])
ds_val = ds_val.remove_columns([c for c in ds_val.column_names if c not in keep_cols])
return ds_train, ds_val
def build_metrics():
acc = evaluate.load("accuracy")
f1 = evaluate.load("f1")
auc = evaluate.load("roc_auc")
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
pos = probs[:, 1] if probs.shape[1] > 1 else probs[:, 0]
res = {}
res.update(acc.compute(predictions=preds, references=labels))
res.update(f1.compute(predictions=preds, references=labels, average="binary"))
try:
res.update(auc.compute(prediction_scores=pos, references=labels))
except ValueError:
res["roc_auc"] = float("nan")
return res
return compute_metrics
def main():
args = parse_args()
set_env_for_mps()
# Seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Tokenizer & data
tok = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
ds_train, ds_val = load_and_split(args.csv_path, args.seed, args.test_size)
ds_train, ds_val = build_tokenized_datasets(
ds_train, ds_val, tok,
use_chunking=args.use_chunking,
max_length=args.max_length,
stride=args.stride,
)
# Model
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=args.num_labels,
torch_dtype=torch.float32, # explicit to avoid half on MPS
)
# TrainingArguments
hf_hub_kwargs = {}
if args.push_to_hub:
hf_hub_kwargs.update(dict(
push_to_hub=True,
hub_model_id=args.hub_model_id,
hub_private_repo=args.hub_private_repo,
))
else:
hf_hub_kwargs.update(dict(push_to_hub=False))
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.train_bs,
per_device_eval_batch_size=args.eval_bs,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
eval_strategy="steps",
save_strategy="steps",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
logging_strategy="steps",
logging_steps=args.logging_steps,
report_to=args.report_to,
fp16=False, bf16=False, fp16_full_eval=False,
seed=args.seed,
save_total_limit=args.save_total_limit,
**hf_hub_kwargs
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=ds_train,
eval_dataset=ds_val,
processing_class=tok,
data_collator=DataCollatorWithPadding(tokenizer=tok),
compute_metrics=build_metrics(),
callbacks=[EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience)],
)
# Train
train_out = trainer.train()
print(train_out)
# Evaluate (chunk-level)
eval_out = trainer.evaluate()
print("Eval:", eval_out)
# Save locally
trainer.save_model(args.output_dir)
tok.save_pretrained(args.output_dir)
print(f"Saved model & tokenizer to: {args.output_dir}")
# Optional: push to hub (if push_to_hub=True)
if args.push_to_hub:
print("Pushing to the Hugging Face Hub...")
# If TrainingArguments.push_to_hub=True this will also run automatically at end;
# we call explicitly to be clear.
trainer.push_to_hub(commit_message="Add trained binary classifier")
print("Pushed to Hub:", args.hub_model_id or "(auto repo)")
else:
print("Hub push disabled. To enable, pass --push_to_hub and --hub_model_id.")
if __name__ == "__main__":
main()