PyTorch
plantbimoe
biology
genomics
language model
plants
custom_code
linkp commited on
Commit
88ee214
·
1 Parent(s): 05a2b22

add: model file

Browse files
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "checkpoint-pretrain",
3
+ "architectures": [
4
+ "PlantbimoeForMaskedLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_plantbimoe.PlantbimoeConfig",
8
+ "AutoModel": "modeling_plantbimoe.Plantbimoe",
9
+ "AutoModelForMaskedLM": "modeling_plantbimoe.PlantbimoeForMaskedLM",
10
+ "AutoModelForSequenceClassification": "modeling_plantbimoe.PlantbimoeForSequenceClassification"
11
+ },
12
+ "bidirectional": true,
13
+ "bidirectional_strategy": "add",
14
+ "bidirectional_weight_tie": true,
15
+ "d_model": 512,
16
+ "fused_add_norm": true,
17
+ "initializer_cfg": {
18
+ "initializer_range": 0.02,
19
+ "n_residuals_per_layer": 1,
20
+ "rescale_prenorm_residual": true
21
+ },
22
+ "intermediate_size": 1408,
23
+ "model_type": "plantbimoe",
24
+ "n_layer": 16,
25
+ "norm_epsilon": 1e-05,
26
+ "num_experts": 4,
27
+ "num_experts_per_tok": 1,
28
+ "pad_token_id": 4,
29
+ "pad_vocab_size_multiple": 6,
30
+ "residual_in_fp32": false,
31
+ "rms_norm": true,
32
+ "ssm_cfg": {
33
+ "bias": false,
34
+ "conv_bias": true,
35
+ "d_conv": 4,
36
+ "d_state": 16,
37
+ "dt_init": "random",
38
+ "dt_init_floor": 0.0001,
39
+ "dt_max": 0.1,
40
+ "dt_min": 0.001,
41
+ "dt_rank": "auto",
42
+ "dt_scale": 1.0,
43
+ "expand": 2,
44
+ "use_fast_path": true
45
+ },
46
+ "torch_dtype": "float32",
47
+ "transformers_version": "4.38.1",
48
+ "vocab_size": 12
49
+ }
plantbimoe/configuration_plantbimoe.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Plantbimoe config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class PlantbimoeConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
12
+ model_type = "plantbimoe"
13
+
14
+ def __init__(
15
+ self,
16
+ d_model: int = 2560,
17
+ n_layer: int = 64,
18
+ vocab_size: int = 50277,
19
+ ssm_cfg: Optional[dict] = None,
20
+ rms_norm: bool = True,
21
+ residual_in_fp32: bool = True,
22
+ fused_add_norm: bool = True,
23
+ pad_vocab_size_multiple: int = 8,
24
+
25
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
26
+ norm_epsilon: float = 1e-5,
27
+
28
+ # Used in init_weights
29
+ initializer_cfg: Optional[dict] = None,
30
+
31
+ # BimoPlant-specific params
32
+ bidirectional: bool = True,
33
+ bidirectional_strategy: Union[str, None] = "add",
34
+ bidirectional_weight_tie: bool = True,
35
+ intermediate_size: int = 3840,
36
+ num_experts: int = 16,
37
+ num_experts_per_tok: int = 2,
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.d_model = d_model
42
+ self.n_layer = n_layer
43
+ self.vocab_size = vocab_size
44
+ self.ssm_cfg = ssm_cfg
45
+ self.rms_norm = rms_norm
46
+ self.residual_in_fp32 = residual_in_fp32
47
+ self.fused_add_norm = fused_add_norm
48
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
49
+ self.norm_epsilon = norm_epsilon
50
+ self.initializer_cfg = initializer_cfg
51
+ self.bidirectional = bidirectional
52
+ self.bidirectional_strategy = bidirectional_strategy
53
+ self.bidirectional_weight_tie = bidirectional_weight_tie
54
+ self.intermediate_size = intermediate_size
55
+ self.num_experts = num_experts
56
+ self.num_experts_per_tok = num_experts_per_tok
plantbimoe/modeling_plantbimoe.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BiMoPlant model for Hugging Face.
2
+
3
+ """
4
+
5
+ import inspect
6
+ import math
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from mamba_ssm.modules.mamba_simple import Mamba
13
+ try:
14
+ from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
15
+ except ImportError:
16
+ from mamba_ssm.modules.block import Block # mambav2 file structure
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from transformers import PreTrainedModel
20
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
21
+
22
+ try:
23
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
24
+ except ImportError:
25
+ try:
26
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+ from .configuration_plantbimoe import PlantbimoeConfig
31
+
32
+ def create_block(
33
+ d_model,
34
+ ssm_cfg=None,
35
+ norm_epsilon=1e-5,
36
+ rms_norm=False,
37
+ residual_in_fp32=False,
38
+ fused_add_norm=False,
39
+ layer_idx=None,
40
+ bidirectional=True,
41
+ bidirectional_strategy="add",
42
+ bidirectional_weight_tie=True,
43
+ device=None,
44
+ dtype=None,
45
+ ):
46
+ """Create Plantbimoe block.
47
+
48
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
49
+ """
50
+ if ssm_cfg is None:
51
+ ssm_cfg = {}
52
+ factory_kwargs = {"device": device, "dtype": dtype}
53
+ bidirectional_kwargs = {
54
+ "bidirectional": bidirectional,
55
+ "bidirectional_strategy": bidirectional_strategy,
56
+ "bidirectional_weight_tie": bidirectional_weight_tie,
57
+ }
58
+ mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
59
+ norm_cls = partial(
60
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
61
+ )
62
+ block_cls = Block
63
+ # mambav2 compatibility
64
+ if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
65
+ block = block_cls(
66
+ d_model,
67
+ mixer_cls,
68
+ mlp_cls=nn.Identity,
69
+ norm_cls=norm_cls,
70
+ fused_add_norm=fused_add_norm,
71
+ residual_in_fp32=residual_in_fp32,
72
+ )
73
+ else:
74
+ block = block_cls(
75
+ d_model,
76
+ mixer_cls,
77
+ norm_cls=norm_cls,
78
+ fused_add_norm=fused_add_norm,
79
+ residual_in_fp32=residual_in_fp32,
80
+ )
81
+ block.layer_idx = layer_idx
82
+ return block
83
+
84
+ class MambaBlock(nn.Module):
85
+ def __init__(self, config, layer_idx, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False, moe=False, device=None, dtype=None ):
86
+
87
+ """
88
+ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
89
+
90
+ This Block has a slightly different structure compared to a regular
91
+ prenorm Transformer block.
92
+ The standard block is: LN -> MHA/MLP -> Add.
93
+ [Ref: https://arxiv.org/abs/2002.04745]
94
+ Here we have: Add -> LN -> Mixer, returning both
95
+ the hidden_states (output of the mixer) and the residual.
96
+ This is purely for performance reasons, as we can fuse add and LayerNorm.
97
+ The residual needs to be provided (except for the very first block).
98
+ """
99
+ super().__init__()
100
+ factory_kwargs = {"device": device, "dtype": dtype}
101
+
102
+ self.mixer = create_block(
103
+ config.d_model,
104
+ ssm_cfg=config.ssm_cfg,
105
+ norm_epsilon=config.norm_epsilon,
106
+ rms_norm=config.rms_norm,
107
+ residual_in_fp32=config.residual_in_fp32,
108
+ fused_add_norm=config.fused_add_norm,
109
+ layer_idx=layer_idx,
110
+ bidirectional=config.bidirectional,
111
+ bidirectional_strategy=config.bidirectional_strategy,
112
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
113
+ **factory_kwargs,
114
+ )
115
+ ffn_layer_class = PlantbimoeSparseMoeBlock if moe else PlantbimoeMlp
116
+ self.feed_forward = ffn_layer_class(config)
117
+ self.residual_in_fp32 = residual_in_fp32
118
+ self.fused_add_norm = fused_add_norm
119
+ self.norm = norm_cls(config.d_model)
120
+
121
+ if self.fused_add_norm:
122
+ assert RMSNorm is not None, "RMSNorm import fails"
123
+ assert isinstance(
124
+ self.norm, (nn.LayerNorm, RMSNorm)
125
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
126
+
127
+ def forward(
128
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
129
+ ):
130
+ r"""Pass the input through the encoder layer.
131
+
132
+ Args:
133
+ hidden_states: the sequence to the encoder layer (required).
134
+ residual: hidden_states = Mixer(LN(residual))
135
+ """
136
+ hidden_states, residual = self.mixer(hidden_states, residual, inference_params=None)
137
+
138
+ if not self.fused_add_norm:
139
+ residual = (hidden_states + residual) if residual is not None else hidden_states
140
+ hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
141
+ if self.residual_in_fp32:
142
+ residual = residual.to(torch.float32)
143
+ else:
144
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
145
+ hidden_states, residual = fused_add_norm_fn(
146
+ hidden_states,
147
+ self.norm.weight,
148
+ self.norm.bias,
149
+ residual=residual,
150
+ prenorm=True,
151
+ residual_in_fp32=self.residual_in_fp32,
152
+ eps=self.norm.eps,
153
+ )
154
+
155
+ hidden_states = self.feed_forward(hidden_states)
156
+
157
+ return hidden_states, residual
158
+
159
+ # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
160
+ class PlantbimoeSparseMoeBlock(nn.Module):
161
+ """
162
+ This implementation is
163
+ strictly equivalent to standard MoE with full capacity (no
164
+ dropped tokens). It's faster since it formulates MoE operations
165
+ in terms of block-sparse operations to accomodate imbalanced
166
+ assignments of tokens to experts, whereas standard MoE either
167
+ (1) drop tokens at the cost of reduced performance or (2) set
168
+ capacity factor to number of experts and thus waste computation
169
+ and memory on padding.
170
+ """
171
+
172
+ def __init__(self, config: PlantbimoeConfig):
173
+ super().__init__()
174
+ self.hidden_dim = config.d_model
175
+ # self.ffn_dim = config.intermediate_size
176
+ self.num_experts = config.num_experts
177
+ self.top_k = config.num_experts_per_tok
178
+
179
+ self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
180
+ self.experts = nn.ModuleList([PlantbimoeMlp(config) for _ in range(self.num_experts)])
181
+
182
+ def forward(self, hidden_states: torch.Tensor):
183
+ """ """
184
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
185
+
186
+ hidden_states = hidden_states.view(-1, hidden_dim)
187
+ # router_logits: (batch * sequence_length, n_experts)
188
+ router_logits = self.router(hidden_states)
189
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
190
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
191
+ # we cast back to the input dtype
192
+ routing_weights = routing_weights.to(hidden_states.dtype)
193
+
194
+ final_hidden_states = torch.zeros(
195
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
196
+ )
197
+
198
+ # One hot encode the selected experts to create an expert mask
199
+ # this will be used to easily index which expert is going to be sollicitated
200
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
201
+
202
+ # Loop over all available experts in the model and perform the computation on each expert
203
+ for expert_idx in range(self.num_experts):
204
+ expert_layer = self.experts[expert_idx]
205
+ idx, top_x = torch.where(expert_mask[expert_idx])
206
+
207
+ if top_x.shape[0] == 0:
208
+ continue
209
+
210
+ # Index the correct hidden states and compute the expert hidden state for
211
+ # the current expert. We need to make sure to multiply the output hidden
212
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
213
+ current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
214
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
215
+
216
+ # However `index_add_` only support torch tensors for indexing so we'll use
217
+ # the `top_x` tensor here.
218
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
219
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
220
+ return final_hidden_states
221
+
222
+ # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
223
+ class PlantbimoeMlp(nn.Module):
224
+ def __init__(self, config):
225
+ super().__init__()
226
+ self.config = config
227
+ self.hidden_size = config.d_model
228
+ self.intermediate_size = config.intermediate_size
229
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
230
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
231
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
232
+ self.act_fn = nn.SiLU()
233
+
234
+ def forward(self, x):
235
+
236
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
237
+
238
+ class BiMambaWrapper(nn.Module):
239
+ """Thin wrapper around Mamba to support bi-directionality."""
240
+
241
+ def __init__(
242
+ self,
243
+ d_model: int,
244
+ bidirectional: bool = True,
245
+ bidirectional_strategy: Optional[str] = "add",
246
+ bidirectional_weight_tie: bool = True,
247
+ **mamba_kwargs,
248
+ ):
249
+ super().__init__()
250
+ if bidirectional and bidirectional_strategy is None:
251
+ bidirectional_strategy = "add" # Default strategy: `add`
252
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
253
+ raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
254
+ self.bidirectional = bidirectional
255
+ self.bidirectional_strategy = bidirectional_strategy
256
+ self.mamba_fwd = Mamba(
257
+ d_model=d_model,
258
+ **mamba_kwargs
259
+ )
260
+ if bidirectional:
261
+ self.mamba_rev = Mamba(
262
+ d_model=d_model,
263
+ **mamba_kwargs
264
+ )
265
+ if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
266
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
267
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
268
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
269
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
270
+ else:
271
+ self.mamba_rev = None
272
+
273
+ def forward(self, hidden_states, inference_params=None):
274
+ """Bidirectional-enabled forward pass
275
+
276
+ hidden_states: (B, L, D)
277
+ Returns: same shape as hidden_states
278
+ """
279
+ out = self.mamba_fwd(hidden_states, inference_params=inference_params)
280
+ if self.bidirectional:
281
+ out_rev = self.mamba_rev(
282
+ hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
283
+ inference_params=inference_params
284
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
285
+ if self.bidirectional_strategy == "add":
286
+ out = out + out_rev
287
+ elif self.bidirectional_strategy == "ew_multiply":
288
+ out = out * out_rev
289
+ else:
290
+ raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
291
+ return out
292
+
293
+
294
+ class PlantbimoeEmbeddings(nn.Module):
295
+ def __init__(
296
+ self,
297
+ config: PlantbimoeConfig,
298
+ device=None,
299
+ dtype=None,
300
+ ):
301
+ super().__init__()
302
+ factory_kwargs = {"device": device, "dtype": dtype}
303
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
304
+
305
+ def forward(self, input_ids):
306
+ """
307
+ input_ids: (batch, seqlen)
308
+ """
309
+ return self.word_embeddings(input_ids)
310
+
311
+
312
+ class PlantbimoeMixerModel(nn.Module):
313
+ def __init__(
314
+ self,
315
+ config: PlantbimoeConfig,
316
+ device=None,
317
+ dtype=None,
318
+ ) -> None:
319
+ super().__init__()
320
+ factory_kwargs = {"device": device, "dtype": dtype}
321
+
322
+ self.fused_add_norm = config.fused_add_norm
323
+ self.residual_in_fp32 = config.residual_in_fp32
324
+
325
+ self.embeddings = PlantbimoeEmbeddings(config, **factory_kwargs)
326
+
327
+ # Mamba changes the order of residual and layer norm:
328
+ # Instead of LN -> Attn / MLP -> Add, we do:
329
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
330
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
331
+ # This is for performance reason: we can fuse add + layer_norm.
332
+ if config.fused_add_norm:
333
+ if layer_norm_fn is None or rms_norm_fn is None:
334
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
335
+
336
+ norm_cls = partial(
337
+ nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.norm_epsilon, **factory_kwargs
338
+ )
339
+
340
+ self.layers = nn.ModuleList()
341
+ for i in range(config.n_layer):
342
+ moe = (i + 1) % 2 == 0
343
+ self.layers.append(
344
+ MambaBlock(
345
+ config=config,
346
+ layer_idx=i,
347
+ norm_cls=norm_cls,
348
+ fused_add_norm=config.fused_add_norm,
349
+ residual_in_fp32=config.residual_in_fp32,
350
+ moe=moe
351
+ )
352
+ )
353
+
354
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
355
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
356
+ )
357
+ self.norm_f = norm_f
358
+
359
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
360
+ """Mixer forward."""
361
+ all_hidden_states = []
362
+ if inputs_embeds is not None:
363
+ hidden_states = inputs_embeds
364
+ else:
365
+ hidden_states = self.embeddings(input_ids)
366
+
367
+ residual = None
368
+ for layer in self.layers:
369
+ if output_hidden_states:
370
+ all_hidden_states.append(hidden_states)
371
+ # TODO: Add support for gradient checkpointing
372
+ hidden_states, residual = layer(
373
+ hidden_states, residual, inference_params=None
374
+ )
375
+
376
+ if not self.fused_add_norm:
377
+ residual = (hidden_states + residual) if residual is not None else hidden_states
378
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
379
+ else:
380
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
381
+ # Set prenorm=False here since we don't need the residual
382
+ hidden_states = fused_add_norm_fn(
383
+ hidden_states,
384
+ self.norm_f.weight,
385
+ self.norm_f.bias,
386
+ eps=self.norm_f.eps,
387
+ residual=residual,
388
+ prenorm=False,
389
+ residual_in_fp32=self.residual_in_fp32,
390
+ )
391
+ if output_hidden_states:
392
+ all_hidden_states.append(hidden_states)
393
+ return hidden_states, all_hidden_states
394
+
395
+
396
+ def cross_entropy(logits, y, ignore_index=-100):
397
+ """Cross entropy loss."""
398
+ logits = logits.view(-1, logits.shape[-1])
399
+ y = y.view(-1)
400
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
401
+
402
+
403
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
404
+ """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
405
+ logits = logits.view(-1, logits.shape[-1])
406
+ y = y.view(-1)
407
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
408
+ loss_weights = loss_weights.view(-1)
409
+ loss_weights[y == ignore_index] = 0.0
410
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
411
+ return (ce * (loss_weights / loss_weights.sum())).sum()
412
+
413
+
414
+ class PlantbimoePreTrainedModel(PreTrainedModel):
415
+ """PreTrainedModel wrapper for Plantbimoe backbone."""
416
+ config_class = PlantbimoeConfig
417
+ # base_model_prefix = "plantbimoe"
418
+ supports_gradient_checkpointing = False
419
+ _no_split_modules = ["MambaBlock"]
420
+
421
+ def _init_weights(
422
+ self,
423
+ module,
424
+ initializer_range=0.02, # Now only used for embedding layer.
425
+ **kwargs,
426
+ ):
427
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
428
+
429
+ n_layer = self.config.n_layer
430
+ initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
431
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
432
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
433
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
434
+
435
+ if isinstance(module, nn.Linear):
436
+ if module.bias is not None:
437
+ if not getattr(module.bias, "_no_reinit", False):
438
+ nn.init.zeros_(module.bias)
439
+ elif isinstance(module, nn.Embedding):
440
+ nn.init.normal_(module.weight, std=initializer_range)
441
+
442
+ if rescale_prenorm_residual:
443
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
444
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
445
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
446
+ # residual layers.
447
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
448
+ #
449
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
450
+ for name, p in module.named_parameters():
451
+ if name in ["out_proj.weight", "fc2.weight"]:
452
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
453
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
454
+ # We need to reinit p since this code could be called multiple times
455
+ # Having just p *= scale would repeatedly scale it down
456
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
457
+ with torch.no_grad():
458
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
459
+
460
+
461
+ class Plantbimoe(PlantbimoePreTrainedModel):
462
+ """Plantbimoe model that can be instantiated using HF patterns."""
463
+ def __init__(self, config: PlantbimoeConfig, device=None, dtype=None, **kwargs):
464
+ super().__init__(config)
465
+
466
+ # Adjust vocab size and complement maps if vocab padding is set.
467
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
468
+ config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
469
+
470
+ self.config = config
471
+ factory_kwargs = {"device": device, "dtype": dtype}
472
+ self.backbone = PlantbimoeMixerModel(config, **factory_kwargs, **kwargs)
473
+
474
+ def forward(
475
+ self,
476
+ input_ids: torch.LongTensor = None,
477
+ inputs_embeds: Optional[torch.FloatTensor] = None,
478
+ output_hidden_states: Optional[bool] = None,
479
+ return_dict: Optional[bool] = None,
480
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
481
+ """HF-compatible forward method."""
482
+ output_hidden_states = (
483
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
484
+ )
485
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
+
487
+ hidden_states, all_hidden_states = self.backbone(
488
+ input_ids,
489
+ inputs_embeds=inputs_embeds,
490
+ output_hidden_states=output_hidden_states
491
+ )
492
+ if return_dict:
493
+ return BaseModelOutputWithNoAttention(
494
+ last_hidden_state=hidden_states,
495
+ hidden_states=all_hidden_states if output_hidden_states else None
496
+ )
497
+ elif output_hidden_states:
498
+ return hidden_states, all_hidden_states
499
+ else:
500
+ return hidden_states
501
+
502
+
503
+ class PlantbimoeForMaskedLM(PlantbimoePreTrainedModel):
504
+ """HF-compatible Plantbimoe model for masked language modeling."""
505
+
506
+ def __init__(self, config: PlantbimoeConfig, device=None, dtype=None, **kwargs):
507
+ super().__init__(config, **kwargs)
508
+ factory_kwargs = {"device": device, "dtype": dtype}
509
+ self.plantbimoe = Plantbimoe(config, **factory_kwargs, **kwargs)
510
+ self.lm_head = nn.Linear(
511
+ config.d_model,
512
+ self.config.vocab_size, # Use plantbimoe config as it might have been updated
513
+ bias=False,
514
+ **factory_kwargs
515
+ )
516
+
517
+ # Initialize weights and apply final processing
518
+ self.post_init()
519
+
520
+ def get_input_embeddings(self):
521
+ return self.plantbimoe.backbone.embeddings.word_embeddings
522
+
523
+ def set_input_embeddings(self, value):
524
+ self.plantbimoe.backbone.embeddings.word_embeddings = value
525
+
526
+ def get_output_embeddings(self):
527
+ return self.lm_head
528
+
529
+ def set_output_embeddings(self, new_embeddings):
530
+ """Overrides output embeddings."""
531
+ self.lm_head = new_embeddings
532
+
533
+ def tie_weights(self):
534
+ """Tie weights, accounting for RCPS."""
535
+ super().tie_weights()
536
+
537
+ def get_decoder(self):
538
+ """Get decoder (backbone) for the model."""
539
+ return self.plantbimoe
540
+
541
+ def set_decoder(self, decoder):
542
+ """Set decoder (backbone) for the model."""
543
+ self.plantbimoe = decoder
544
+
545
+ def forward(
546
+ self,
547
+ input_ids: torch.LongTensor = None,
548
+ inputs_embeds: Optional[torch.FloatTensor] = None,
549
+ labels: Optional[torch.LongTensor] = None,
550
+ loss_weights: Optional[torch.FloatTensor] = None,
551
+ output_hidden_states: Optional[bool] = None,
552
+ return_dict: Optional[bool] = None,
553
+ **kwargs,
554
+ ) -> Union[Tuple, MaskedLMOutput]:
555
+ """HF-compatible forward method."""
556
+
557
+ output_hidden_states = (
558
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
559
+ )
560
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
561
+
562
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
563
+ outputs = self.plantbimoe(
564
+ input_ids=input_ids,
565
+ inputs_embeds=inputs_embeds,
566
+ output_hidden_states=output_hidden_states,
567
+ return_dict=return_dict,
568
+ )
569
+
570
+ hidden_states = outputs[0]
571
+ logits = self.lm_head(hidden_states)
572
+ logits = logits.float()
573
+
574
+ loss = None
575
+ if labels is not None:
576
+ if loss_weights is not None:
577
+ loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
578
+ else:
579
+ loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
580
+
581
+ if not return_dict:
582
+ output = (logits,) + outputs[1:]
583
+ return (loss,) + output if loss is not None else output
584
+
585
+ return MaskedLMOutput(
586
+ loss=loss,
587
+ logits=logits,
588
+ hidden_states=outputs.hidden_states,
589
+ )
590
+
591
+
592
+ class PlantbimoeForSequenceClassification(PlantbimoePreTrainedModel):
593
+ def __init__(
594
+ self,
595
+ config: PlantbimoeConfig,
596
+ pooling_strategy: str = "mean",
597
+ conjoin_train: bool = False,
598
+ conjoin_eval: bool = False,
599
+ device=None,
600
+ dtype=None,
601
+ **kwargs):
602
+ super().__init__(config, **kwargs)
603
+ if pooling_strategy not in ["mean", "max", "first", "last"]:
604
+ raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
605
+ self.pooling_strategy = pooling_strategy
606
+ factory_kwargs = {"device": device, "dtype": dtype}
607
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
608
+ self.plantbimoe = Plantbimoe(config, **factory_kwargs, **kwargs)
609
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
610
+
611
+ self.conjoin_train = conjoin_train
612
+ self.conjoin_eval = conjoin_eval
613
+
614
+ # Initialize weights and apply final processing
615
+ self.post_init()
616
+ self.init_scorer()
617
+
618
+ def init_scorer(self, initializer_range=0.02):
619
+ initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
620
+ if self.config.initializer_cfg is not None else initializer_range
621
+ self.score.weight.data.normal_(std=initializer_range)
622
+
623
+ def get_input_embeddings(self):
624
+ return self.plantbimoe.backbone.embeddings.word_embeddings
625
+
626
+ def set_input_embeddings(self, value):
627
+ self.plantbimoe.backbone.embeddings.word_embeddings = value
628
+
629
+ def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
630
+ """Pools hidden states along sequence length dimension."""
631
+ if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
632
+ return hidden_states.mean(dim=sequence_length_dim)
633
+ if self.pooling_strategy == "max": # Max pooling along sequence length dimension
634
+ return hidden_states.max(dim=sequence_length_dim).values
635
+ if self.pooling_strategy == "last": # Use embedding of last token in the sequence
636
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
637
+ if self.pooling_strategy == "first": # Use embedding of first token in the sequence
638
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
639
+ if self.pooling_strategy == "all": # Segamentation pooling
640
+ return hidden_states[:,1:-1,:]
641
+
642
+ def forward(
643
+ self,
644
+ input_ids: torch.LongTensor = None,
645
+ attention_mask=None,
646
+ inputs_embeds: Optional[torch.FloatTensor] = None,
647
+ labels: Optional[torch.LongTensor] = None,
648
+ output_hidden_states: Optional[bool] = None,
649
+ return_dict: Optional[bool] = None,
650
+ ) -> Union[Tuple, SequenceClassifierOutput]:
651
+ r"""
652
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
653
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
654
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
655
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
656
+ """
657
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
658
+
659
+ # Get hidden representations from the backbone
660
+ if self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
661
+ assert input_ids is not None, "`input_ids` must be provided for conjoining."
662
+ assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
663
+ transformer_outputs = self.plantbimoe(
664
+ input_ids[..., 0],
665
+ inputs_embeds=None,
666
+ output_hidden_states=output_hidden_states,
667
+ return_dict=return_dict,
668
+ )
669
+ transformer_outputs_rc = self.plantbimoe(
670
+ input_ids[..., 1],
671
+ inputs_embeds=None,
672
+ output_hidden_states=output_hidden_states,
673
+ return_dict=return_dict,
674
+ )
675
+ # Stack along channel dimension (dim=-1)
676
+ hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
677
+ else:
678
+ transformer_outputs = self.plantbimoe(
679
+ input_ids,
680
+ inputs_embeds=None,
681
+ output_hidden_states=output_hidden_states,
682
+ return_dict=return_dict,
683
+ )
684
+ hidden_states = transformer_outputs[0]
685
+
686
+ # Pool and get logits
687
+ pooled_hidden_states = self.pool_hidden_states(hidden_states)
688
+ # Potentially run `score` twice (with parameters shared) for conjoining
689
+ if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
690
+ logits_fwd = self.score(pooled_hidden_states[..., 0])
691
+ logits_rc = self.score(pooled_hidden_states[..., 1])
692
+ logits = (logits_fwd + logits_rc) / 2
693
+ else:
694
+ logits = self.score(pooled_hidden_states)
695
+
696
+ loss = None
697
+ if labels is not None:
698
+ labels = labels.to(logits.device)
699
+ if self.config.problem_type is None:
700
+ if self.num_labels == 1:
701
+ self.config.problem_type = "regression"
702
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
703
+ self.config.problem_type = "single_label_classification"
704
+ else:
705
+ self.config.problem_type = "multi_label_classification"
706
+
707
+
708
+ if self.config.problem_type == "regression":
709
+ if self.num_labels == 1:
710
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
711
+ else:
712
+ loss = F.mse_loss(logits, labels)
713
+ elif self.config.problem_type == "single_label_classification":
714
+ loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1).squeeze(-1))
715
+ elif self.config.problem_type == "multi_label_classification":
716
+ loss = F.binary_cross_entropy_with_logits(logits, labels.float())
717
+ if not return_dict:
718
+ output = (logits,) + transformer_outputs[1:]
719
+ return ((loss,) + output) if loss is not None else output
720
+
721
+ return SequenceClassifierOutput(
722
+ loss=loss,
723
+ logits=logits,
724
+ hidden_states=transformer_outputs.hidden_states,
725
+ )
plantbimoe/tokenization_plantbimoe.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedTokenizer
2
+ from typing import List, Optional, Dict, Tuple
3
+
4
+ # Copied from HyenaDNATokenizer
5
+ class PlantbimoeTokenizer(PreTrainedTokenizer):
6
+ model_input_names = ["input_ids"]
7
+
8
+ def __init__(self,
9
+ model_max_length: int,
10
+ bos_token="[BOS]",
11
+ eos_token="[SEP]",
12
+ sep_token="[SEP]",
13
+ cls_token="[CLS]",
14
+ pad_token="[PAD]",
15
+ mask_token="[MASK]",
16
+ unk_token="[UNK]",
17
+ **kwargs):
18
+ """Character tokenizer for Hugging Face transformers.
19
+ Args:
20
+ characters (Sequence[str]): List of desired characters. Any character which
21
+ is not included in this list will be replaced by a special token called
22
+ [UNK] with id=6. Following are list of all of the special tokens with
23
+ their corresponding ids:
24
+ "[CLS]": 0
25
+ "[SEP]": 1
26
+ "[BOS]": 2
27
+ "[MASK]": 3
28
+ "[PAD]": 4
29
+ "[RESERVED]": 5
30
+ "[UNK]": 6
31
+ an id (starting at 7) will be assigned to each character.
32
+ model_max_length (int): Model maximum sequence length.
33
+ """
34
+ self.characters = ('A', 'C', 'G', 'T', 'N')
35
+ self.model_max_length = model_max_length
36
+
37
+ self._vocab_str_to_int = {
38
+ "[CLS]": 0,
39
+ "[SEP]": 1,
40
+ "[BOS]": 2,
41
+ "[MASK]": 3,
42
+ "[PAD]": 4,
43
+ "[RESERVED]": 5,
44
+ "[UNK]": 6,
45
+ **{ch: i + 7 for i, ch in enumerate(self.characters)},
46
+ }
47
+ self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
48
+ add_prefix_space = kwargs.pop("add_prefix_space", False)
49
+ padding_side = kwargs.pop("padding_side", "left")
50
+
51
+ super().__init__(
52
+ bos_token=bos_token,
53
+ eos_token=eos_token,
54
+ sep_token=sep_token,
55
+ cls_token=cls_token,
56
+ pad_token=pad_token,
57
+ mask_token=mask_token,
58
+ unk_token=unk_token,
59
+ add_prefix_space=add_prefix_space,
60
+ model_max_length=model_max_length,
61
+ padding_side=padding_side,
62
+ **kwargs,
63
+ )
64
+
65
+ @property
66
+ def vocab_size(self) -> int:
67
+ return len(self._vocab_str_to_int)
68
+
69
+ def _tokenize(self, text: str) -> List[str]:
70
+ return list(text)
71
+
72
+ def _convert_token_to_id(self, token: str) -> int:
73
+ return self._vocab_str_to_int.get(token, self._vocab_str_to_int["[UNK]"])
74
+
75
+ def _convert_id_to_token(self, index: int) -> str:
76
+ return self._vocab_int_to_str[index]
77
+
78
+ def convert_tokens_to_string(self, tokens):
79
+ return "".join(tokens)
80
+
81
+ def get_special_tokens_mask(
82
+ self,
83
+ token_ids_0: List[int],
84
+ token_ids_1: Optional[List[int]] = None,
85
+ already_has_special_tokens: bool = False,
86
+ ) -> List[int]:
87
+ if already_has_special_tokens:
88
+ return super().get_special_tokens_mask(
89
+ token_ids_0=token_ids_0,
90
+ token_ids_1=token_ids_1,
91
+ already_has_special_tokens=True,
92
+ )
93
+
94
+ result = ([0] * len(token_ids_0)) + [1]
95
+ if token_ids_1 is not None:
96
+ result += ([0] * len(token_ids_1)) + [1]
97
+ return result
98
+
99
+ def build_inputs_with_special_tokens(
100
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
101
+ ) -> List[int]:
102
+ cls = [self.cls_token_id]
103
+ sep = [self.sep_token_id]
104
+ result = cls + token_ids_0 + sep
105
+ if token_ids_1 is not None:
106
+ result += token_ids_1 + sep
107
+ return result
108
+
109
+ def get_vocab(self) -> Dict[str, int]:
110
+ return self._vocab_str_to_int
111
+
112
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple:
113
+ return ()
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:945126f2ae0fdd3cd04f9708fa289c1273515fc40afa6ff624d368177fb59a95
3
+ size 462598584