pere commited on
Commit
a7ab6f8
·
1 Parent(s): 3a8c330

more data

Browse files
.gitattributes CHANGED
@@ -31,3 +31,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
  *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ *.txt filter=lfs diff=lfs merge=lfs -text
35
+ *.csv filter=lfs diff=lfs merge=lfs -text
data/._data_text_default-d50be04fe2b594a9_0.0.0_21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad.lock ADDED
File without changes
data/mnli_no_for_simcse.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3a3d7fb9905ec57c1bc62a578c88002dc9a20f3ebcbd2c675c1a549172aadf8
3
+ size 28826475
data/nli_for_simcse.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0747687ec3594fa449d2004fd3757a56c24bf5f7428976fb5b67176775a68d48
3
+ size 48978197
data/nor_news_1998_2019_sentences_1M.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a11b0487ea13419b7902c6890723469a3c540ec0143053e82312642509adfaae
3
+ size 91417493
data/text/default-d50be04fe2b594a9/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad.incomplete_info.lock ADDED
File without changes
data/text/default-d50be04fe2b594a9/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad/cache-9a4514157724681b.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:901eb82696fe0cb5119b6898332be1df2c3035af313e4cd8fff9cd855cd6b1bf
3
+ size 325905944
data/text/default-d50be04fe2b594a9/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad/dataset_info.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"description": "", "citation": "", "homepage": "", "license": "", "features": {"text": {"dtype": "string", "_type": "Value"}}, "builder_name": "text", "config_name": "default", "version": {"version_str": "0.0.0", "major": 0, "minor": 0, "patch": 0}, "splits": {"train": {"name": "train", "num_bytes": 123038621, "num_examples": 1000000, "dataset_name": "text"}}, "download_checksums": {"/home/perk/models/SimCSE-test/data/wiki1m_for_simcse.txt": {"num_bytes": 120038621, "checksum": "7b1825863a99aa76479b0456f7c210539dfaeeb69598b41fb4de4f524dd5a706"}}, "download_size": 120038621, "dataset_size": 123038621, "size_in_bytes": 243077242}
data/text/default-d50be04fe2b594a9/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad/text-train.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4cd7d749ccccf59a58dc1f2c4349440ee844c40554214559b7a1f91638f6051
3
+ size 123059952
data/text/default-d50be04fe2b594a9/0.0.0/21a506d1b2b34316b1e82d0bd79066905d846e5d7e619823c0dd338d6f1fa6ad_builder.lock ADDED
File without changes
data/wiki1m_for_simcse.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b1825863a99aa76479b0456f7c210539dfaeeb69598b41fb4de4f524dd5a706
3
+ size 120038621
run_sup_example.sh ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # In this example, we show how to train SimCSE using multiple GPU cards and PyTorch's distributed data parallel on supervised NLI dataset.
4
+ # Set how many GPUs to use
5
+
6
+ NUM_GPU=4
7
+
8
+ # Randomly set a port number
9
+ # If you encounter "address already used" error, just run again or manually set an available port id.
10
+ PORT_ID=$(expr $RANDOM + 1000)
11
+
12
+ # Allow multiple threads
13
+ export OMP_NUM_THREADS=8
14
+
15
+ # Use distributed data parallel
16
+ # If you only want to use one card, uncomment the following line and comment the line with "torch.distributed.launch"
17
+ # python train.py \
18
+ python -m torch.distributed.launch --nproc_per_node $NUM_GPU --master_port $PORT_ID train.py \
19
+ --model_name_or_path bert-base-uncased \
20
+ --train_file data/nli_for_simcse.csv \
21
+ --output_dir result/my-sup-simcse-bert-base-uncased \
22
+ --num_train_epochs 3 \
23
+ --per_device_train_batch_size 128 \
24
+ --learning_rate 5e-5 \
25
+ --max_seq_length 32 \
26
+ --evaluation_strategy steps \
27
+ --metric_for_best_model stsb_spearman \
28
+ --load_best_model_at_end \
29
+ --eval_steps 125 \
30
+ --pooler_type cls \
31
+ --overwrite_output_dir \
32
+ --temp 0.05 \
33
+ --do_train \
34
+ --do_eval \
35
+ --fp16 \
36
+ "$@"
run_unsup_example.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # In this example, we show how to train SimCSE on unsupervised Wikipedia data.
4
+ # If you want to train it with multiple GPU cards, see "run_sup_example.sh"
5
+ # about how to use PyTorch's distributed data parallel.
6
+
7
+ python3 ../../SimCSE/train.py \
8
+ --model_name_or_path NbAiLab/nb-bert-base \
9
+ --train_file data/wiki1m_for_simcse.txt \
10
+ --output_dir result/unsup-simcse-nb-bert-bert-base \
11
+ --num_train_epochs 1 \
12
+ --per_device_train_batch_size 64 \
13
+ --learning_rate 3e-5 \
14
+ --max_seq_length 32 \
15
+ --evaluation_strategy steps \
16
+ --metric_for_best_model stsb_spearman \
17
+ --load_best_model_at_end \
18
+ --eval_steps 125 \
19
+ --pooler_type cls \
20
+ --mlp_only_train \
21
+ --overwrite_output_dir \
22
+ --temp 0.05 \
23
+ --do_train \
24
+ --do_eval \
25
+ "$@"
26
+
runs/Oct21_12-19-17_t1v-n-ca292eb3-w-0/1666354862.6816092/events.out.tfevents.1666354862.t1v-n-ca292eb3-w-0.70028.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:16a4f0f653d3a3285569eeba72f8b6dc920644a279d2e6482dcda7034909ef7b
3
+ size 3160
runs/Oct21_12-19-17_t1v-n-ca292eb3-w-0/events.out.tfevents.1666354862.t1v-n-ca292eb3-w-0.70028.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f2916e775c808f53dfe16e50595afb9f42859bcd51257d828e227c417ed45e5b
3
+ size 2523
train.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ import sys
5
+ from dataclasses import dataclass, field
6
+ from typing import Optional, Union, List, Dict, Tuple
7
+ import torch
8
+ import collections
9
+ import random
10
+
11
+ from datasets import load_dataset
12
+
13
+ import transformers
14
+ from transformers import (
15
+ CONFIG_MAPPING,
16
+ MODEL_FOR_MASKED_LM_MAPPING,
17
+ AutoConfig,
18
+ AutoModelForMaskedLM,
19
+ AutoModelForSequenceClassification,
20
+ AutoTokenizer,
21
+ DataCollatorForLanguageModeling,
22
+ DataCollatorWithPadding,
23
+ HfArgumentParser,
24
+ Trainer,
25
+ TrainingArguments,
26
+ default_data_collator,
27
+ set_seed,
28
+ EvalPrediction,
29
+ BertModel,
30
+ BertForPreTraining,
31
+ RobertaModel
32
+ )
33
+ from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase
34
+ from transformers.trainer_utils import is_main_process
35
+ from transformers.data.data_collator import DataCollatorForLanguageModeling
36
+ from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available
37
+ from simcse.models import RobertaForCL, BertForCL
38
+ from simcse.trainers import CLTrainer
39
+
40
+ logger = logging.getLogger(__name__)
41
+ MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
42
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
43
+
44
+ @dataclass
45
+ class ModelArguments:
46
+ """
47
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
48
+ """
49
+
50
+ # Huggingface's original arguments
51
+ model_name_or_path: Optional[str] = field(
52
+ default=None,
53
+ metadata={
54
+ "help": "The model checkpoint for weights initialization."
55
+ "Don't set if you want to train a model from scratch."
56
+ },
57
+ )
58
+ model_type: Optional[str] = field(
59
+ default=None,
60
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
61
+ )
62
+ config_name: Optional[str] = field(
63
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
64
+ )
65
+ tokenizer_name: Optional[str] = field(
66
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
67
+ )
68
+ cache_dir: Optional[str] = field(
69
+ default=None,
70
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
71
+ )
72
+ use_fast_tokenizer: bool = field(
73
+ default=True,
74
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
75
+ )
76
+ model_revision: str = field(
77
+ default="main",
78
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
79
+ )
80
+ use_auth_token: bool = field(
81
+ default=False,
82
+ metadata={
83
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
84
+ "with private models)."
85
+ },
86
+ )
87
+
88
+ # SimCSE's arguments
89
+ temp: float = field(
90
+ default=0.05,
91
+ metadata={
92
+ "help": "Temperature for softmax."
93
+ }
94
+ )
95
+ pooler_type: str = field(
96
+ default="cls",
97
+ metadata={
98
+ "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)."
99
+ }
100
+ )
101
+ hard_negative_weight: float = field(
102
+ default=0,
103
+ metadata={
104
+ "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)."
105
+ }
106
+ )
107
+ do_mlm: bool = field(
108
+ default=False,
109
+ metadata={
110
+ "help": "Whether to use MLM auxiliary objective."
111
+ }
112
+ )
113
+ mlm_weight: float = field(
114
+ default=0.1,
115
+ metadata={
116
+ "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)."
117
+ }
118
+ )
119
+ mlp_only_train: bool = field(
120
+ default=False,
121
+ metadata={
122
+ "help": "Use MLP only during training"
123
+ }
124
+ )
125
+
126
+
127
+ @dataclass
128
+ class DataTrainingArguments:
129
+ """
130
+ Arguments pertaining to what data we are going to input our model for training and eval.
131
+ """
132
+
133
+ # Huggingface's original arguments.
134
+ dataset_name: Optional[str] = field(
135
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
136
+ )
137
+ dataset_config_name: Optional[str] = field(
138
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
139
+ )
140
+ overwrite_cache: bool = field(
141
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
142
+ )
143
+ validation_split_percentage: Optional[int] = field(
144
+ default=5,
145
+ metadata={
146
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
147
+ },
148
+ )
149
+ preprocessing_num_workers: Optional[int] = field(
150
+ default=None,
151
+ metadata={"help": "The number of processes to use for the preprocessing."},
152
+ )
153
+
154
+ # SimCSE's arguments
155
+ train_file: Optional[str] = field(
156
+ default=None,
157
+ metadata={"help": "The training data file (.txt or .csv)."}
158
+ )
159
+ max_seq_length: Optional[int] = field(
160
+ default=32,
161
+ metadata={
162
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
163
+ "than this will be truncated."
164
+ },
165
+ )
166
+ pad_to_max_length: bool = field(
167
+ default=False,
168
+ metadata={
169
+ "help": "Whether to pad all samples to `max_seq_length`. "
170
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
171
+ },
172
+ )
173
+ mlm_probability: float = field(
174
+ default=0.15,
175
+ metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"}
176
+ )
177
+
178
+ def __post_init__(self):
179
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
180
+ raise ValueError("Need either a dataset name or a training/validation file.")
181
+ else:
182
+ if self.train_file is not None:
183
+ extension = self.train_file.split(".")[-1]
184
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
185
+
186
+
187
+ @dataclass
188
+ class OurTrainingArguments(TrainingArguments):
189
+ # Evaluation
190
+ ## By default, we evaluate STS (dev) during training (for selecting best checkpoints) and evaluate
191
+ ## both STS and transfer tasks (dev) at the end of training. Using --eval_transfer will allow evaluating
192
+ ## both STS and transfer tasks (dev) during training.
193
+ eval_transfer: bool = field(
194
+ default=False,
195
+ metadata={"help": "Evaluate transfer task dev sets (in validation)."}
196
+ )
197
+
198
+ @cached_property
199
+ @torch_required
200
+ def _setup_devices(self) -> "torch.device":
201
+ logger.info("PyTorch: setting up devices")
202
+ if self.no_cuda:
203
+ device = torch.device("cpu")
204
+ self._n_gpu = 0
205
+ elif is_torch_tpu_available():
206
+ import torch_xla.core.xla_model as xm
207
+ device = xm.xla_device()
208
+ self._n_gpu = 0
209
+ elif self.local_rank == -1:
210
+ # if n_gpu is > 1 we'll use nn.DataParallel.
211
+ # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
212
+ # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
213
+ # trigger an error that a device index is missing. Index 0 takes into account the
214
+ # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
215
+ # will use the first GPU in that env, i.e. GPU#1
216
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
217
+ # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
218
+ # the default value.
219
+ self._n_gpu = torch.cuda.device_count()
220
+ else:
221
+ # Here, we'll use torch.distributed.
222
+ # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
223
+ #
224
+ # deepspeed performs its own DDP internally, and requires the program to be started with:
225
+ # deepspeed ./program.py
226
+ # rather than:
227
+ # python -m torch.distributed.launch --nproc_per_node=2 ./program.py
228
+ if self.deepspeed:
229
+ from .integrations import is_deepspeed_available
230
+
231
+ if not is_deepspeed_available():
232
+ raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
233
+ import deepspeed
234
+
235
+ deepspeed.init_distributed()
236
+ else:
237
+ torch.distributed.init_process_group(backend="nccl")
238
+ device = torch.device("cuda", self.local_rank)
239
+ self._n_gpu = 1
240
+
241
+ if device.type == "cuda":
242
+ torch.cuda.set_device(device)
243
+
244
+ return device
245
+
246
+
247
+ def main():
248
+ # See all possible arguments in src/transformers/training_args.py
249
+ # or by passing the --help flag to this script.
250
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
251
+
252
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments))
253
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
254
+ # If we pass only one argument to the script and it's the path to a json file,
255
+ # let's parse it to get our arguments.
256
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
257
+ else:
258
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
259
+
260
+ if (
261
+ os.path.exists(training_args.output_dir)
262
+ and os.listdir(training_args.output_dir)
263
+ and training_args.do_train
264
+ and not training_args.overwrite_output_dir
265
+ ):
266
+ raise ValueError(
267
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
268
+ "Use --overwrite_output_dir to overcome."
269
+ )
270
+
271
+ # Setup logging
272
+ logging.basicConfig(
273
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
274
+ datefmt="%m/%d/%Y %H:%M:%S",
275
+ level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
276
+ )
277
+
278
+ # Log on each process the small summary:
279
+ logger.warning(
280
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
281
+ + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
282
+ )
283
+ # Set the verbosity to info of the Transformers logger (on main process only):
284
+ if is_main_process(training_args.local_rank):
285
+ transformers.utils.logging.set_verbosity_info()
286
+ transformers.utils.logging.enable_default_handler()
287
+ transformers.utils.logging.enable_explicit_format()
288
+ logger.info("Training/evaluation parameters %s", training_args)
289
+
290
+ # Set seed before initializing model.
291
+ set_seed(training_args.seed)
292
+
293
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
294
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
295
+ # (the dataset will be downloaded automatically from the datasets Hub
296
+ #
297
+ # For CSV/JSON files, this script will use the column called 'text' or the first column. You can easily tweak this
298
+ # behavior (see below)
299
+ #
300
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
301
+ # download the dataset.
302
+ data_files = {}
303
+ if data_args.train_file is not None:
304
+ data_files["train"] = data_args.train_file
305
+ extension = data_args.train_file.split(".")[-1]
306
+ if extension == "txt":
307
+ extension = "text"
308
+ if extension == "csv":
309
+ datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/", delimiter="\t" if "tsv" in data_args.train_file else ",")
310
+ else:
311
+ datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/")
312
+
313
+ # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
314
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
315
+
316
+ # Load pretrained model and tokenizer
317
+ #
318
+ # Distributed training:
319
+ # The .from_pretrained methods guarantee that only one local process can concurrently
320
+ # download model & vocab.
321
+ config_kwargs = {
322
+ "cache_dir": model_args.cache_dir,
323
+ "revision": model_args.model_revision,
324
+ "use_auth_token": True if model_args.use_auth_token else None,
325
+ }
326
+ if model_args.config_name:
327
+ config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
328
+ elif model_args.model_name_or_path:
329
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
330
+ else:
331
+ config = CONFIG_MAPPING[model_args.model_type]()
332
+ logger.warning("You are instantiating a new config instance from scratch.")
333
+
334
+ tokenizer_kwargs = {
335
+ "cache_dir": model_args.cache_dir,
336
+ "use_fast": model_args.use_fast_tokenizer,
337
+ "revision": model_args.model_revision,
338
+ "use_auth_token": True if model_args.use_auth_token else None,
339
+ }
340
+ if model_args.tokenizer_name:
341
+ tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
342
+ elif model_args.model_name_or_path:
343
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
344
+ else:
345
+ raise ValueError(
346
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
347
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
348
+ )
349
+
350
+ if model_args.model_name_or_path:
351
+ if 'roberta' in model_args.model_name_or_path:
352
+ model = RobertaForCL.from_pretrained(
353
+ model_args.model_name_or_path,
354
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
355
+ config=config,
356
+ cache_dir=model_args.cache_dir,
357
+ revision=model_args.model_revision,
358
+ use_auth_token=True if model_args.use_auth_token else None,
359
+ model_args=model_args
360
+ )
361
+ elif 'bert' in model_args.model_name_or_path:
362
+ model = BertForCL.from_pretrained(
363
+ model_args.model_name_or_path,
364
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
365
+ config=config,
366
+ cache_dir=model_args.cache_dir,
367
+ revision=model_args.model_revision,
368
+ use_auth_token=True if model_args.use_auth_token else None,
369
+ model_args=model_args
370
+ )
371
+ if model_args.do_mlm:
372
+ pretrained_model = BertForPreTraining.from_pretrained(model_args.model_name_or_path)
373
+ model.lm_head.load_state_dict(pretrained_model.cls.predictions.state_dict())
374
+ else:
375
+ raise NotImplementedError
376
+ else:
377
+ raise NotImplementedError
378
+ logger.info("Training new model from scratch")
379
+ model = AutoModelForMaskedLM.from_config(config)
380
+
381
+ model.resize_token_embeddings(len(tokenizer))
382
+
383
+ # Prepare features
384
+ column_names = datasets["train"].column_names
385
+ sent2_cname = None
386
+ if len(column_names) == 2:
387
+ # Pair datasets
388
+ sent0_cname = column_names[0]
389
+ sent1_cname = column_names[1]
390
+ elif len(column_names) == 3:
391
+ # Pair datasets with hard negatives
392
+ sent0_cname = column_names[0]
393
+ sent1_cname = column_names[1]
394
+ sent2_cname = column_names[2]
395
+ elif len(column_names) == 1:
396
+ # Unsupervised datasets
397
+ sent0_cname = column_names[0]
398
+ sent1_cname = column_names[0]
399
+ else:
400
+ raise NotImplementedError
401
+
402
+ def prepare_features(examples):
403
+ # padding = longest (default)
404
+ # If no sentence in the batch exceed the max length, then use
405
+ # the max sentence length in the batch, otherwise use the
406
+ # max sentence length in the argument and truncate those that
407
+ # exceed the max length.
408
+ # padding = max_length (when pad_to_max_length, for pressure test)
409
+ # All sentences are padded/truncated to data_args.max_seq_length.
410
+ total = len(examples[sent0_cname])
411
+
412
+ # Avoid "None" fields
413
+ for idx in range(total):
414
+ if examples[sent0_cname][idx] is None:
415
+ examples[sent0_cname][idx] = " "
416
+ if examples[sent1_cname][idx] is None:
417
+ examples[sent1_cname][idx] = " "
418
+
419
+ sentences = examples[sent0_cname] + examples[sent1_cname]
420
+
421
+ # If hard negative exists
422
+ if sent2_cname is not None:
423
+ for idx in range(total):
424
+ if examples[sent2_cname][idx] is None:
425
+ examples[sent2_cname][idx] = " "
426
+ sentences += examples[sent2_cname]
427
+
428
+ sent_features = tokenizer(
429
+ sentences,
430
+ max_length=data_args.max_seq_length,
431
+ truncation=True,
432
+ padding="max_length" if data_args.pad_to_max_length else False,
433
+ )
434
+
435
+ features = {}
436
+ if sent2_cname is not None:
437
+ for key in sent_features:
438
+ features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)]
439
+ else:
440
+ for key in sent_features:
441
+ features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
442
+
443
+ return features
444
+
445
+ if training_args.do_train:
446
+ train_dataset = datasets["train"].map(
447
+ prepare_features,
448
+ batched=True,
449
+ num_proc=data_args.preprocessing_num_workers,
450
+ remove_columns=column_names,
451
+ load_from_cache_file=not data_args.overwrite_cache,
452
+ )
453
+
454
+ # Data collator
455
+ @dataclass
456
+ class OurDataCollatorWithPadding:
457
+
458
+ tokenizer: PreTrainedTokenizerBase
459
+ padding: Union[bool, str, PaddingStrategy] = True
460
+ max_length: Optional[int] = None
461
+ pad_to_multiple_of: Optional[int] = None
462
+ mlm: bool = True
463
+ mlm_probability: float = data_args.mlm_probability
464
+
465
+ def __call__(self, features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
466
+ special_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'mlm_input_ids', 'mlm_labels']
467
+ bs = len(features)
468
+ if bs > 0:
469
+ num_sent = len(features[0]['input_ids'])
470
+ else:
471
+ return
472
+ flat_features = []
473
+ for feature in features:
474
+ for i in range(num_sent):
475
+ flat_features.append({k: feature[k][i] if k in special_keys else feature[k] for k in feature})
476
+
477
+ batch = self.tokenizer.pad(
478
+ flat_features,
479
+ padding=self.padding,
480
+ max_length=self.max_length,
481
+ pad_to_multiple_of=self.pad_to_multiple_of,
482
+ return_tensors="pt",
483
+ )
484
+ if model_args.do_mlm:
485
+ batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"])
486
+
487
+ batch = {k: batch[k].view(bs, num_sent, -1) if k in special_keys else batch[k].view(bs, num_sent, -1)[:, 0] for k in batch}
488
+
489
+ if "label" in batch:
490
+ batch["labels"] = batch["label"]
491
+ del batch["label"]
492
+ if "label_ids" in batch:
493
+ batch["labels"] = batch["label_ids"]
494
+ del batch["label_ids"]
495
+
496
+ return batch
497
+
498
+ def mask_tokens(
499
+ self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
500
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
501
+ """
502
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
503
+ """
504
+ inputs = inputs.clone()
505
+ labels = inputs.clone()
506
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
507
+ probability_matrix = torch.full(labels.shape, self.mlm_probability)
508
+ if special_tokens_mask is None:
509
+ special_tokens_mask = [
510
+ self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
511
+ ]
512
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
513
+ else:
514
+ special_tokens_mask = special_tokens_mask.bool()
515
+
516
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
517
+ masked_indices = torch.bernoulli(probability_matrix).bool()
518
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
519
+
520
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
521
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
522
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
523
+
524
+ # 10% of the time, we replace masked input tokens with random word
525
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
526
+ random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
527
+ inputs[indices_random] = random_words[indices_random]
528
+
529
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
530
+ return inputs, labels
531
+
532
+ data_collator = default_data_collator if data_args.pad_to_max_length else OurDataCollatorWithPadding(tokenizer)
533
+
534
+ trainer = CLTrainer(
535
+ model=model,
536
+ args=training_args,
537
+ train_dataset=train_dataset if training_args.do_train else None,
538
+ tokenizer=tokenizer,
539
+ data_collator=data_collator,
540
+ )
541
+ trainer.model_args = model_args
542
+
543
+ # Training
544
+ if training_args.do_train:
545
+ model_path = (
546
+ model_args.model_name_or_path
547
+ if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path))
548
+ else None
549
+ )
550
+ train_result = trainer.train(model_path=model_path)
551
+ trainer.save_model() # Saves the tokenizer too for easy upload
552
+
553
+ output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
554
+ if trainer.is_world_process_zero():
555
+ with open(output_train_file, "w") as writer:
556
+ logger.info("***** Train results *****")
557
+ for key, value in sorted(train_result.metrics.items()):
558
+ logger.info(f" {key} = {value}")
559
+ writer.write(f"{key} = {value}\n")
560
+
561
+ # Need to save the state, since Trainer.save_model saves only the tokenizer with the model
562
+ trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
563
+
564
+ # Evaluation
565
+ results = {}
566
+ if training_args.do_eval:
567
+ logger.info("*** Evaluate ***")
568
+ results = trainer.evaluate(eval_senteval_transfer=True)
569
+
570
+ output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
571
+ if trainer.is_world_process_zero():
572
+ with open(output_eval_file, "w") as writer:
573
+ logger.info("***** Eval results *****")
574
+ for key, value in sorted(results.items()):
575
+ logger.info(f" {key} = {value}")
576
+ writer.write(f"{key} = {value}\n")
577
+
578
+ return results
579
+
580
+ def _mp_fn(index):
581
+ # For xla_spawn (TPUs)
582
+ main()
583
+
584
+
585
+ if __name__ == "__main__":
586
+ main()