Ire-O commited on
Commit
c31edbd
·
verified ·
1 Parent(s): 9ba546b

Create test_mult_choice.py

Browse files
Files changed (1) hide show
  1. test_mult_choice.py +127 -0
test_mult_choice.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ swag = load_dataset("swag", "regular")
4
+
5
+ swag["train"][0]
6
+
7
+ from transformers import AutoTokenizer
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
10
+
11
+ ending_names = ["ending0", "ending1", "ending2", "ending3"]
12
+
13
+
14
+ def preprocess_function(examples):
15
+ first_sentences = [[context] * 4 for context in examples["sent1"]]
16
+ question_headers = examples["sent2"]
17
+ second_sentences = [
18
+ [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
19
+ ]
20
+
21
+ first_sentences = sum(first_sentences, [])
22
+ second_sentences = sum(second_sentences, [])
23
+
24
+ tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
25
+ return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
26
+
27
+ tokenized_swag = swag.map(preprocess_function, batched=True)
28
+
29
+ from dataclasses import dataclass
30
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
31
+ from typing import Optional, Union
32
+ import torch
33
+
34
+
35
+ @dataclass
36
+ class DataCollatorForMultipleChoice:
37
+ """
38
+ Data collator that will dynamically pad the inputs for multiple choice received.
39
+ """
40
+
41
+ tokenizer: PreTrainedTokenizerBase
42
+ padding: Union[bool, str, PaddingStrategy] = True
43
+ max_length: Optional[int] = None
44
+ pad_to_multiple_of: Optional[int] = None
45
+
46
+ def __call__(self, features):
47
+ label_name = "label" if "label" in features[0].keys() else "labels"
48
+ labels = [feature.pop(label_name) for feature in features]
49
+ batch_size = len(features)
50
+ num_choices = len(features[0]["input_ids"])
51
+ flattened_features = [
52
+ [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
53
+ ]
54
+ flattened_features = sum(flattened_features, [])
55
+
56
+ batch = self.tokenizer.pad(
57
+ flattened_features,
58
+ padding=self.padding,
59
+ max_length=self.max_length,
60
+ pad_to_multiple_of=self.pad_to_multiple_of,
61
+ return_tensors="pt",
62
+ )
63
+
64
+ batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
65
+ batch["labels"] = torch.tensor(labels, dtype=torch.int64)
66
+ return batch
67
+
68
+ from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
69
+
70
+ model = AutoModelForMultipleChoice.from_pretrained("bert-base-uncased")
71
+
72
+ training_args = TrainingArguments(
73
+ output_dir="./results",
74
+ evaluation_strategy="epoch",
75
+ learning_rate=5e-5,
76
+ per_device_train_batch_size=16,
77
+ per_device_eval_batch_size=16,
78
+ num_train_epochs=3,
79
+ weight_decay=0.01,
80
+ )
81
+
82
+ trainer = Trainer(
83
+ model=model,
84
+ args=training_args,
85
+ train_dataset=tokenized_swag["train"],
86
+ eval_dataset=tokenized_swag["validation"],
87
+ tokenizer=tokenizer,
88
+ data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
89
+ )
90
+
91
+ trainer.train()
92
+
93
+
94
+ data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
95
+ tf_train_set = tokenized_swag["train"].to_tf_dataset(
96
+ columns=["attention_mask", "input_ids"],
97
+ label_cols=["labels"],
98
+ shuffle=True,
99
+ batch_size=batch_size,
100
+ collate_fn=data_collator,
101
+ )
102
+
103
+ tf_validation_set = tokenized_swag["validation"].to_tf_dataset(
104
+ columns=["attention_mask", "input_ids"],
105
+ label_cols=["labels"],
106
+ shuffle=False,
107
+ batch_size=batch_size,
108
+ collate_fn=data_collator,
109
+ )
110
+
111
+ from transformers import create_optimizer
112
+
113
+ batch_size = 16
114
+ num_train_epochs = 2
115
+ total_train_steps = (len(tokenized_swag["train"]) // batch_size) * num_train_epochs
116
+ optimizer, schedule = create_optimizer(init_lr=5e-5, num_warmup_steps=0, num_train_steps=total_train_steps)
117
+
118
+ from transformers import TFAutoModelForMultipleChoice
119
+
120
+ model = TFAutoModelForMultipleChoice.from_pretrained("bert-base-uncased")
121
+
122
+ model.compile(
123
+ optimizer=optimizer,
124
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
125
+ )
126
+
127
+ model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=2)