evalstate HF Staff commited on
Commit
0b83afd
·
verified ·
1 Parent(s): 0f6274e

Add training script

Browse files
Files changed (1) hide show
  1. training_script.py +214 -0
training_script.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "transformers>=4.51.0",
5
+ # "datasets>=3.0.0",
6
+ # "accelerate>=1.0.0",
7
+ # "scikit-learn>=1.4.0",
8
+ # "trackio>=0.25.0",
9
+ # "huggingface_hub>=0.30.0",
10
+ # ]
11
+ # ///
12
+
13
+ import os
14
+ from collections import Counter
15
+
16
+ import numpy as np
17
+ import torch
18
+ import trackio
19
+ from datasets import load_dataset
20
+ from huggingface_hub import HfApi
21
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
22
+ from transformers import (
23
+ AutoModelForSequenceClassification,
24
+ AutoTokenizer,
25
+ DataCollatorWithPadding,
26
+ Trainer,
27
+ TrainerCallback,
28
+ TrainingArguments,
29
+ set_seed,
30
+ )
31
+
32
+ DATASET_ID = "biglam/on_the_books"
33
+ MODEL_ID = "distilbert-base-uncased"
34
+ HUB_MODEL_ID = "evalstate/jim-crow-test2323"
35
+ PROJECT = "jim-crow-law-classifier"
36
+ RUN_NAME = "distilbert-on-the-books"
37
+ MAX_LENGTH = 512
38
+ SEED = 42
39
+
40
+ set_seed(SEED)
41
+
42
+ if not os.environ.get("HF_TOKEN"):
43
+ raise RuntimeError("HF_TOKEN is required so the trained model can be pushed to the Hub.")
44
+
45
+ run = trackio.init(
46
+ project=PROJECT,
47
+ name=RUN_NAME,
48
+ config={
49
+ "dataset": DATASET_ID,
50
+ "base_model": MODEL_ID,
51
+ "hub_model_id": HUB_MODEL_ID,
52
+ "task": "binary sequence classification: Jim Crow law identification",
53
+ "max_length": MAX_LENGTH,
54
+ "seed": SEED,
55
+ },
56
+ private=False,
57
+ auto_log_gpu=True,
58
+ )
59
+ print(f"Trackio run: {run}")
60
+
61
+ raw = load_dataset(DATASET_ID, split="train")
62
+ label_names = raw.features["jim_crow"].names
63
+ id2label = {i: name for i, name in enumerate(label_names)}
64
+ label2id = {name: i for i, name in id2label.items()}
65
+ print(raw)
66
+ print("Label distribution:", Counter(raw["jim_crow"]))
67
+
68
+ # Stratified split because the dataset has only one split and a modest class imbalance.
69
+ splits = raw.train_test_split(test_size=0.2, seed=SEED, stratify_by_column="jim_crow")
70
+ train_ds = splits["train"]
71
+ eval_ds = splits["test"]
72
+
73
+ trackio.log({
74
+ "data/train_examples": len(train_ds),
75
+ "data/eval_examples": len(eval_ds),
76
+ "data/train_jim_crow": Counter(train_ds["jim_crow"])[1],
77
+ "data/train_no_jim_crow": Counter(train_ds["jim_crow"])[0],
78
+ })
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
81
+
82
+ def make_text(example):
83
+ chapter = example.get("chapter_text") or ""
84
+ section = example.get("section_text") or ""
85
+ meta = f"Source: {example.get('source','')}; Type: {example.get('type','')}; Chapter: {example.get('chapter_num','')}; Section: {example.get('section_num','')}"
86
+ return meta + "\n\nChapter text:\n" + chapter + "\n\nSection text:\n" + section
87
+
88
+ def preprocess(batch):
89
+ texts = []
90
+ for i in range(len(batch["section_text"])):
91
+ ex = {k: batch[k][i] for k in batch.keys()}
92
+ texts.append(make_text(ex))
93
+ enc = tokenizer(texts, truncation=True, max_length=MAX_LENGTH)
94
+ enc["labels"] = batch["jim_crow"]
95
+ return enc
96
+
97
+ remove_cols = raw.column_names
98
+ train_tok = train_ds.map(preprocess, batched=True, remove_columns=remove_cols)
99
+ eval_tok = eval_ds.map(preprocess, batched=True, remove_columns=remove_cols)
100
+
101
+ counts = Counter(train_ds["jim_crow"])
102
+ total = sum(counts.values())
103
+ class_weights = torch.tensor([total / (2 * counts[i]) for i in range(len(label_names))], dtype=torch.float)
104
+ print("Class weights:", class_weights.tolist())
105
+
106
+ model = AutoModelForSequenceClassification.from_pretrained(
107
+ MODEL_ID,
108
+ num_labels=len(label_names),
109
+ id2label=id2label,
110
+ label2id=label2id,
111
+ )
112
+
113
+ class WeightedTrainer(Trainer):
114
+ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
115
+ labels = inputs.pop("labels")
116
+ outputs = model(**inputs)
117
+ weights = class_weights.to(outputs.logits.device)
118
+ loss_fct = torch.nn.CrossEntropyLoss(weight=weights)
119
+ loss = loss_fct(outputs.logits.view(-1, model.config.num_labels), labels.view(-1))
120
+ return (loss, outputs) if return_outputs else loss
121
+
122
+ class TrackioCallback(TrainerCallback):
123
+ def on_log(self, args, state, control, logs=None, **kwargs):
124
+ if logs:
125
+ trackio.log({f"trainer/{k}": v for k, v in logs.items() if isinstance(v, (int, float))}, step=state.global_step)
126
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
127
+ if metrics:
128
+ trackio.log({f"eval/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))}, step=state.global_step)
129
+
130
+ def compute_metrics(eval_pred):
131
+ logits, labels = eval_pred
132
+ preds = np.argmax(logits, axis=-1)
133
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", pos_label=1, zero_division=0)
134
+ macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(labels, preds, average="macro", zero_division=0)
135
+ acc = accuracy_score(labels, preds)
136
+ cm = confusion_matrix(labels, preds, labels=[0, 1])
137
+ return {
138
+ "accuracy": acc,
139
+ "precision": precision,
140
+ "recall": recall,
141
+ "f1": f1,
142
+ "macro_precision": macro_precision,
143
+ "macro_recall": macro_recall,
144
+ "macro_f1": macro_f1,
145
+ "tn": int(cm[0, 0]),
146
+ "fp": int(cm[0, 1]),
147
+ "fn": int(cm[1, 0]),
148
+ "tp": int(cm[1, 1]),
149
+ }
150
+
151
+ args = TrainingArguments(
152
+ output_dir="jim-crow-test2323",
153
+ learning_rate=2e-5,
154
+ per_device_train_batch_size=16,
155
+ per_device_eval_batch_size=32,
156
+ gradient_accumulation_steps=1,
157
+ num_train_epochs=5,
158
+ weight_decay=0.01,
159
+ warmup_ratio=0.1,
160
+ lr_scheduler_type="linear",
161
+ eval_strategy="epoch",
162
+ save_strategy="epoch",
163
+ logging_steps=10,
164
+ load_best_model_at_end=True,
165
+ metric_for_best_model="f1",
166
+ greater_is_better=True,
167
+ save_total_limit=2,
168
+ fp16=torch.cuda.is_available(),
169
+ push_to_hub=True,
170
+ hub_model_id=HUB_MODEL_ID,
171
+ hub_private_repo=False,
172
+ report_to=[],
173
+ run_name=RUN_NAME,
174
+ seed=SEED,
175
+ )
176
+
177
+ trainer = WeightedTrainer(
178
+ model=model,
179
+ args=args,
180
+ train_dataset=train_tok,
181
+ eval_dataset=eval_tok,
182
+ processing_class=tokenizer,
183
+ data_collator=DataCollatorWithPadding(tokenizer),
184
+ compute_metrics=compute_metrics,
185
+ callbacks=[TrackioCallback()],
186
+ )
187
+
188
+ trainer.train()
189
+ metrics = trainer.evaluate()
190
+ print("Final eval metrics:", metrics)
191
+ trackio.log({f"final/{k}": v for k, v in metrics.items() if isinstance(v, (int, float))})
192
+
193
+ # Ensure useful metadata and a model card are present on the final Hub repo.
194
+ trainer.save_model()
195
+ tokenizer.save_pretrained(args.output_dir)
196
+ trainer.create_model_card(
197
+ model_name="Jim Crow law classifier",
198
+ dataset_tags=DATASET_ID,
199
+ finetuned_from=MODEL_ID,
200
+ tasks="text-classification",
201
+ language="en",
202
+ tags=["legal", "history", "jim-crow", "sequence-classification", "distilbert"],
203
+ )
204
+ trainer.push_to_hub(commit_message="Fine-tune DistilBERT to identify Jim Crow laws")
205
+
206
+ api = HfApi(token=os.environ["HF_TOKEN"])
207
+ api.upload_file(
208
+ path_or_fileobj=__file__,
209
+ path_in_repo="training_script.py",
210
+ repo_id=HUB_MODEL_ID,
211
+ repo_type="model",
212
+ commit_message="Add training script",
213
+ )
214
+ print(f"Pushed trained model to https://huggingface.co/{HUB_MODEL_ID}")