patrickvonplaten commited on
Commit
ec3cfc6
·
1 Parent(s): 2db17eb
Files changed (3) hide show
  1. config.json +4 -4
  2. run_mlm.py +492 -0
  3. run_mlm.sh +19 -0
config.json CHANGED
@@ -8,14 +8,14 @@
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
11
- "hidden_size": 768,
12
  "initializer_range": 0.02,
13
- "intermediate_size": 3072,
14
  "layer_norm_eps": 1e-05,
15
  "max_position_embeddings": 514,
16
  "model_type": "roberta",
17
- "num_attention_heads": 12,
18
- "num_hidden_layers": 12,
19
  "pad_token_id": 1,
20
  "position_embedding_type": "absolute",
21
  "transformers_version": "4.6.0.dev0",
 
8
  "gradient_checkpointing": false,
9
  "hidden_act": "gelu",
10
  "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
  "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
  "layer_norm_eps": 1e-05,
15
  "max_position_embeddings": 514,
16
  "model_type": "roberta",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
  "pad_token_id": 1,
20
  "position_embedding_type": "absolute",
21
  "transformers_version": "4.6.0.dev0",
run_mlm.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2020 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) on a text file or a dataset.
18
+
19
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20
+ https://huggingface.co/models?filter=masked-lm
21
+ """
22
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
23
+
24
+ import logging
25
+ import math
26
+ import os
27
+ import sys
28
+ from dataclasses import dataclass, field
29
+ from typing import Optional
30
+
31
+ from datasets import load_dataset
32
+
33
+ import transformers
34
+ from transformers import (
35
+ CONFIG_MAPPING,
36
+ MODEL_FOR_MASKED_LM_MAPPING,
37
+ AutoConfig,
38
+ AutoModelForMaskedLM,
39
+ AutoTokenizer,
40
+ DataCollatorForLanguageModeling,
41
+ HfArgumentParser,
42
+ Trainer,
43
+ TrainingArguments,
44
+ set_seed,
45
+ )
46
+ from transformers.trainer_utils import get_last_checkpoint, is_main_process
47
+ from transformers.utils import check_min_version
48
+
49
+
50
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
51
+ check_min_version("4.6.0.dev0")
52
+
53
+ logger = logging.getLogger(__name__)
54
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
55
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
56
+
57
+
58
+ @dataclass
59
+ class ModelArguments:
60
+ """
61
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
62
+ """
63
+
64
+ model_name_or_path: Optional[str] = field(
65
+ default=None,
66
+ metadata={
67
+ "help": "The model checkpoint for weights initialization."
68
+ "Don't set if you want to train a model from scratch."
69
+ },
70
+ )
71
+ model_type: Optional[str] = field(
72
+ default=None,
73
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
74
+ )
75
+ config_name: Optional[str] = field(
76
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
77
+ )
78
+ tokenizer_name: Optional[str] = field(
79
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
80
+ )
81
+ cache_dir: Optional[str] = field(
82
+ default=None,
83
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
84
+ )
85
+ use_fast_tokenizer: bool = field(
86
+ default=True,
87
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
88
+ )
89
+ model_revision: str = field(
90
+ default="main",
91
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
92
+ )
93
+ use_auth_token: bool = field(
94
+ default=False,
95
+ metadata={
96
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
97
+ "with private models)."
98
+ },
99
+ )
100
+
101
+
102
+ @dataclass
103
+ class DataTrainingArguments:
104
+ """
105
+ Arguments pertaining to what data we are going to input our model for training and eval.
106
+ """
107
+
108
+ dataset_name: Optional[str] = field(
109
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
110
+ )
111
+ dataset_config_name: Optional[str] = field(
112
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
113
+ )
114
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
115
+ validation_file: Optional[str] = field(
116
+ default=None,
117
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
118
+ )
119
+ overwrite_cache: bool = field(
120
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
121
+ )
122
+ validation_split_percentage: Optional[int] = field(
123
+ default=5,
124
+ metadata={
125
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
126
+ },
127
+ )
128
+ max_seq_length: Optional[int] = field(
129
+ default=None,
130
+ metadata={
131
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
132
+ "than this will be truncated."
133
+ },
134
+ )
135
+ preprocessing_num_workers: Optional[int] = field(
136
+ default=None,
137
+ metadata={"help": "The number of processes to use for the preprocessing."},
138
+ )
139
+ mlm_probability: float = field(
140
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
141
+ )
142
+ line_by_line: bool = field(
143
+ default=False,
144
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
145
+ )
146
+ pad_to_max_length: bool = field(
147
+ default=False,
148
+ metadata={
149
+ "help": "Whether to pad all samples to `max_seq_length`. "
150
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
151
+ },
152
+ )
153
+ max_train_samples: Optional[int] = field(
154
+ default=None,
155
+ metadata={
156
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
157
+ "value if set."
158
+ },
159
+ )
160
+ max_val_samples: Optional[int] = field(
161
+ default=None,
162
+ metadata={
163
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
164
+ "value if set."
165
+ },
166
+ )
167
+
168
+ def __post_init__(self):
169
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
170
+ raise ValueError("Need either a dataset name or a training/validation file.")
171
+ else:
172
+ if self.train_file is not None:
173
+ extension = self.train_file.split(".")[-1]
174
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
175
+ if self.validation_file is not None:
176
+ extension = self.validation_file.split(".")[-1]
177
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
178
+
179
+
180
+ def main():
181
+ # See all possible arguments in src/transformers/training_args.py
182
+ # or by passing the --help flag to this script.
183
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
184
+
185
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
186
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
187
+ # If we pass only one argument to the script and it's the path to a json file,
188
+ # let's parse it to get our arguments.
189
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
190
+ else:
191
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
192
+
193
+ # Detecting last checkpoint.
194
+ last_checkpoint = None
195
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
196
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
197
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
198
+ raise ValueError(
199
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
200
+ "Use --overwrite_output_dir to overcome."
201
+ )
202
+ elif last_checkpoint is not None:
203
+ logger.info(
204
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
205
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
206
+ )
207
+
208
+ # Setup logging
209
+ logging.basicConfig(
210
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
211
+ datefmt="%m/%d/%Y %H:%M:%S",
212
+ handlers=[logging.StreamHandler(sys.stdout)],
213
+ )
214
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
215
+
216
+ # Log on each process the small summary:
217
+ logger.warning(
218
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
219
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
220
+ )
221
+ # Set the verbosity to info of the Transformers logger (on main process only):
222
+ if is_main_process(training_args.local_rank):
223
+ transformers.utils.logging.set_verbosity_info()
224
+ transformers.utils.logging.enable_default_handler()
225
+ transformers.utils.logging.enable_explicit_format()
226
+ logger.info(f"Training/evaluation parameters {training_args}")
227
+
228
+ # Set seed before initializing model.
229
+ set_seed(training_args.seed)
230
+
231
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
232
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
233
+ # (the dataset will be downloaded automatically from the datasets Hub
234
+ #
235
+ # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
236
+ # behavior (see below)
237
+ #
238
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
239
+ # download the dataset.
240
+ if data_args.dataset_name is not None:
241
+ # Downloading and loading a dataset from the hub.
242
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
243
+ if "validation" not in datasets.keys():
244
+ datasets["validation"] = load_dataset(
245
+ data_args.dataset_name,
246
+ data_args.dataset_config_name,
247
+ split=f"train[:{data_args.validation_split_percentage}%]",
248
+ cache_dir=model_args.cache_dir,
249
+ )
250
+ datasets["train"] = load_dataset(
251
+ data_args.dataset_name,
252
+ data_args.dataset_config_name,
253
+ split=f"train[{data_args.validation_split_percentage}%:]",
254
+ cache_dir=model_args.cache_dir,
255
+ )
256
+ else:
257
+ data_files = {}
258
+ if data_args.train_file is not None:
259
+ data_files["train"] = data_args.train_file
260
+ if data_args.validation_file is not None:
261
+ data_files["validation"] = data_args.validation_file
262
+ extension = data_args.train_file.split(".")[-1]
263
+ if extension == "txt":
264
+ extension = "text"
265
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
266
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
267
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
268
+
269
+
270
+ # TODO: delete after
271
+ datasets["train"] = datasets["train"].select(range(data_args.max_train_samples))
272
+ datasets["validation"] = datasets["validation"].select(range(data_args.max_val_samples))
273
+
274
+ # Load pretrained model and tokenizer
275
+ #
276
+ # Distributed training:
277
+ # The .from_pretrained methods guarantee that only one local process can concurrently
278
+ # download model & vocab.
279
+ config_kwargs = {
280
+ "cache_dir": model_args.cache_dir,
281
+ "revision": model_args.model_revision,
282
+ "use_auth_token": True if model_args.use_auth_token else None,
283
+ }
284
+ if model_args.config_name:
285
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
286
+ elif model_args.model_name_or_path:
287
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
288
+ else:
289
+ config = CONFIG_MAPPING[model_args.model_type]()
290
+ logger.warning("You are instantiating a new config instance from scratch.")
291
+
292
+ tokenizer_kwargs = {
293
+ "cache_dir": model_args.cache_dir,
294
+ "use_fast": model_args.use_fast_tokenizer,
295
+ "revision": model_args.model_revision,
296
+ "use_auth_token": True if model_args.use_auth_token else None,
297
+ }
298
+ if model_args.tokenizer_name:
299
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
300
+ elif model_args.model_name_or_path:
301
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
302
+ else:
303
+ raise ValueError(
304
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
305
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
306
+ )
307
+
308
+ if model_args.model_name_or_path:
309
+ model = AutoModelForMaskedLM.from_pretrained(
310
+ model_args.model_name_or_path,
311
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
312
+ config=config,
313
+ cache_dir=model_args.cache_dir,
314
+ revision=model_args.model_revision,
315
+ use_auth_token=True if model_args.use_auth_token else None,
316
+ )
317
+ else:
318
+ logger.info("Training new model from scratch")
319
+ model = AutoModelForMaskedLM.from_config(config)
320
+
321
+ model.resize_token_embeddings(len(tokenizer))
322
+
323
+ # Preprocessing the datasets.
324
+ # First we tokenize all the texts.
325
+ if training_args.do_train:
326
+ column_names = datasets["train"].column_names
327
+ else:
328
+ column_names = datasets["validation"].column_names
329
+ text_column_name = "text" if "text" in column_names else column_names[0]
330
+
331
+ if data_args.max_seq_length is None:
332
+ max_seq_length = tokenizer.model_max_length
333
+ if max_seq_length > 1024:
334
+ logger.warning(
335
+ f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
336
+ "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
337
+ )
338
+ max_seq_length = 1024
339
+ else:
340
+ if data_args.max_seq_length > tokenizer.model_max_length:
341
+ logger.warning(
342
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
343
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
344
+ )
345
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
346
+
347
+ if data_args.line_by_line:
348
+ # When using line_by_line, we just tokenize each nonempty line.
349
+ padding = "max_length" if data_args.pad_to_max_length else False
350
+
351
+ def tokenize_function(examples):
352
+ # Remove empty lines
353
+ examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
354
+ return tokenizer(
355
+ examples["text"],
356
+ padding=padding,
357
+ truncation=True,
358
+ max_length=max_seq_length,
359
+ # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
360
+ # receives the `special_tokens_mask`.
361
+ return_special_tokens_mask=True,
362
+ )
363
+
364
+ tokenized_datasets = datasets.map(
365
+ tokenize_function,
366
+ batch_size=1600,
367
+ batched=True,
368
+ num_proc=data_args.preprocessing_num_workers,
369
+ remove_columns=[text_column_name],
370
+ load_from_cache_file=not data_args.overwrite_cache,
371
+ )
372
+ else:
373
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
374
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
375
+ # efficient when it receives the `special_tokens_mask`.
376
+ def tokenize_function(examples):
377
+ return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
378
+
379
+ tokenized_datasets = datasets.map(
380
+ tokenize_function,
381
+ batched=True,
382
+ num_proc=data_args.preprocessing_num_workers,
383
+ remove_columns=column_names,
384
+ load_from_cache_file=not data_args.overwrite_cache,
385
+ )
386
+
387
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of
388
+ # max_seq_length.
389
+ def group_texts(examples):
390
+ # Concatenate all texts.
391
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
392
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
393
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
394
+ # customize this part to your needs.
395
+ total_length = (total_length // max_seq_length) * max_seq_length
396
+ # Split by chunks of max_len.
397
+ result = {
398
+ k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
399
+ for k, t in concatenated_examples.items()
400
+ }
401
+ return result
402
+
403
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
404
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
405
+ # might be slower to preprocess.
406
+ #
407
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
408
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
409
+
410
+ tokenized_datasets = tokenized_datasets.map(
411
+ group_texts,
412
+ batched=True,
413
+ num_proc=data_args.preprocessing_num_workers,
414
+ load_from_cache_file=not data_args.overwrite_cache,
415
+ )
416
+
417
+ if training_args.do_train:
418
+ if "train" not in tokenized_datasets:
419
+ raise ValueError("--do_train requires a train dataset")
420
+ train_dataset = tokenized_datasets["train"]
421
+ if data_args.max_train_samples is not None:
422
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
423
+
424
+ if training_args.do_eval:
425
+ if "validation" not in tokenized_datasets:
426
+ raise ValueError("--do_eval requires a validation dataset")
427
+ eval_dataset = tokenized_datasets["validation"]
428
+ if data_args.max_val_samples is not None:
429
+ eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
430
+
431
+ # Data collator
432
+ # This one will take care of randomly masking the tokens.
433
+ pad_to_multiple_of_8 = data_args.line_by_line and training_args.fp16 and not data_args.pad_to_max_length
434
+ data_collator = DataCollatorForLanguageModeling(
435
+ tokenizer=tokenizer,
436
+ mlm_probability=data_args.mlm_probability,
437
+ pad_to_multiple_of=8 if pad_to_multiple_of_8 else None,
438
+ )
439
+
440
+ # Initialize our Trainer
441
+ trainer = Trainer(
442
+ model=model,
443
+ args=training_args,
444
+ train_dataset=train_dataset if training_args.do_train else None,
445
+ eval_dataset=eval_dataset if training_args.do_eval else None,
446
+ tokenizer=tokenizer,
447
+ data_collator=data_collator,
448
+ )
449
+
450
+ # Training
451
+ if training_args.do_train:
452
+ if last_checkpoint is not None:
453
+ checkpoint = last_checkpoint
454
+ elif model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path):
455
+ checkpoint = model_args.model_name_or_path
456
+ else:
457
+ checkpoint = None
458
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
459
+ trainer.save_model() # Saves the tokenizer too for easy upload
460
+ metrics = train_result.metrics
461
+
462
+ max_train_samples = (
463
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
464
+ )
465
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
466
+
467
+ trainer.log_metrics("train", metrics)
468
+ trainer.save_metrics("train", metrics)
469
+ trainer.save_state()
470
+
471
+ # Evaluation
472
+ if training_args.do_eval:
473
+ logger.info("*** Evaluate ***")
474
+
475
+ metrics = trainer.evaluate()
476
+
477
+ max_val_samples = data_args.max_val_samples if data_args.max_val_samples is not None else len(eval_dataset)
478
+ metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
479
+ perplexity = math.exp(metrics["eval_loss"])
480
+ metrics["perplexity"] = perplexity
481
+
482
+ trainer.log_metrics("eval", metrics)
483
+ trainer.save_metrics("eval", metrics)
484
+
485
+
486
+ def _mp_fn(index):
487
+ # For xla_spawn (TPUs)
488
+ main()
489
+
490
+
491
+ if __name__ == "__main__":
492
+ main()
run_mlm.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ python ./run_mlm.py \
3
+ --output_dir="./runs" \
4
+ --model_type="roberta" \
5
+ --config_name="patrickvonplaten/german-roberta-base" \
6
+ --tokenizer_name="patrickvonplaten/german-roberta-base" \
7
+ --dataset_name="oscar" \
8
+ --dataset_config_name="unshuffled_deduplicated_de" \
9
+ --max_seq_length="128" \
10
+ --per_gpu_train_batch_size="64" \
11
+ --learning_rate="1e-4" \
12
+ --warmup_steps="1000" \
13
+ --logging_steps="10" \
14
+ --max_train_samples=5000 \
15
+ --max_val_samples=500 \
16
+ --do_train \
17
+ --do_eval \
18
+ --fp16
19
+ # --preprocessing_num_workers="16" \