| | import logging |
| | import math |
| | import os |
| | import sys |
| | from dataclasses import dataclass, field |
| | from typing import Optional, Union, List, Dict, Tuple |
| | import torch |
| | import collections |
| | import random |
| |
|
| | from datasets import load_dataset |
| |
|
| | import transformers |
| | from transformers import ( |
| | CONFIG_MAPPING, |
| | MODEL_FOR_MASKED_LM_MAPPING, |
| | AutoConfig, |
| | AutoModelForMaskedLM, |
| | AutoModelForSequenceClassification, |
| | AutoTokenizer, |
| | DataCollatorForLanguageModeling, |
| | DataCollatorWithPadding, |
| | HfArgumentParser, |
| | Trainer, |
| | TrainingArguments, |
| | default_data_collator, |
| | set_seed, |
| | EvalPrediction, |
| | BertModel, |
| | BertForPreTraining, |
| | RobertaModel |
| | ) |
| | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTrainedTokenizerBase |
| | from transformers.trainer_utils import is_main_process |
| | from transformers.data.data_collator import DataCollatorForLanguageModeling |
| | from transformers.file_utils import cached_property, torch_required, is_torch_available, is_torch_tpu_available |
| | from simcse.models import RobertaForCL, BertForCL |
| | from simcse.trainers import CLTrainer |
| |
|
| | logger = logging.getLogger(__name__) |
| | MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) |
| | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) |
| |
|
| | @dataclass |
| | class ModelArguments: |
| | """ |
| | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. |
| | """ |
| |
|
| | |
| | model_name_or_path: Optional[str] = field( |
| | default=None, |
| | metadata={ |
| | "help": "The model checkpoint for weights initialization." |
| | "Don't set if you want to train a model from scratch." |
| | }, |
| | ) |
| | model_type: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, |
| | ) |
| | config_name: Optional[str] = field( |
| | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
| | ) |
| | tokenizer_name: Optional[str] = field( |
| | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
| | ) |
| | cache_dir: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
| | ) |
| | use_fast_tokenizer: bool = field( |
| | default=True, |
| | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
| | ) |
| | model_revision: str = field( |
| | default="main", |
| | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
| | ) |
| | use_auth_token: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " |
| | "with private models)." |
| | }, |
| | ) |
| |
|
| | |
| | temp: float = field( |
| | default=0.05, |
| | metadata={ |
| | "help": "Temperature for softmax." |
| | } |
| | ) |
| | pooler_type: str = field( |
| | default="cls", |
| | metadata={ |
| | "help": "What kind of pooler to use (cls, cls_before_pooler, avg, avg_top2, avg_first_last)." |
| | } |
| | ) |
| | hard_negative_weight: float = field( |
| | default=0, |
| | metadata={ |
| | "help": "The **logit** of weight for hard negatives (only effective if hard negatives are used)." |
| | } |
| | ) |
| | do_mlm: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Whether to use MLM auxiliary objective." |
| | } |
| | ) |
| | mlm_weight: float = field( |
| | default=0.1, |
| | metadata={ |
| | "help": "Weight for MLM auxiliary objective (only effective if --do_mlm)." |
| | } |
| | ) |
| | mlp_only_train: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Use MLP only during training" |
| | } |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DataTrainingArguments: |
| | """ |
| | Arguments pertaining to what data we are going to input our model for training and eval. |
| | """ |
| |
|
| | |
| | dataset_name: Optional[str] = field( |
| | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} |
| | ) |
| | dataset_config_name: Optional[str] = field( |
| | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} |
| | ) |
| | overwrite_cache: bool = field( |
| | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} |
| | ) |
| | validation_split_percentage: Optional[int] = field( |
| | default=5, |
| | metadata={ |
| | "help": "The percentage of the train set used as validation set in case there's no validation split" |
| | }, |
| | ) |
| | preprocessing_num_workers: Optional[int] = field( |
| | default=None, |
| | metadata={"help": "The number of processes to use for the preprocessing."}, |
| | ) |
| |
|
| | |
| | train_file: Optional[str] = field( |
| | default=None, |
| | metadata={"help": "The training data file (.txt or .csv)."} |
| | ) |
| | max_seq_length: Optional[int] = field( |
| | default=32, |
| | metadata={ |
| | "help": "The maximum total input sequence length after tokenization. Sequences longer " |
| | "than this will be truncated." |
| | }, |
| | ) |
| | pad_to_max_length: bool = field( |
| | default=False, |
| | metadata={ |
| | "help": "Whether to pad all samples to `max_seq_length`. " |
| | "If False, will pad the samples dynamically when batching to the maximum length in the batch." |
| | }, |
| | ) |
| | mlm_probability: float = field( |
| | default=0.15, |
| | metadata={"help": "Ratio of tokens to mask for MLM (only effective if --do_mlm)"} |
| | ) |
| |
|
| | def __post_init__(self): |
| | if self.dataset_name is None and self.train_file is None and self.validation_file is None: |
| | raise ValueError("Need either a dataset name or a training/validation file.") |
| | else: |
| | if self.train_file is not None: |
| | extension = self.train_file.split(".")[-1] |
| | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." |
| |
|
| |
|
| | @dataclass |
| | class OurTrainingArguments(TrainingArguments): |
| | |
| | |
| | |
| | |
| | eval_transfer: bool = field( |
| | default=False, |
| | metadata={"help": "Evaluate transfer task dev sets (in validation)."} |
| | ) |
| |
|
| | @cached_property |
| | @torch_required |
| | def _setup_devices(self) -> "torch.device": |
| | logger.info("PyTorch: setting up devices") |
| | if self.no_cuda: |
| | device = torch.device("cpu") |
| | self._n_gpu = 0 |
| | elif is_torch_tpu_available(): |
| | import torch_xla.core.xla_model as xm |
| | device = xm.xla_device() |
| | self._n_gpu = 0 |
| | elif self.local_rank == -1: |
| | |
| | |
| | |
| | |
| | |
| | |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | self._n_gpu = torch.cuda.device_count() |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.deepspeed: |
| | from .integrations import is_deepspeed_available |
| |
|
| | if not is_deepspeed_available(): |
| | raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") |
| | import deepspeed |
| |
|
| | deepspeed.init_distributed() |
| | else: |
| | torch.distributed.init_process_group(backend="nccl") |
| | device = torch.device("cuda", self.local_rank) |
| | self._n_gpu = 1 |
| |
|
| | if device.type == "cuda": |
| | torch.cuda.set_device(device) |
| |
|
| | return device |
| |
|
| |
|
| | def main(): |
| | |
| | |
| | |
| |
|
| | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, OurTrainingArguments)) |
| | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| | |
| | |
| | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| | else: |
| | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| |
|
| | if ( |
| | os.path.exists(training_args.output_dir) |
| | and os.listdir(training_args.output_dir) |
| | and training_args.do_train |
| | and not training_args.overwrite_output_dir |
| | ): |
| | raise ValueError( |
| | f"Output directory ({training_args.output_dir}) already exists and is not empty." |
| | "Use --overwrite_output_dir to overcome." |
| | ) |
| |
|
| | |
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| | datefmt="%m/%d/%Y %H:%M:%S", |
| | level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN, |
| | ) |
| |
|
| | |
| | logger.warning( |
| | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
| | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
| | ) |
| | |
| | if is_main_process(training_args.local_rank): |
| | transformers.utils.logging.set_verbosity_info() |
| | transformers.utils.logging.enable_default_handler() |
| | transformers.utils.logging.enable_explicit_format() |
| | logger.info("Training/evaluation parameters %s", training_args) |
| |
|
| | |
| | set_seed(training_args.seed) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | data_files = {} |
| | if data_args.train_file is not None: |
| | data_files["train"] = data_args.train_file |
| | extension = data_args.train_file.split(".")[-1] |
| | if extension == "txt": |
| | extension = "text" |
| | if extension == "csv": |
| | datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/", delimiter="\t" if "tsv" in data_args.train_file else ",") |
| | else: |
| | datasets = load_dataset(extension, data_files=data_files, cache_dir="./data/") |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | config_kwargs = { |
| | "cache_dir": model_args.cache_dir, |
| | "revision": model_args.model_revision, |
| | "use_auth_token": True if model_args.use_auth_token else None, |
| | } |
| | if model_args.config_name: |
| | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) |
| | elif model_args.model_name_or_path: |
| | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) |
| | else: |
| | config = CONFIG_MAPPING[model_args.model_type]() |
| | logger.warning("You are instantiating a new config instance from scratch.") |
| |
|
| | tokenizer_kwargs = { |
| | "cache_dir": model_args.cache_dir, |
| | "use_fast": model_args.use_fast_tokenizer, |
| | "revision": model_args.model_revision, |
| | "use_auth_token": True if model_args.use_auth_token else None, |
| | } |
| | if model_args.tokenizer_name: |
| | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) |
| | elif model_args.model_name_or_path: |
| | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) |
| | else: |
| | raise ValueError( |
| | "You are instantiating a new tokenizer from scratch. This is not supported by this script." |
| | "You can do it from another script, save it, and load it from here, using --tokenizer_name." |
| | ) |
| |
|
| | if model_args.model_name_or_path: |
| | if 'roberta' in model_args.model_name_or_path: |
| | model = RobertaForCL.from_pretrained( |
| | model_args.model_name_or_path, |
| | from_tf=bool(".ckpt" in model_args.model_name_or_path), |
| | config=config, |
| | cache_dir=model_args.cache_dir, |
| | revision=model_args.model_revision, |
| | use_auth_token=True if model_args.use_auth_token else None, |
| | model_args=model_args |
| | ) |
| | elif 'bert' in model_args.model_name_or_path: |
| | model = BertForCL.from_pretrained( |
| | model_args.model_name_or_path, |
| | from_tf=bool(".ckpt" in model_args.model_name_or_path), |
| | config=config, |
| | cache_dir=model_args.cache_dir, |
| | revision=model_args.model_revision, |
| | use_auth_token=True if model_args.use_auth_token else None, |
| | model_args=model_args |
| | ) |
| | if model_args.do_mlm: |
| | pretrained_model = BertForPreTraining.from_pretrained(model_args.model_name_or_path) |
| | model.lm_head.load_state_dict(pretrained_model.cls.predictions.state_dict()) |
| | else: |
| | raise NotImplementedError |
| | else: |
| | raise NotImplementedError |
| | logger.info("Training new model from scratch") |
| | model = AutoModelForMaskedLM.from_config(config) |
| |
|
| | model.resize_token_embeddings(len(tokenizer)) |
| |
|
| | |
| | column_names = datasets["train"].column_names |
| | sent2_cname = None |
| | if len(column_names) == 2: |
| | |
| | sent0_cname = column_names[0] |
| | sent1_cname = column_names[1] |
| | elif len(column_names) == 3: |
| | |
| | sent0_cname = column_names[0] |
| | sent1_cname = column_names[1] |
| | sent2_cname = column_names[2] |
| | elif len(column_names) == 1: |
| | |
| | sent0_cname = column_names[0] |
| | sent1_cname = column_names[0] |
| | else: |
| | raise NotImplementedError |
| |
|
| | def prepare_features(examples): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | total = len(examples[sent0_cname]) |
| |
|
| | |
| | for idx in range(total): |
| | if examples[sent0_cname][idx] is None: |
| | examples[sent0_cname][idx] = " " |
| | if examples[sent1_cname][idx] is None: |
| | examples[sent1_cname][idx] = " " |
| | |
| | sentences = examples[sent0_cname] + examples[sent1_cname] |
| |
|
| | |
| | if sent2_cname is not None: |
| | for idx in range(total): |
| | if examples[sent2_cname][idx] is None: |
| | examples[sent2_cname][idx] = " " |
| | sentences += examples[sent2_cname] |
| |
|
| | sent_features = tokenizer( |
| | sentences, |
| | max_length=data_args.max_seq_length, |
| | truncation=True, |
| | padding="max_length" if data_args.pad_to_max_length else False, |
| | ) |
| |
|
| | features = {} |
| | if sent2_cname is not None: |
| | for key in sent_features: |
| | features[key] = [[sent_features[key][i], sent_features[key][i+total], sent_features[key][i+total*2]] for i in range(total)] |
| | else: |
| | for key in sent_features: |
| | features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)] |
| | |
| | return features |
| |
|
| | if training_args.do_train: |
| | train_dataset = datasets["train"].map( |
| | prepare_features, |
| | batched=True, |
| | num_proc=data_args.preprocessing_num_workers, |
| | remove_columns=column_names, |
| | load_from_cache_file=not data_args.overwrite_cache, |
| | ) |
| |
|
| | |
| | @dataclass |
| | class OurDataCollatorWithPadding: |
| |
|
| | tokenizer: PreTrainedTokenizerBase |
| | padding: Union[bool, str, PaddingStrategy] = True |
| | max_length: Optional[int] = None |
| | pad_to_multiple_of: Optional[int] = None |
| | mlm: bool = True |
| | mlm_probability: float = data_args.mlm_probability |
| |
|
| | def __call__(self, features: List[Dict[str, Union[List[int], List[List[int]], torch.Tensor]]]) -> Dict[str, torch.Tensor]: |
| | special_keys = ['input_ids', 'attention_mask', 'token_type_ids', 'mlm_input_ids', 'mlm_labels'] |
| | bs = len(features) |
| | if bs > 0: |
| | num_sent = len(features[0]['input_ids']) |
| | else: |
| | return |
| | flat_features = [] |
| | for feature in features: |
| | for i in range(num_sent): |
| | flat_features.append({k: feature[k][i] if k in special_keys else feature[k] for k in feature}) |
| |
|
| | batch = self.tokenizer.pad( |
| | flat_features, |
| | padding=self.padding, |
| | max_length=self.max_length, |
| | pad_to_multiple_of=self.pad_to_multiple_of, |
| | return_tensors="pt", |
| | ) |
| | if model_args.do_mlm: |
| | batch["mlm_input_ids"], batch["mlm_labels"] = self.mask_tokens(batch["input_ids"]) |
| |
|
| | batch = {k: batch[k].view(bs, num_sent, -1) if k in special_keys else batch[k].view(bs, num_sent, -1)[:, 0] for k in batch} |
| |
|
| | if "label" in batch: |
| | batch["labels"] = batch["label"] |
| | del batch["label"] |
| | if "label_ids" in batch: |
| | batch["labels"] = batch["label_ids"] |
| | del batch["label_ids"] |
| |
|
| | return batch |
| | |
| | def mask_tokens( |
| | self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. |
| | """ |
| | inputs = inputs.clone() |
| | labels = inputs.clone() |
| | |
| | probability_matrix = torch.full(labels.shape, self.mlm_probability) |
| | if special_tokens_mask is None: |
| | special_tokens_mask = [ |
| | self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() |
| | ] |
| | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) |
| | else: |
| | special_tokens_mask = special_tokens_mask.bool() |
| |
|
| | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) |
| | masked_indices = torch.bernoulli(probability_matrix).bool() |
| | labels[~masked_indices] = -100 |
| |
|
| | |
| | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices |
| | inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token) |
| |
|
| | |
| | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
| | random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long) |
| | inputs[indices_random] = random_words[indices_random] |
| |
|
| | |
| | return inputs, labels |
| |
|
| | data_collator = default_data_collator if data_args.pad_to_max_length else OurDataCollatorWithPadding(tokenizer) |
| |
|
| | trainer = CLTrainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset if training_args.do_train else None, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator, |
| | ) |
| | trainer.model_args = model_args |
| |
|
| | |
| | if training_args.do_train: |
| | model_path = ( |
| | model_args.model_name_or_path |
| | if (model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path)) |
| | else None |
| | ) |
| | train_result = trainer.train(model_path=model_path) |
| | trainer.save_model() |
| |
|
| | output_train_file = os.path.join(training_args.output_dir, "train_results.txt") |
| | if trainer.is_world_process_zero(): |
| | with open(output_train_file, "w") as writer: |
| | logger.info("***** Train results *****") |
| | for key, value in sorted(train_result.metrics.items()): |
| | logger.info(f" {key} = {value}") |
| | writer.write(f"{key} = {value}\n") |
| |
|
| | |
| | trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json")) |
| |
|
| | |
| | results = {} |
| | if training_args.do_eval: |
| | logger.info("*** Evaluate ***") |
| | results = trainer.evaluate(eval_senteval_transfer=True) |
| |
|
| | output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") |
| | if trainer.is_world_process_zero(): |
| | with open(output_eval_file, "w") as writer: |
| | logger.info("***** Eval results *****") |
| | for key, value in sorted(results.items()): |
| | logger.info(f" {key} = {value}") |
| | writer.write(f"{key} = {value}\n") |
| |
|
| | return results |
| |
|
| | def _mp_fn(index): |
| | |
| | main() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|