Upload modeling_BiMambaForMaskedLM.py
Browse files- modeling_BiMambaForMaskedLM.py +152 -0
modeling_BiMambaForMaskedLM.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from transformers import PreTrainedModel, AutoConfig
|
| 5 |
+
from transformers.modeling_outputs import MaskedLMOutput
|
| 6 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
| 7 |
+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
|
| 8 |
+
from mamba_ssm.models.config_mamba import MambaConfig
|
| 9 |
+
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
| 10 |
+
try:
|
| 11 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 12 |
+
except ImportError:
|
| 13 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
|
| 16 |
+
def convert_hf_config_to_mamba(hf_config) -> MambaConfig:
|
| 17 |
+
return MambaConfig(
|
| 18 |
+
d_model=hf_config.d_model,
|
| 19 |
+
d_intermediate=getattr(hf_config, "intermediate_size", 4 * hf_config.d_model),
|
| 20 |
+
n_layer=getattr(hf_config, "n_layer", getattr(hf_config, "num_hidden_layers", 12)),
|
| 21 |
+
vocab_size=hf_config.vocab_size,
|
| 22 |
+
ssm_cfg=getattr(hf_config, "ssm_cfg", {}),
|
| 23 |
+
attn_layer_idx=getattr(hf_config, "attn_layer_idx", []),
|
| 24 |
+
attn_cfg=getattr(hf_config, "attn_cfg", {}),
|
| 25 |
+
rms_norm=getattr(hf_config, "rms_norm", True),
|
| 26 |
+
residual_in_fp32=getattr(hf_config, "residual_in_fp32", True),
|
| 27 |
+
fused_add_norm=getattr(hf_config, "fused_add_norm", False),
|
| 28 |
+
pad_vocab_size_multiple=getattr(hf_config, "pad_vocab_size_multiple", 8),
|
| 29 |
+
tie_embeddings=getattr(hf_config, "tie_embeddings", False),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def patch_mixer_forward_to_accept_embeddings(model):
|
| 33 |
+
"""
|
| 34 |
+
Injects a new forward method into a MixerModel instance,
|
| 35 |
+
allowing it to accept either input_ids or inputs_embeds.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def new_forward(self, input_ids=None, inputs_embeds=None, inference_params=None, attention_mask=None, **mixer_kwargs):
|
| 39 |
+
if inputs_embeds is not None:
|
| 40 |
+
hidden_states = inputs_embeds
|
| 41 |
+
elif input_ids is not None:
|
| 42 |
+
hidden_states = self.embedding(input_ids)
|
| 43 |
+
else:
|
| 44 |
+
raise ValueError("You must provide either input_ids or inputs_embeds.")
|
| 45 |
+
|
| 46 |
+
residual = None
|
| 47 |
+
|
| 48 |
+
# hiddens: (batch_size, seq_len, d_model)
|
| 49 |
+
# attention_mask: (batch_size, seq_len) -- 1 for real tokens, 0 for padding
|
| 50 |
+
mask = attention_mask.unsqueeze(-1) # (batch_size, seq_len, 1)
|
| 51 |
+
|
| 52 |
+
for layer in self.layers:
|
| 53 |
+
hidden_states, residual = layer(
|
| 54 |
+
hidden_states, residual, inference_params=inference_params, **mixer_kwargs
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Add attention mask
|
| 58 |
+
hidden_states = hidden_states * mask
|
| 59 |
+
residual = residual * mask
|
| 60 |
+
|
| 61 |
+
if not self.fused_add_norm:
|
| 62 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
| 63 |
+
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
|
| 64 |
+
else:
|
| 65 |
+
# Set prenorm=False here since we don't need the residual
|
| 66 |
+
hidden_states = layer_norm_fn(
|
| 67 |
+
hidden_states,
|
| 68 |
+
self.norm_f.weight,
|
| 69 |
+
self.norm_f.bias,
|
| 70 |
+
eps=self.norm_f.eps,
|
| 71 |
+
residual=residual,
|
| 72 |
+
prenorm=False,
|
| 73 |
+
residual_in_fp32=self.residual_in_fp32,
|
| 74 |
+
is_rms_norm=isinstance(self.norm_f, RMSNorm)
|
| 75 |
+
)
|
| 76 |
+
return hidden_states
|
| 77 |
+
|
| 78 |
+
# Bind the new forward method to the instance
|
| 79 |
+
model.backbone.forward = new_forward.__get__(model.backbone, model.backbone.__class__)
|
| 80 |
+
|
| 81 |
+
class BiMambaForMaskedLM(PreTrainedModel):
|
| 82 |
+
config_class = AutoConfig
|
| 83 |
+
base_model_prefix = "bimamba"
|
| 84 |
+
|
| 85 |
+
def __init__(self, config):
|
| 86 |
+
super().__init__(config) # <-- HF init
|
| 87 |
+
mamba_cfg = convert_hf_config_to_mamba(config)
|
| 88 |
+
|
| 89 |
+
# your embedding + two Mamba directions + proj
|
| 90 |
+
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
|
| 91 |
+
self.mamba_forward = MambaLMHeadModel(mamba_cfg)
|
| 92 |
+
self.mamba_backward = MambaLMHeadModel(mamba_cfg)
|
| 93 |
+
self.lm_head_proj = nn.Linear(config.d_model * 2, config.d_model, bias=False)
|
| 94 |
+
|
| 95 |
+
# Patch mixer_forward_to accept embeddings
|
| 96 |
+
patch_mixer_forward_to_accept_embeddings(self.mamba_forward)
|
| 97 |
+
patch_mixer_forward_to_accept_embeddings(self.mamba_backward)
|
| 98 |
+
|
| 99 |
+
# self.post_init() # wires up HF weight-tying & save/load
|
| 100 |
+
|
| 101 |
+
#### Added:
|
| 102 |
+
def get_input_embeddings(self):
|
| 103 |
+
return self.token_embedding
|
| 104 |
+
|
| 105 |
+
def set_input_embeddings(self, new_emb):
|
| 106 |
+
self.token_embedding = new_emb
|
| 107 |
+
|
| 108 |
+
def get_output_embeddings(self):
|
| 109 |
+
return self.lm_head_proj
|
| 110 |
+
|
| 111 |
+
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 112 |
+
for backbone in (self.mamba_forward.backbone,
|
| 113 |
+
self.mamba_backward.backbone):
|
| 114 |
+
for block in backbone.layers:
|
| 115 |
+
block.gradient_checkpointing = True
|
| 116 |
+
|
| 117 |
+
def forward(
|
| 118 |
+
self,
|
| 119 |
+
input_ids=None,
|
| 120 |
+
inputs_embeds=None,
|
| 121 |
+
attention_mask=None,
|
| 122 |
+
labels=None,
|
| 123 |
+
return_dict=True,
|
| 124 |
+
):
|
| 125 |
+
if inputs_embeds is None:
|
| 126 |
+
input_ids = input_ids.long()
|
| 127 |
+
inputs_embeds = self.token_embedding(input_ids)
|
| 128 |
+
|
| 129 |
+
hid_fwd = self.mamba_forward.backbone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
|
| 130 |
+
rev_emb = torch.flip(inputs_embeds, dims=[1])
|
| 131 |
+
rev_mask = torch.flip(attention_mask, dims=[1])
|
| 132 |
+
hid_bwd = self.mamba_backward.backbone(inputs_embeds=rev_emb, attention_mask=rev_mask)
|
| 133 |
+
hid_bwd = torch.flip(hid_bwd, dims=[1])
|
| 134 |
+
|
| 135 |
+
combined = torch.cat([hid_fwd, hid_bwd], dim=-1)
|
| 136 |
+
projected = self.lm_head_proj(combined)
|
| 137 |
+
logits = F.linear(projected, self.token_embedding.weight)
|
| 138 |
+
|
| 139 |
+
loss = None
|
| 140 |
+
if labels is not None:
|
| 141 |
+
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
|
| 142 |
+
loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
|
| 143 |
+
|
| 144 |
+
if not return_dict:
|
| 145 |
+
out = (logits, combined)
|
| 146 |
+
return (loss,) + out if loss is not None else out
|
| 147 |
+
|
| 148 |
+
return MaskedLMOutput(
|
| 149 |
+
loss=loss,
|
| 150 |
+
logits=logits,
|
| 151 |
+
hidden_states=projected,
|
| 152 |
+
)
|