cnut1648 commited on
Commit
0b5de44
·
1 Parent(s): 1bd61ae

Create mnli.py

Browse files
Files changed (1) hide show
  1. mnli.py +647 -0
mnli.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Finetuning the library models for sequence classification on GLUE."""
17
+ # You can also adapt this script on your own text classification task. Pointers for this are left as comments.
18
+ import logging
19
+ import os
20
+ import random
21
+ import sys
22
+ from dataclasses import dataclass, field
23
+ from typing import Optional
24
+
25
+ import datasets
26
+ import numpy as np
27
+ from datasets import load_dataset, concatenate_datasets
28
+
29
+ import evaluate
30
+ import transformers
31
+ from transformers import (
32
+ AutoConfig,
33
+ AutoModelForSequenceClassification,
34
+ AutoTokenizer,
35
+ DataCollatorWithPadding,
36
+ EvalPrediction,
37
+ HfArgumentParser,
38
+ PretrainedConfig,
39
+ Trainer,
40
+ TrainingArguments,
41
+ default_data_collator,
42
+ set_seed,
43
+ )
44
+ from transformers.trainer_utils import get_last_checkpoint
45
+ from transformers.utils import check_min_version, send_example_telemetry
46
+ from transformers.utils.versions import require_version
47
+
48
+
49
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
50
+ check_min_version("4.22.2")
51
+
52
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")
53
+
54
+ task_to_keys = {
55
+ "cola": ("sentence", None),
56
+ "mnli": ("premise", "hypothesis"),
57
+ "mrpc": ("sentence1", "sentence2"),
58
+ "qnli": ("question", "sentence"),
59
+ "qqp": ("question1", "question2"),
60
+ "rte": ("sentence1", "sentence2"),
61
+ "sst2": ("sentence", None),
62
+ "stsb": ("sentence1", "sentence2"),
63
+ "wnli": ("sentence1", "sentence2"),
64
+ }
65
+
66
+ logger = logging.getLogger(__name__)
67
+
68
+
69
+ @dataclass
70
+ class DataTrainingArguments:
71
+ """
72
+ Arguments pertaining to what data we are going to input our model for training and eval.
73
+ Using `HfArgumentParser` we can turn this class
74
+ into argparse arguments to be able to specify them on
75
+ the command line.
76
+ """
77
+
78
+ task_name: Optional[str] = field(
79
+ default=None,
80
+ metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
81
+ )
82
+ add_snli: bool = field(
83
+ default=False, metadata={
84
+ "help": "if set, add snli in train / val / test"
85
+ }
86
+ )
87
+ dataset_name: Optional[str] = field(
88
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
89
+ )
90
+ dataset_config_name: Optional[str] = field(
91
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
92
+ )
93
+ max_seq_length: int = field(
94
+ default=128,
95
+ metadata={
96
+ "help": (
97
+ "The maximum total input sequence length after tokenization. Sequences longer "
98
+ "than this will be truncated, sequences shorter will be padded."
99
+ )
100
+ },
101
+ )
102
+ overwrite_cache: bool = field(
103
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
104
+ )
105
+ pad_to_max_length: bool = field(
106
+ default=True,
107
+ metadata={
108
+ "help": (
109
+ "Whether to pad all samples to `max_seq_length`. "
110
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
111
+ )
112
+ },
113
+ )
114
+ max_train_samples: Optional[int] = field(
115
+ default=None,
116
+ metadata={
117
+ "help": (
118
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
119
+ "value if set."
120
+ )
121
+ },
122
+ )
123
+ max_eval_samples: Optional[int] = field(
124
+ default=None,
125
+ metadata={
126
+ "help": (
127
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
128
+ "value if set."
129
+ )
130
+ },
131
+ )
132
+ max_predict_samples: Optional[int] = field(
133
+ default=None,
134
+ metadata={
135
+ "help": (
136
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
137
+ "value if set."
138
+ )
139
+ },
140
+ )
141
+ train_file: Optional[str] = field(
142
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
143
+ )
144
+ validation_file: Optional[str] = field(
145
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
146
+ )
147
+ test_file: Optional[str] = field(default=None, metadata={"help": "A csv or a json file containing the test data."})
148
+
149
+ def __post_init__(self):
150
+ if self.task_name is not None:
151
+ self.task_name = self.task_name.lower()
152
+ if self.task_name not in task_to_keys.keys():
153
+ raise ValueError("Unknown task, you should pick one in " + ",".join(task_to_keys.keys()))
154
+ elif self.dataset_name is not None:
155
+ pass
156
+ elif self.train_file is None or self.validation_file is None:
157
+ raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.")
158
+ else:
159
+ train_extension = self.train_file.split(".")[-1]
160
+ assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file."
161
+ validation_extension = self.validation_file.split(".")[-1]
162
+ assert (
163
+ validation_extension == train_extension
164
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
165
+
166
+
167
+ @dataclass
168
+ class ModelArguments:
169
+ """
170
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
171
+ """
172
+
173
+ model_name_or_path: str = field(
174
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
175
+ )
176
+ config_name: Optional[str] = field(
177
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
178
+ )
179
+ tokenizer_name: Optional[str] = field(
180
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
181
+ )
182
+ cache_dir: Optional[str] = field(
183
+ default=None,
184
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
185
+ )
186
+ use_fast_tokenizer: bool = field(
187
+ default=True,
188
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
189
+ )
190
+ model_revision: str = field(
191
+ default="main",
192
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
193
+ )
194
+ use_auth_token: bool = field(
195
+ default=False,
196
+ metadata={
197
+ "help": (
198
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
199
+ "with private models)."
200
+ )
201
+ },
202
+ )
203
+ ignore_mismatched_sizes: bool = field(
204
+ default=False,
205
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
206
+ )
207
+
208
+
209
+ def main():
210
+ # See all possible arguments in src/transformers/training_args.py
211
+ # or by passing the --help flag to this script.
212
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
213
+
214
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
215
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
216
+ # If we pass only one argument to the script and it's the path to a json file,
217
+ # let's parse it to get our arguments.
218
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
219
+ else:
220
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
221
+
222
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
223
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
224
+ send_example_telemetry("run_glue", model_args, data_args)
225
+
226
+ # Setup logging
227
+ logging.basicConfig(
228
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
229
+ datefmt="%m/%d/%Y %H:%M:%S",
230
+ handlers=[logging.StreamHandler(sys.stdout)],
231
+ )
232
+
233
+ log_level = training_args.get_process_log_level()
234
+ logger.setLevel(log_level)
235
+ datasets.utils.logging.set_verbosity(log_level)
236
+ transformers.utils.logging.set_verbosity(log_level)
237
+ transformers.utils.logging.enable_default_handler()
238
+ transformers.utils.logging.enable_explicit_format()
239
+
240
+ # Log on each process the small summary:
241
+ logger.warning(
242
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
243
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
244
+ )
245
+ logger.info(f"Training/evaluation parameters {training_args}")
246
+
247
+ # Detecting last checkpoint.
248
+ last_checkpoint = None
249
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
250
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
251
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
252
+ raise ValueError(
253
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
254
+ "Use --overwrite_output_dir to overcome."
255
+ )
256
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
257
+ logger.info(
258
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
259
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
260
+ )
261
+
262
+ # Set seed before initializing model.
263
+ set_seed(training_args.seed)
264
+
265
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
266
+ # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
267
+ #
268
+ # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
269
+ # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
270
+ # label if at least two columns are provided.
271
+ #
272
+ # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
273
+ # single column. You can easily tweak this behavior (see below)
274
+ #
275
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
276
+ # download the dataset.
277
+ if data_args.task_name is not None:
278
+ # Downloading and loading a dataset from the hub.
279
+ raw_datasets = load_dataset(
280
+ "glue",
281
+ data_args.task_name,
282
+ cache_dir=model_args.cache_dir,
283
+ use_auth_token=True if model_args.use_auth_token else None,
284
+ )
285
+ elif data_args.dataset_name is not None:
286
+ # Downloading and loading a dataset from the hub.
287
+ raw_datasets = load_dataset(
288
+ data_args.dataset_name,
289
+ data_args.dataset_config_name,
290
+ cache_dir=model_args.cache_dir,
291
+ use_auth_token=True if model_args.use_auth_token else None,
292
+ )
293
+ else:
294
+ # Loading a dataset from your local files.
295
+ # CSV/JSON training and evaluation files are needed.
296
+ data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
297
+
298
+ # Get the test dataset: you can provide your own CSV/JSON test file (see below)
299
+ # when you use `do_predict` without specifying a GLUE benchmark task.
300
+ if training_args.do_predict:
301
+ if data_args.test_file is not None:
302
+ train_extension = data_args.train_file.split(".")[-1]
303
+ test_extension = data_args.test_file.split(".")[-1]
304
+ assert (
305
+ test_extension == train_extension
306
+ ), "`test_file` should have the same extension (csv or json) as `train_file`."
307
+ data_files["test"] = data_args.test_file
308
+ else:
309
+ raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
310
+
311
+ for key in data_files.keys():
312
+ logger.info(f"load a local file for {key}: {data_files[key]}")
313
+
314
+ if data_args.train_file.endswith(".csv"):
315
+ # Loading a dataset from local csv files
316
+ raw_datasets = load_dataset(
317
+ "csv",
318
+ data_files=data_files,
319
+ cache_dir=model_args.cache_dir,
320
+ use_auth_token=True if model_args.use_auth_token else None,
321
+ )
322
+ else:
323
+ # Loading a dataset from local json files
324
+ raw_datasets = load_dataset(
325
+ "json",
326
+ data_files=data_files,
327
+ cache_dir=model_args.cache_dir,
328
+ use_auth_token=True if model_args.use_auth_token else None,
329
+ )
330
+ # See more about loading any type of standard or custom dataset at
331
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
332
+
333
+ # Labels
334
+ if data_args.task_name is not None:
335
+ is_regression = data_args.task_name == "stsb"
336
+ if not is_regression:
337
+ label_list = raw_datasets["train"].features["label"].names
338
+ num_labels = len(label_list)
339
+ else:
340
+ num_labels = 1
341
+ else:
342
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
343
+ is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
344
+ if is_regression:
345
+ num_labels = 1
346
+ else:
347
+ # A useful fast method:
348
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
349
+ label_list = raw_datasets["train"].unique("label")
350
+ label_list.sort() # Let's sort it for determinism
351
+ num_labels = len(label_list)
352
+
353
+ # Load pretrained model and tokenizer
354
+ #
355
+ # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
356
+ # download model & vocab.
357
+ config = AutoConfig.from_pretrained(
358
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
359
+ num_labels=num_labels,
360
+ finetuning_task=data_args.task_name,
361
+ cache_dir=model_args.cache_dir,
362
+ revision=model_args.model_revision,
363
+ use_auth_token=True if model_args.use_auth_token else None,
364
+ )
365
+ tokenizer = AutoTokenizer.from_pretrained(
366
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
367
+ cache_dir=model_args.cache_dir,
368
+ use_fast=model_args.use_fast_tokenizer,
369
+ revision=model_args.model_revision,
370
+ use_auth_token=True if model_args.use_auth_token else None,
371
+ )
372
+ model = AutoModelForSequenceClassification.from_pretrained(
373
+ model_args.model_name_or_path,
374
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
375
+ config=config,
376
+ cache_dir=model_args.cache_dir,
377
+ revision=model_args.model_revision,
378
+ use_auth_token=True if model_args.use_auth_token else None,
379
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
380
+ )
381
+
382
+ # Preprocessing the raw_datasets
383
+ if data_args.task_name is not None:
384
+ sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
385
+ else:
386
+ # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
387
+ non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
388
+ if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
389
+ sentence1_key, sentence2_key = "sentence1", "sentence2"
390
+ else:
391
+ if len(non_label_column_names) >= 2:
392
+ sentence1_key, sentence2_key = non_label_column_names[:2]
393
+ else:
394
+ sentence1_key, sentence2_key = non_label_column_names[0], None
395
+
396
+ # Padding strategy
397
+ if data_args.pad_to_max_length:
398
+ padding = "max_length"
399
+ else:
400
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
401
+ padding = False
402
+
403
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
404
+ label_to_id = None
405
+ if (
406
+ model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
407
+ and data_args.task_name is not None
408
+ and not is_regression
409
+ ):
410
+ # Some have all caps in their config, some don't.
411
+ label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
412
+ if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
413
+ label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
414
+ else:
415
+ logger.warning(
416
+ "Your model seems to have been trained with labels, but they don't match the dataset: ",
417
+ f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
418
+ "\nIgnoring the model labels as a result.",
419
+ )
420
+ elif data_args.task_name is None and not is_regression:
421
+ label_to_id = {v: i for i, v in enumerate(label_list)}
422
+
423
+ if label_to_id is not None:
424
+ model.config.label2id = label_to_id
425
+ model.config.id2label = {id: label for label, id in config.label2id.items()}
426
+ elif data_args.task_name is not None and not is_regression:
427
+ model.config.label2id = {l: i for i, l in enumerate(label_list)}
428
+ model.config.id2label = {id: label for label, id in config.label2id.items()}
429
+
430
+ if data_args.max_seq_length > tokenizer.model_max_length:
431
+ logger.warning(
432
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
433
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
434
+ )
435
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
436
+
437
+ def preprocess_function(examples):
438
+ # Tokenize the texts
439
+ args = (
440
+ (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
441
+ )
442
+ result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
443
+
444
+ # Map labels to IDs (not necessary for GLUE tasks)
445
+ if label_to_id is not None and "label" in examples:
446
+ result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
447
+ return result
448
+
449
+ with training_args.main_process_first(desc="dataset map pre-processing"):
450
+ raw_datasets = raw_datasets.map(
451
+ preprocess_function,
452
+ batched=True,
453
+ load_from_cache_file=not data_args.overwrite_cache,
454
+ desc="Running tokenizer on dataset",
455
+ )
456
+ if data_args.add_snli:
457
+ snli = load_dataset("snli")
458
+ snli = snli.filter(lambda x: x["label"] != -1)
459
+ snli = snli.map(
460
+ preprocess_function,
461
+ batched=True,
462
+ load_from_cache_file=not data_args.overwrite_cache,
463
+ desc="Running tokenizer on snli",
464
+ )
465
+ if training_args.do_train:
466
+ if "train" not in raw_datasets:
467
+ raise ValueError("--do_train requires a train dataset")
468
+ if data_args.add_snli:
469
+ train_dataset = concatenate_datasets([raw_datasets["train"], snli["train"]])
470
+ else:
471
+ train_dataset = raw_datasets["train"]
472
+ if data_args.max_train_samples is not None:
473
+ max_train_samples = min(len(train_dataset), data_args.max_train_samples)
474
+ train_dataset = train_dataset.select(range(max_train_samples))
475
+
476
+ if training_args.do_eval:
477
+ if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
478
+ raise ValueError("--do_eval requires a validation dataset")
479
+ eval_dataset = raw_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
480
+ if data_args.max_eval_samples is not None:
481
+ max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
482
+ eval_dataset = eval_dataset.select(range(max_eval_samples))
483
+
484
+ if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None:
485
+ if "test" not in raw_datasets and "test_matched" not in raw_datasets:
486
+ raise ValueError("--do_predict requires a test dataset")
487
+ predict_dataset = raw_datasets["test_matched" if data_args.task_name == "mnli" else "test"]
488
+ if data_args.max_predict_samples is not None:
489
+ max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
490
+ predict_dataset = predict_dataset.select(range(max_predict_samples))
491
+
492
+ # Log a few random samples from the training set:
493
+ if training_args.do_train:
494
+ for index in random.sample(range(len(train_dataset)), 3):
495
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
496
+
497
+ # Get the metric function
498
+ if data_args.task_name is not None:
499
+ metric = evaluate.load("glue", data_args.task_name)
500
+ else:
501
+ metric = evaluate.load("accuracy")
502
+
503
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
504
+ # predictions and label_ids field) and has to return a dictionary string to float.
505
+ def compute_metrics(p: EvalPrediction):
506
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
507
+ preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
508
+ if data_args.task_name is not None:
509
+ result = metric.compute(predictions=preds, references=p.label_ids)
510
+ if len(result) > 1:
511
+ result["combined_score"] = np.mean(list(result.values())).item()
512
+ return result
513
+ elif is_regression:
514
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
515
+ else:
516
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
517
+
518
+ # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
519
+ # we already did the padding.
520
+ if data_args.pad_to_max_length:
521
+ data_collator = default_data_collator
522
+ elif training_args.fp16:
523
+ data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
524
+ else:
525
+ data_collator = None
526
+
527
+ # Initialize our Trainer
528
+ trainer = Trainer(
529
+ model=model,
530
+ args=training_args,
531
+ train_dataset=train_dataset if training_args.do_train else None,
532
+ eval_dataset=eval_dataset if training_args.do_eval else None,
533
+ compute_metrics=compute_metrics,
534
+ tokenizer=tokenizer,
535
+ data_collator=data_collator,
536
+ )
537
+
538
+ # Training
539
+ if training_args.do_train:
540
+ checkpoint = None
541
+ if training_args.resume_from_checkpoint is not None:
542
+ checkpoint = training_args.resume_from_checkpoint
543
+ elif last_checkpoint is not None:
544
+ checkpoint = last_checkpoint
545
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
546
+ metrics = train_result.metrics
547
+ max_train_samples = (
548
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
549
+ )
550
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
551
+
552
+ trainer.save_model() # Saves the tokenizer too for easy upload
553
+
554
+ trainer.log_metrics("train", metrics)
555
+ trainer.save_metrics("train", metrics)
556
+ trainer.save_state()
557
+
558
+ # Evaluation
559
+ if training_args.do_eval:
560
+ logger.info("*** Evaluate ***")
561
+
562
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
563
+ tasks = [data_args.task_name]
564
+ eval_datasets = [eval_dataset]
565
+ if data_args.task_name == "mnli":
566
+ tasks.append("mnli-mm")
567
+ valid_mm_dataset = raw_datasets["validation_mismatched"]
568
+ if data_args.max_eval_samples is not None:
569
+ max_eval_samples = min(len(valid_mm_dataset), data_args.max_eval_samples)
570
+ valid_mm_dataset = valid_mm_dataset.select(range(max_eval_samples))
571
+ eval_datasets.append(valid_mm_dataset)
572
+ combined = {}
573
+
574
+ if data_args.add_snli:
575
+ eval_datasets.append(snli["validation"])
576
+ tasks.append("snli-val")
577
+ eval_datasets.append(snli["test"])
578
+ tasks.append("snli-test")
579
+
580
+ for eval_dataset, task in zip(eval_datasets, tasks):
581
+ metrics = trainer.evaluate(eval_dataset=eval_dataset)
582
+
583
+ max_eval_samples = (
584
+ data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
585
+ )
586
+ metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
587
+
588
+ metrics = {task+"_"+k: v for k, v in metrics.items()}
589
+ # if task == "mnli-mm":
590
+ # metrics = {k + "_mm": v for k, v in metrics.items()}
591
+ # if task is not None and "mnli" in task:
592
+ combined.update(metrics)
593
+
594
+ trainer.log_metrics("eval", metrics)
595
+ trainer.save_metrics("eval", combined if task is not None and "mnli" in task else metrics)
596
+
597
+ if training_args.do_predict:
598
+ logger.info("*** Predict ***")
599
+
600
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
601
+ tasks = [data_args.task_name]
602
+ predict_datasets = [predict_dataset]
603
+ if data_args.task_name == "mnli":
604
+ tasks.append("mnli-mm")
605
+ predict_datasets.append(raw_datasets["test_mismatched"])
606
+ if data_args.add_snli:
607
+ eval_datasets.append(snli["test"])
608
+ tasks.append("snli-test")
609
+
610
+ for predict_dataset, task in zip(predict_datasets, tasks):
611
+ # Removing the `label` columns because it contains -1 and Trainer won't like that.
612
+ predict_dataset = predict_dataset.remove_columns("label")
613
+ predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions
614
+ predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
615
+
616
+ output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.txt")
617
+ if trainer.is_world_process_zero():
618
+ with open(output_predict_file, "w") as writer:
619
+ logger.info(f"***** Predict results {task} *****")
620
+ writer.write("index\tprediction\n")
621
+ for index, item in enumerate(predictions):
622
+ if is_regression:
623
+ writer.write(f"{index}\t{item:3.3f}\n")
624
+ else:
625
+ item = label_list[item]
626
+ writer.write(f"{index}\t{item}\n")
627
+
628
+ kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-classification"}
629
+ if data_args.task_name is not None:
630
+ kwargs["language"] = "en"
631
+ kwargs["dataset_tags"] = "glue"
632
+ kwargs["dataset_args"] = data_args.task_name
633
+ kwargs["dataset"] = f"GLUE {data_args.task_name.upper()}"
634
+
635
+ if training_args.push_to_hub:
636
+ trainer.push_to_hub(**kwargs)
637
+ else:
638
+ trainer.create_model_card(**kwargs)
639
+
640
+
641
+ def _mp_fn(index):
642
+ # For xla_spawn (TPUs)
643
+ main()
644
+
645
+
646
+ if __name__ == "__main__":
647
+ main()