jim-crow-test2323 / training_script.py
evalstate's picture
evalstate HF Staff
Add training script
0b83afd verified
# /// script
# dependencies = [
# "torch",
# "transformers>=4.51.0",
# "datasets>=3.0.0",
# "accelerate>=1.0.0",
# "scikit-learn>=1.4.0",
# "trackio>=0.25.0",
# "huggingface_hub>=0.30.0",
# ]
# ///
import os
from collections import Counter
import numpy as np
import torch
import trackio
from datasets import load_dataset
from huggingface_hub import HfApi
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
DATASET_ID = "biglam/on_the_books"
MODEL_ID = "distilbert-base-uncased"
HUB_MODEL_ID = "evalstate/jim-crow-test2323"
PROJECT = "jim-crow-law-classifier"
RUN_NAME = "distilbert-on-the-books"
MAX_LENGTH = 512
SEED = 42
set_seed(SEED)
if not os.environ.get("HF_TOKEN"):
raise RuntimeError("HF_TOKEN is required so the trained model can be pushed to the Hub.")
run = trackio.init(
project=PROJECT,
name=RUN_NAME,
config={
"dataset": DATASET_ID,
"base_model": MODEL_ID,
"hub_model_id": HUB_MODEL_ID,
"task": "binary sequence classification: Jim Crow law identification",
"max_length": MAX_LENGTH,
"seed": SEED,
},
private=False,
auto_log_gpu=True,
)
print(f"Trackio run: {run}")
raw = load_dataset(DATASET_ID, split="train")
label_names = raw.features["jim_crow"].names
id2label = {i: name for i, name in enumerate(label_names)}
label2id = {name: i for i, name in id2label.items()}
print(raw)
print("Label distribution:", Counter(raw["jim_crow"]))
# Stratified split because the dataset has only one split and a modest class imbalance.
splits = raw.train_test_split(test_size=0.2, seed=SEED, stratify_by_column="jim_crow")
train_ds = splits["train"]
eval_ds = splits["test"]
trackio.log({
"data/train_examples": len(train_ds),
"data/eval_examples": len(eval_ds),
"data/train_jim_crow": Counter(train_ds["jim_crow"])[1],
"data/train_no_jim_crow": Counter(train_ds["jim_crow"])[0],
})
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
def make_text(example):
chapter = example.get("chapter_text") or ""
section = example.get("section_text") or ""
meta = f"Source: {example.get('source','')}; Type: {example.get('type','')}; Chapter: {example.get('chapter_num','')}; Section: {example.get('section_num','')}"
return meta + "\n\nChapter text:\n" + chapter + "\n\nSection text:\n" + section
def preprocess(batch):
texts = []
for i in range(len(batch["section_text"])):
ex = {k: batch[k][i] for k in batch.keys()}
texts.append(make_text(ex))
enc = tokenizer(texts, truncation=True, max_length=MAX_LENGTH)
enc["labels"] = batch["jim_crow"]
return enc
remove_cols = raw.column_names
train_tok = train_ds.map(preprocess, batched=True, remove_columns=remove_cols)
eval_tok = eval_ds.map(preprocess, batched=True, remove_columns=remove_cols)
counts = Counter(train_ds["jim_crow"])
total = sum(counts.values())
class_weights = torch.tensor([total / (2 * counts[i]) for i in range(len(label_names))], dtype=torch.float)
print("Class weights:", class_weights.tolist())
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_ID,
num_labels=len(label_names),
id2label=id2label,
label2id=label2id,
)
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels")
outputs = model(**inputs)
weights = class_weights.to(outputs.logits.device)
loss_fct = torch.nn.CrossEntropyLoss(weight=weights)
loss = loss_fct(outputs.logits.view(-1, model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
class TrackioCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
trackio.log({f"trainer/{k}": v for k, v in logs.items() if isinstance(v, (int, float))}, step=state.global_step)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics:
trackio.log({f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, step=state.global_step)
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", pos_label=1, zero_division=0)
macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(labels, preds, average="macro", zero_division=0)
acc = accuracy_score(labels, preds)
cm = confusion_matrix(labels, preds, labels=[0, 1])
return {
"accuracy": acc,
"precision": precision,
"recall": recall,
"f1": f1,
"macro_precision": macro_precision,
"macro_recall": macro_recall,
"macro_f1": macro_f1,
"tn": int(cm[0, 0]),
"fp": int(cm[0, 1]),
"fn": int(cm[1, 0]),
"tp": int(cm[1, 1]),
}
args = TrainingArguments(
output_dir="jim-crow-test2323",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
gradient_accumulation_steps=1,
num_train_epochs=5,
weight_decay=0.01,
warmup_ratio=0.1,
lr_scheduler_type="linear",
eval_strategy="epoch",
save_strategy="epoch",
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
save_total_limit=2,
fp16=torch.cuda.is_available(),
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
hub_private_repo=False,
report_to=[],
run_name=RUN_NAME,
seed=SEED,
)
trainer = WeightedTrainer(
model=model,
args=args,
train_dataset=train_tok,
eval_dataset=eval_tok,
processing_class=tokenizer,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics,
callbacks=[TrackioCallback()],
)
trainer.train()
metrics = trainer.evaluate()
print("Final eval metrics:", metrics)
trackio.log({f"final/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))})
# Ensure useful metadata and a model card are present on the final Hub repo.
trainer.save_model()
tokenizer.save_pretrained(args.output_dir)
trainer.create_model_card(
model_name="Jim Crow law classifier",
dataset_tags=DATASET_ID,
finetuned_from=MODEL_ID,
tasks="text-classification",
language="en",
tags=["legal", "history", "jim-crow", "sequence-classification", "distilbert"],
)
trainer.push_to_hub(commit_message="Fine-tune DistilBERT to identify Jim Crow laws")
api = HfApi(token=os.environ["HF_TOKEN"])
api.upload_file(
path_or_fileobj=__file__,
path_in_repo="training_script.py",
repo_id=HUB_MODEL_ID,
repo_type="model",
commit_message="Add training script",
)
print(f"Pushed trained model to https://huggingface.co/{HUB_MODEL_ID}")