| | from typing import Dict, List
|
| | import numpy as np
|
| | from transformers import BatchEncoding
|
| | from dataclasses import dataclass
|
| | from transformers import AutoTokenizer
|
| | import torch
|
| | import math
|
| | from torch.optim import Optimizer
|
| | from typing import Iterable, Tuple
|
| | from torch import nn
|
| | import random
|
| | import string
|
| |
|
| |
|
| | @dataclass
|
| | class DataCollatorForT5MLM:
|
| | """
|
| | [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py]
|
| | Data collator used for T5 span-masked language modeling.
|
| | It is made sure that after masking the inputs are of length `data_args.max_seq_length` and targets are also of fixed length.
|
| | For more information on how T5 span-masked language modeling works, one can take a look
|
| | at the `official paper <https://arxiv.org/pdf/1910.10683.pdf>`__
|
| | or the `official code for preprocessing <https://github.com/google-research/text-to-text-transfer-transformer/blob/master/t5/data/preprocessors.py>`__ .
|
| | Args:
|
| | tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
| | The tokenizer used for encoding the data.
|
| | noise_density (:obj:`float`):
|
| | The probability with which to (randomly) mask tokens in the input.
|
| | mean_noise_span_length (:obj:`float`):
|
| | The average span length of the masked tokens.
|
| | input_length (:obj:`int`):
|
| | The expected input length after masking.
|
| | target_length (:obj:`int`):
|
| | The expected target length after masking.
|
| | pad_token_id: (:obj:`int`):
|
| | The pad token id of the model
|
| | decoder_start_token_id: (:obj:`int):
|
| | The decoder start token id of the model
|
| | """
|
| |
|
| | tokenizer: AutoTokenizer
|
| | noise_density: float
|
| | mean_noise_span_length: float
|
| | input_length: int
|
| | target_length: int
|
| | pad_token_id: int
|
| |
|
| | def __call__(self, examples: List[Dict[str, np.ndarray]]) -> BatchEncoding:
|
| |
|
| | batch = BatchEncoding(
|
| | {
|
| | k: np.array([examples[i][k] for i in range(len(examples))])
|
| | for k, v in examples[0].items()
|
| | }
|
| | )
|
| |
|
| | input_ids = batch["input_ids"]
|
| | batch_size, expandend_input_length = input_ids.shape
|
| |
|
| | mask_indices = np.asarray(
|
| | [
|
| | self.random_spans_noise_mask(expandend_input_length)
|
| | for i in range(batch_size)
|
| | ]
|
| | )
|
| | labels_mask = ~mask_indices
|
| |
|
| | input_ids_sentinel = self.create_sentinel_ids(mask_indices.astype(np.int8))
|
| | labels_sentinel = self.create_sentinel_ids(labels_mask.astype(np.int8))
|
| |
|
| | batch["input_ids"] = self.filter_input_ids(input_ids, input_ids_sentinel)
|
| | batch["labels"] = self.filter_input_ids(input_ids, labels_sentinel)
|
| |
|
| | if batch["input_ids"].shape[-1] != self.input_length:
|
| | raise ValueError(
|
| | f"`input_ids` are incorrectly preprocessed. `input_ids` length is {batch['input_ids'].shape[-1]}, but"
|
| | f" should be {self.input_length}."
|
| | )
|
| |
|
| | if batch["labels"].shape[-1] != self.target_length:
|
| | raise ValueError(
|
| | f"`labels` are incorrectly preprocessed. `labels` length is {batch['labels'].shape[-1]}, but should be"
|
| | f" {self.target_length}."
|
| | )
|
| |
|
| | batch = {k: torch.from_numpy(v) for k, v in batch.items()}
|
| | return batch
|
| |
|
| | def create_sentinel_ids(self, mask_indices):
|
| | """
|
| | Sentinel ids creation given the indices that should be masked.
|
| | The start indices of each mask are replaced by the sentinel ids in increasing
|
| | order. Consecutive mask indices to be deleted are replaced with `-1`.
|
| | """
|
| | start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
|
| | start_indices[:, 0] = mask_indices[:, 0]
|
| |
|
| | sentinel_ids = np.where(
|
| | start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices
|
| | )
|
| | sentinel_ids = np.where(
|
| | sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0
|
| | )
|
| | sentinel_ids -= mask_indices - start_indices
|
| |
|
| | return sentinel_ids
|
| |
|
| | def filter_input_ids(self, input_ids, sentinel_ids):
|
| | """
|
| | Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
|
| | This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
|
| | """
|
| | batch_size = input_ids.shape[0]
|
| |
|
| | input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
|
| |
|
| |
|
| | input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
|
| | input_ids = np.concatenate(
|
| | [
|
| | input_ids,
|
| | np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32),
|
| | ],
|
| | axis=-1,
|
| | )
|
| | return input_ids
|
| |
|
| | def random_spans_noise_mask(self, length):
|
| | """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
|
| |
|
| | Noise mask consisting of random spans of noise tokens.
|
| | The number of noise tokens and the number of noise spans and non-noise spans
|
| | are determined deterministically as follows:
|
| | num_noise_tokens = round(length * noise_density)
|
| | num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
|
| | Spans alternate between non-noise and noise, beginning with non-noise.
|
| | Subject to the above restrictions, all masks are equally likely.
|
| |
|
| | Args:
|
| | length: an int32 scalar (length of the incoming token sequence)
|
| | noise_density: a float - approximate density of output mask
|
| | mean_noise_span_length: a number
|
| |
|
| | Returns:
|
| | a boolean tensor with shape [length]
|
| | """
|
| |
|
| | orig_length = length
|
| |
|
| | num_noise_tokens = int(np.round(length * self.noise_density))
|
| |
|
| | num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
|
| | num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
|
| |
|
| |
|
| | num_noise_spans = max(num_noise_spans, 1)
|
| | num_nonnoise_tokens = length - num_noise_tokens
|
| |
|
| |
|
| | def _random_segmentation(num_items, num_segments):
|
| | """Partition a sequence of items randomly into non-empty segments.
|
| | Args:
|
| | num_items: an integer scalar > 0
|
| | num_segments: an integer scalar in [1, num_items]
|
| | Returns:
|
| | a Tensor with shape [num_segments] containing positive integers that add
|
| | up to num_items
|
| | """
|
| | mask_indices = np.arange(num_items - 1) < (num_segments - 1)
|
| | np.random.shuffle(mask_indices)
|
| | first_in_segment = np.pad(mask_indices, [[1, 0]])
|
| | segment_id = np.cumsum(first_in_segment)
|
| |
|
| | _, segment_length = np.unique(segment_id, return_counts=True)
|
| | return segment_length
|
| |
|
| | noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
| | nonnoise_span_lengths = _random_segmentation(
|
| | num_nonnoise_tokens, num_noise_spans
|
| | )
|
| |
|
| | interleaved_span_lengths = np.reshape(
|
| | np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1),
|
| | [num_noise_spans * 2],
|
| | )
|
| | span_starts = np.cumsum(interleaved_span_lengths)[:-1]
|
| | span_start_indicator = np.zeros((length,), dtype=np.int8)
|
| | span_start_indicator[span_starts] = True
|
| | span_num = np.cumsum(span_start_indicator)
|
| | is_noise = np.equal(span_num % 2, 1)
|
| |
|
| | return is_noise[:orig_length]
|
| |
|
| |
|
| | def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
|
| | """This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
|
| |
|
| | [Copied from https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py]
|
| | Training parameters to avoid padding with random_spans_noise_mask.
|
| | When training a model with random_spans_noise_mask, we would like to set the other
|
| | training hyperparmeters in a way that avoids padding.
|
| | This function helps us compute these hyperparameters.
|
| | We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,
|
| | and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.
|
| | This function tells us the required number of tokens in the raw example (for split_tokens())
|
| | as well as the length of the encoded targets. Note that this function assumes
|
| | the inputs and targets will have EOS appended and includes that in the reported length.
|
| |
|
| | Args:
|
| | inputs_length: an integer - desired length of the tokenized inputs sequence
|
| | noise_density: a float
|
| | mean_noise_span_length: a float
|
| | Returns:
|
| | tokens_length: length of original text in tokens
|
| | targets_length: an integer - length in tokens of encoded targets sequence
|
| | """
|
| |
|
| | def _tokens_length_to_inputs_length_targets_length(tokens_length):
|
| | num_noise_tokens = int(round(tokens_length * noise_density))
|
| | num_nonnoise_tokens = tokens_length - num_noise_tokens
|
| | num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))
|
| |
|
| |
|
| | _input_length = num_nonnoise_tokens + num_noise_spans + 1
|
| | _output_length = num_noise_tokens + num_noise_spans + 1
|
| | return _input_length, _output_length
|
| |
|
| | tokens_length = inputs_length
|
| |
|
| | while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
|
| | tokens_length += 1
|
| |
|
| | inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
|
| |
|
| |
|
| |
|
| | if noise_density == 0.5 and targets_length > inputs_length:
|
| | tokens_length -= 1
|
| | targets_length -= 1
|
| | return tokens_length, targets_length
|
| |
|
| |
|
| | class AdamWScale(Optimizer):
|
| | """
|
| | This AdamW implementation is copied from Huggingface.
|
| | We modified it with Adagrad scaling by rms of a weight tensor
|
| |
|
| | Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay
|
| | Regularization](https://arxiv.org/abs/1711.05101).
|
| |
|
| | Parameters:
|
| | params (`Iterable[nn.parameter.Parameter]`):
|
| | Iterable of parameters to optimize or dictionaries defining parameter groups.
|
| | lr (`float`, *optional*, defaults to 1e-3):
|
| | The learning rate to use.
|
| | betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)):
|
| | Adam's betas parameters (b1, b2).
|
| | eps (`float`, *optional*, defaults to 1e-6):
|
| | Adam's epsilon for numerical stability.
|
| | weight_decay (`float`, *optional*, defaults to 0):
|
| | Decoupled weight decay to apply.
|
| | correct_bias (`bool`, *optional*, defaults to `True`):
|
| | Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`).
|
| | no_deprecation_warning (`bool`, *optional*, defaults to `False`):
|
| | A flag used to disable the deprecation warning (set to `True` to disable the warning).
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | params: Iterable[nn.parameter.Parameter],
|
| | lr: float = 1e-3,
|
| | betas: Tuple[float, float] = (0.9, 0.999),
|
| | eps: float = 1e-6,
|
| | weight_decay: float = 0.0,
|
| | correct_bias: bool = True,
|
| | ):
|
| | if lr < 0.0:
|
| | raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
| | if not 0.0 <= betas[0] < 1.0:
|
| | raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
| | if not 0.0 <= betas[1] < 1.0:
|
| | raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
| | if not 0.0 <= eps:
|
| | raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
| | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
| | super().__init__(params, defaults)
|
| |
|
| | @staticmethod
|
| | def _rms(tensor):
|
| | return tensor.norm(2) / (tensor.numel() ** 0.5)
|
| |
|
| | def step(self, closure=None):
|
| | """
|
| | Performs a single optimization step.
|
| |
|
| | Arguments:
|
| | closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
| | """
|
| | loss = None
|
| | if closure is not None:
|
| | loss = closure()
|
| |
|
| | for group in self.param_groups:
|
| | for p in group["params"]:
|
| | if p.grad is None:
|
| | continue
|
| | grad = p.grad.data
|
| | if grad.is_sparse:
|
| | raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
| |
|
| | state = self.state[p]
|
| | beta1, beta2 = group["betas"]
|
| |
|
| |
|
| | if len(state) == 0:
|
| | state["step"] = 0
|
| |
|
| | state["exp_avg"] = torch.zeros_like(p.data)
|
| |
|
| | state["exp_avg_sq"] = torch.zeros_like(p.data)
|
| |
|
| | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
| |
|
| | state["step"] += 1
|
| |
|
| |
|
| |
|
| | exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
| | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
| | denom = exp_avg_sq.sqrt().add_(group["eps"])
|
| |
|
| | step_size = group["lr"]
|
| | if group["correct_bias"]:
|
| | bias_correction1 = 1.0 - beta1 ** state["step"]
|
| | bias_correction2 = 1.0 - beta2 ** state["step"]
|
| | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
| |
|
| |
|
| | step_size = step_size * max(1e-3, self._rms(p.data))
|
| |
|
| |
|
| | p.data.addcdiv_(exp_avg, denom, value=-step_size)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if group["weight_decay"] > 0.0:
|
| | p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"]))
|
| |
|
| | return loss
|
| |
|
| |
|
| | def tokenize_function(examples, tokenizer, in_length):
|
| | tokenizer_out = tokenizer(
|
| | text=examples["text"],
|
| | return_attention_mask=False,
|
| | )
|
| |
|
| | input_ids = tokenizer_out["input_ids"]
|
| |
|
| | concatenated_ids = np.concatenate(input_ids)
|
| |
|
| | total_length = concatenated_ids.shape[0]
|
| | total_length = (total_length // in_length) * in_length
|
| |
|
| | concatenated_ids = concatenated_ids[:total_length].reshape(-1, in_length)
|
| | result = {"input_ids": concatenated_ids}
|
| |
|
| | return result
|
| |
|
| |
|
| | from transformers.data.data_collator import *
|
| | @dataclass
|
| | class DataCollatorForNI:
|
| | tokenizer: PreTrainedTokenizerBase
|
| | padding: Union[bool, str, PaddingStrategy] = True
|
| | max_source_length: Optional[int] = None
|
| | max_target_length: Optional[int] = None
|
| | pad_to_multiple_of: Optional[int] = None
|
| | label_pad_token_id: int = -100
|
| | return_tensors: str = "pt"
|
| | add_task_name: bool = False
|
| | add_task_definition: bool = True
|
| | num_pos_examples: int = 0
|
| | num_neg_examples: int = 0
|
| | add_explanation: bool = False
|
| | tk_instruct: bool = False
|
| | text_only: bool = False
|
| |
|
| | def __call__(self, batch, return_tensors=None):
|
| |
|
| | if return_tensors is None:
|
| | return_tensors = self.return_tensors
|
| |
|
| | sources = []
|
| | for instance in batch:
|
| | if self.tk_instruct:
|
| | all_valid_encodings = [
|
| |
|
| | {
|
| | "add_task_name": False,
|
| | "add_task_definition": True,
|
| | "num_pos_examples": 0,
|
| | "num_neg_examples": 0,
|
| | "add_explanation": False,
|
| | },
|
| |
|
| | {
|
| | "add_task_name": False,
|
| | "add_task_definition": False,
|
| | "num_pos_examples": 2,
|
| | "num_neg_examples": 0,
|
| | "add_explanation": False,
|
| | },
|
| |
|
| | {
|
| | "add_task_name": False,
|
| | "add_task_definition": True,
|
| | "num_pos_examples": 2,
|
| | "num_neg_examples": 0,
|
| | "add_explanation": False,
|
| | },
|
| |
|
| | {
|
| | "add_task_name": False,
|
| | "add_task_definition": True,
|
| | "num_pos_examples": 2,
|
| | "num_neg_examples": 2,
|
| | "add_explanation": False,
|
| | },
|
| |
|
| | {
|
| | "add_task_name": False,
|
| | "add_task_definition": True,
|
| | "num_pos_examples": 2,
|
| | "num_neg_examples": 0,
|
| | "add_explanation": True,
|
| | },
|
| | ]
|
| | encoding_schema = random.choice(all_valid_encodings)
|
| | add_task_name = encoding_schema["add_task_name"]
|
| | add_task_definition = encoding_schema["add_task_definition"]
|
| | num_pos_examples = encoding_schema["num_pos_examples"]
|
| | num_neg_examples = encoding_schema["num_neg_examples"]
|
| | add_explanation = encoding_schema["add_explanation"]
|
| | else:
|
| | add_task_name = self.add_task_name
|
| | add_task_definition = self.add_task_definition
|
| | num_pos_examples = self.num_pos_examples
|
| | num_neg_examples = self.num_neg_examples
|
| | add_explanation = self.add_explanation
|
| |
|
| | task_input = ""
|
| |
|
| | task_input += "Now complete the following example -\n"
|
| | task_input += f"Input: {instance['Instance']['input'].strip()}"
|
| | if not task_input[-1] in string.punctuation:
|
| | task_input += "."
|
| | task_input += "\n"
|
| | task_input += "Output: "
|
| |
|
| | task_name = ""
|
| | if add_task_name:
|
| | task_name += instance["Task"] + ". "
|
| |
|
| | definition = ""
|
| | if add_task_definition:
|
| | if isinstance(instance["Definition"], list):
|
| | definition = (
|
| | "Definition: " + instance["Definition"][0].strip()
|
| | )
|
| | else:
|
| | definition = "Definition: " + instance["Definition"].strip()
|
| | if not definition[-1] in string.punctuation:
|
| | definition += "."
|
| | definition += "\n\n"
|
| |
|
| |
|
| | pos_examples = []
|
| | for idx, pos_example in enumerate(
|
| | instance["Positive Examples"][:num_pos_examples]
|
| | ):
|
| | pos_example_str = f" Positive Example {idx+1} -\n"
|
| | pos_example_str += f"Input: {pos_example['input'].strip()}"
|
| | if not pos_example_str[-1] in string.punctuation:
|
| | pos_example_str += "."
|
| | pos_example_str += "\n"
|
| | pos_example_str += f" Output: {pos_example['output'].strip()}"
|
| | if not pos_example_str[-1] in string.punctuation:
|
| | pos_example_str += "."
|
| | pos_example_str += "\n"
|
| | if add_explanation and "explanation" in pos_example:
|
| | pos_example_str += (
|
| | f" Explanation: {pos_example['explanation'].strip()}"
|
| | )
|
| | if not pos_example_str[-1] in string.punctuation:
|
| | pos_example_str += "."
|
| | pos_example_str += "\n"
|
| | pos_example_str += "\n"
|
| | if (
|
| | len(
|
| | self.tokenizer(
|
| | definition
|
| | + " ".join(pos_examples)
|
| | + pos_example_str
|
| | + task_input
|
| | )["input_ids"]
|
| | )
|
| | <= self.max_source_length
|
| | ):
|
| | pos_examples.append(pos_example_str)
|
| | else:
|
| | break
|
| |
|
| |
|
| | neg_examples = []
|
| | for idx, neg_example in enumerate(
|
| | instance["Negative Examples"][:num_neg_examples]
|
| | ):
|
| | neg_example_str = f" Negative Example {idx+1} -\n"
|
| | neg_example_str += f"Input: {neg_example['input'].strip()}"
|
| | if not neg_example_str[-1] in string.punctuation:
|
| | neg_example_str += "."
|
| | neg_example_str += "\n"
|
| | neg_example_str += f" Output: {neg_example['output'].strip()}"
|
| | if not neg_example_str[-1] in string.punctuation:
|
| | neg_example_str += "."
|
| | neg_example_str += "\n"
|
| | if add_explanation and "explanation" in neg_example:
|
| | neg_example_str += (
|
| | f" Explanation: {neg_example['explanation'].strip()}"
|
| | )
|
| | if not neg_example_str[-1] in string.punctuation:
|
| | neg_example_str += "."
|
| | neg_example_str += "\n"
|
| | neg_example_str += "\n"
|
| | if (
|
| | len(
|
| | self.tokenizer(
|
| | definition
|
| | + " ".join(pos_examples)
|
| | + " ".join(neg_examples)
|
| | + neg_example_str
|
| | + task_input
|
| | )["input_ids"]
|
| | )
|
| | <= self.max_source_length
|
| | ):
|
| | neg_examples.append(neg_example_str)
|
| | else:
|
| | break
|
| |
|
| | source = (
|
| | task_name
|
| | + definition
|
| | + "".join(pos_examples)
|
| | + "".join(neg_examples)
|
| | + task_input
|
| | )
|
| | tokenized_source = self.tokenizer(source)["input_ids"]
|
| | if len(tokenized_source) <= self.max_source_length:
|
| | sources.append(source)
|
| | else:
|
| | sources.append(
|
| | self.tokenizer.decode(
|
| | tokenized_source[: self.max_source_length],
|
| | skip_special_tokens=True,
|
| | )
|
| | )
|
| |
|
| | if self.text_only:
|
| | model_inputs = {"inputs": sources}
|
| | else:
|
| | model_inputs = self.tokenizer(
|
| | sources,
|
| | max_length=self.max_source_length,
|
| | padding=self.padding,
|
| | return_tensors=self.return_tensors,
|
| | truncation=True,
|
| | pad_to_multiple_of=self.pad_to_multiple_of,
|
| | )
|
| |
|
| | if "output" in batch[0]["Instance"] and batch[0]["Instance"]["output"]:
|
| |
|
| | labels = [random.choice(ex["Instance"]["output"]) for ex in batch]
|
| | if self.text_only:
|
| | model_inputs["labels"] = labels
|
| | else:
|
| | labels = self.tokenizer(
|
| | labels,
|
| | max_length=self.max_target_length,
|
| | padding=self.padding,
|
| | return_tensors=self.return_tensors,
|
| | truncation=True,
|
| | pad_to_multiple_of=self.pad_to_multiple_of,
|
| | )
|
| | label_mask = labels["attention_mask"].bool()
|
| | model_inputs["labels"] = labels["input_ids"].masked_fill(
|
| | ~label_mask, self.label_pad_token_id
|
| | )
|
| | else:
|
| | model_inputs["labels"] = None
|
| |
|
| | return model_inputs
|
| |
|