| | |
| | |
| |
|
| | import warnings |
| |
|
| | from transformers import BertConfig as TransformersBertConfig |
| |
|
| |
|
| | class BertConfig(TransformersBertConfig): |
| | def __init__( |
| | self, |
| | alibi_starting_size: int = 512, |
| | normalization: str = "layernorm", |
| | attention_probs_dropout_prob: float = 0.0, |
| | head_pred_act: str = "gelu", |
| | deterministic_fa2: bool = False, |
| | allow_embedding_resizing: bool = False, |
| | **kwargs, |
| | ): |
| | """Configuration class for MosaicBert. |
| | |
| | Args: |
| | alibi_starting_size (int): Use `alibi_starting_size` to determine how large of an alibi tensor to |
| | create when initializing the model. You should be able to ignore this parameter in most cases. |
| | Defaults to 512. |
| | attention_probs_dropout_prob (float): By default, turn off attention dropout in MosaicBERT |
| | Note that the custom Triton Flash Attention with ALiBi implementation does not support droput. |
| | However, Flash Attention 2 supports ALiBi and dropout https://github.com/Dao-AILab/flash-attention |
| | embed_dropout_prob (float): Dropout probability for the embedding layer. |
| | attn_out_dropout_prob (float): Dropout probability for the attention output layer. |
| | mlp_dropout_prob (float): Dropout probability for the MLP layer. |
| | allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size. |
| | """ |
| | super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) |
| | self.alibi_starting_size = alibi_starting_size |
| | self.normalization = normalization |
| | self.head_pred_act = head_pred_act |
| | self.deterministic_fa2 = deterministic_fa2 |
| | self.allow_embedding_resizing = allow_embedding_resizing |
| |
|
| |
|
| | class FlexBertConfig(TransformersBertConfig): |
| | model_type = "flex_bert" |
| | def __init__( |
| | self, |
| | attention_layer: str = "base", |
| | attention_probs_dropout_prob: float = 0.0, |
| | attn_out_bias: bool = False, |
| | attn_out_dropout_prob: float = 0.0, |
| | attn_qkv_bias: bool = False, |
| | bert_layer: str = "prenorm", |
| | decoder_bias: bool = True, |
| | embed_dropout_prob: float = 0.0, |
| | embed_norm: bool = True, |
| | final_norm: bool = False, |
| | embedding_layer: str = "absolute_pos", |
| | encoder_layer: str = "base", |
| | loss_function: str = "cross_entropy", |
| | loss_kwargs: dict = {}, |
| | mlp_dropout_prob: float = 0.0, |
| | mlp_in_bias: bool = False, |
| | mlp_layer: str = "mlp", |
| | mlp_out_bias: bool = False, |
| | norm_kwargs: dict = {}, |
| | normalization: str = "rmsnorm", |
| | padding: str = "unpadded", |
| | head_class_act: str = "silu", |
| | head_class_bias: bool = False, |
| | head_class_dropout: float = 0.0, |
| | head_class_norm: str = False, |
| | head_pred_act: str = "silu", |
| | head_pred_bias: bool = False, |
| | head_pred_dropout: float = 0.0, |
| | head_pred_norm: bool = True, |
| | pooling_type: str = "cls", |
| | rotary_emb_dim: int | None = None, |
| | rotary_emb_base: float = 10000.0, |
| | rotary_emb_scale_base=None, |
| | rotary_emb_interleaved: bool = False, |
| | use_fa2: bool = True, |
| | use_sdpa_attn_mask: bool = False, |
| | allow_embedding_resizing: bool = False, |
| | init_method: str = "default", |
| | init_std: float = 0.02, |
| | init_cutoff_factor: float = 2.0, |
| | init_small_embedding: bool = False, |
| | initial_attention_layer: str | None = None, |
| | initial_bert_layer: str | None = None, |
| | initial_mlp_layer: str | None = None, |
| | num_initial_layers: int = 1, |
| | skip_first_prenorm: bool = False, |
| | deterministic_fa2: bool = False, |
| | sliding_window: int = -1, |
| | global_attn_every_n_layers: int = -1, |
| | local_attn_rotary_emb_base: float = -1, |
| | local_attn_rotary_emb_dim: int | None = None, |
| | unpad_embeddings: bool = False, |
| | pad_logits: bool = False, |
| | compile_model: bool = False, |
| | masked_prediction: bool = False, |
| | causal_mask: bool = False, |
| | **kwargs, |
| | ): |
| | """ |
| | Args: |
| | attention_layer (str): Attention layer type. |
| | attention_probs_dropout_prob (float): Dropout probability for attention probabilities. |
| | attn_out_bias (bool): use bias in attention output projection. |
| | attn_out_dropout_prob (float): Dropout probability for attention output. |
| | attn_qkv_bias (bool): use bias for query, key, value linear layer(s). |
| | bert_layer (str): BERT layer type. |
| | decoder_bias (bool): use bias in decoder linear layer. |
| | embed_dropout_prob (float): Dropout probability for embeddings. |
| | embed_norm (bool): Normalize embedding output. |
| | final_norm (bool): Add normalization after the final encoder layer and before head. |
| | embedding_layer (str): Embedding layer type. |
| | encoder_layer (str): Encoder layer type. |
| | loss_function (str): Loss function to use. |
| | loss_kwargs (dict): Keyword arguments for loss function. |
| | mlp_dropout_prob (float): Dropout probability for MLP layers. |
| | mlp_in_bias (bool): Use bias in MLP input linear layer. |
| | mlp_layer (str): MLP layer type. |
| | mlp_out_bias (bool): Use bias in MLP output linear layer. |
| | norm_kwargs (dict): Keyword arguments for normalization layers. |
| | normalization (str): Normalization type. |
| | padding (str): Unpad inputs. Best with `use_fa2=True`. |
| | head_class_act (str): Activation function for classification head. |
| | head_class_bias (bool): Use bias in classification head linear layer(s). |
| | head_class_dropout (float): Dropout probability for classification head. |
| | head_class_norm (str): Normalization type for classification head. |
| | head_pred_act (str): Activation function for prediction head. |
| | head_pred_bias (bool): Use bias in prediction head linear layer(s). |
| | head_pred_dropout (float): Dropout probability for prediction head. |
| | head_pred_norm (bool): Normalize prediction head output. |
| | pooling_type (str): Pooling type. |
| | rotary_emb_dim (int | None): Rotary embedding dimension. |
| | rotary_emb_base (float): Rotary embedding base. |
| | rotary_emb_scale_base (float): Rotary embedding scale base. |
| | rotary_emb_interleaved (bool): Use interleaved rotary embeddings. |
| | use_fa2 (bool): Use FlashAttention2. Requires flash_attn package. |
| | use_sdpa_attn_mask (bool): Pass a mask to SDPA. This will prevent SDPA from using the PyTorch FA2 kernel. |
| | allow_embedding_resizing (bool): Embeddings will be automatically resized when they are smaller than the tokenizer vocab size. |
| | init_method (str): Model layers initialization method. |
| | init_std (float): Standard deviation for initialization. Used for normal and full_megatron init. |
| | init_cutoff_factor (float): Cutoff factor for initialization. Used for normal and full_megatron init. |
| | init_small_embedding (bool): Initialize embeddings with RWKV small init. |
| | initial_attention_layer (str | None): Replace first `num_initial_layers` attention_layer instance with this layer. |
| | initial_bert_layer (str | None): Replace first `num_initial_layers` bert_layer instance with this layer. |
| | initial_mlp_layer (str | None): Replace first `num_initial_layers` mlp_layer instance with this layer. |
| | num_initial_layers (int): Number of initial layers to set via `initial_attention_layer`, `initial_bert_layer`, and `initial_mlp_layer`. |
| | skip_first_prenorm (bool): Skip pre-normalization for the first bert layer. Requires `embed_norm=True`. |
| | deterministic_fa2 (bool): Use Flash Attention 2 deterministic mode. This is slower then the default non-deterministic mode. |
| | sliding_window (int): Use sliding window attention with window size `n`. -1 to disable. Window size split between the left and right context. Only supports FA2. |
| | global_attn_every_n_layers (int): Use global attention every `n` layers and sliding window for the rest. -1 to disable. |
| | local_attn_rotary_emb_base (float): Rotary embedding base for local attention. -1 to disable and use `rotary_emb_base` for all layers. |
| | local_attn_rotary_emb_dim (int | None): Rotary embedding dimension for local attention. None to disable and use `rotary_emb_dim` for all layers. |
| | unpad_embeddings (bool): Unpad inputs before the embedding layer. |
| | pad_logits (bool): Pad logits after the calculating the loss. |
| | compile_model (bool): Compile the subset of the model which can be compiled. |
| | masked_prediction (bool): Use only pass the masked tokens throught the final MLM layers |
| | causal (bool): Use a causal mask, defaulting to false. |
| | **kwargs: Additional keyword arguments. |
| | """ |
| | super().__init__(attention_probs_dropout_prob=attention_probs_dropout_prob, **kwargs) |
| | self.attention_layer = attention_layer |
| | self.attn_out_bias = attn_out_bias |
| | self.attn_out_dropout_prob = attn_out_dropout_prob |
| | self.attn_qkv_bias = attn_qkv_bias |
| | self.bert_layer = bert_layer |
| | self.decoder_bias = decoder_bias |
| | self.embed_dropout_prob = embed_dropout_prob |
| | self.embed_norm = embed_norm |
| | self.final_norm = final_norm |
| | self.embedding_layer = embedding_layer |
| | self.encoder_layer = encoder_layer |
| | self.loss_function = loss_function |
| | self.loss_kwargs = loss_kwargs |
| | self.mlp_dropout_prob = mlp_dropout_prob |
| | self.mlp_in_bias = mlp_in_bias |
| | self.mlp_layer = mlp_layer |
| | self.mlp_out_bias = mlp_out_bias |
| | self.norm_kwargs = norm_kwargs |
| | self.normalization = normalization |
| | self.padding = padding |
| | self.head_class_act = head_class_act |
| | self.head_class_bias = head_class_bias |
| | self.head_class_dropout = head_class_dropout |
| | self.head_class_norm = head_class_norm |
| | self.head_pred_act = head_pred_act |
| | self.head_pred_bias = head_pred_bias |
| | self.head_pred_dropout = head_pred_dropout |
| | self.head_pred_norm = head_pred_norm |
| | self.pooling_type = pooling_type |
| | self.rotary_emb_dim = rotary_emb_dim |
| | self.rotary_emb_base = rotary_emb_base |
| | self.rotary_emb_scale_base = rotary_emb_scale_base |
| | self.rotary_emb_interleaved = rotary_emb_interleaved |
| | self.use_fa2 = use_fa2 |
| | self.use_sdpa_attn_mask = use_sdpa_attn_mask |
| | self.allow_embedding_resizing = allow_embedding_resizing |
| | self.init_method = init_method |
| | self.init_std = init_std |
| | self.init_cutoff_factor = init_cutoff_factor |
| | self.init_small_embedding = init_small_embedding |
| | self.initial_attention_layer = initial_attention_layer |
| | self.initial_bert_layer = initial_bert_layer |
| | self.initial_mlp_layer = initial_mlp_layer |
| | self.num_initial_layers = num_initial_layers |
| | self.skip_first_prenorm = skip_first_prenorm |
| | self.deterministic_fa2 = deterministic_fa2 |
| | self.sliding_window = sliding_window |
| | self.global_attn_every_n_layers = global_attn_every_n_layers |
| | self.local_attn_rotary_emb_base = local_attn_rotary_emb_base |
| | self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim |
| | self.unpad_embeddings = unpad_embeddings |
| | self.pad_logits = pad_logits |
| | self.compile_model = compile_model |
| | self.masked_prediction = masked_prediction |
| | self.causal_mask = causal_mask |
| |
|
| | if loss_kwargs.get("return_z_loss", False): |
| | if loss_function != "fa_cross_entropy": |
| | raise ValueError("loss_function must be 'fa_cross_entropy' when return_z_loss is True") |
| | if loss_kwargs.get("lse_square_scale", 0) <= 0: |
| | raise ValueError( |
| | "lse_square_scale must be passed to `loss_kwargs` and must be greater than 0 for z_loss" |
| | ) |
| | if loss_kwargs.get("inplace_backward", False): |
| | self.loss_kwargs["inplace_backward"] = False |
| | warnings.warn("`inplace_backward=True` will cause incorrect metrics. Automatically setting to False.") |
| |
|
| | if global_attn_every_n_layers > 0 and (self.num_hidden_layers - 1) % global_attn_every_n_layers != 0: |
| | raise ValueError( |
| | f"{global_attn_every_n_layers=} must be a divisor of one less than {self.num_hidden_layers=}" |
| | ) |
| |
|
| | if self.sliding_window != -1: |
| | if not self.use_fa2: |
| | raise ValueError("Sliding window attention is only supported with FlashAttention2") |
| | if self.sliding_window % 2 != 0 and self.sliding_window % 64 != 0: |
| | raise ValueError( |
| | f"Sliding window must be an even number and divisible by 64: {self.sliding_window=} {self.sliding_window % 64} {self.sliding_window % 2}" |
| | ) |
| | else: |
| | if self.global_attn_every_n_layers != -1: |
| | raise ValueError("global_attn_every_n_layers must be -1 when sliding_window is disabled") |
| | if self.local_attn_rotary_emb_base != -1: |
| | raise ValueError("local_attn_rotary_emb_base must be -1 when sliding_window is disabled") |
| | if self.local_attn_rotary_emb_dim is not None: |
| | raise ValueError("local_attn_rotary_emb_dim must be None when sliding_window is disabled") |
| |
|
| | if self.unpad_embeddings and self.padding != "unpadded": |
| | warnings.warn( |
| | "`unpad_embeddings=True` requires `padding='unpadded'`. Automatically setting `padding='unpadded'`." |
| | ) |
| | self.padding = "unpadded" |
| | if self.pad_logits and not self.unpad_embeddings: |
| | raise ValueError("`pad_logits=True` requires `unpad_embeddings=True`") |
| | if self.unpad_embeddings and self.embedding_layer == "absolute_pos": |
| | raise ValueError(f"{self.unpad_embeddings=} is incompatible with {self.embedding_layer=}") |
| |
|
| |
|
| | PADDING = ["unpadded", "padded"] |
| |
|
| |
|
| | def maybe_add_padding(config: FlexBertConfig, config_option: str) -> str: |
| | if config.padding not in PADDING: |
| | raise ValueError(f"Invalid padding type: {config.padding}, must be one of {PADDING}") |
| |
|
| | if not any(config_option.startswith(pad + "_") for pad in PADDING): |
| | config_option = f"{config.padding}_{config_option}" |
| |
|
| | return config_option |
| |
|