| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .configuration_bert import FlexBertConfig |
| | from .activation import get_act_fn |
| | from .normalization import get_norm_layer |
| | from .initialization import ModuleType, init_weights |
| |
|
| |
|
| | class BertResidualGLU(nn.Module): |
| | """Applies the FFN at the end of each Mosaic BERT layer. |
| | |
| | Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| | and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality, but |
| | introduces Gated Linear Units. |
| | |
| | Note: Mosaic BERT adds parameters in order to implement Gated Linear Units. To keep parameter count consistent with that of a |
| | standard Hugging Face BERT, scale down `config.intermediate_size` by 2/3. For example, a Mosaic BERT constructed with |
| | `config.intermediate_size=2048` will have the same parameter footprint as its Hugging Face BERT counterpart constructed |
| | with the `config.intermediate_size=3072`. |
| | However, in most cases it will not be necessary to adjust `config.intermediate_size` since, despite the increased |
| | parameter size, Mosaic BERT typically offers a net higher throughput than a Hugging Face BERT built from the same `config`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | config, |
| | ): |
| | super().__init__() |
| | self.config = config |
| | self.gated_layers = nn.Linear(config.hidden_size, config.intermediate_size * 2, bias=False) |
| | self.act = get_act_fn(config.hidden_act) |
| | self.wo = nn.Linear(config.intermediate_size, config.hidden_size) |
| | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| | self.layernorm = get_norm_layer(config) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """Compute new hidden states from current hidden states. |
| | |
| | Args: |
| | hidden_states (torch.Tensor): The (unpadded) hidden states from |
| | the attention layer [nnz, dim]. |
| | """ |
| | residual_connection = hidden_states |
| | |
| | hidden_states = self.gated_layers(hidden_states) |
| | gated = hidden_states[:, : self.config.intermediate_size] |
| | non_gated = hidden_states[:, self.config.intermediate_size :] |
| | hidden_states = self.act(gated) * non_gated |
| | hidden_states = self.dropout(hidden_states) |
| | |
| | hidden_states = self.wo(hidden_states) |
| | |
| | hidden_states = self.layernorm(hidden_states + residual_connection) |
| | return hidden_states |
| |
|
| |
|
| | class FlexBertMLPBase(nn.Module): |
| | """A FlexBERT MLP base class for type hints.""" |
| |
|
| | def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| | super().__init__() |
| | self.config = config |
| | self.layer_id = layer_id |
| |
|
| | def _init_weights(self, reset_params: bool = False): |
| | raise NotImplementedError("This is a base class and should not be used directly.") |
| |
|
| | def reset_parameters(self): |
| | self._init_weights(reset_params=True) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | raise NotImplementedError("This is a base class and should not be used directly.") |
| |
|
| |
|
| | class FlexBertMLP(FlexBertMLPBase): |
| | """Applies the MLP at the end of each FlexBERT layer. |
| | |
| | Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| | and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| | """ |
| |
|
| | def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| | super().__init__(config=config, layer_id=layer_id) |
| | self.Wi = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_in_bias) |
| | self.act = get_act_fn(config.hidden_act) |
| | self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
| | self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
| |
|
| | def _init_weights(self, reset_params: bool = False): |
| | init_weights( |
| | self.config, |
| | self.Wi, |
| | layer_dim=self.config.hidden_size, |
| | layer_id=None, |
| | type_of_module=ModuleType.in_module, |
| | ) |
| | init_weights( |
| | self.config, |
| | self.Wo, |
| | layer_dim=self.config.intermediate_size, |
| | layer_id=self.layer_id, |
| | type_of_module=ModuleType.out_module, |
| | ) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | """Compute new hidden states from current hidden states. |
| | |
| | Args: |
| | hidden_states (torch.Tensor): The (unpadded) hidden states from |
| | the attention layer [nnz, dim]. |
| | """ |
| | return self.Wo(self.drop(self.act(self.Wi(hidden_states)))) |
| |
|
| |
|
| | class FlexBertGLU(FlexBertMLPBase): |
| | """Applies the GLU at the end of each FlexBERT layer. |
| | |
| | Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| | and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| | """ |
| |
|
| | def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| | super().__init__(config=config, layer_id=layer_id) |
| | self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=config.mlp_in_bias) |
| | self.act = get_act_fn(config.hidden_act) |
| | self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
| | self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
| |
|
| | def _init_weights(self, reset_params: bool = False): |
| | init_weights( |
| | self.config, |
| | self.Wi, |
| | layer_dim=self.config.hidden_size, |
| | layer_id=None, |
| | type_of_module=ModuleType.in_module, |
| | ) |
| | init_weights( |
| | self.config, |
| | self.Wo, |
| | layer_dim=self.config.intermediate_size, |
| | layer_id=self.layer_id, |
| | type_of_module=ModuleType.out_module, |
| | ) |
| |
|
| | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| | input, gate = self.Wi(hidden_states).chunk(2, dim=-1) |
| | return self.Wo(self.drop(self.act(input) * gate)) |
| |
|
| |
|
| | class FlexBertParallelGLU(FlexBertMLPBase): |
| | """Applies the GLU at the end of each FlexBERT layer using intermediate_ff computed in parallel of the attention. |
| | |
| | Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate` |
| | and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality. |
| | """ |
| |
|
| | def __init__(self, config: FlexBertConfig, layer_id: Optional[int] = None): |
| | super().__init__(config=config, layer_id=layer_id) |
| | self.act = get_act_fn(config.hidden_act) |
| | self.drop = nn.Dropout(config.mlp_dropout_prob) if config.mlp_dropout_prob > 0.0 else nn.Identity() |
| | self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_out_bias) |
| |
|
| | def _init_weights(self, reset_params: bool = False): |
| | init_weights( |
| | self.config, |
| | self.Wo, |
| | layer_dim=self.config.intermediate_size, |
| | layer_id=self.layer_id, |
| | type_of_module=ModuleType.out_module, |
| | ) |
| |
|
| | def forward(self, intermediate_ff: torch.Tensor) -> torch.Tensor: |
| | input, gate = intermediate_ff.chunk(2, dim=-1) |
| | return self.Wo(self.drop(self.act(input) * gate)) |
| |
|
| |
|
| | MLP2CLS = { |
| | "mlp": FlexBertMLP, |
| | "glu": FlexBertGLU, |
| | "parallel_glu": FlexBertParallelGLU, |
| | } |
| |
|
| |
|
| | def get_mlp_layer(config: FlexBertConfig, layer_id: Optional[int] = None) -> FlexBertMLPBase: |
| | try: |
| | mlp_layer = ( |
| | config.initial_mlp_layer |
| | if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None |
| | else config.mlp_layer |
| | ) |
| | return MLP2CLS[mlp_layer](config, layer_id=layer_id) |
| | except KeyError as e: |
| | if layer_id < config.num_initial_layers and getattr(config, "initial_mlp_layer", None) is not None: |
| | raise ValueError( |
| | f"Invalid MLP layer type: {config.initial_mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}" |
| | ) |
| | else: |
| | raise ValueError(f"Invalid MLP layer type: {config.mlp_layer=}, must be one of {MLP2CLS.keys()}. {e}") |
| |
|