| import copy |
| import os |
| from datetime import timedelta |
| from pathlib import Path |
| from typing import Dict, List, Literal, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import transformers |
| from accelerate import ( |
| Accelerator, |
| InitProcessGroupKwargs, |
| find_executable_batch_size, |
| ) |
| from accelerate.utils import get_max_memory |
| from huggingface_hub import HfApi |
| from packaging import version |
| from peft import PeftModel |
| from peft import __version__ as PEFT_VERSION |
| from tqdm import tqdm |
| from transformers.models.auto.modeling_auto import ( |
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, |
| MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, |
| ) |
|
|
| from lm_eval import utils |
| from lm_eval.api.instance import Instance |
| from lm_eval.api.model import TemplateLM |
| from lm_eval.api.registry import register_model |
| from lm_eval.models.utils import ( |
| Collator, |
| clear_torch_cache, |
| configure_pad_token, |
| get_dtype, |
| pad_and_concat, |
| stop_sequences_criteria, |
| ) |
|
|
|
|
| eval_logger = utils.eval_logger |
|
|
|
|
| @register_model("hf-auto", "hf", "huggingface") |
| class HFLM(TemplateLM): |
| """ |
| An abstracted Huggingface model class. Enables usage with both models of |
| `transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes. |
| |
| Supports data-parallel multi-GPU with HF Accelerate. |
| """ |
|
|
| AUTO_MODEL_CLASS = None |
| _DEFAULT_MAX_LENGTH = 2048 |
|
|
| def __init__( |
| self, |
| pretrained: Union[str, transformers.PreTrainedModel], |
| backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", |
| |
| revision: Optional[str] = "main", |
| subfolder: Optional[str] = None, |
| tokenizer: Optional[ |
| Union[ |
| str, |
| transformers.PreTrainedTokenizer, |
| transformers.PreTrainedTokenizerFast, |
| ] |
| ] = None, |
| truncation: Optional[bool] = False, |
| logits_cache: bool = True, |
| max_length: Optional[int] = None, |
| device: Optional[str] = "cuda", |
| dtype: Optional[Union[str, torch.dtype]] = "auto", |
| batch_size: Optional[Union[int, str]] = 1, |
| max_batch_size: Optional[int] = 64, |
| trust_remote_code: Optional[bool] = False, |
| use_fast_tokenizer: Optional[bool] = True, |
| add_bos_token: Optional[bool] = False, |
| prefix_token_id: Optional[int] = None, |
| |
| |
| parallelize: Optional[bool] = False, |
| max_memory_per_gpu: Optional[Union[int, str]] = None, |
| max_cpu_memory: Optional[Union[int, str]] = None, |
| offload_folder: Optional[Union[str, os.PathLike]] = "./offload", |
| |
| peft: Optional[str] = None, |
| delta: Optional[str] = None, |
| autogptq: Optional[Union[bool, str]] = False, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
|
|
| |
| if not isinstance(pretrained, str): |
| eval_logger.warning( |
| "`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way." |
| ) |
| assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`" |
| self._model = pretrained |
| self._device = self._model.device |
| self._config = self._model.config |
| gpus = 0 |
|
|
| else: |
| assert isinstance(device, str) |
| assert isinstance(pretrained, str) |
| assert isinstance(batch_size, (int, str)) |
|
|
| gpus = torch.cuda.device_count() |
| accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) |
| accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) |
| if accelerator.num_processes > 1: |
| self.accelerator = accelerator |
|
|
| if "npu" in accelerator.device.type: |
| gpus = torch.npu.device_count() |
|
|
| |
| if not (parallelize or accelerator.num_processes > 1): |
| |
| device_list = set( |
| ["cuda", "cpu"] |
| + [f"cuda:{i}" for i in range(gpus)] |
| + ["mps", "mps:0"] |
| + [f"npu:{i}" for i in range(gpus)] |
| ) |
| if device and device in device_list: |
| self._device = torch.device(device) |
| eval_logger.info(f"Using device '{device}'") |
| if device in ("mps", "mps:0") and version.parse( |
| torch.__version__ |
| ) < version.parse("2.1"): |
| raise RuntimeError( |
| f"mps requires torch >= 2.1. You have {torch.__version__}" |
| ) |
| else: |
| eval_logger.info("Device not specified") |
| eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") |
| self._device = ( |
| torch.device("cuda") |
| if torch.cuda.is_available() |
| else torch.device("cpu") |
| ) |
| else: |
| if device != "cuda": |
| eval_logger.info( |
| f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model." |
| ) |
| |
| self._device = ( |
| self.accelerator.device |
| if hasattr(self, "accelerator") |
| else torch.device(device) |
| ) |
|
|
| revision = str(revision) |
| |
| revision = revision + ("/" + subfolder if subfolder is not None else "") |
|
|
| self._get_config( |
| pretrained, |
| revision=revision, |
| trust_remote_code=trust_remote_code, |
| ) |
|
|
| |
| self._get_backend( |
| config=self.config, backend=backend, trust_remote_code=trust_remote_code |
| ) |
|
|
| |
| self._create_tokenizer( |
| pretrained, |
| tokenizer, |
| revision=revision, |
| trust_remote_code=trust_remote_code, |
| use_fast_tokenizer=use_fast_tokenizer, |
| ) |
|
|
| |
| if isinstance(pretrained, str): |
| self._create_model( |
| pretrained=pretrained, |
| revision=revision, |
| dtype=dtype, |
| trust_remote_code=trust_remote_code, |
| parallelize=parallelize, |
| gpus=gpus, |
| max_memory_per_gpu=max_memory_per_gpu, |
| max_cpu_memory=max_cpu_memory, |
| offload_folder=offload_folder, |
| peft=peft, |
| delta=delta, |
| autogptq=autogptq, |
| **kwargs, |
| ) |
|
|
| |
| if isinstance(self.model, torch.nn.Module): |
| self.model.eval() |
| self.model.tie_weights() |
|
|
| self.truncation = truncation |
| self.logits_cache = logits_cache |
| self.vocab_size = self.tokenizer.vocab_size |
| |
| self.tokenizer = configure_pad_token(self.tokenizer, model_config=self.config) |
|
|
| self.add_bos_token = add_bos_token |
| if "gemma" in getattr(self.config, "model_type", ""): |
| self.add_bos_token = True |
| eval_logger.info( |
| f"Model type is '{self.config.model_type}', part of the Gemma family--a BOS token will be used as Gemma underperforms without it." |
| ) |
|
|
| self._max_length = max_length |
| self.pretrained = pretrained |
| self.delta = delta |
| self.peft = peft |
| self.revision = revision |
| self.batch_schedule = 1 |
| self.batch_sizes = {} |
| self.max_batch_size = max_batch_size |
|
|
| if str(batch_size).startswith("auto"): |
| batch_size = batch_size.split(":") |
| self.batch_size_per_gpu = batch_size[0] |
| self.batch_schedule = float(batch_size[1]) if len(batch_size) > 1 else 1 |
| else: |
| self.batch_size_per_gpu = int(batch_size) |
|
|
| if isinstance(pretrained, str): |
| if gpus >= 1 or str(self.device) == "mps": |
| |
| if not (parallelize or autogptq or hasattr(self, "accelerator")): |
| |
| |
| |
| try: |
| self.model.to(self.device) |
| except ValueError: |
| eval_logger.debug( |
| "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore." |
| ) |
| |
| if gpus > 1: |
| if accelerator.num_processes > 1: |
| if parallelize: |
| eval_logger.warning( |
| "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available." |
| ) |
| elif gpus > accelerator.num_processes: |
| eval_logger.warning( |
| "WARNING: The number of total system GPUs does not match the number of spawned processes. " |
| "If you would like to use data parallelism, please launch the script " |
| "with 'accelerate launch *script*'. " |
| f"Current run will proceed with {accelerator.num_processes} devices." |
| ) |
| if self.accelerator.is_local_main_process: |
| eval_logger.info( |
| f"Using {gpus} devices with data parallelism" |
| ) |
|
|
| self._device = torch.device(f"{accelerator.device}") |
| self.accelerator = accelerator |
|
|
| self._rank = self.accelerator.local_process_index |
| self._world_size = self.accelerator.num_processes |
| else: |
| |
| self._rank = 0 |
| self._world_size = 1 |
| else: |
| |
| eval_logger.warning( |
| "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration" |
| ) |
| self._rank = 0 |
| self._world_size = 1 |
|
|
| self.custom_prefix_token_id = prefix_token_id |
| if prefix_token_id is not None: |
| eval_logger.info( |
| f"Loglikelihood prefix token id used in evaluation: {self.prefix_token_id}" |
| ) |
|
|
| def _get_accelerate_args( |
| self, |
| parallelize: bool = None, |
| device_map: Optional[str] = "auto", |
| max_memory_per_gpu: Optional[Union[int, str]] = None, |
| max_cpu_memory: Optional[Union[int, str]] = None, |
| offload_folder: Optional[str] = "./offload", |
| gpus: Optional[int] = None, |
| ) -> dict: |
| """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" |
| num_local_processes = int(os.environ.get("LOCAL_WORLD_SIZE", 1)) |
| num_machines = int(os.environ.get("WORLD_SIZE", 0)) // num_local_processes |
| if ( |
| num_machines == 0 |
| and hasattr(self, "accelerator") |
| and self.accelerator is not None |
| ): |
| eval_logger.info( |
| "We are not in a distributed setting for accelerate. Setting model_parallel to False." |
| ) |
| parallelize = False |
|
|
| if parallelize is None: |
| |
| |
| max_memory_all_gpus = get_max_memory() |
| |
| if "cpu" in max_memory_all_gpus: |
| del max_memory_all_gpus["cpu"] |
| parallelize = bool(num_local_processes < len(max_memory_all_gpus)) |
| eval_logger.info( |
| f"Setting model parallel to {parallelize} since " |
| f"the number of local processes is {num_local_processes} " |
| f"and the number of GPUs is {len(max_memory_all_gpus)}" |
| ) |
|
|
| args = {} |
| if parallelize: |
| max_memory = {} |
| if max_memory_per_gpu is not None: |
| max_memory_per_gpu_map = { |
| device_idx: max_memory_per_gpu for device_idx in range(gpus) |
| } |
| else: |
| max_memory_all_gpus = get_max_memory() |
| if "cpu" in max_memory_all_gpus: |
| del max_memory_all_gpus["cpu"] |
| if not hasattr(self, "accelerator"): |
| max_memory_per_gpu_map = { |
| k: v for k, v in max_memory_all_gpus.items() |
| } |
| else: |
| |
| max_memory_per_gpu_map = { |
| k: v |
| for k, v in max_memory_all_gpus.items() |
| if k % num_local_processes |
| == (self.accelerator.process_index % num_local_processes) |
| } |
| args["max_memory"] = max_memory_per_gpu_map |
| args["device_map"] = "auto" |
| eval_logger.info( |
| f"Model parallel was set to True, setting max memory per GPU to {max_memory_per_gpu_map} and device map to 'auto'" |
| ) |
|
|
| if max_cpu_memory is not None: |
| max_memory["cpu"] = max_cpu_memory |
|
|
| args["offload_folder"] = offload_folder |
| elif ( |
| device_map is None |
| ): |
| if hasattr(self, "accelerator"): |
| device_map = {"": f"{self.accelerator.device}"} |
| else: |
| device_map = {"": str(self.device)} |
| args["max_memory"] = None |
| args["device_map"] = device_map |
| eval_logger.info( |
| f"Model parallel was set to False, max memory was not set, and device map was set to {device_map}" |
| ) |
| else: |
| args["max_memory"] = None |
| args["device_map"] = None |
| eval_logger.info("Model parallel was set to False.") |
|
|
| return args |
|
|
| @property |
| def config(self): |
| |
| return self._config |
|
|
| @property |
| def model(self): |
| |
| if hasattr(self, "accelerator"): |
| return self.accelerator.unwrap_model(self._model) |
| else: |
| return self._model |
|
|
| @property |
| def eot_token_id(self): |
| |
| return self.tokenizer.eos_token_id |
|
|
| @property |
| def prefix_token_id(self): |
| |
| if self.custom_prefix_token_id is not None: |
| return self.custom_prefix_token_id |
| if self.tokenizer.bos_token_id is not None: |
| return self.tokenizer.bos_token_id |
| return self.tokenizer.eos_token_id |
|
|
| @property |
| def max_length(self): |
| if self._max_length: |
| return self._max_length |
| seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx") |
| for attr in seqlen_config_attrs: |
| if hasattr(self.model.config, attr): |
| return getattr(self.model.config, attr) |
| if hasattr(self.tokenizer, "model_max_length"): |
| if self.tokenizer.model_max_length == 1000000000000000019884624838656: |
| return self._DEFAULT_MAX_LENGTH |
| return self.tokenizer.model_max_length |
| return self._DEFAULT_MAX_LENGTH |
|
|
| @property |
| def max_gen_toks(self) -> int: |
| return 256 |
|
|
| @property |
| def batch_size(self): |
| return self.batch_size_per_gpu |
|
|
| @property |
| def device(self): |
| return self._device |
|
|
| @property |
| def rank(self): |
| return self._rank |
|
|
| @property |
| def world_size(self): |
| return self._world_size |
|
|
| @property |
| def tokenizer_name(self) -> str: |
| return self.tokenizer.name_or_path.replace("/", "__") |
|
|
| @property |
| def chat_template(self) -> str: |
| if self.tokenizer.chat_template is not None: |
| return self.tokenizer.chat_template |
| return self.tokenizer.default_chat_template |
|
|
| def _get_backend( |
| self, |
| config: Union[transformers.PretrainedConfig, transformers.AutoConfig], |
| backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", |
| trust_remote_code: Optional[bool] = False, |
| ) -> None: |
| """ |
| Helper method during initialization. |
| Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) |
| model type to be used. |
| """ |
| assert backend in ["default", "causal", "seq2seq"] |
|
|
| if backend != "default": |
| |
| if backend == "causal": |
| self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM |
| elif backend == "seq2seq": |
| self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM |
| eval_logger.info( |
| f"Overrode HF model backend type, and using type '{backend}'" |
| ) |
| else: |
| |
| if ( |
| getattr(config, "model_type") |
| in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
| ): |
| |
| |
| |
| self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM |
| elif ( |
| getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES |
| ): |
| self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM |
| else: |
| if not trust_remote_code: |
| eval_logger.warning( |
| "HF model type is neither marked as CausalLM or Seq2SeqLM. \ |
| This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." |
| ) |
| |
| |
| self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM |
|
|
| assert self.AUTO_MODEL_CLASS in [ |
| transformers.AutoModelForCausalLM, |
| transformers.AutoModelForSeq2SeqLM, |
| ] |
| return None |
|
|
| def _get_config( |
| self, |
| pretrained: str, |
| revision: str = "main", |
| trust_remote_code: bool = False, |
| ) -> None: |
| self._config = transformers.AutoConfig.from_pretrained( |
| pretrained, |
| revision=revision, |
| trust_remote_code=trust_remote_code, |
| ) |
|
|
| def _create_model( |
| self, |
| pretrained: str, |
| revision: Optional[str] = "main", |
| dtype: Optional[Union[str, torch.dtype]] = "auto", |
| trust_remote_code: Optional[bool] = False, |
| |
| |
| |
| parallelize: Optional[bool] = False, |
| gpus: Optional[int] = None, |
| max_memory_per_gpu: Optional[Union[int, str]] = None, |
| max_cpu_memory: Optional[Union[int, str]] = None, |
| offload_folder: Optional[str] = "./offload", |
| |
| peft: Optional[str] = None, |
| delta: Optional[str] = None, |
| autogptq: Optional[Union[bool, str]] = False, |
| **kwargs, |
| ) -> None: |
| """ |
| Initializes an HF or HF-compatible PreTrainedModel from scratch |
| inside HFLM, using the kwargs passed into self.__init__(). |
| |
| Also handles functionality such as AutoGPTQ usage and PEFT wrapping. |
| |
| For future similar extensions to AutoGPTQ that are not core to HF's ecosystem, |
| (such as PyTorch models that are nearly, but not quite, fully mirroring |
| HF's public interface relied on in this HFLM class) |
| please consider subclassing HFLM and overriding this and other methods as needed. |
| """ |
|
|
| model_kwargs = kwargs if kwargs else {} |
|
|
| model_kwargs.update( |
| self._get_accelerate_args( |
| parallelize=parallelize, |
| device_map=kwargs.get("device_map", None), |
| max_memory_per_gpu=max_memory_per_gpu, |
| max_cpu_memory=max_cpu_memory, |
| offload_folder=offload_folder, |
| gpus=gpus, |
| ) |
| ) |
|
|
| if not autogptq: |
| if model_kwargs.get("load_in_4bit", None): |
| assert ( |
| transformers.__version__ >= "4.30.0" |
| ), "load_in_4bit requires transformers >= 4.30.0" |
| if transformers.__version__ >= "4.30.0": |
| if model_kwargs.get("load_in_4bit", None): |
| if model_kwargs.get("bnb_4bit_compute_dtype", None): |
| model_kwargs["bnb_4bit_compute_dtype"] = get_dtype( |
| model_kwargs["bnb_4bit_compute_dtype"] |
| ) |
|
|
| self._model = self.AUTO_MODEL_CLASS.from_pretrained( |
| pretrained, |
| revision=revision, |
| torch_dtype=get_dtype(dtype), |
| trust_remote_code=trust_remote_code, |
| **model_kwargs, |
| ) |
| else: |
| try: |
| from auto_gptq import AutoGPTQForCausalLM |
| except ModuleNotFoundError: |
| raise Exception( |
| "Tried to load auto_gptq, but auto-gptq is not installed ", |
| "please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]", |
| ) |
|
|
| self._model = AutoGPTQForCausalLM.from_quantized( |
| pretrained, |
| trust_remote_code=trust_remote_code, |
| model_basename=None if autogptq is True else Path(autogptq).stem, |
| use_safetensors=True |
| if autogptq is True |
| else autogptq.endswith(".safetensors"), |
| **model_kwargs, |
| ) |
|
|
| if peft and delta: |
| raise ValueError( |
| "Cannot use both 'peft' and 'delta' options at the same time." |
| ) |
|
|
| if peft: |
| if model_kwargs.get("load_in_4bit", None): |
| if version.parse(PEFT_VERSION) < version.parse("0.4.0"): |
| raise AssertionError("load_in_4bit requires peft >= 0.4.0") |
| if self._model.config.vocab_size != len(self.tokenizer): |
| |
| self._model.resize_token_embeddings(len(self.tokenizer)) |
| eval_logger.info( |
| f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..." |
| ) |
| self._model = PeftModel.from_pretrained( |
| self._model, peft, revision=revision |
| ) |
| elif delta: |
| if autogptq: |
| eval_logger.warning( |
| "Delta weights might trigger unexpected behavior when used with AutoGPTQ." |
| ) |
| _model_delta = self.AUTO_MODEL_CLASS.from_pretrained( |
| delta, |
| revision=revision, |
| torch_dtype=get_dtype(dtype), |
| trust_remote_code=trust_remote_code, |
| **model_kwargs, |
| ) |
| for name, param in self._model.state_dict().items(): |
| try: |
| param.data += _model_delta.state_dict()[name] |
| except KeyError: |
| raise KeyError(f"Delta model is missing weights for layer: {name}") |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed to add delta weights to layer {name}. Error: {e}" |
| ) |
|
|
| del _model_delta |
|
|
| return None |
|
|
| def _create_tokenizer( |
| self, |
| pretrained: Union[str, transformers.PreTrainedModel], |
| tokenizer: Optional[ |
| Union[ |
| str, |
| transformers.PreTrainedTokenizer, |
| transformers.PreTrainedTokenizerFast, |
| ] |
| ], |
| revision: Optional[str] = "main", |
| trust_remote_code: Optional[bool] = False, |
| use_fast_tokenizer: Optional[bool] = True, |
| ) -> None: |
| """ |
| Helper method during initialization. |
| |
| Create a tokenizer object corresponding to the correct |
| tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed. |
| """ |
|
|
| if tokenizer: |
| if isinstance(tokenizer, str): |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
| tokenizer, |
| revision=revision, |
| trust_remote_code=trust_remote_code, |
| use_fast=use_fast_tokenizer, |
| ) |
| else: |
| assert isinstance( |
| tokenizer, transformers.PreTrainedTokenizer |
| ) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast) |
| self.tokenizer = tokenizer |
| else: |
| |
| if isinstance(pretrained, str): |
| model_name = pretrained |
| else: |
| |
| model_name = self.model.name_or_path |
| self.tokenizer = transformers.AutoTokenizer.from_pretrained( |
| model_name, |
| revision=revision, |
| trust_remote_code=trust_remote_code, |
| use_fast=use_fast_tokenizer, |
| ) |
| return None |
|
|
| def _detect_batch_size(self, requests=None, pos: int = 0): |
| if requests: |
| _, context_enc, continuation_enc = requests[pos] |
| max_length = len( |
| (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] |
| ) |
| max_context_enc = len(context_enc[-(self.max_length + 1) :]) |
| max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) |
| else: |
| max_length = self.max_length |
| max_context_enc = max_length |
| max_cont_enc = max_length |
|
|
| |
| @find_executable_batch_size(starting_batch_size=self.max_batch_size) |
| def forward_batch(batch_size): |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
| length = max(max_context_enc, max_cont_enc) |
| batched_conts = torch.ones( |
| (batch_size, length), device=self.device |
| ).long() |
| test_batch = torch.ones((batch_size, length), device=self.device).long() |
| call_kwargs = { |
| "attn_mask": test_batch, |
| "labels": batched_conts, |
| } |
| else: |
| call_kwargs = {} |
| test_batch = torch.ones( |
| (batch_size, max_length), device=self.device |
| ).long() |
| for _ in range(5): |
| out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) |
|
|
| return batch_size |
|
|
| try: |
| batch_size = forward_batch() |
| except RuntimeError as e: |
| if "No executable batch size found" in str(e): |
| batch_size = 1 |
| else: |
| raise |
|
|
| if self.world_size > 1: |
| |
| max_rnk_bs = torch.tensor([batch_size], device=self.device) |
| gathered = ( |
| self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() |
| ) |
| batch_size = min(gathered) |
| clear_torch_cache() |
| return batch_size |
|
|
| clear_torch_cache() |
| return batch_size |
|
|
| def tok_encode( |
| self, string: str, left_truncate_len=None, add_special_tokens=None |
| ) -> List[int]: |
| """ """ |
| |
| |
| special_tokens_kwargs = {} |
|
|
| |
| if add_special_tokens is None: |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| special_tokens_kwargs = { |
| "add_special_tokens": False or self.add_bos_token |
| } |
| |
| else: |
| special_tokens_kwargs = {"add_special_tokens": add_special_tokens} |
|
|
| encoding = self.tokenizer.encode(string, **special_tokens_kwargs) |
|
|
| |
| if left_truncate_len: |
| encoding = encoding[-left_truncate_len:] |
|
|
| return encoding |
|
|
| def tok_batch_encode( |
| self, |
| strings: List[str], |
| padding_side: str = "left", |
| left_truncate_len: int = None, |
| truncation: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| old_padding_side = self.tokenizer.padding_side |
| self.tokenizer.padding_side = padding_side |
|
|
| add_special_tokens = {} |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| add_special_tokens = {"add_special_tokens": False or self.add_bos_token} |
|
|
| encoding = self.tokenizer( |
| strings, |
| truncation=truncation, |
| padding="longest", |
| return_tensors="pt", |
| **add_special_tokens, |
| ) |
| if left_truncate_len: |
| encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] |
| encoding["attention_mask"] = encoding["attention_mask"][ |
| :, -left_truncate_len: |
| ] |
| self.tokenizer.padding_side = old_padding_side |
|
|
| return encoding["input_ids"], encoding["attention_mask"] |
|
|
| def tok_decode(self, tokens, skip_special_tokens=True): |
| return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) |
|
|
| def _model_call(self, inps, attn_mask=None, labels=None): |
| """ |
| :param inps: torch.Tensor |
| A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape |
| [batch, sequence_ctx]. the size of sequence may vary from call to call |
| :param attn_mask: torch.Tensor, optional |
| A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed |
| (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM |
| :param labels: torch.Tensor, optional |
| A torch tensor of shape [batch, (sequence_ctx + sequence_cont)]. Only passed |
| (and must be passed) if self.AUTO_MODEL_CLASS is transformers.AutoModelForSeq2SeqLM |
| :return |
| A torch tensor of shape [batch, sequence, vocab] with the |
| logits returned from the model's decoder |
| """ |
| with torch.no_grad(): |
| if attn_mask is not None or labels is not None: |
| assert attn_mask is not None and labels is not None |
| assert self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM |
| return self.model( |
| input_ids=inps, attention_mask=attn_mask, labels=labels |
| ).logits |
| else: |
| assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
| return self.model(inps).logits |
|
|
| def _model_generate(self, context, max_length, stop, **generation_kwargs): |
| |
| |
| |
| |
| generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) |
| do_sample = generation_kwargs.get("do_sample", None) |
|
|
| |
| if generation_kwargs.get("temperature") == 0.0 and do_sample is None: |
| generation_kwargs["do_sample"] = do_sample = False |
|
|
| if do_sample is False and generation_kwargs.get("temperature") == 0.0: |
| generation_kwargs.pop("temperature") |
| |
| stopping_criteria = stop_sequences_criteria( |
| self.tokenizer, stop, context.shape[1], context.shape[0] |
| ) |
| return self.model.generate( |
| input_ids=context, |
| max_length=max_length, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=self.tokenizer.pad_token_id, |
| use_cache=True, |
| **generation_kwargs, |
| ) |
|
|
| def _select_cont_toks( |
| self, logits: torch.Tensor, contlen: int = None, inplen: int = None |
| ) -> torch.Tensor: |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| assert ( |
| contlen and inplen |
| ), "Must pass input len and cont. len to select scored logits for causal LM" |
| |
| |
| logits = logits[inplen - contlen : inplen] |
| elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
| assert ( |
| contlen and not inplen |
| ), "Selecting scored logits for Seq2SeqLM requires only cont. len" |
| |
| |
| logits = logits[:contlen] |
|
|
| return logits |
|
|
| def loglikelihood_rolling( |
| self, requests: List[Instance], disable_tqdm: bool = False |
| ) -> List[float]: |
| loglikelihoods = [] |
|
|
| adaptive_batch_size = None |
| if self.batch_size == "auto": |
| |
| print("Passed argument batch_size = auto. Detecting largest batch size") |
| batch_size = self._detect_batch_size() |
| print(f"Determined Largest batch size: {batch_size}") |
| adaptive_batch_size = batch_size |
|
|
| for (string,) in tqdm( |
| [req.args for req in requests], disable=(disable_tqdm or (self.rank != 0)) |
| ): |
| rolling_token_windows = list( |
| map( |
| utils.make_disjoint_window, |
| utils.get_rolling_token_windows( |
| token_list=self.tok_encode(string), |
| prefix_token=self.prefix_token_id, |
| max_seq_len=self.max_length, |
| context_len=1, |
| ), |
| ) |
| ) |
|
|
| |
| rolling_token_windows = [(None,) + x for x in rolling_token_windows] |
|
|
| pad_amnt = 0 |
| if self.world_size > 1: |
| |
| mytensor = torch.tensor(len(rolling_token_windows), device=self.device) |
| gathered = ( |
| self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() |
| ) |
|
|
| pad_amnt = max(gathered) - gathered[self.rank] |
| if pad_amnt > 0: |
| rolling_token_windows += pad_amnt * [rolling_token_windows[0]] |
|
|
| string_nll = self._loglikelihood_tokens( |
| requests=rolling_token_windows, |
| disable_tqdm=True, |
| override_bs=adaptive_batch_size, |
| ) |
|
|
| if (self.world_size > 1) and (pad_amnt > 0): |
| string_nll = [x[0] for x in string_nll[:-pad_amnt]] |
| else: |
| |
| string_nll = [x[0] for x in string_nll] |
|
|
| string_nll = sum(string_nll) |
| loglikelihoods.append(string_nll) |
|
|
| return loglikelihoods |
|
|
| def _batch_scheduler(self, pos, n_reordered_requests): |
| sched = pos // int(len(n_reordered_requests) / self.batch_schedule) |
| if sched in self.batch_sizes: |
| return self.batch_sizes[sched] |
| if (len(self.batch_sizes) > 1) and ( |
| self.batch_sizes[sched - 1] == self.max_batch_size |
| ): |
| |
| self.batch_sizes[sched] = self.max_batch_size |
| return self.batch_sizes[sched] |
| print( |
| f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" |
| ) |
| self.batch_sizes[sched] = self._detect_batch_size(n_reordered_requests, pos) |
| print(f"Determined largest batch size: {self.batch_sizes[sched]}") |
| return self.batch_sizes[sched] |
|
|
| def _loglikelihood_tokens( |
| self, |
| requests: List[Tuple[Tuple[str, str], List[int], List[int]]], |
| disable_tqdm: bool = False, |
| override_bs: int = None, |
| ) -> List[Tuple[float, bool]]: |
| |
| res = [] |
|
|
| def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]): |
| """Defines the key for the sorted method""" |
| |
| |
| |
| |
| |
| |
|
|
| toks = req[1] + req[2] |
| return -len(toks), tuple(toks) |
|
|
| def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): |
| """Defines the key to group and lookup one-token continuations""" |
| |
| |
| |
| |
| return req[-2] + req[-1][:-1] |
|
|
| re_ord = Collator( |
| requests, |
| sort_fn=_collate, |
| group_by="contexts" |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
| and self.logits_cache |
| else None, |
| group_fn=_lookup_one_token_cont, |
| ) |
|
|
| |
| |
| n_reordered_requests = len(re_ord) |
| batch_size = ( |
| self.batch_size |
| if self.batch_size != "auto" |
| else override_bs |
| if override_bs is not None |
| else 0 |
| ) |
| batch_fn = ( |
| self._batch_scheduler |
| if self.batch_size == "auto" |
| and n_reordered_requests > 0 |
| and not override_bs |
| else None |
| ) |
|
|
| chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn) |
| pbar = tqdm( |
| total=len(requests), |
| disable=(disable_tqdm or (self.rank != 0)), |
| desc="Running loglikelihood requests", |
| ) |
| for chunk in chunks: |
| inps = [] |
| cont_toks_list = [] |
| inplens = [] |
|
|
| conts = [] |
| encoder_attns = [] |
|
|
| padding_len_inp = None |
| padding_len_cont = None |
| |
| |
| |
|
|
| for _, context_enc, continuation_enc in chunk: |
| |
| assert len(context_enc) > 0 |
| assert len(continuation_enc) > 0 |
| assert len(continuation_enc) <= self.max_length |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| inp = torch.tensor( |
| (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], |
| dtype=torch.long, |
| device=self.device, |
| ) |
| (inplen,) = inp.shape |
| elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
| inp = torch.tensor( |
| (context_enc)[-self.max_length :], |
| dtype=torch.long, |
| device=self.device, |
| ) |
| (inplen,) = inp.shape |
|
|
| |
| encoder_attns.append(torch.ones_like(inp)) |
|
|
| cont = torch.tensor( |
| (continuation_enc)[-self.max_length :], |
| |
| |
| dtype=torch.long, |
| device=self.device, |
| ) |
| (contlen,) = cont.shape |
|
|
| conts.append(cont) |
|
|
| padding_len_cont = ( |
| max(padding_len_cont, contlen) |
| if padding_len_cont is not None |
| else contlen |
| ) |
|
|
| padding_len_inp = ( |
| max(padding_len_inp, inplen) |
| if padding_len_inp is not None |
| else inplen |
| ) |
|
|
| inps.append(inp) |
| cont_toks_list.append(continuation_enc) |
| inplens.append(inplen) |
|
|
| |
| call_kwargs = {} |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| batched_inps = pad_and_concat( |
| padding_len_inp, inps, padding_side="right" |
| ) |
| elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
| |
| batched_inps = pad_and_concat( |
| padding_len_inp, inps |
| ) |
| batched_conts = pad_and_concat( |
| padding_len_cont, conts |
| ) |
| batched_encoder_mask = pad_and_concat( |
| padding_len_inp, encoder_attns |
| ) |
| call_kwargs = { |
| "attn_mask": batched_encoder_mask, |
| "labels": batched_conts, |
| } |
|
|
| multi_logits = F.log_softmax( |
| self._model_call(batched_inps, **call_kwargs), dim=-1 |
| ) |
|
|
| for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( |
| chunk, multi_logits, inplens, cont_toks_list |
| ): |
| |
| contlen = len(cont_toks) |
| |
| |
| |
| |
| ctx_len = ( |
| inplen + (logits.shape[0] - padding_len_inp) |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM |
| else None |
| ) |
| logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) |
| logits = logits.unsqueeze(0) |
|
|
| |
| greedy_tokens = logits.argmax(dim=-1) |
|
|
| |
| |
| |
| |
| |
| for request_str, cont_toks, logits in re_ord.get_cache( |
| req_str=request_str, |
| cxt_toks=ctx_tokens, |
| cont_toks=cont_toks, |
| logits=logits, |
| ): |
| cont_toks = torch.tensor( |
| cont_toks, dtype=torch.long, device=self.device |
| ).unsqueeze(0) |
| max_equal = (greedy_tokens == cont_toks).all() |
|
|
| |
| |
| logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze( |
| -1 |
| ) |
|
|
| |
| answer = (float(logits.sum()), bool(max_equal)) |
|
|
| res.append(answer) |
|
|
| self.cache_hook.add_partial("loglikelihood", request_str, answer) |
| pbar.update(1) |
|
|
| pbar.close() |
|
|
| return re_ord.get_original(res) |
|
|
| def generate_until( |
| self, requests: List[Instance], disable_tqdm: bool = False |
| ) -> List[str]: |
| res = [] |
|
|
| def _collate(req: Tuple[str, dict]): |
| """Defines the key for the sorted method""" |
| |
| |
| |
| |
| |
| |
| toks = self.tok_encode(req[0]) |
| return -len(toks), req[0] |
|
|
| pbar = tqdm( |
| total=len(requests), |
| disable=(disable_tqdm or (self.rank != 0)), |
| desc="Running generate_until requests", |
| ) |
| adaptive_batch_size = None |
| if self.batch_size == "auto": |
| |
| print("Passed argument batch_size = auto. Detecting largest batch size") |
| batch_size = self._detect_batch_size() |
| print(f"Determined Largest batch size: {batch_size}") |
| adaptive_batch_size = batch_size |
| |
| batch_size = ( |
| self.batch_size |
| if self.batch_size != "auto" |
| else adaptive_batch_size |
| if adaptive_batch_size is not None |
| else 0 |
| ) |
| batch_fn = ( |
| self._batch_scheduler |
| if self.batch_size == "auto" and not adaptive_batch_size |
| else None |
| ) |
|
|
| |
| |
| |
| |
| re_ords = Collator( |
| [reg.args for reg in requests], |
| sort_fn=_collate, |
| group_by="gen_kwargs", |
| group_fn=lambda x: x[1], |
| ) |
| chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn) |
| for chunk in chunks: |
| contexts, all_gen_kwargs = zip(*chunk) |
| |
| |
| gen_kwargs = all_gen_kwargs[0] |
| |
| until = None |
| if isinstance(gen_kwargs, dict): |
| kwargs = copy.deepcopy(gen_kwargs) |
| if "until" in kwargs.keys(): |
| until = kwargs.pop("until") |
| if isinstance(until, str): |
| until = [until] |
| elif not isinstance(until, list): |
| raise ValueError( |
| f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" |
| ) |
| else: |
| raise ValueError( |
| f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" |
| ) |
| |
| eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False) |
| if not until: |
| until = [eos] |
| else: |
| until.append(eos) |
| if "max_gen_toks" in kwargs.keys(): |
| max_gen_toks = kwargs.pop("max_gen_toks") |
| else: |
| max_gen_toks = self.max_gen_toks |
|
|
| |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| |
| max_ctx_len = self.max_length - max_gen_toks |
| elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: |
| |
| max_ctx_len = self.max_length |
|
|
| |
| context_enc, attn_masks = self.tok_batch_encode( |
| contexts, |
| left_truncate_len=max_ctx_len, |
| truncation=self.truncation, |
| ) |
| context_enc = context_enc.to(self.device) |
| attn_masks = attn_masks.to(self.device) |
|
|
| if "max_length" not in kwargs: |
| kwargs["max_length"] = context_enc.shape[1] + max_gen_toks |
|
|
| |
| cont = self._model_generate( |
| context=context_enc, |
| attention_mask=attn_masks, |
| stop=until, |
| **kwargs, |
| ) |
|
|
| cont_toks_list = cont.tolist() |
| for cont_toks, context in zip(cont_toks_list, contexts): |
| |
| if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: |
| cont_toks = cont_toks[context_enc.shape[1] :] |
|
|
| s = self.tok_decode(cont_toks) |
|
|
| |
| for term in until: |
| if len(term) > 0: |
| |
| |
| s = s.split(term)[0] |
|
|
| res.append(s) |
|
|
| self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) |
| pbar.update(1) |
| |
| res = re_ords.get_original(res) |
|
|
| pbar.close() |
|
|
| return res |
|
|
| def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: |
| """ |
| Method to apply a chat template to a list of chat history between user and model. |
| """ |
| return self.tokenizer.apply_chat_template( |
| chat_history, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| def get_model_info(self) -> dict: |
| """ |
| Method to get Hugging Face model information for experiment reproducibility. |
| """ |
|
|
| def get_model_num_params(model) -> int: |
| if hasattr(model, "num_parameters"): |
| return model.num_parameters() |
| if hasattr(model, "parameters"): |
| return sum(p.numel() for p in model.parameters()) |
| else: |
| return -1 |
|
|
| def get_model_dtype(model) -> str: |
| if hasattr(model, "dtype"): |
| return model.dtype |
| else: |
| return "" |
|
|
| def get_model_sha(pretrained: str, revision: str) -> str: |
| try: |
| model_info = HfApi().model_info(repo_id=pretrained, revision=revision) |
| return model_info.sha |
| except Exception as e: |
| eval_logger.warn( |
| f"Failed to get model SHA for {pretrained} at revision {revision}. Error: {e}" |
| ) |
| return "" |
|
|
| model_info = { |
| "model_num_parameters": get_model_num_params(self._model), |
| "model_dtype": get_model_dtype(self._model), |
| "model_revision": self.revision, |
| "model_sha": get_model_sha(self.pretrained, self.revision), |
| } |
| if self.peft: |
| model_info["peft_sha"] = get_model_sha(self.peft, self.revision) |
| if self.delta: |
| model_info["delta_sha"] = get_model_sha(self.delta, self.revision) |
| return model_info |
|
|