igorithm commited on
Commit
4e026d9
·
1 Parent(s): 9bddf5f

Add train script

Browse files
Files changed (1) hide show
  1. train.py +109 -0
train.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from sklearn.metrics import f1_score, accuracy_score
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForSequenceClassification,
6
+ Trainer,
7
+ TrainingArguments,
8
+ )
9
+ from bert_paper_classifier_model import SciBertPaperClassifier
10
+
11
+
12
+ def encode_labels(example):
13
+ example["labels"] = label2id[example["category"]]
14
+ return example
15
+
16
+
17
+ def preprocess_function(examples):
18
+ texts = [
19
+ f"AUTHORS: {' '.join(a) if isinstance(a, list) else a} TITLE: {t} ABSTRACT: {ab}"
20
+ for a, t, ab in zip(
21
+ examples["authors"], examples["title"], examples["abstract"]
22
+ )
23
+ ]
24
+ return tokenizer(texts, truncation=True, padding="max_length", max_length=256)
25
+
26
+
27
+ def compute_metrics(pred):
28
+ labels = pred.label_ids
29
+ logits = pred.predictions
30
+ preds = logits.argmax(-1)
31
+ return {
32
+ "accuracy": accuracy_score(labels, preds),
33
+ "f1": f1_score(labels, preds, average="weighted"),
34
+ }
35
+
36
+
37
+ if __name__ == "__main__":
38
+ print("DOWNLOADING DATASET...")
39
+ data_files = {"train": "arxiv_train.json", "test": "arxiv_test.json"}
40
+ dataset = load_dataset("json", data_files=data_files)
41
+
42
+ dataset["train"] = dataset["train"].shuffle(seed=42).select(range(100000))
43
+ print(f"DATA IS READY. TRAIN: {len(dataset['train'])}")
44
+
45
+ print("LABELING...")
46
+ unique_labels = sorted(set(example["category"] for example in dataset["train"]))
47
+ label2id = {label: idx for idx, label in enumerate(unique_labels)}
48
+ id2label = {idx: label for label, idx in label2id.items()}
49
+
50
+ dataset["train"] = dataset["train"].map(encode_labels)
51
+
52
+ split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
53
+ train_dataset = split_dataset["train"]
54
+ valid_dataset = split_dataset["test"]
55
+ print(f"TRAIN SET: {len(train_dataset)}, VALIDATION SET: {len(valid_dataset)}")
56
+
57
+ print("TOKENIZATION...")
58
+ model_name = "allenai/scibert_scivocab_uncased"
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
60
+
61
+ encoded_train = train_dataset.map(preprocess_function, batched=True, batch_size=32)
62
+ encoded_valid = valid_dataset.map(preprocess_function, batched=True, batch_size=32)
63
+ encoded_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
64
+ encoded_valid.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
65
+ print("TOKENIZATION COMPLETED")
66
+
67
+ print("DOWNLOADING MODEL...")
68
+ model = AutoModelForSequenceClassification.from_pretrained(
69
+ model_name,
70
+ num_labels=len(unique_labels),
71
+ id2label=id2label,
72
+ label2id=label2id,
73
+ )
74
+
75
+ training_args = TrainingArguments(
76
+ output_dir="./dataset_output",
77
+ report_to="none",
78
+ eval_strategy="steps",
79
+ eval_steps=100,
80
+ logging_steps=200,
81
+ disable_tqdm=True,
82
+ learning_rate=3e-5,
83
+ per_device_train_batch_size=32,
84
+ per_device_eval_batch_size=32,
85
+ num_train_epochs=2,
86
+ save_steps=200,
87
+ fp16=True,
88
+ remove_unused_columns=False,
89
+ )
90
+
91
+ print("LEARNING...")
92
+ trainer = Trainer(
93
+ model=model,
94
+ args=training_args,
95
+ train_dataset=encoded_train,
96
+ eval_dataset=encoded_valid,
97
+ compute_metrics=compute_metrics,
98
+ )
99
+ trainer.train()
100
+ print("LEARNING COMPLETED")
101
+
102
+ model.save_pretrained("trained_model")
103
+ tokenizer.save_pretrained("trained_model")
104
+
105
+ print("EVALUATION...")
106
+ final_metrics = trainer.evaluate()
107
+ print("METRICS:")
108
+ for key, value in final_metrics.items():
109
+ print(f"{key}: {value:.4f}")