| | import math |
| | import torch |
| | import torch.nn.functional as F |
| | import transformers |
| | import peft |
| | from pathlib import Path |
| | from typing import List, Mapping, NewType, Optional, Tuple, Union |
| | from tqdm import tqdm |
| |
|
| | from transformers import BatchEncoding |
| |
|
| | from lm_eval import utils |
| | from lm_eval.base import BaseLM |
| |
|
| | TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] |
| |
|
| | _DeviceMapping = NewType("DeviceMapping", Mapping[str, Union[int, str, torch.device]]) |
| |
|
| |
|
| | def _get_accelerate_args( |
| | device_map_option: Optional[str] = "auto", |
| | max_memory_per_gpu: Optional[Union[int, str]] = None, |
| | max_cpu_memory: Optional[Union[int, str]] = None, |
| | offload_folder: Optional[str] = "./offload", |
| | ) -> dict: |
| | """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" |
| | max_memory = {} |
| | if max_memory_per_gpu is not None: |
| | max_memory_per_gpu_map = { |
| | device_idx: max_memory_per_gpu |
| | for device_idx in range(torch.cuda.device_count()) |
| | } |
| | max_memory.update(max_memory_per_gpu_map) |
| | if max_cpu_memory is not None: |
| | max_memory["cpu"] = max_cpu_memory |
| |
|
| | args = {} |
| | if max_memory: |
| | args["max_memory"] = max_memory |
| | args["device_map"] = device_map_option |
| | args["offload_folder"] = offload_folder |
| | return args |
| |
|
| |
|
| | def _get_dtype( |
| | dtype: Union[str, torch.dtype], config: Optional[transformers.AutoConfig] = None |
| | ) -> torch.dtype: |
| | """Converts `dtype` from `str` to torch.dtype when possible.""" |
| | if dtype is None and config is not None: |
| | _torch_dtype = config.torch_dtype |
| | elif isinstance(dtype, str) and dtype != "auto": |
| | |
| | _torch_dtype = getattr(torch, dtype) |
| | else: |
| | _torch_dtype = dtype |
| | return _torch_dtype |
| |
|
| |
|
| | class HuggingFaceAutoLM(BaseLM): |
| | AUTO_CONFIG_CLASS: transformers.AutoConfig = transformers.AutoConfig |
| | AUTO_TOKENIZER_CLASS: transformers.AutoTokenizer = transformers.AutoTokenizer |
| | AUTO_MODEL_CLASS: transformers.AutoModel = None |
| | AUTO_PEFT_CLASS: peft.PeftModel = None |
| |
|
| | |
| | |
| | _DEFAULT_MAX_LENGTH: int = 2048 |
| |
|
| | def __init__( |
| | self, |
| | pretrained: str, |
| | quantized: Optional[Union[bool, str]] = None, |
| | tokenizer: Optional[str] = None, |
| | subfolder: Optional[str] = None, |
| | revision: Optional[str] = "main", |
| | batch_size: Optional[int] = 1, |
| | max_gen_toks: Optional[int] = 256, |
| | max_length: Optional[int] = None, |
| | add_special_tokens: Optional[bool] = None, |
| | use_accelerate: Optional[bool] = False, |
| | device_map_option: Optional[str] = "auto", |
| | max_memory_per_gpu: Optional[Union[int, str]] = None, |
| | max_cpu_memory: Optional[Union[int, str]] = None, |
| | offload_folder: Optional[str] = "./offload", |
| | dtype: Optional[Union[str, torch.dtype]] = None, |
| | device: Optional[Union[int, str]] = "cuda", |
| | peft: str = None, |
| | load_in_8bit: Optional[bool] = False, |
| | trust_remote_code: Optional[bool] = False, |
| | use_fast: Optional[bool] = True, |
| | gptq_use_triton: Optional[bool] = False, |
| | ): |
| | """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation. |
| | Args: |
| | pretrained (str): |
| | The HuggingFace Hub model ID name or the path to a pre-trained |
| | model to load. This is effectively the `pretrained_model_name_or_path` |
| | argument of `from_pretrained` in the HuggingFace `transformers` API. |
| | quantized (str or True, optional, defaults to None): |
| | File name of a GPTQ quantized model to load. Set to `True` to use the |
| | default name of the quantized model. |
| | add_special_tokens (bool, optional, defaults to True): |
| | Whether to add special tokens to the input sequences. If `None`, the |
| | default value will be set to `True` for seq2seq models (e.g. T5) and |
| | `False` for causal models. |
| | WARNING: Evaluating causal models with `add_special_tokens=True` is |
| | currently __not__ supported. |
| | > Large model loading `accelerate` arguments |
| | use_accelerate (bool, optional, defaults to False): |
| | If True, uses the `accelerate` library to load a large model across |
| | multiple devices. |
| | device_map_option (str, optional, defaults to "auto"): |
| | The device map option to use when loading the model with |
| | `accelerate`. |
| | Options: |
| | "auto", "balanced", "balanced_low_0", "sequential" |
| | See the `accelerate` docs for more details on these options: |
| | https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.device_map |
| | max_memory_per_gpu (Union[int, str], optional, defaults to None): |
| | The maximum memory available for each GPU in bytes as `int` or in |
| | the format f"{significand}{unit_symbol}" where {unit_symbol} is |
| | any of ["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in |
| | the "Parameters for big model inference" section of the following |
| | docs: |
| | https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory |
| | max_cpu_memory (Union[int, str], optional, defaults to None): |
| | The maximum available CPU RAM in bytes as `int` or in the format |
| | f"{significand}{unit_symbol}" where {unit_symbol} is any of |
| | ["GB", "MB", "GIB", "MIB"]. Refer to the `max_memory` arg in the |
| | "Parameters for big model inference" section of the following docs: |
| | https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.max_memory |
| | offload_folder (str, optional, defaults to "./offload"): |
| | The folder to offload weights into if `device_map` contains any |
| | "disk" value. |
| | dtype (Union[str, torch.dtype], optional, defaults to None):): |
| | Converts the model weights to `dtype`, if specified. Strings get |
| | converted to `torch.dtype` objects (e.g. `float16` -> `torch.float16`). |
| | Use `dtype="auto"` to derive the type from the model’s weights. |
| | peft (str, optional, defaults to None): |
| | Path of the adapter weights to load from Huggingface. This will usually |
| | include a directory that includes the files `adapter_config.json` and |
| | `adapter_model.bin`. Compatible with [PEFT](https://github.com/huggingface/peft) |
| | load_in_8bit (bool, optional, defaults to False): |
| | If True, will convert the loaded model into mixed-8bit quantized model. See: |
| | https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit |
| | trust_remote_code (bool, optional, defaults to False): |
| | If True, will trust the remote code when loading the model. |
| | use_fast (bool, optional, defaults to True): |
| | If True, will use the fast tokenizer when loading the model. |
| | gptq_use_triton (bool, optional, defaults to False): |
| | Use Triton for GPTQ inference. |
| | """ |
| | super().__init__() |
| |
|
| | assert isinstance(pretrained, str) |
| | assert isinstance(device, str) |
| | assert isinstance(batch_size, int) |
| | if ( |
| | add_special_tokens is not None |
| | and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | assert ( |
| | not add_special_tokens |
| | ), "Evaluating causal models with `add_special_tokens=True` is currently not supported." |
| |
|
| | self._batch_size = batch_size |
| | self._max_gen_toks = max_gen_toks |
| | self._max_length = max_length |
| | self._config = self.AUTO_CONFIG_CLASS.from_pretrained( |
| | pretrained, |
| | trust_remote_code=trust_remote_code, |
| | revision=revision + ("/" + subfolder if subfolder is not None else ""), |
| | ) |
| |
|
| | self._add_special_tokens = add_special_tokens |
| | self.tokenizer = self._create_auto_tokenizer( |
| | pretrained=pretrained, |
| | revision=revision, |
| | subfolder=subfolder, |
| | tokenizer=tokenizer, |
| | use_fast=use_fast, |
| | ) |
| | self.tokenizer.model_max_length = self.max_length |
| |
|
| | model_kwargs = {} |
| | if use_accelerate: |
| | model_kwargs = _get_accelerate_args( |
| | device_map_option, |
| | max_memory_per_gpu, |
| | max_cpu_memory, |
| | offload_folder, |
| | ) |
| | model_kwargs["load_in_8bit"] = load_in_8bit |
| | self.model = self._create_auto_model( |
| | pretrained=pretrained, |
| | quantized=quantized, |
| | trust_remote_code=trust_remote_code, |
| | revision=revision, |
| | subfolder=subfolder, |
| | torch_dtype=_get_dtype(dtype, self._config), |
| | gptq_use_triton=gptq_use_triton, |
| | **model_kwargs, |
| | ) |
| | |
| | if peft is not None: |
| | self.model = self._create_auto_model_peft( |
| | model=self.model, |
| | peft=peft, |
| | revision=revision, |
| | subfolder=subfolder, |
| | torch_dtype=_get_dtype(dtype, self._config), |
| | **model_kwargs, |
| | ) |
| | self.model.eval() |
| | torch.set_grad_enabled(False) |
| |
|
| | self._device = device |
| | if use_accelerate and "lm_head" in self.model.hf_device_map: |
| | |
| | |
| | |
| | self._device = self.model.hf_device_map["lm_head"] |
| | if not use_accelerate and not load_in_8bit: |
| | try: |
| | self.model.to(self._device) |
| | except: |
| | print( |
| | "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore." |
| | ) |
| |
|
| | def _create_auto_model( |
| | self, |
| | *, |
| | pretrained: str, |
| | quantized: Optional[Union[bool, str]] = None, |
| | revision: str, |
| | subfolder: str, |
| | device_map: Optional[Union[str, _DeviceMapping]] = None, |
| | max_memory: Optional[dict] = None, |
| | offload_folder: Optional[str] = None, |
| | load_in_8bit: Optional[bool] = False, |
| | trust_remote_code: Optional[bool] = False, |
| | torch_dtype: Optional[Union[str, torch.dtype]] = None, |
| | gptq_use_triton: Optional[bool] = False, |
| | ) -> transformers.AutoModel: |
| | """Returns a pre-trained pytorch model from a pre-trained model configuration.""" |
| | if quantized is None: |
| | model = self.AUTO_MODEL_CLASS.from_pretrained( |
| | pretrained, |
| | revision=revision + ("/" + subfolder if subfolder is not None else ""), |
| | device_map=device_map, |
| | max_memory=max_memory, |
| | offload_folder=offload_folder, |
| | load_in_8bit=load_in_8bit, |
| | trust_remote_code=trust_remote_code, |
| | torch_dtype=torch_dtype, |
| | ) |
| | else: |
| | from auto_gptq import AutoGPTQForCausalLM |
| |
|
| | model = AutoGPTQForCausalLM.from_quantized( |
| | pretrained, |
| | model_basename=None if quantized is True else Path(quantized).stem, |
| | device_map=device_map, |
| | max_memory=max_memory, |
| | trust_remote_code=trust_remote_code, |
| | use_safetensors=True |
| | if quantized is True |
| | else quantized.endswith(".safetensors"), |
| | use_triton=gptq_use_triton, |
| | warmup_triton=gptq_use_triton, |
| | ) |
| | return model |
| |
|
| | def _create_auto_model_peft( |
| | self, |
| | *, |
| | model: transformers.PreTrainedModel, |
| | peft: str, |
| | revision: str, |
| | subfolder: str, |
| | device_map: Optional[Union[str, _DeviceMapping]] = None, |
| | max_memory: Optional[dict] = None, |
| | offload_folder: Optional[str] = None, |
| | load_in_8bit: Optional[bool] = False, |
| | trust_remote_code: Optional[bool] = False, |
| | torch_dtype: Optional[Union[str, torch.dtype]] = None, |
| | ): |
| | model = self.AUTO_PEFT_CLASS.from_pretrained( |
| | model, |
| | peft, |
| | revision=revision + ("/" + subfolder if subfolder is not None else ""), |
| | device_map=device_map, |
| | max_memory=max_memory, |
| | offload_folder=offload_folder, |
| | load_in_8bit=load_in_8bit, |
| | trust_remote_code=trust_remote_code, |
| | torch_dtype=torch_dtype, |
| | ) |
| | return model |
| |
|
| | def _create_auto_tokenizer( |
| | self, |
| | *, |
| | pretrained: str, |
| | revision: str, |
| | subfolder: str, |
| | tokenizer: Optional[str] = None, |
| | use_fast: Optional[bool] = True, |
| | ) -> transformers.PreTrainedTokenizer: |
| | """Returns a pre-trained tokenizer from a pre-trained tokenizer configuration.""" |
| | tokenizer = self.AUTO_TOKENIZER_CLASS.from_pretrained( |
| | pretrained if tokenizer is None else tokenizer, |
| | revision=revision + ("/" + subfolder if subfolder is not None else ""), |
| | use_fast=use_fast, |
| | ) |
| | tokenizer.pad_token = tokenizer.eos_token |
| | return tokenizer |
| |
|
| | @property |
| | def add_special_tokens(self) -> bool: |
| | """Whether to include special tokens in encoded text. This should be |
| | determined by whether or not the model was trained with special tokens. |
| | TODO: Remove these conditionals once HuggingFace supports a way to |
| | check whether or not an arbitrary model was trained with special tokens. |
| | """ |
| | if self._add_special_tokens is not None: |
| | return self._add_special_tokens |
| | elif self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM: |
| | return False |
| | elif self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM: |
| | return True |
| | else: |
| | raise ValueError( |
| | "Could not determine `add_special_tokens` value from the model " |
| | "class. Set to `True` or `False` depending on whether the model " |
| | "was pre-trained with special tokens." |
| | ) |
| |
|
| | @property |
| | def eot_token(self) -> str: |
| | return self.tokenizer.eos_token |
| |
|
| | @property |
| | def eot_token_id(self) -> int: |
| | return self.tokenizer.eos_token_id |
| |
|
| | @property |
| | def max_gen_toks(self) -> int: |
| | return self._max_gen_toks |
| |
|
| | @property |
| | def max_length(self) -> int: |
| | """Return the maximum sequence length of the model. |
| | NOTE: Different model configurations have different max sequence length |
| | attribute names. |
| | - n_positions: (CTRLConfig) |
| | - max_position_embeddings: (BartConfig, RoFormerConfig) |
| | - n_ctx: (GPT2Config) |
| | NOTE: For relative position encoded models you should specify the max |
| | sequence length of the model in the constructor via `max_length`. |
| | """ |
| | if self._max_length is not None: |
| | return self._max_length |
| | |
| | seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") |
| | for attr in seqlen_config_attrs: |
| | if hasattr(self._config, attr): |
| | return getattr(self._config, attr) |
| | if hasattr(self.tokenizer, "model_max_length"): |
| | return self.tokenizer.model_max_length |
| | return self._DEFAULT_MAX_LENGTH |
| |
|
| | @property |
| | def batch_size(self) -> int: |
| | |
| | return self._batch_size |
| |
|
| | @property |
| | def device(self) -> Union[int, str, torch.device]: |
| | return self._device |
| |
|
| | def tok_encode(self, string: str) -> TokenSequence: |
| | |
| | return self.tokenizer.encode(string, add_special_tokens=self.add_special_tokens) |
| |
|
| | def tok_encode_batch(self, strings: List[str]) -> TokenSequence: |
| | return self.tokenizer( |
| | strings, |
| | padding=True, |
| | add_special_tokens=self.add_special_tokens, |
| | return_tensors="pt", |
| | ) |
| |
|
| | def tok_decode(self, tokens: torch.LongTensor) -> List[str]: |
| | return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) |
| |
|
| | def greedy_until( |
| | self, requests: List[Tuple[str, Union[List[str], str]]] |
| | ) -> List[str]: |
| | def _collate(x): |
| | tokens = self.tok_encode(x[0]) |
| | return len(tokens), x[0] |
| |
|
| | results = [] |
| | reorder = utils.Reorderer(requests, _collate) |
| | for chunk in utils.chunks( |
| | tqdm(reorder.get_reordered(), disable=False), self.batch_size |
| | ): |
| | context = [c[0] for c in chunk] |
| | request_args = chunk[0][1] |
| | stop_sequences = ( |
| | request_args if isinstance(request_args, list) else [request_args] |
| | ) |
| | max_generation_length = ( |
| | self._max_gen_toks |
| | ) |
| |
|
| | assert ( |
| | isinstance(max_generation_length, int) or max_generation_length is None |
| | ) |
| | assert isinstance(stop_sequences, list) or stop_sequences is None |
| |
|
| | |
| | if stop_sequences is None: |
| | until = [self.eot_token] |
| | else: |
| | until = stop_sequences + [self.eot_token] |
| |
|
| | if max_generation_length is None: |
| | max_tokens = self.max_gen_toks |
| | else: |
| | max_tokens = max_generation_length |
| |
|
| | token_context = self.tok_encode_batch(context) |
| |
|
| | responses = self._model_generate( |
| | inputs=token_context, |
| | max_tokens=max_tokens, |
| | stop=until, |
| | ) |
| | responses = self.tok_decode(responses.tolist()) |
| |
|
| | for response in responses: |
| | |
| | for term in until: |
| | response = response.split(term)[0] |
| | |
| | self.cache_hook.add_partial("greedy_until", (context, until), response) |
| | results.append(response) |
| | return reorder.get_original(results) |
| |
|
| |
|
| | class AutoCausalLM(HuggingFaceAutoLM): |
| | """Causal language modeling. |
| | You can find a set of supported models in the HF documentation: |
| | https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForCausalLM |
| | """ |
| |
|
| | AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM |
| | AUTO_PEFT_CLASS = peft.PeftModel |
| |
|
| | def _create_auto_tokenizer( |
| | self, |
| | *, |
| | pretrained: str, |
| | revision: str, |
| | subfolder: str, |
| | tokenizer: Optional[str] = None, |
| | use_fast: Optional[bool] = True, |
| | ) -> transformers.PreTrainedTokenizer: |
| | tokenizer = super()._create_auto_tokenizer( |
| | pretrained=pretrained, |
| | revision=revision, |
| | subfolder=subfolder, |
| | tokenizer=tokenizer, |
| | use_fast=use_fast, |
| | ) |
| | tokenizer.padding_side = "left" |
| | return tokenizer |
| |
|
| | def _model_call( |
| | self, inputs: TokenSequence, labels: Optional[TokenSequence] = None |
| | ) -> TokenSequence: |
| | return self.model(inputs)["logits"] |
| |
|
| | def _model_generate( |
| | self, |
| | inputs: transformers.BatchEncoding, |
| | max_tokens: int, |
| | stop: Optional[List[str]] = None, |
| | ) -> TokenSequence: |
| | |
| | |
| | input_ids = inputs["input_ids"][:, self.max_gen_toks - self.max_length :] |
| | attention_mask = inputs["attention_mask"][ |
| | :, self.max_gen_toks - self.max_length : |
| | ] |
| | input_ids = input_ids.to(self.device) |
| | attention_mask = attention_mask.to(self.device) |
| |
|
| | stopping_criteria = stop_sequences_criteria( |
| | self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0] |
| | ) |
| |
|
| | generations = self.model.generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | |
| | |
| | |
| | max_new_tokens=max_tokens, |
| | stopping_criteria=stopping_criteria, |
| | do_sample=False, |
| | ) |
| | return utils.select_continuation_from_batch_left_padding( |
| | generations, max_context_size=inputs["input_ids"].size(1) |
| | ) |
| |
|
| |
|
| | class AutoSeq2SeqLM(HuggingFaceAutoLM): |
| | """Seq2Seq language modeling. |
| | You can find a set of supported models in the following documentation: |
| | https://huggingface.co/docs/transformers/main/model_doc/auto#transformers.AutoModelForSeq2SeqLM |
| | """ |
| |
|
| | AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM |
| | AUTO_PEFT_CLASS = peft.PeftModel |
| |
|
| | @property |
| | def max_length(self) -> int: |
| | """Return the maximum sequence length of the model. |
| | TODO: Currently only works for relative position encoded Seq2Seq models. |
| | """ |
| | if self._max_length is not None: |
| | return self._max_length |
| | return self._DEFAULT_MAX_LENGTH |
| |
|
| | def loglikelihood( |
| | self, requests: List[Tuple[str, str]] |
| | ) -> List[Tuple[float, bool]]: |
| | new_requests = [] |
| | for chunk in utils.chunks(requests, self.batch_size): |
| | context, continuation = zip(*chunk) |
| |
|
| | |
| | context = [ |
| | f"{self.eot_token}" if len(text) == 0 else text for text in context |
| | ] |
| | context_enc = self.tok_encode_batch(context) |
| | for key in context_enc: |
| | context_enc[key] = context_enc[key][:, -self.max_length :] |
| |
|
| | |
| | |
| | |
| | continuation = [text.lstrip() for text in continuation] |
| | continuation_enc = self.tok_encode_batch(list(continuation)) |
| | for key in continuation_enc: |
| | continuation_enc[key] = continuation_enc[key][:, -self.max_length :] |
| |
|
| | new_requests.append( |
| | ((context, continuation), context_enc, continuation_enc) |
| | ) |
| | return self._loglikelihood_tokens(new_requests) |
| |
|
| | def loglikelihood_rolling(self, requests: List[Tuple[str, str]]) -> List[float]: |
| | loglikelihoods = [] |
| | for (string,) in tqdm(requests): |
| | rolling_token_windows = list( |
| | map( |
| | utils.make_disjoint_window, |
| | utils.get_rolling_token_windows( |
| | token_list=self.tok_encode(string), |
| | prefix_token=self.eot_token_id, |
| | max_seq_len=self.max_length, |
| | context_len=1, |
| | ), |
| | ) |
| | ) |
| | contexts, conts = utils.split_and_pad_windows( |
| | rolling_token_windows, |
| | pad_token_id=self.eot_token_id, |
| | max_seq_len=self.max_length, |
| | ) |
| | |
| | |
| | contexts_enc = torch.Tensor(contexts).long() |
| | contexts_enc = transformers.tokenization_utils_base.BatchEncoding( |
| | { |
| | "input_ids": contexts_enc, |
| | "attention_mask": (contexts_enc != self.eot_token_id).long(), |
| | } |
| | ) |
| | conts_enc = torch.Tensor(conts).long() |
| | conts_enc = transformers.tokenization_utils_base.BatchEncoding( |
| | { |
| | "input_ids": conts_enc, |
| | "attention_mask": (conts_enc != self.eot_token_id).long(), |
| | } |
| | ) |
| | |
| | |
| | rolling_token_windows_request = [ |
| | ((contexts, conts), contexts_enc, conts_enc) |
| | ] |
| | string_nll = self._loglikelihood_tokens( |
| | rolling_token_windows_request, disable_tqdm=True |
| | ) |
| | string_nll = [x[0] for x in string_nll] |
| | string_nll = sum(string_nll) |
| | loglikelihoods.append(string_nll) |
| | return loglikelihoods |
| |
|
| | def _loglikelihood_tokens( |
| | self, |
| | requests: List[Tuple[Tuple[str, str], TokenSequence, TokenSequence]], |
| | disable_tqdm: Optional[bool] = False, |
| | ) -> List[Tuple[float, bool]]: |
| | results = [] |
| | for chunk in tqdm( |
| | requests, total=math.ceil(len(requests)), disable=disable_tqdm |
| | ): |
| | cache_keys, inputs_tokens, targets_tokens = chunk |
| | inputs_tokens = inputs_tokens.to(self.device) |
| | targets_tokens = targets_tokens.to(self.device) |
| | outputs = self._model_call(inputs=inputs_tokens, labels=targets_tokens) |
| | log_softmaxes = F.log_softmax(outputs.logits, dim=-1) |
| |
|
| | output_iterator = zip( |
| | zip(cache_keys[0], cache_keys[1]), |
| | log_softmaxes, |
| | targets_tokens["input_ids"], |
| | targets_tokens["attention_mask"], |
| | ) |
| | for cache_key, log_softmax, target_tokens, target_mask in output_iterator: |
| | length = target_mask.sum() |
| | log_softmax = log_softmax[:length] |
| | target_tokens = target_tokens[:length] |
| | greedy_tokens = log_softmax.argmax(dim=-1) |
| | max_equal = (greedy_tokens == target_tokens).all() |
| | target_logits = torch.gather( |
| | log_softmax, 1, target_tokens.unsqueeze(-1) |
| | ).squeeze(-1) |
| | answer = (float(target_logits.sum()), bool(max_equal)) |
| | results.append(answer) |
| | if cache_key is not None: |
| | self.cache_hook.add_partial("loglikelihood", cache_key, answer) |
| | return results |
| |
|
| | def _model_call( |
| | self, inputs: TokenSequence, labels: Optional[TokenSequence] = None |
| | ) -> TokenSequence: |
| | return self.model(**inputs, labels=labels["input_ids"]) |
| |
|
| | def _model_generate( |
| | self, |
| | inputs: transformers.BatchEncoding, |
| | max_tokens: int, |
| | stop: Optional[List[str]] = None, |
| | ) -> TokenSequence: |
| | input_ids = inputs["input_ids"][:, -self.max_length :].to(self.device) |
| | attention_mask = inputs["attention_mask"][:, -self.max_length :].to(self.device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | stopping_criteria = stop_sequences_criteria( |
| | self.tokenizer, stop, 1, input_ids.shape[0] |
| | ) |
| |
|
| | generations = self.model.generate( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | max_new_tokens=max_tokens, |
| | stopping_criteria=stopping_criteria, |
| | do_sample=False, |
| | ) |
| | return generations |
| |
|
| |
|
| | class MultiTokenEOSCriteria(transformers.StoppingCriteria): |
| | """Criteria to stop on the specified multi-token sequence.""" |
| |
|
| | def __init__( |
| | self, |
| | sequence: str, |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | initial_decoder_input_length: int, |
| | batch_size: int, |
| | ): |
| | self.initial_decoder_input_length = initial_decoder_input_length |
| | self.done_tracker = [False] * batch_size |
| | self.sequence = sequence |
| | self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False) |
| | self.sequence_id_len = len(self.sequence_ids) |
| | self.tokenizer = tokenizer |
| |
|
| | def __call__(self, input_ids, scores, **kwargs) -> bool: |
| | |
| | lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][ |
| | :, -self.sequence_id_len : |
| | ] |
| |
|
| | lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) |
| |
|
| | for i, done in enumerate(self.done_tracker): |
| | if not done: |
| | self.done_tracker[i] = self.sequence in lookback_tokens_batch[i] |
| | return False not in self.done_tracker |
| |
|
| |
|
| | def stop_sequences_criteria( |
| | tokenizer: transformers.PreTrainedTokenizer, |
| | stop_sequences: List[str], |
| | initial_decoder_input_length: int, |
| | batch_size: int, |
| | ) -> transformers.StoppingCriteriaList: |
| | return transformers.StoppingCriteriaList( |
| | [ |
| | *[ |
| | MultiTokenEOSCriteria( |
| | sequence, tokenizer, initial_decoder_input_length, batch_size |
| | ) |
| | for sequence in stop_sequences |
| | ], |
| | ] |
| | ) |
| |
|