DelaliScratchwerk commited on
Commit
9504c80
·
verified ·
1 Parent(s): 86588d1

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +152 -0
train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List
5
+
6
+ import numpy as np
7
+ from datasets import load_dataset
8
+ import evaluate
9
+
10
+ from transformers import (
11
+ AutoTokenizer,
12
+ AutoModelForSequenceClassification,
13
+ DataCollatorWithPadding,
14
+ TrainingArguments,
15
+ Trainer,
16
+ )
17
+
18
+ # ======================
19
+ # LABEL SCHEMA
20
+ # ======================
21
+
22
+ LABELS: List[str] = [
23
+ "pre-1900",
24
+ "1900-1945",
25
+ "1946-1968",
26
+ "1969-1979",
27
+ "1980s",
28
+ "1990s",
29
+ "2000-2008",
30
+ "2009-2015",
31
+ "2016-2018",
32
+ "2019-2021",
33
+ "2022-present",
34
+ ]
35
+
36
+ id2label: Dict[int, str] = {i: l for i, l in enumerate(LABELS)}
37
+ label2id: Dict[str, int] = {l: i for i, l in enumerate(LABELS)}
38
+
39
+ # Base model to fine-tune
40
+ BASE_MODEL = os.environ.get("BASE_MODEL", "distilroberta-base")
41
+
42
+ # Hugging Face hub repo where the fine-tuned model will be pushed
43
+ HUB_MODEL_ID = "DelaliScratchwerk/time-period-classifier-bert"
44
+
45
+ # ======================
46
+ # LOAD DATA
47
+ # ======================
48
+
49
+ # Expect CSVs at data/train.csv and data/val.csv
50
+ dataset = load_dataset(
51
+ "csv",
52
+ data_files={
53
+ "train": "data/train.csv",
54
+ "validation": "data/val.csv",
55
+ },
56
+ )
57
+
58
+ print("Raw dataset:", dataset)
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
61
+
62
+
63
+ def encode_batch(batch):
64
+ # tokenize texts
65
+ enc = tokenizer(batch["text"], truncation=True)
66
+ # map string labels -> integer ids
67
+ # strip helps if there are trailing spaces in the CSV
68
+ enc["labels"] = [label2id[l.strip()] for l in batch["label"]]
69
+ return enc
70
+
71
+
72
+ # IMPORTANT: remove original 'text' and 'label' columns so Trainer only sees tensors
73
+ encoded = dataset.map(
74
+ encode_batch,
75
+ batched=True,
76
+ remove_columns=dataset["train"].column_names,
77
+ )
78
+
79
+ print(encoded)
80
+ print("Encoded train sample keys:", encoded["train"][0].keys())
81
+ # should be: dict_keys(['input_ids', 'attention_mask', 'labels'])
82
+
83
+ # ======================
84
+ # MODEL
85
+ # ======================
86
+
87
+ model = AutoModelForSequenceClassification.from_pretrained(
88
+ BASE_MODEL,
89
+ num_labels=len(LABELS),
90
+ id2label=id2label,
91
+ label2id=label2id,
92
+ )
93
+
94
+ # ======================
95
+ # METRICS
96
+ # ======================
97
+
98
+ accuracy = evaluate.load("accuracy")
99
+ f1_macro = evaluate.load("f1")
100
+
101
+
102
+ def compute_metrics(eval_pred):
103
+ logits, labels = eval_pred
104
+ preds = np.argmax(logits, axis=-1)
105
+ return {
106
+ "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
107
+ "f1_macro": f1_macro.compute(
108
+ predictions=preds, references=labels, average="macro"
109
+ )["f1"],
110
+ }
111
+
112
+
113
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
114
+
115
+ # ======================
116
+ # TRAINING ARGS
117
+ # ======================
118
+
119
+ training_args = TrainingArguments(
120
+ output_dir="out",
121
+ per_device_train_batch_size=16,
122
+ per_device_eval_batch_size=32,
123
+ learning_rate=2e-5,
124
+ num_train_epochs=4,
125
+ eval_strategy="epoch",
126
+ save_strategy="no",
127
+ load_best_model_at_end=False,
128
+ logging_steps=50,
129
+ push_to_hub=True,
130
+ hub_model_id=HUB_MODEL_ID,
131
+ hub_private_repo=False,
132
+ )
133
+
134
+ # ======================
135
+ # TRAINER
136
+ # ======================
137
+
138
+ trainer = Trainer(
139
+ model=model,
140
+ args=training_args,
141
+ train_dataset=encoded["train"],
142
+ eval_dataset=encoded["validation"],
143
+ tokenizer=tokenizer,
144
+ data_collator=data_collator,
145
+ compute_metrics=compute_metrics,
146
+ )
147
+
148
+ if __name__ == "__main__":
149
+ trainer.train()
150
+ # push best model + tokenizer to the Hub
151
+ trainer.push_to_hub()
152
+ tokenizer.push_to_hub(HUB_MODEL_ID)