| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM). |
| Adapted from `examples/text-classification/run_glue.py`""" |
|
|
| import logging |
| import os |
| import random |
| import sys |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
| import datasets |
| import evaluate |
| import numpy as np |
| from datasets import load_dataset |
|
|
| import transformers |
| from transformers import ( |
| AutoConfig, |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| DataCollatorWithPadding, |
| EvalPrediction, |
| HfArgumentParser, |
| Trainer, |
| TrainingArguments, |
| default_data_collator, |
| set_seed, |
| ) |
| from transformers.trainer_utils import get_last_checkpoint |
| from transformers.utils import check_min_version, send_example_telemetry |
| from transformers.utils.versions import require_version |
|
|
|
|
| |
| check_min_version("4.52.0.dev0") |
|
|
| require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class DataTrainingArguments: |
| """ |
| Arguments pertaining to what data we are going to input our model for training and eval. |
| |
| Using `HfArgumentParser` we can turn this class |
| into argparse arguments to be able to specify them on |
| the command line. |
| """ |
|
|
| max_seq_length: Optional[int] = field( |
| default=128, |
| metadata={ |
| "help": ( |
| "The maximum total input sequence length after tokenization. Sequences longer " |
| "than this will be truncated, sequences shorter will be padded." |
| ) |
| }, |
| ) |
| overwrite_cache: bool = field( |
| default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} |
| ) |
| pad_to_max_length: bool = field( |
| default=True, |
| metadata={ |
| "help": ( |
| "Whether to pad all samples to `max_seq_length`. " |
| "If False, will pad the samples dynamically when batching to the maximum length in the batch." |
| ) |
| }, |
| ) |
| max_train_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "For debugging purposes or quicker training, truncate the number of training examples to this " |
| "value if set." |
| ) |
| }, |
| ) |
| max_eval_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
| "value if set." |
| ) |
| }, |
| ) |
| max_predict_samples: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "For debugging purposes or quicker training, truncate the number of prediction examples to this " |
| "value if set." |
| ) |
| }, |
| ) |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """ |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
| """ |
|
|
| model_name_or_path: str = field( |
| default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
| ) |
| language: str = field( |
| default=None, metadata={"help": "Evaluation language. Also train language if `train_language` is set to None."} |
| ) |
| train_language: Optional[str] = field( |
| default=None, metadata={"help": "Train language if it is different from the evaluation language."} |
| ) |
| config_name: Optional[str] = field( |
| default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} |
| ) |
| tokenizer_name: Optional[str] = field( |
| default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} |
| ) |
| cache_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, |
| ) |
| do_lower_case: Optional[bool] = field( |
| default=False, |
| metadata={"help": "arg to indicate if tokenizer should do lower case in AutoTokenizer.from_pretrained()"}, |
| ) |
| use_fast_tokenizer: bool = field( |
| default=True, |
| metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, |
| ) |
| model_revision: str = field( |
| default="main", |
| metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, |
| ) |
| token: str = field( |
| default=None, |
| metadata={ |
| "help": ( |
| "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " |
| "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." |
| ) |
| }, |
| ) |
| trust_remote_code: bool = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " |
| "should only be set to `True` for repositories you trust and in which you have read the code, as it will " |
| "execute code present on the Hub on your local machine." |
| ) |
| }, |
| ) |
| ignore_mismatched_sizes: bool = field( |
| default=False, |
| metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, |
| ) |
|
|
|
|
| def main(): |
| |
| |
| |
|
|
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| |
| |
| send_example_telemetry("run_xnli", model_args) |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| handlers=[logging.StreamHandler(sys.stdout)], |
| ) |
|
|
| if training_args.should_log: |
| |
| transformers.utils.logging.set_verbosity_info() |
|
|
| log_level = training_args.get_process_log_level() |
| logger.setLevel(log_level) |
| datasets.utils.logging.set_verbosity(log_level) |
| transformers.utils.logging.set_verbosity(log_level) |
| transformers.utils.logging.enable_default_handler() |
| transformers.utils.logging.enable_explicit_format() |
|
|
| |
| logger.warning( |
| f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " |
| + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" |
| ) |
| logger.info(f"Training/evaluation parameters {training_args}") |
|
|
| |
| last_checkpoint = None |
| if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: |
| last_checkpoint = get_last_checkpoint(training_args.output_dir) |
| if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: |
| raise ValueError( |
| f"Output directory ({training_args.output_dir}) already exists and is not empty. " |
| "Use --overwrite_output_dir to overcome." |
| ) |
| elif last_checkpoint is not None: |
| logger.info( |
| f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " |
| "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." |
| ) |
|
|
| |
| set_seed(training_args.seed) |
|
|
| |
| |
| |
| if training_args.do_train: |
| if model_args.train_language is None: |
| train_dataset = load_dataset( |
| "xnli", |
| model_args.language, |
| split="train", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| ) |
| else: |
| train_dataset = load_dataset( |
| "xnli", |
| model_args.train_language, |
| split="train", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| ) |
| label_list = train_dataset.features["label"].names |
|
|
| if training_args.do_eval: |
| eval_dataset = load_dataset( |
| "xnli", |
| model_args.language, |
| split="validation", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| ) |
| label_list = eval_dataset.features["label"].names |
|
|
| if training_args.do_predict: |
| predict_dataset = load_dataset( |
| "xnli", |
| model_args.language, |
| split="test", |
| cache_dir=model_args.cache_dir, |
| token=model_args.token, |
| ) |
| label_list = predict_dataset.features["label"].names |
|
|
| |
| num_labels = len(label_list) |
|
|
| |
| |
| |
| config = AutoConfig.from_pretrained( |
| model_args.config_name if model_args.config_name else model_args.model_name_or_path, |
| num_labels=num_labels, |
| id2label={str(i): label for i, label in enumerate(label_list)}, |
| label2id={label: i for i, label in enumerate(label_list)}, |
| finetuning_task="xnli", |
| cache_dir=model_args.cache_dir, |
| revision=model_args.model_revision, |
| token=model_args.token, |
| trust_remote_code=model_args.trust_remote_code, |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, |
| do_lower_case=model_args.do_lower_case, |
| cache_dir=model_args.cache_dir, |
| use_fast=model_args.use_fast_tokenizer, |
| revision=model_args.model_revision, |
| token=model_args.token, |
| trust_remote_code=model_args.trust_remote_code, |
| ) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_args.model_name_or_path, |
| from_tf=bool(".ckpt" in model_args.model_name_or_path), |
| config=config, |
| cache_dir=model_args.cache_dir, |
| revision=model_args.model_revision, |
| token=model_args.token, |
| trust_remote_code=model_args.trust_remote_code, |
| ignore_mismatched_sizes=model_args.ignore_mismatched_sizes, |
| ) |
|
|
| |
| |
| if data_args.pad_to_max_length: |
| padding = "max_length" |
| else: |
| |
| padding = False |
|
|
| def preprocess_function(examples): |
| |
| return tokenizer( |
| examples["premise"], |
| examples["hypothesis"], |
| padding=padding, |
| max_length=data_args.max_seq_length, |
| truncation=True, |
| ) |
|
|
| if training_args.do_train: |
| if data_args.max_train_samples is not None: |
| max_train_samples = min(len(train_dataset), data_args.max_train_samples) |
| train_dataset = train_dataset.select(range(max_train_samples)) |
| with training_args.main_process_first(desc="train dataset map pre-processing"): |
| train_dataset = train_dataset.map( |
| preprocess_function, |
| batched=True, |
| load_from_cache_file=not data_args.overwrite_cache, |
| desc="Running tokenizer on train dataset", |
| ) |
| |
| for index in random.sample(range(len(train_dataset)), 3): |
| logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") |
|
|
| if training_args.do_eval: |
| if data_args.max_eval_samples is not None: |
| max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) |
| eval_dataset = eval_dataset.select(range(max_eval_samples)) |
| with training_args.main_process_first(desc="validation dataset map pre-processing"): |
| eval_dataset = eval_dataset.map( |
| preprocess_function, |
| batched=True, |
| load_from_cache_file=not data_args.overwrite_cache, |
| desc="Running tokenizer on validation dataset", |
| ) |
|
|
| if training_args.do_predict: |
| if data_args.max_predict_samples is not None: |
| max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) |
| predict_dataset = predict_dataset.select(range(max_predict_samples)) |
| with training_args.main_process_first(desc="prediction dataset map pre-processing"): |
| predict_dataset = predict_dataset.map( |
| preprocess_function, |
| batched=True, |
| load_from_cache_file=not data_args.overwrite_cache, |
| desc="Running tokenizer on prediction dataset", |
| ) |
|
|
| |
| metric = evaluate.load("xnli", cache_dir=model_args.cache_dir) |
|
|
| |
| |
| def compute_metrics(p: EvalPrediction): |
| preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions |
| preds = np.argmax(preds, axis=1) |
| return metric.compute(predictions=preds, references=p.label_ids) |
|
|
| |
| if data_args.pad_to_max_length: |
| data_collator = default_data_collator |
| elif training_args.fp16: |
| data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) |
| else: |
| data_collator = None |
|
|
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset if training_args.do_train else None, |
| eval_dataset=eval_dataset if training_args.do_eval else None, |
| compute_metrics=compute_metrics, |
| processing_class=tokenizer, |
| data_collator=data_collator, |
| ) |
|
|
| |
| if training_args.do_train: |
| checkpoint = None |
| if training_args.resume_from_checkpoint is not None: |
| checkpoint = training_args.resume_from_checkpoint |
| elif last_checkpoint is not None: |
| checkpoint = last_checkpoint |
| train_result = trainer.train(resume_from_checkpoint=checkpoint) |
| metrics = train_result.metrics |
| max_train_samples = ( |
| data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) |
| ) |
| metrics["train_samples"] = min(max_train_samples, len(train_dataset)) |
|
|
| trainer.save_model() |
|
|
| trainer.log_metrics("train", metrics) |
| trainer.save_metrics("train", metrics) |
| trainer.save_state() |
|
|
| |
| if training_args.do_eval: |
| logger.info("*** Evaluate ***") |
| metrics = trainer.evaluate(eval_dataset=eval_dataset) |
|
|
| max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) |
| metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) |
|
|
| trainer.log_metrics("eval", metrics) |
| trainer.save_metrics("eval", metrics) |
|
|
| |
| if training_args.do_predict: |
| logger.info("*** Predict ***") |
| predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") |
|
|
| max_predict_samples = ( |
| data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) |
| ) |
| metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) |
|
|
| trainer.log_metrics("predict", metrics) |
| trainer.save_metrics("predict", metrics) |
|
|
| predictions = np.argmax(predictions, axis=1) |
| output_predict_file = os.path.join(training_args.output_dir, "predictions.txt") |
| if trainer.is_world_process_zero(): |
| with open(output_predict_file, "w") as writer: |
| writer.write("index\tprediction\n") |
| for index, item in enumerate(predictions): |
| item = label_list[item] |
| writer.write(f"{index}\t{item}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|