| from typing import Optional, Union |
|
|
| import torch |
|
|
| import lm_eval.models.utils |
| from lm_eval.api.registry import register_model |
| from lm_eval.models.huggingface import HFLM |
|
|
|
|
| @register_model("mamba_ssm") |
| class MambaLMWrapper(HFLM): |
| def __init__( |
| self, |
| pretrained="state-spaces/mamba-130m", |
| |
| is_hf: bool = False, |
| **kwargs, |
| ) -> None: |
| """ |
| Mamba (via the `mamba_ssm` package) supports the following args: |
| ``` |
| d_model: int, |
| n_layer: int, |
| vocab_size: int, |
| initializer_cfg=None, |
| pad_vocab_size_multiple: int = 1, |
| ssm_cfg=None, |
| norm_epsilon: float = 1e-5, |
| rms_norm: bool = False, |
| initializer_cfg=None, |
| fused_add_norm=False, |
| residual_in_fp32=False, |
| ``` |
| |
| See https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L175 for more info. |
| The above can all be passed via `--model_args` or to this __init__() directly |
| but we recommend placing many of these within the config.json file uploaded alongside your |
| Mamba model to the HF Hub instead. |
| All other HuggingFace from_pretrained() kwargs |
| such as those related to |
| `parallelize=True`, PEFT, autoGPTQ, |
| or any sub-configurations of these advanced args, |
| are unsupported by the `mamba_ssm` package. |
| |
| The HFLM arguments |
| |
| `backend`, `tokenizer`, `truncation`, `max_length`, |
| `device`, `dtype`, `batch_size`, `max_batch_size`, `trust_remote_code`, `use_fast_tokenizer` |
| |
| Are all supported by Mamba where they do not conflict |
| with Mamba-specific restrictions such as causal LMs only. |
| """ |
|
|
| if "backend" in kwargs: |
| |
| assert kwargs["backend"] == "causal" |
| self.is_hf = is_hf or (True if pretrained.endswith("hf") else False) |
| super().__init__( |
| pretrained=pretrained, |
| |
| backend=kwargs.pop("backend", "causal"), |
| tokenizer=kwargs.pop("tokenizer", "EleutherAI/gpt-neox-20b"), |
| max_length=kwargs.pop("max_length", 2048), |
| **kwargs, |
| ) |
|
|
| def _get_config( |
| self, |
| pretrained: str, |
| **kwargs, |
| ) -> None: |
| if self.is_hf: |
| super()._get_config(pretrained, **kwargs) |
| else: |
| try: |
| from mamba_ssm.utils.hf import load_config_hf |
| except ModuleNotFoundError as exception: |
| raise type(exception)( |
| "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ |
| please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", |
| ) |
|
|
| self._config = load_config_hf(pretrained) |
|
|
| def _create_model( |
| self, |
| pretrained: str, |
| dtype: Optional[Union[str, torch.dtype]] = "float16", |
| |
| |
| |
| **kwargs, |
| ) -> None: |
| if self.is_hf: |
| super()._create_model(pretrained, dtype=dtype, **kwargs) |
| else: |
| try: |
| from mamba_ssm.models.mixer_seq_simple import ( |
| MambaLMHeadModel, |
| ) |
| except ModuleNotFoundError as exception: |
| raise type(exception)( |
| "attempted to use 'mamba_ssm' LM type, but package `mamba_ssm` is not installed. \ |
| please install mamba via `pip install lm-eval[mamba]` or `pip install -e .[mamba]`", |
| ) |
|
|
| self._model = MambaLMHeadModel.from_pretrained( |
| pretrained, |
| device=self._device, |
| dtype=torch.float16 |
| if dtype == "auto" |
| else lm_eval.models.utils.get_dtype(dtype), |
| ) |
|
|
| def _model_generate(self, context, max_length, stop, **generation_kwargs): |
| remove_arg = ( |
| ["attention_mask"] if self.is_hf else ["do_sample", "attention_mask"] |
| ) |
| for key in remove_arg: |
| if key in generation_kwargs: |
| generation_kwargs.pop(key) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if not self.is_hf: |
| return self.model.generate( |
| input_ids=context, |
| max_length=max_length, |
| |
| |
| |
| **generation_kwargs, |
| ) |
| else: |
| stopping_criteria = lm_eval.models.utils.stop_sequences_criteria( |
| self.tokenizer, |
| stop, |
| context.shape[1], |
| context.shape[0], |
| ) |
|
|
| generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) |
| do_sample = generation_kwargs.get("do_sample", None) |
|
|
| |
| if generation_kwargs.get("temperature") == 0.0 and do_sample is None: |
| generation_kwargs["do_sample"] = do_sample = False |
| if do_sample is False and generation_kwargs.get("temperature") == 0.0: |
| generation_kwargs.pop("temperature") |
|
|
| return self.model.generate( |
| input_ids=context, |
| max_length=max_length, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=self.tokenizer.pad_token_id, |
| use_cache=True, |
| **generation_kwargs, |
| ) |
|
|