Wonder-Griffin commited on
Commit
9615cd6
·
verified ·
1 Parent(s): 82791c6

Create DLME

Browse files
Files changed (1) hide show
  1. DLME +358 -0
DLME ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, CTRL, BERT, RoBERTa, XLNet).
3
+ GPT, GPT-2 and CTRL are fine-tuned using a causal language modeling (CLM) loss. BERT and RoBERTa are fine-tuned
4
+ using a masked language modeling (MLM) loss. XLNet is fine-tuned using a permutation language modeling (PLM) loss.
5
+ """
6
+
7
+ import logging
8
+ import math
9
+ import os
10
+ from dataclasses import dataclass, field
11
+ from glob import glob
12
+ from typing import Optional
13
+
14
+ from torch.utils.data import ConcatDataset
15
+
16
+ import transformers
17
+ from transformers import (
18
+ CONFIG_MAPPING,
19
+ MODEL_WITH_LM_HEAD_MAPPING,
20
+ AutoConfig,
21
+ AutoModelWithLMHead,
22
+ AutoTokenizer,
23
+ DataCollatorForLanguageModeling,
24
+ DataCollatorForPermutationLanguageModeling,
25
+ DataCollatorForWholeWordMask,
26
+ HfArgumentParser,
27
+ LineByLineTextDataset,
28
+ LineByLineWithRefDataset,
29
+ PreTrainedTokenizer,
30
+ TextDataset,
31
+ Trainer,
32
+ TrainingArguments,
33
+ set_seed,
34
+ )
35
+ from transformers.trainer_utils import is_main_process
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
42
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
43
+
44
+
45
+ @dataclass
46
+ class ModelArguments:
47
+ """
48
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
49
+ """
50
+
51
+ model_name_or_path: Optional[str] = field(
52
+ default=None,
53
+ metadata={
54
+ "help": (
55
+ "The model checkpoint for weights initialization. Leave None if you want to train a model from"
56
+ " scratch."
57
+ )
58
+ },
59
+ )
60
+ model_type: Optional[str] = field(
61
+ default=None,
62
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
63
+ )
64
+ config_name: Optional[str] = field(
65
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
66
+ )
67
+ tokenizer_name: Optional[str] = field(
68
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
69
+ )
70
+ cache_dir: Optional[str] = field(
71
+ default=None,
72
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
73
+ )
74
+
75
+
76
+ @dataclass
77
+ class DataTrainingArguments:
78
+ """
79
+ Arguments pertaining to what data we are going to input our model for training and eval.
80
+ """
81
+
82
+ train_data_file: Optional[str] = field(
83
+ default=None, metadata={"help": "The input training data file (a text file)."}
84
+ )
85
+ train_data_files: Optional[str] = field(
86
+ default=None,
87
+ metadata={
88
+ "help": (
89
+ "The input training data files (multiple files in glob format). "
90
+ "Very often splitting large files to smaller files can prevent tokenizer going out of memory"
91
+ )
92
+ },
93
+ )
94
+ eval_data_file: Optional[str] = field(
95
+ default=None,
96
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
97
+ )
98
+ train_ref_file: Optional[str] = field(
99
+ default=None,
100
+ metadata={"help": "An optional input train ref data file for whole word mask in Chinese."},
101
+ )
102
+ eval_ref_file: Optional[str] = field(
103
+ default=None,
104
+ metadata={"help": "An optional input eval ref data file for whole word mask in Chinese."},
105
+ )
106
+ line_by_line: bool = field(
107
+ default=False,
108
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
109
+ )
110
+
111
+ mlm: bool = field(
112
+ default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
113
+ )
114
+ whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."})
115
+ mlm_probability: float = field(
116
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
117
+ )
118
+ plm_probability: float = field(
119
+ default=1 / 6,
120
+ metadata={
121
+ "help": (
122
+ "Ratio of length of a span of masked tokens to surrounding context length for permutation language"
123
+ " modeling."
124
+ )
125
+ },
126
+ )
127
+ max_span_length: int = field(
128
+ default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."}
129
+ )
130
+
131
+ block_size: int = field(
132
+ default=-1,
133
+ metadata={
134
+ "help": (
135
+ "Optional input sequence length after tokenization. "
136
+ "The training dataset will be truncated in block of this size for training."
137
+ "Default to the model max input length for single sentence inputs (take into account special tokens)."
138
+ )
139
+ },
140
+ )
141
+ overwrite_cache: bool = field(
142
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
143
+ )
144
+
145
+
146
+ def get_dataset(
147
+ args: DataTrainingArguments,
148
+ tokenizer: PreTrainedTokenizer,
149
+ evaluate: bool = False,
150
+ cache_dir: Optional[str] = None,
151
+ ):
152
+ def _dataset(file_path, ref_path=None):
153
+ if args.line_by_line:
154
+ if ref_path is not None:
155
+ if not args.whole_word_mask or not args.mlm:
156
+ raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
157
+ return LineByLineWithRefDataset(
158
+ tokenizer=tokenizer,
159
+ file_path=file_path,
160
+ block_size=args.block_size,
161
+ ref_path=ref_path,
162
+ )
163
+
164
+ return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
165
+ else:
166
+ return TextDataset(
167
+ tokenizer=tokenizer,
168
+ file_path=file_path,
169
+ block_size=args.block_size,
170
+ overwrite_cache=args.overwrite_cache,
171
+ cache_dir=cache_dir,
172
+ )
173
+
174
+ if evaluate:
175
+ return _dataset(args.eval_data_file, args.eval_ref_file)
176
+ elif args.train_data_files:
177
+ return ConcatDataset([_dataset(f) for f in glob(args.train_data_files)])
178
+ else:
179
+ return _dataset(args.train_data_file, args.train_ref_file)
180
+
181
+
182
+ def main():
183
+ # See all possible arguments in src/transformers/training_args.py
184
+ # or by passing the --help flag to this script.
185
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
186
+
187
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
188
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
189
+
190
+ if data_args.eval_data_file is None and training_args.do_eval:
191
+ raise ValueError(
192
+ "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
193
+ "or remove the --do_eval argument."
194
+ )
195
+ if (
196
+ os.path.exists(training_args.output_dir)
197
+ and os.listdir(training_args.output_dir)
198
+ and training_args.do_train
199
+ and not training_args.overwrite_output_dir
200
+ ):
201
+ raise ValueError(
202
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. Use"
203
+ " --overwrite_output_dir to overcome."
204
+ )
205
+
206
+ # Setup logging
207
+ logging.basicConfig(
208
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
209
+ datefmt="%m/%d/%Y %H:%M:%S",
210
+ level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
211
+ )
212
+ logger.warning(
213
+ "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
214
+ training_args.local_rank,
215
+ training_args.device,
216
+ training_args.n_gpu,
217
+ bool(training_args.local_rank != -1),
218
+ training_args.fp16,
219
+ )
220
+ # Set the verbosity to info of the Transformers logger (on main process only):
221
+ if is_main_process(training_args.local_rank):
222
+ transformers.utils.logging.set_verbosity_info()
223
+ transformers.utils.logging.enable_default_handler()
224
+ transformers.utils.logging.enable_explicit_format()
225
+ logger.info("Training/evaluation parameters %s", training_args)
226
+
227
+ # Set seed
228
+ set_seed(training_args.seed)
229
+
230
+ # Load pretrained model and tokenizer
231
+ #
232
+ # Distributed training:
233
+ # The .from_pretrained methods guarantee that only one local process can concurrently
234
+ # download model & vocab.
235
+
236
+ if model_args.config_name:
237
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
238
+ elif model_args.model_name_or_path:
239
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
240
+ else:
241
+ config = CONFIG_MAPPING[model_args.model_type]()
242
+ logger.warning("You are instantiating a new config instance from scratch.")
243
+
244
+ if model_args.tokenizer_name:
245
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
246
+ elif model_args.model_name_or_path:
247
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
248
+ else:
249
+ raise ValueError(
250
+ "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another"
251
+ " script, save it,and load it from here, using --tokenizer_name"
252
+ )
253
+
254
+ if model_args.model_name_or_path:
255
+ model = AutoModelWithLMHead.from_pretrained(
256
+ model_args.model_name_or_path,
257
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
258
+ config=config,
259
+ cache_dir=model_args.cache_dir,
260
+ )
261
+ else:
262
+ logger.info("Training new model from scratch")
263
+ model = AutoModelWithLMHead.from_config(config)
264
+
265
+ model.resize_token_embeddings(len(tokenizer))
266
+
267
+ if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
268
+ raise ValueError(
269
+ "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the "
270
+ "--mlm flag (masked language modeling)."
271
+ )
272
+
273
+ if data_args.block_size <= 0:
274
+ data_args.block_size = tokenizer.max_len
275
+ # Our input block size will be the max possible for the model
276
+ else:
277
+ data_args.block_size = min(data_args.block_size, tokenizer.max_len)
278
+
279
+ # Get datasets
280
+
281
+ train_dataset = (
282
+ get_dataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
283
+ )
284
+ eval_dataset = (
285
+ get_dataset(data_args, tokenizer=tokenizer, evaluate=True, cache_dir=model_args.cache_dir)
286
+ if training_args.do_eval
287
+ else None
288
+ )
289
+ if config.model_type == "xlnet":
290
+ data_collator = DataCollatorForPermutationLanguageModeling(
291
+ tokenizer=tokenizer,
292
+ plm_probability=data_args.plm_probability,
293
+ max_span_length=data_args.max_span_length,
294
+ )
295
+ else:
296
+ if data_args.mlm and data_args.whole_word_mask:
297
+ data_collator = DataCollatorForWholeWordMask(
298
+ tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
299
+ )
300
+ else:
301
+ data_collator = DataCollatorForLanguageModeling(
302
+ tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
303
+ )
304
+
305
+ # Initialize our Trainer
306
+ trainer = Trainer(
307
+ model=model,
308
+ args=training_args,
309
+ data_collator=data_collator,
310
+ train_dataset=train_dataset,
311
+ eval_dataset=eval_dataset,
312
+ prediction_loss_only=True,
313
+ )
314
+
315
+ # Training
316
+ if training_args.do_train:
317
+ model_path = (
318
+ model_args.model_name_or_path
319
+ if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)
320
+ else None
321
+ )
322
+ trainer.train(model_path=model_path)
323
+ trainer.save_model()
324
+ # For convenience, we also re-save the tokenizer to the same directory,
325
+ # so that you can share your model easily on huggingface.co/models =)
326
+ if trainer.is_world_master():
327
+ tokenizer.save_pretrained(training_args.output_dir)
328
+
329
+ # Evaluation
330
+ results = {}
331
+ if training_args.do_eval:
332
+ logger.info("*** Evaluate ***")
333
+
334
+ eval_output = trainer.evaluate()
335
+
336
+ perplexity = math.exp(eval_output["eval_loss"])
337
+ result = {"perplexity": perplexity}
338
+
339
+ output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt")
340
+ if trainer.is_world_master():
341
+ with open(output_eval_file, "w") as writer:
342
+ logger.info("***** Eval results *****")
343
+ for key in sorted(result.keys()):
344
+ logger.info(" %s = %s", key, str(result[key]))
345
+ writer.write("%s = %s\n" % (key, str(result[key])))
346
+
347
+ results.update(result)
348
+
349
+ return results
350
+
351
+
352
+ def _mp_fn(index):
353
+ # For xla_spawn (TPUs)
354
+ main()
355
+
356
+
357
+ if __name__ == "__main__":
358
+ main()