# /// script # dependencies = [ # "torch", # "transformers>=4.51.0", # "datasets>=3.0.0", # "accelerate>=1.0.0", # "scikit-learn>=1.4.0", # "trackio>=0.25.0", # "huggingface_hub>=0.30.0", # ] # /// import os from collections import Counter import numpy as np import torch import trackio from datasets import load_dataset from huggingface_hub import HfApi from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainerCallback, TrainingArguments, set_seed, ) DATASET_ID = "biglam/on_the_books" MODEL_ID = "distilbert-base-uncased" HUB_MODEL_ID = "evalstate/jim-crow-test2323" PROJECT = "jim-crow-law-classifier" RUN_NAME = "distilbert-on-the-books" MAX_LENGTH = 512 SEED = 42 set_seed(SEED) if not os.environ.get("HF_TOKEN"): raise RuntimeError("HF_TOKEN is required so the trained model can be pushed to the Hub.") run = trackio.init( project=PROJECT, name=RUN_NAME, config={ "dataset": DATASET_ID, "base_model": MODEL_ID, "hub_model_id": HUB_MODEL_ID, "task": "binary sequence classification: Jim Crow law identification", "max_length": MAX_LENGTH, "seed": SEED, }, private=False, auto_log_gpu=True, ) print(f"Trackio run: {run}") raw = load_dataset(DATASET_ID, split="train") label_names = raw.features["jim_crow"].names id2label = {i: name for i, name in enumerate(label_names)} label2id = {name: i for i, name in id2label.items()} print(raw) print("Label distribution:", Counter(raw["jim_crow"])) # Stratified split because the dataset has only one split and a modest class imbalance. splits = raw.train_test_split(test_size=0.2, seed=SEED, stratify_by_column="jim_crow") train_ds = splits["train"] eval_ds = splits["test"] trackio.log({ "data/train_examples": len(train_ds), "data/eval_examples": len(eval_ds), "data/train_jim_crow": Counter(train_ds["jim_crow"])[1], "data/train_no_jim_crow": Counter(train_ds["jim_crow"])[0], }) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) def make_text(example): chapter = example.get("chapter_text") or "" section = example.get("section_text") or "" meta = f"Source: {example.get('source','')}; Type: {example.get('type','')}; Chapter: {example.get('chapter_num','')}; Section: {example.get('section_num','')}" return meta + "\n\nChapter text:\n" + chapter + "\n\nSection text:\n" + section def preprocess(batch): texts = [] for i in range(len(batch["section_text"])): ex = {k: batch[k][i] for k in batch.keys()} texts.append(make_text(ex)) enc = tokenizer(texts, truncation=True, max_length=MAX_LENGTH) enc["labels"] = batch["jim_crow"] return enc remove_cols = raw.column_names train_tok = train_ds.map(preprocess, batched=True, remove_columns=remove_cols) eval_tok = eval_ds.map(preprocess, batched=True, remove_columns=remove_cols) counts = Counter(train_ds["jim_crow"]) total = sum(counts.values()) class_weights = torch.tensor([total / (2 * counts[i]) for i in range(len(label_names))], dtype=torch.float) print("Class weights:", class_weights.tolist()) model = AutoModelForSequenceClassification.from_pretrained( MODEL_ID, num_labels=len(label_names), id2label=id2label, label2id=label2id, ) class WeightedTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, **kwargs): labels = inputs.pop("labels") outputs = model(**inputs) weights = class_weights.to(outputs.logits.device) loss_fct = torch.nn.CrossEntropyLoss(weight=weights) loss = loss_fct(outputs.logits.view(-1, model.config.num_labels), labels.view(-1)) return (loss, outputs) if return_outputs else loss class TrackioCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): if logs: trackio.log({f"trainer/{k}": v for k, v in logs.items() if isinstance(v, (int, float))}, step=state.global_step) def on_evaluate(self, args, state, control, metrics=None, **kwargs): if metrics: trackio.log({f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, step=state.global_step) def compute_metrics(eval_pred): logits, labels = eval_pred preds = np.argmax(logits, axis=-1) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", pos_label=1, zero_division=0) macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(labels, preds, average="macro", zero_division=0) acc = accuracy_score(labels, preds) cm = confusion_matrix(labels, preds, labels=[0, 1]) return { "accuracy": acc, "precision": precision, "recall": recall, "f1": f1, "macro_precision": macro_precision, "macro_recall": macro_recall, "macro_f1": macro_f1, "tn": int(cm[0, 0]), "fp": int(cm[0, 1]), "fn": int(cm[1, 0]), "tp": int(cm[1, 1]), } args = TrainingArguments( output_dir="jim-crow-test2323", learning_rate=2e-5, per_device_train_batch_size=16, per_device_eval_batch_size=32, gradient_accumulation_steps=1, num_train_epochs=5, weight_decay=0.01, warmup_ratio=0.1, lr_scheduler_type="linear", eval_strategy="epoch", save_strategy="epoch", logging_steps=10, load_best_model_at_end=True, metric_for_best_model="f1", greater_is_better=True, save_total_limit=2, fp16=torch.cuda.is_available(), push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_private_repo=False, report_to=[], run_name=RUN_NAME, seed=SEED, ) trainer = WeightedTrainer( model=model, args=args, train_dataset=train_tok, eval_dataset=eval_tok, processing_class=tokenizer, data_collator=DataCollatorWithPadding(tokenizer), compute_metrics=compute_metrics, callbacks=[TrackioCallback()], ) trainer.train() metrics = trainer.evaluate() print("Final eval metrics:", metrics) trackio.log({f"final/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}) # Ensure useful metadata and a model card are present on the final Hub repo. trainer.save_model() tokenizer.save_pretrained(args.output_dir) trainer.create_model_card( model_name="Jim Crow law classifier", dataset_tags=DATASET_ID, finetuned_from=MODEL_ID, tasks="text-classification", language="en", tags=["legal", "history", "jim-crow", "sequence-classification", "distilbert"], ) trainer.push_to_hub(commit_message="Fine-tune DistilBERT to identify Jim Crow laws") api = HfApi(token=os.environ["HF_TOKEN"]) api.upload_file( path_or_fileobj=__file__, path_in_repo="training_script.py", repo_id=HUB_MODEL_ID, repo_type="model", commit_message="Add training script", ) print(f"Pushed trained model to https://huggingface.co/{HUB_MODEL_ID}")