| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM |
| |
|
| | from .modeling_base import PreTrainedModelWrapper |
| |
|
| |
|
| | class ValueHead(nn.Module): |
| | r""" |
| | The ValueHead class implements a head for GPT2 that returns a scalar for each output token. |
| | """ |
| |
|
| | def __init__(self, config, **kwargs): |
| | super().__init__() |
| | if not hasattr(config, "summary_dropout_prob"): |
| | summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) |
| | else: |
| | summary_dropout_prob = config.summary_dropout_prob |
| |
|
| | self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() |
| |
|
| | |
| | if hasattr(config, "hidden_size"): |
| | hidden_size = config.hidden_size |
| | if hasattr(config, "word_embed_proj_dim"): |
| | hidden_size = config.word_embed_proj_dim |
| | elif hasattr(config, "is_encoder_decoder"): |
| | if config.is_encoder_decoder and hasattr(config, "decoder"): |
| | if hasattr(config.decoder, "hidden_size"): |
| | hidden_size = config.decoder.hidden_size |
| |
|
| | self.summary = nn.Linear(hidden_size, 1) |
| |
|
| | self.flatten = nn.Flatten() |
| |
|
| | def forward(self, hidden_states): |
| | output = self.dropout(hidden_states) |
| |
|
| | |
| | |
| | if output.dtype != self.summary.weight.dtype: |
| | output = output.to(self.summary.weight.dtype) |
| |
|
| | output = self.summary(output) |
| | return output |
| |
|
| |
|
| | class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): |
| | r""" |
| | An autoregressive model with a value head in addition to the language model head. |
| | This class inherits from `~trl.PreTrainedModelWrapper` and wraps a |
| | `transformers.PreTrainedModel` class. The wrapper class supports classic functions |
| | such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped |
| | model, simply manipulate the `pretrained_model` attribute of this class. |
| | |
| | Class attributes: |
| | - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This |
| | should be set to `transformers.AutoModelForCausalLM` for this class. |
| | - **lm_head_namings** (`tuple`) -- A tuple of strings that are used to identify the language model head of the |
| | wrapped model. This is set to `("lm_head", "embed_out")` for this class but can be changed for other models |
| | in the future |
| | - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported |
| | by the `ValueHead` class. Currently, the supported args are: |
| | - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the |
| | `ValueHead` class. |
| | - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the |
| | `ValueHead` if a specific initialization strategy is selected. |
| | - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the |
| | `ValueHead`. Currently, the supported strategies are: |
| | - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default |
| | strategy. |
| | - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. |
| | |
| | """ |
| |
|
| | transformers_parent_class = AutoModelForCausalLM |
| | lm_head_namings = ["lm_head", "embed_out"] |
| | supported_args = ( |
| | "summary_dropout_prob", |
| | "v_head_initializer_range", |
| | "v_head_init_strategy", |
| | ) |
| |
|
| | def __init__(self, pretrained_model, **kwargs): |
| | r""" |
| | Initializes the model. |
| | |
| | Args: |
| | pretrained_model (`transformers.PreTrainedModel`): |
| | The model to wrap. It should be a causal language model such as GPT2. |
| | or any model mapped inside the `AutoModelForCausalLM` class. |
| | kwargs (`dict`, `optional`): |
| | Additional keyword arguments, that are passed to the `ValueHead` class. |
| | """ |
| | super().__init__(pretrained_model, **kwargs) |
| | v_head_kwargs, _, _ = self._split_kwargs(kwargs) |
| |
|
| | if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings): |
| | raise ValueError("The model does not have a language model head, please use a model that has one.") |
| |
|
| | self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) |
| |
|
| | self._init_weights(**v_head_kwargs) |
| |
|
| | def _init_weights(self, **kwargs): |
| | r""" |
| | Initializes the weights of the value head. The default initialization strategy is random. |
| | Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument |
| | when calling `.from_pretrained`. Supported strategies are: |
| | - `normal`: initializes the weights with a normal distribution. |
| | |
| | Args: |
| | **kwargs (`dict`, `optional`): |
| | Additional keyword arguments, that are passed to the `ValueHead` class. These arguments |
| | can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` |
| | argument. |
| | """ |
| | initializer_range = kwargs.pop("v_head_initializer_range", 0.2) |
| | |
| | init_strategy = kwargs.pop("v_head_init_strategy", None) |
| | if init_strategy is None: |
| | |
| | pass |
| | elif init_strategy == "normal": |
| | self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) |
| | self.v_head.summary.bias.data.zero_() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | past_key_values=None, |
| | attention_mask=None, |
| | **kwargs, |
| | ): |
| | r""" |
| | Applies a forward pass to the wrapped model and returns the logits of the value head. |
| | |
| | Args: |
| | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| | Indices of input sequence tokens in the vocabulary. |
| | past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): |
| | Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model |
| | (see `past_key_values` input) to speed up sequential decoding. |
| | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | kwargs (`dict`, `optional`): |
| | Additional keyword arguments, that are passed to the wrapped model. |
| | """ |
| | kwargs["output_hidden_states"] = True |
| | kwargs["past_key_values"] = past_key_values |
| |
|
| | if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": |
| | kwargs.pop("past_key_values") |
| |
|
| | base_model_output = self.pretrained_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | **kwargs, |
| | ) |
| |
|
| | last_hidden_state = base_model_output.hidden_states[-1] |
| | lm_logits = base_model_output.logits |
| | loss = base_model_output.loss |
| |
|
| | if last_hidden_state.device != self.v_head.summary.weight.device: |
| | last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) |
| |
|
| | value = self.v_head(last_hidden_state).squeeze(-1) |
| |
|
| | |
| | if lm_logits.dtype != torch.float32: |
| | lm_logits = lm_logits.float() |
| |
|
| | return (lm_logits, loss, value) |
| |
|
| | def generate(self, *args, **kwargs): |
| | r""" |
| | A simple wrapper around the `generate` method of the wrapped model. |
| | Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) |
| | method of the wrapped model for more information about the supported arguments. |
| | |
| | Args: |
| | *args (`list`, *optional*): |
| | Positional arguments passed to the `generate` method of the wrapped model. |
| | **kwargs (`dict`, *optional*): |
| | Keyword arguments passed to the `generate` method of the wrapped model. |
| | """ |
| | return self.pretrained_model.generate(*args, **kwargs) |
| |
|
| | def state_dict(self, *args, **kwargs): |
| | r""" |
| | Returns the state dictionary of the model. We add the state dictionary of the value head |
| | to the state dictionary of the wrapped model by prepending the key with `v_head.`. |
| | """ |
| | if not self.is_peft_model: |
| | pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) |
| | else: |
| | |
| | pretrained_model_state_dict = {} |
| |
|
| | v_head_state_dict = self.v_head.state_dict(*args, **kwargs) |
| | for k, v in v_head_state_dict.items(): |
| | pretrained_model_state_dict[f"v_head.{k}"] = v |
| | return pretrained_model_state_dict |
| |
|
| | def push_to_hub(self, *args, **kwargs): |
| | setattr(self.pretrained_model, "v_head", self.v_head) |
| |
|
| | return self.pretrained_model.push_to_hub(*args, **kwargs) |
| |
|
| | def post_init(self, state_dict): |
| | r""" |
| | We add the state dictionary of the value head to the state dictionary of the wrapped model |
| | by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the |
| | keys of the value head state dictionary. |
| | """ |
| | for k in list(state_dict.keys()): |
| | if "v_head." in k: |
| | state_dict[k.replace("v_head.", "")] = state_dict.pop(k) |
| | self.v_head.load_state_dict(state_dict, strict=False) |
| | del state_dict |
| |
|
| | if hasattr(self.pretrained_model, "hf_device_map"): |
| | if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values(): |
| | raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.") |
| |
|
| | first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] |
| |
|
| | self.v_head = self.v_head.to(first_device) |
| |
|
| | def set_device_hook(module, input, outputs): |
| | new_output = () |
| | for output in outputs: |
| | if isinstance(output, torch.Tensor): |
| | new_output += (output.to(first_device),) |
| | else: |
| | new_output += (output,) |
| | return new_output |
| |
|
| | self.register_forward_hook(set_device_hook) |
| |
|
| | self.is_sequential_parallel = True |
| |
|
| |
|
| | class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): |
| | r""" |
| | A seq2seq model with a value head in addition to the language model head. |
| | This class inherits from `~trl.PreTrainedModelWrapper` and wraps a |
| | `transformers.PreTrainedModel` class. The wrapper class supports classic functions |
| | such as `from_pretrained` and `push_to_hub` and also provides some additional |
| | functionalities such as `generate`. |
| | |
| | Args: |
| | pretrained_model (`transformers.PreTrainedModel`): |
| | The model to wrap. It should be a causal language model such as GPT2. |
| | or any model mapped inside the `AutoModelForSeq2SeqLM` class. |
| | kwargs: |
| | Additional keyword arguments passed along to the `ValueHead` class. |
| | """ |
| |
|
| | transformers_parent_class = AutoModelForSeq2SeqLM |
| | lm_head_namings = ["lm_head", "embed_out", "output_projection"] |
| | supported_args = ( |
| | "summary_dropout_prob", |
| | "v_head_initializer_range", |
| | "v_head_init_strategy", |
| | ) |
| |
|
| | def __init__(self, pretrained_model, **kwargs): |
| | super().__init__(pretrained_model, **kwargs) |
| | v_head_kwargs, _, _ = self._split_kwargs(kwargs) |
| | self.is_encoder_decoder = True |
| |
|
| | if not self._has_lm_head(): |
| | raise ValueError("The model does not have a language model head, please use a model that has one.") |
| |
|
| | self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) |
| |
|
| | self._init_weights(**v_head_kwargs) |
| |
|
| | def _has_lm_head(self): |
| | |
| | for name, module in self.pretrained_model.named_modules(): |
| | if any(attribute in name for attribute in self.lm_head_namings): |
| | return True |
| | return False |
| |
|
| | def post_init(self, state_dict): |
| | r""" |
| | We add the state dictionary of the value head to the state dictionary of the wrapped model |
| | by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the |
| | keys of the value head state dictionary. |
| | """ |
| | for k in list(state_dict.keys()): |
| | if "v_head." in k: |
| | state_dict[k.replace("v_head.", "")] = state_dict.pop(k) |
| | self.v_head.load_state_dict(state_dict, strict=False) |
| | del state_dict |
| |
|
| | if hasattr(self.pretrained_model, "hf_device_map"): |
| | if "cpu" in self.pretrained_model.hf_device_map.values() or "disk" in self.pretrained_model.hf_device_map.values(): |
| | raise ValueError("The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models.") |
| |
|
| | |
| | for name, module in self.pretrained_model.named_modules(): |
| | if any(attribute in name for attribute in self.lm_head_namings): |
| | lm_head_device = module.weight.device |
| | break |
| |
|
| | |
| | self.v_head = self.v_head.to(lm_head_device) |
| |
|
| | def set_device_hook(module, input, outputs): |
| | r""" |
| | A hook that sets the device of the output of the model to the device of the first |
| | parameter of the model. |
| | |
| | Args: |
| | module (`nn.Module`): |
| | The module to which the hook is attached. |
| | input (`tuple`): |
| | The input to the module. |
| | outputs (`tuple`): |
| | The output of the module. |
| | """ |
| | new_output = () |
| | for output in outputs: |
| | if isinstance(output, torch.Tensor): |
| | new_output += (output.to(lm_head_device),) |
| | else: |
| | new_output += (output,) |
| | return new_output |
| |
|
| | self.register_forward_hook(set_device_hook) |
| | self.is_sequential_parallel = True |
| |
|
| | def state_dict(self, *args, **kwargs): |
| | r""" |
| | Returns the state dictionary of the model. We add the state dictionary of the value head |
| | to the state dictionary of the wrapped model by prepending the key with `v_head.`. |
| | """ |
| | if not self.is_peft_model: |
| | pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) |
| | else: |
| | |
| | pretrained_model_state_dict = {} |
| |
|
| | v_head_state_dict = self.v_head.state_dict(*args, **kwargs) |
| | for k, v in v_head_state_dict.items(): |
| | pretrained_model_state_dict[f"v_head.{k}"] = v |
| | return pretrained_model_state_dict |
| |
|
| | def push_to_hub(self, *args, **kwargs): |
| | setattr(self.pretrained_model, "v_head", self.v_head) |
| |
|
| | return self.pretrained_model.push_to_hub(*args, **kwargs) |
| |
|
| | def _init_weights(self, **kwargs): |
| | r""" |
| | We initialize the weights of the value head. |
| | """ |
| | initializer_range = kwargs.pop("v_head_initializer_range", 0.2) |
| | |
| | init_strategy = kwargs.pop("v_head_init_strategy", None) |
| | if init_strategy is None: |
| | |
| | pass |
| | elif init_strategy == "normal": |
| | self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) |
| | self.v_head.summary.bias.data.zero_() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | past_key_values=None, |
| | attention_mask=None, |
| | **kwargs, |
| | ): |
| | kwargs["past_key_values"] = past_key_values |
| | if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": |
| | kwargs.pop("past_key_values") |
| |
|
| | base_model_output = self.pretrained_model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | **kwargs, |
| | ) |
| |
|
| | last_hidden_state = base_model_output.decoder_hidden_states[-1] |
| | lm_logits = base_model_output.logits |
| | loss = base_model_output.loss |
| |
|
| | value = self.v_head(last_hidden_state).squeeze(-1) |
| |
|
| | |
| | if lm_logits.dtype != torch.float32: |
| | lm_logits = lm_logits.float() |
| |
|
| | return (lm_logits, loss, value) |
| |
|
| | def generate(self, *args, **kwargs): |
| | r""" |
| | We call `generate` on the wrapped model. |
| | """ |
| | return self.pretrained_model.generate(*args, **kwargs) |
| |
|