Spaces:
No application file
No application file
File size: 4,012 Bytes
c31edbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
from datasets import load_dataset
swag = load_dataset("swag", "regular")
swag["train"][0]
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
ending_names = ["ending0", "ending1", "ending2", "ending3"]
def preprocess_function(examples):
first_sentences = [[context] * 4 for context in examples["sent1"]]
question_headers = examples["sent2"]
second_sentences = [
[f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
]
first_sentences = sum(first_sentences, [])
second_sentences = sum(second_sentences, [])
tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
tokenized_swag = swag.map(preprocess_function, batched=True)
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch
@dataclass
class DataCollatorForMultipleChoice:
"""
Data collator that will dynamically pad the inputs for multiple choice received.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature.pop(label_name) for feature in features]
batch_size = len(features)
num_choices = len(features[0]["input_ids"])
flattened_features = [
[{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
]
flattened_features = sum(flattened_features, [])
batch = self.tokenizer.pad(
flattened_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors="pt",
)
batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
batch["labels"] = torch.tensor(labels, dtype=torch.int64)
return batch
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
model = AutoModelForMultipleChoice.from_pretrained("bert-base-uncased")
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_swag["train"],
eval_dataset=tokenized_swag["validation"],
tokenizer=tokenizer,
data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
)
trainer.train()
data_collator = DataCollatorForMultipleChoice(tokenizer=tokenizer)
tf_train_set = tokenized_swag["train"].to_tf_dataset(
columns=["attention_mask", "input_ids"],
label_cols=["labels"],
shuffle=True,
batch_size=batch_size,
collate_fn=data_collator,
)
tf_validation_set = tokenized_swag["validation"].to_tf_dataset(
columns=["attention_mask", "input_ids"],
label_cols=["labels"],
shuffle=False,
batch_size=batch_size,
collate_fn=data_collator,
)
from transformers import create_optimizer
batch_size = 16
num_train_epochs = 2
total_train_steps = (len(tokenized_swag["train"]) // batch_size) * num_train_epochs
optimizer, schedule = create_optimizer(init_lr=5e-5, num_warmup_steps=0, num_train_steps=total_train_steps)
from transformers import TFAutoModelForMultipleChoice
model = TFAutoModelForMultipleChoice.from_pretrained("bert-base-uncased")
model.compile(
optimizer=optimizer,
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=2)
|