| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from collections import Counter |
|
|
| import numpy as np |
| import torch |
| import trackio |
| from datasets import load_dataset |
| from huggingface_hub import HfApi |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix |
| from transformers import ( |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| DataCollatorWithPadding, |
| Trainer, |
| TrainerCallback, |
| TrainingArguments, |
| set_seed, |
| ) |
|
|
| DATASET_ID = "biglam/on_the_books" |
| MODEL_ID = "distilbert-base-uncased" |
| HUB_MODEL_ID = "evalstate/jim-crow-test2323" |
| PROJECT = "jim-crow-law-classifier" |
| RUN_NAME = "distilbert-on-the-books" |
| MAX_LENGTH = 512 |
| SEED = 42 |
|
|
| set_seed(SEED) |
|
|
| if not os.environ.get("HF_TOKEN"): |
| raise RuntimeError("HF_TOKEN is required so the trained model can be pushed to the Hub.") |
|
|
| run = trackio.init( |
| project=PROJECT, |
| name=RUN_NAME, |
| config={ |
| "dataset": DATASET_ID, |
| "base_model": MODEL_ID, |
| "hub_model_id": HUB_MODEL_ID, |
| "task": "binary sequence classification: Jim Crow law identification", |
| "max_length": MAX_LENGTH, |
| "seed": SEED, |
| }, |
| private=False, |
| auto_log_gpu=True, |
| ) |
| print(f"Trackio run: {run}") |
|
|
| raw = load_dataset(DATASET_ID, split="train") |
| label_names = raw.features["jim_crow"].names |
| id2label = {i: name for i, name in enumerate(label_names)} |
| label2id = {name: i for i, name in id2label.items()} |
| print(raw) |
| print("Label distribution:", Counter(raw["jim_crow"])) |
|
|
| |
| splits = raw.train_test_split(test_size=0.2, seed=SEED, stratify_by_column="jim_crow") |
| train_ds = splits["train"] |
| eval_ds = splits["test"] |
|
|
| trackio.log({ |
| "data/train_examples": len(train_ds), |
| "data/eval_examples": len(eval_ds), |
| "data/train_jim_crow": Counter(train_ds["jim_crow"])[1], |
| "data/train_no_jim_crow": Counter(train_ds["jim_crow"])[0], |
| }) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
| def make_text(example): |
| chapter = example.get("chapter_text") or "" |
| section = example.get("section_text") or "" |
| meta = f"Source: {example.get('source','')}; Type: {example.get('type','')}; Chapter: {example.get('chapter_num','')}; Section: {example.get('section_num','')}" |
| return meta + "\n\nChapter text:\n" + chapter + "\n\nSection text:\n" + section |
|
|
| def preprocess(batch): |
| texts = [] |
| for i in range(len(batch["section_text"])): |
| ex = {k: batch[k][i] for k in batch.keys()} |
| texts.append(make_text(ex)) |
| enc = tokenizer(texts, truncation=True, max_length=MAX_LENGTH) |
| enc["labels"] = batch["jim_crow"] |
| return enc |
|
|
| remove_cols = raw.column_names |
| train_tok = train_ds.map(preprocess, batched=True, remove_columns=remove_cols) |
| eval_tok = eval_ds.map(preprocess, batched=True, remove_columns=remove_cols) |
|
|
| counts = Counter(train_ds["jim_crow"]) |
| total = sum(counts.values()) |
| class_weights = torch.tensor([total / (2 * counts[i]) for i in range(len(label_names))], dtype=torch.float) |
| print("Class weights:", class_weights.tolist()) |
|
|
| model = AutoModelForSequenceClassification.from_pretrained( |
| MODEL_ID, |
| num_labels=len(label_names), |
| id2label=id2label, |
| label2id=label2id, |
| ) |
|
|
| class WeightedTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False, **kwargs): |
| labels = inputs.pop("labels") |
| outputs = model(**inputs) |
| weights = class_weights.to(outputs.logits.device) |
| loss_fct = torch.nn.CrossEntropyLoss(weight=weights) |
| loss = loss_fct(outputs.logits.view(-1, model.config.num_labels), labels.view(-1)) |
| return (loss, outputs) if return_outputs else loss |
|
|
| class TrackioCallback(TrainerCallback): |
| def on_log(self, args, state, control, logs=None, **kwargs): |
| if logs: |
| trackio.log({f"trainer/{k}": v for k, v in logs.items() if isinstance(v, (int, float))}, step=state.global_step) |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): |
| if metrics: |
| trackio.log({f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, step=state.global_step) |
|
|
| def compute_metrics(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", pos_label=1, zero_division=0) |
| macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(labels, preds, average="macro", zero_division=0) |
| acc = accuracy_score(labels, preds) |
| cm = confusion_matrix(labels, preds, labels=[0, 1]) |
| return { |
| "accuracy": acc, |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "macro_precision": macro_precision, |
| "macro_recall": macro_recall, |
| "macro_f1": macro_f1, |
| "tn": int(cm[0, 0]), |
| "fp": int(cm[0, 1]), |
| "fn": int(cm[1, 0]), |
| "tp": int(cm[1, 1]), |
| } |
|
|
| args = TrainingArguments( |
| output_dir="jim-crow-test2323", |
| learning_rate=2e-5, |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=32, |
| gradient_accumulation_steps=1, |
| num_train_epochs=5, |
| weight_decay=0.01, |
| warmup_ratio=0.1, |
| lr_scheduler_type="linear", |
| eval_strategy="epoch", |
| save_strategy="epoch", |
| logging_steps=10, |
| load_best_model_at_end=True, |
| metric_for_best_model="f1", |
| greater_is_better=True, |
| save_total_limit=2, |
| fp16=torch.cuda.is_available(), |
| push_to_hub=True, |
| hub_model_id=HUB_MODEL_ID, |
| hub_private_repo=False, |
| report_to=[], |
| run_name=RUN_NAME, |
| seed=SEED, |
| ) |
|
|
| trainer = WeightedTrainer( |
| model=model, |
| args=args, |
| train_dataset=train_tok, |
| eval_dataset=eval_tok, |
| processing_class=tokenizer, |
| data_collator=DataCollatorWithPadding(tokenizer), |
| compute_metrics=compute_metrics, |
| callbacks=[TrackioCallback()], |
| ) |
|
|
| trainer.train() |
| metrics = trainer.evaluate() |
| print("Final eval metrics:", metrics) |
| trackio.log({f"final/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}) |
|
|
| |
| trainer.save_model() |
| tokenizer.save_pretrained(args.output_dir) |
| trainer.create_model_card( |
| model_name="Jim Crow law classifier", |
| dataset_tags=DATASET_ID, |
| finetuned_from=MODEL_ID, |
| tasks="text-classification", |
| language="en", |
| tags=["legal", "history", "jim-crow", "sequence-classification", "distilbert"], |
| ) |
| trainer.push_to_hub(commit_message="Fine-tune DistilBERT to identify Jim Crow laws") |
|
|
| api = HfApi(token=os.environ["HF_TOKEN"]) |
| api.upload_file( |
| path_or_fileobj=__file__, |
| path_in_repo="training_script.py", |
| repo_id=HUB_MODEL_ID, |
| repo_type="model", |
| commit_message="Add training script", |
| ) |
| print(f"Pushed trained model to https://huggingface.co/{HUB_MODEL_ID}") |
|
|