zzq1zh commited on
Commit
9948ba6
·
verified ·
1 Parent(s): 6538bc4

Upload modeling_BiMambaForMaskedLM.py

Browse files
Files changed (1) hide show
  1. modeling_BiMambaForMaskedLM.py +152 -0
modeling_BiMambaForMaskedLM.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, AutoConfig
5
+ from transformers.modeling_outputs import MaskedLMOutput
6
+ from mamba_ssm.modules.mamba_simple import Mamba
7
+ from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
8
+ from mamba_ssm.models.config_mamba import MambaConfig
9
+ from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
10
+ try:
11
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
12
+ except ImportError:
13
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
14
+ from einops import rearrange
15
+
16
+ def convert_hf_config_to_mamba(hf_config) -> MambaConfig:
17
+ return MambaConfig(
18
+ d_model=hf_config.d_model,
19
+ d_intermediate=getattr(hf_config, "intermediate_size", 4 * hf_config.d_model),
20
+ n_layer=getattr(hf_config, "n_layer", getattr(hf_config, "num_hidden_layers", 12)),
21
+ vocab_size=hf_config.vocab_size,
22
+ ssm_cfg=getattr(hf_config, "ssm_cfg", {}),
23
+ attn_layer_idx=getattr(hf_config, "attn_layer_idx", []),
24
+ attn_cfg=getattr(hf_config, "attn_cfg", {}),
25
+ rms_norm=getattr(hf_config, "rms_norm", True),
26
+ residual_in_fp32=getattr(hf_config, "residual_in_fp32", True),
27
+ fused_add_norm=getattr(hf_config, "fused_add_norm", False),
28
+ pad_vocab_size_multiple=getattr(hf_config, "pad_vocab_size_multiple", 8),
29
+ tie_embeddings=getattr(hf_config, "tie_embeddings", False),
30
+ )
31
+
32
+ def patch_mixer_forward_to_accept_embeddings(model):
33
+ """
34
+ Injects a new forward method into a MixerModel instance,
35
+ allowing it to accept either input_ids or inputs_embeds.
36
+ """
37
+
38
+ def new_forward(self, input_ids=None, inputs_embeds=None, inference_params=None, attention_mask=None, **mixer_kwargs):
39
+ if inputs_embeds is not None:
40
+ hidden_states = inputs_embeds
41
+ elif input_ids is not None:
42
+ hidden_states = self.embedding(input_ids)
43
+ else:
44
+ raise ValueError("You must provide either input_ids or inputs_embeds.")
45
+
46
+ residual = None
47
+
48
+ # hiddens: (batch_size, seq_len, d_model)
49
+ # attention_mask: (batch_size, seq_len) -- 1 for real tokens, 0 for padding
50
+ mask = attention_mask.unsqueeze(-1) # (batch_size, seq_len, 1)
51
+
52
+ for layer in self.layers:
53
+ hidden_states, residual = layer(
54
+ hidden_states, residual, inference_params=inference_params, **mixer_kwargs
55
+ )
56
+
57
+ # Add attention mask
58
+ hidden_states = hidden_states * mask
59
+ residual = residual * mask
60
+
61
+ if not self.fused_add_norm:
62
+ residual = (hidden_states + residual) if residual is not None else hidden_states
63
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
64
+ else:
65
+ # Set prenorm=False here since we don't need the residual
66
+ hidden_states = layer_norm_fn(
67
+ hidden_states,
68
+ self.norm_f.weight,
69
+ self.norm_f.bias,
70
+ eps=self.norm_f.eps,
71
+ residual=residual,
72
+ prenorm=False,
73
+ residual_in_fp32=self.residual_in_fp32,
74
+ is_rms_norm=isinstance(self.norm_f, RMSNorm)
75
+ )
76
+ return hidden_states
77
+
78
+ # Bind the new forward method to the instance
79
+ model.backbone.forward = new_forward.__get__(model.backbone, model.backbone.__class__)
80
+
81
+ class BiMambaForMaskedLM(PreTrainedModel):
82
+ config_class = AutoConfig
83
+ base_model_prefix = "bimamba"
84
+
85
+ def __init__(self, config):
86
+ super().__init__(config) # <-- HF init
87
+ mamba_cfg = convert_hf_config_to_mamba(config)
88
+
89
+ # your embedding + two Mamba directions + proj
90
+ self.token_embedding = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
91
+ self.mamba_forward = MambaLMHeadModel(mamba_cfg)
92
+ self.mamba_backward = MambaLMHeadModel(mamba_cfg)
93
+ self.lm_head_proj = nn.Linear(config.d_model * 2, config.d_model, bias=False)
94
+
95
+ # Patch mixer_forward_to accept embeddings
96
+ patch_mixer_forward_to_accept_embeddings(self.mamba_forward)
97
+ patch_mixer_forward_to_accept_embeddings(self.mamba_backward)
98
+
99
+ # self.post_init() # wires up HF weight-tying & save/load
100
+
101
+ #### Added:
102
+ def get_input_embeddings(self):
103
+ return self.token_embedding
104
+
105
+ def set_input_embeddings(self, new_emb):
106
+ self.token_embedding = new_emb
107
+
108
+ def get_output_embeddings(self):
109
+ return self.lm_head_proj
110
+
111
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
112
+ for backbone in (self.mamba_forward.backbone,
113
+ self.mamba_backward.backbone):
114
+ for block in backbone.layers:
115
+ block.gradient_checkpointing = True
116
+
117
+ def forward(
118
+ self,
119
+ input_ids=None,
120
+ inputs_embeds=None,
121
+ attention_mask=None,
122
+ labels=None,
123
+ return_dict=True,
124
+ ):
125
+ if inputs_embeds is None:
126
+ input_ids = input_ids.long()
127
+ inputs_embeds = self.token_embedding(input_ids)
128
+
129
+ hid_fwd = self.mamba_forward.backbone(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
130
+ rev_emb = torch.flip(inputs_embeds, dims=[1])
131
+ rev_mask = torch.flip(attention_mask, dims=[1])
132
+ hid_bwd = self.mamba_backward.backbone(inputs_embeds=rev_emb, attention_mask=rev_mask)
133
+ hid_bwd = torch.flip(hid_bwd, dims=[1])
134
+
135
+ combined = torch.cat([hid_fwd, hid_bwd], dim=-1)
136
+ projected = self.lm_head_proj(combined)
137
+ logits = F.linear(projected, self.token_embedding.weight)
138
+
139
+ loss = None
140
+ if labels is not None:
141
+ loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
142
+ loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
143
+
144
+ if not return_dict:
145
+ out = (logits, combined)
146
+ return (loss,) + out if loss is not None else out
147
+
148
+ return MaskedLMOutput(
149
+ loss=loss,
150
+ logits=logits,
151
+ hidden_states=projected,
152
+ )