emarro commited on
Commit
3b6966d
·
verified ·
1 Parent(s): d443eaa

Upload CaduceusForMaskedLM

Browse files
config.json ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "_name_or_path": "hf-compo-cad2-l24-ms-v-chtk-c12k-1t-v2-b2-lr4e4-pHntE6-ep1-ba320185/",
4
+ "architectures": [
5
+ "CaduceusForMaskedLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_caduceus.CaduceusConfig",
9
+ "AutoModel": "modeling_caduceus.Caduceus",
10
+ "AutoModelForMaskedLM": "modeling_caduceus.CaduceusForMaskedLM",
11
+ "AutoModelForSequenceClassification": "modeling_caduceus.CaduceusForSequenceClassification"
12
+ },
13
+ "bidirectional": true,
14
+ "bidirectional_strategy": "add",
15
+ "bidirectional_weight_tie": true,
16
+ "complement_map": {
17
+ "0": 0,
18
+ "1": 1,
19
+ "2": 2,
20
+ "3": 6,
21
+ "4": 5,
22
+ "5": 4,
23
+ "6": 3,
24
+ "7": 7
25
+ },
26
+ "d_intermediate": 0,
27
+ "d_model": 768,
28
+ "fused_add_norm": true,
29
+ "initializer_cfg": {
30
+ "initializer_range": 0.02,
31
+ "n_residuals_per_layer": 1,
32
+ "rescale_prenorm_residual": true
33
+ },
34
+ "model_type": "caduceus",
35
+ "n_layer": 24,
36
+ "norm_epsilon": 1e-05,
37
+ "pad_token_id": -100,
38
+ "pad_vocab_size_multiple": 8,
39
+ "rcps": true,
40
+ "residual_in_fp32": false,
41
+ "rms_norm": true,
42
+ "ssm_cfg": {
43
+ "bias": false,
44
+ "conv_bias": true,
45
+ "d_conv": 4,
46
+ "d_state": 64,
47
+ "dt_init_floor": 0.0001,
48
+ "dt_max": 0.1,
49
+ "dt_min": 0.001,
50
+ "expand": 2
51
+ },
52
+ "torch_dtype": "float32",
53
+ "transformers_version": "4.36.1",
54
+ "vocab_size": 8
55
+ }
configuration_caduceus.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class CaduceusConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
12
+ model_type = "caduceus"
13
+
14
+ def __init__(
15
+ self,
16
+ # From original MambaConfig
17
+ d_model: int = 2560,
18
+ n_layer: int = 64,
19
+ vocab_size: int = 50277,
20
+ ssm_cfg: Optional[dict] = None,
21
+ rms_norm: bool = True,
22
+ residual_in_fp32: bool = True,
23
+ fused_add_norm: bool = True,
24
+ pad_vocab_size_multiple: int = 8,
25
+
26
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
27
+ norm_epsilon: float = 1e-5,
28
+
29
+ # Used in init_weights
30
+ initializer_cfg: Optional[dict] = None,
31
+
32
+ # Caduceus-specific params
33
+ bidirectional: bool = True,
34
+ bidirectional_strategy: Union[str, None] = "add",
35
+ bidirectional_weight_tie: bool = True,
36
+ rcps: bool = False,
37
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.d_model = d_model
42
+ self.n_layer = n_layer
43
+ self.vocab_size = vocab_size
44
+ self.ssm_cfg = ssm_cfg
45
+ self.rms_norm = rms_norm
46
+ self.residual_in_fp32 = residual_in_fp32
47
+ self.fused_add_norm = fused_add_norm
48
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
49
+ self.norm_epsilon = norm_epsilon
50
+ self.initializer_cfg = initializer_cfg
51
+ self.bidirectional = bidirectional
52
+ self.bidirectional_strategy = bidirectional_strategy
53
+ self.bidirectional_weight_tie = bidirectional_weight_tie
54
+ self.rcps = rcps
55
+ self.complement_map = complement_map
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62402395ca103398c1e78ee8749934e436acba51b28c027cf23512dc71933f11
3
+ size 353000744
modeling_caduceus.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus model for Hugging Face.
2
+
3
+ """
4
+ import copy
5
+ import math
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ # from mamba_ssm.modules.mamba_simple import Mamba
11
+ # #from mamba_ssm.modules.mamba2_simple import Mamba2Simple as Mamba2
12
+ # from mamba_ssm import Mamba2
13
+ # from mamba_ssm.modules.block import Block
14
+
15
+
16
+ from mamba_ssm.models.config_mamba import MambaConfig
17
+ from mamba_ssm.modules.mamba_simple import Mamba
18
+ from mamba_ssm.modules.mamba2 import Mamba2
19
+ from mamba_ssm.modules.mha import MHA
20
+ from mamba_ssm.modules.mlp import GatedMLP
21
+ from mamba_ssm.modules.block import Block
22
+ from mamba_ssm.utils.generation import GenerationMixin
23
+ from torch import nn
24
+ from torch.nn import functional as F
25
+ from transformers import PreTrainedModel
26
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
27
+
28
+ #try:
29
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
30
+ #except ImportError:
31
+ # RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
32
+
33
+ from .configuration_caduceus import CaduceusConfig
34
+ from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
35
+
36
+
37
+ # def create_block(
38
+ # d_model,
39
+ # d_intermediate,
40
+ # ssm_cfg=None,
41
+ # norm_epsilon=1e-5,
42
+ # rms_norm=False,
43
+ # residual_in_fp32=False,
44
+ # fused_add_norm=False,
45
+ # layer_idx=None,
46
+ # bidirectional=True,
47
+ # bidirectional_strategy="add",
48
+ # bidirectional_weight_tie=True,
49
+ # rcps=False,
50
+ # device=None,
51
+ # dtype=None,
52
+ # ):
53
+ # """Create Caduceus block.
54
+
55
+ # Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
56
+ # """
57
+ # if ssm_cfg is None:
58
+ # ssm_cfg = {}
59
+ # factory_kwargs = {"device": device, "dtype": dtype}
60
+ # bidirectional_kwargs = {
61
+ # "bidirectional": bidirectional,
62
+ # "bidirectional_strategy": bidirectional_strategy,
63
+ # "bidirectional_weight_tie": bidirectional_weight_tie,
64
+ # }
65
+ # mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
66
+ # norm_cls = partial(
67
+ # nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
68
+ # )
69
+ # block_cls = RCPSMambaBlock if rcps else Block
70
+ # if d_intermediate == 0:
71
+ # mlp_cls = nn.Identity
72
+ # else:
73
+ # mlp_cls = partial(
74
+ # GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
75
+ # )
76
+ # block = block_cls(
77
+ # d_model,
78
+ # mixer_cls,
79
+ # mlp_cls,
80
+ # norm_cls=norm_cls,
81
+ # fused_add_norm=fused_add_norm,
82
+ # residual_in_fp32=residual_in_fp32,
83
+ # )
84
+ # block.layer_idx = layer_idx
85
+ # return block
86
+
87
+ def create_block(
88
+ d_model,
89
+ d_intermediate,
90
+ ssm_cfg=None,
91
+ attn_layer_idx=None,
92
+ attn_cfg=None,
93
+ norm_epsilon=1e-5,
94
+ rms_norm=False,
95
+ residual_in_fp32=False,
96
+ fused_add_norm=False,
97
+ layer_idx=None,
98
+ device=None,
99
+ dtype=None,
100
+ bidirectional=True,
101
+ bidirectional_strategy="add",
102
+ bidirectional_weight_tie=True,
103
+ rcps=False,
104
+ ):
105
+ if ssm_cfg is None:
106
+ ssm_cfg = {}
107
+ if attn_layer_idx is None:
108
+ attn_layer_idx = []
109
+ if attn_cfg is None:
110
+ attn_cfg = {}
111
+ factory_kwargs = {"device": device, "dtype": dtype}
112
+ bidirectional_kwargs = {
113
+ "bidirectional": bidirectional,
114
+ "bidirectional_strategy": bidirectional_strategy,
115
+ "bidirectional_weight_tie": bidirectional_weight_tie,
116
+ }
117
+ if layer_idx not in attn_layer_idx:
118
+ # Create a copy of the config to modify
119
+ ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {}
120
+ ssm_layer = ssm_cfg.pop("ssm_layer", "Mamba1")
121
+ if ssm_layer not in ["Mamba1", "Mamba2"]:
122
+ raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2")
123
+ # mixer_cls = partial(
124
+ # Mamba2 if ssm_layer == "Mamba2" else Mamba,
125
+ # layer_idx=layer_idx,
126
+ # **ssm_cfg,
127
+ # **factory_kwargs
128
+ # )
129
+ mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
130
+ else:
131
+ #ssm_cfg.pop("layer", "Mamba1")
132
+ #TODO add bidirectional support
133
+ mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs)
134
+ norm_cls = partial(
135
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
136
+ )
137
+ if d_intermediate == 0:
138
+ mlp_cls = nn.Identity
139
+ else:
140
+ mlp_cls = partial(
141
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
142
+ )
143
+ block_cls = RCPSMambaBlock if rcps else Block
144
+
145
+ block = block_cls(
146
+ d_model,
147
+ mixer_cls,
148
+ mlp_cls,
149
+ norm_cls=norm_cls,
150
+ fused_add_norm=fused_add_norm,
151
+ residual_in_fp32=residual_in_fp32,
152
+ )
153
+ block.layer_idx = layer_idx
154
+ return block
155
+
156
+
157
+ class BiMambaWrapper(nn.Module):
158
+ """Thin wrapper around Mamba to support bi-directionality."""
159
+
160
+ def __init__(
161
+ self,
162
+ d_model: int,
163
+ bidirectional: bool = True,
164
+ bidirectional_strategy: Optional[str] = "add",
165
+ bidirectional_weight_tie: bool = True,
166
+ ssm_layer = "Mamba2",
167
+ **mamba_kwargs,
168
+ ):
169
+ super().__init__()
170
+ assert ssm_layer in ("Mamba", "Mamba2"), f"{block_name=}"
171
+ if bidirectional and bidirectional_strategy is None:
172
+ bidirectional_strategy = "add" # Default strategy: `add`
173
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
174
+ raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
175
+ self.bidirectional = bidirectional
176
+ self.bidirectional_strategy = bidirectional_strategy
177
+ if ssm_layer == "Mamba":
178
+ block_cls = Mamba
179
+ self.mamba_fwd = block_cls(
180
+ d_model=d_model,
181
+ **mamba_kwargs
182
+ )
183
+ elif ssm_layer == "Mamba2":
184
+ block_cls = Mamba2
185
+ self.mamba_fwd = block_cls(
186
+ d_model=d_model,
187
+ **mamba_kwargs
188
+ )
189
+ else:
190
+ raise ValueError(f"Unrecognized {block_name=}")
191
+ if bidirectional:
192
+ self.mamba_rev = block_cls(
193
+ d_model=d_model,
194
+ **mamba_kwargs
195
+ )
196
+ if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
197
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
198
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
199
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
200
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
201
+ else:
202
+ self.mamba_rev = None
203
+
204
+ def forward(self, hidden_states, inference_params=None):
205
+ """Bidirectional-enabled forward pass
206
+
207
+ hidden_states: (B, L, D)
208
+ Returns: same shape as hidden_states
209
+ """
210
+ out = self.mamba_fwd(hidden_states.contiguous(), inference_params=inference_params)
211
+ if self.bidirectional:
212
+ out_rev = self.mamba_rev(
213
+ hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
214
+ inference_params=inference_params
215
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
216
+ if self.bidirectional_strategy == "add":
217
+ out = out + out_rev
218
+ elif self.bidirectional_strategy == "ew_multiply":
219
+ out = out * out_rev
220
+ else:
221
+ raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
222
+ return out.contiguous()
223
+
224
+
225
+ class CaduceusEmbeddings(nn.Module):
226
+ def __init__(
227
+ self,
228
+ config: CaduceusConfig,
229
+ device=None,
230
+ dtype=None,
231
+ ):
232
+ super().__init__()
233
+ factory_kwargs = {"device": device, "dtype": dtype}
234
+ if config.rcps:
235
+ self.word_embeddings = RCPSEmbedding(
236
+ config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
237
+ )
238
+ else:
239
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
240
+
241
+ def forward(self, input_ids):
242
+ """
243
+ input_ids: (batch, seqlen)
244
+ """
245
+ return self.word_embeddings(input_ids)
246
+
247
+
248
+ class CaduceusMixerModel(nn.Module):
249
+ def __init__(
250
+ self,
251
+ config: CaduceusConfig,
252
+ device=None,
253
+ dtype=None,
254
+ ) -> None:
255
+ super().__init__()
256
+ factory_kwargs = {"device": device, "dtype": dtype}
257
+
258
+ self.fused_add_norm = config.fused_add_norm
259
+ self.rcps = config.rcps
260
+ self.residual_in_fp32 = config.residual_in_fp32
261
+
262
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
263
+
264
+ # Mamba changes the order of residual and layer norm:
265
+ # Instead of LN -> Attn / MLP -> Add, we do:
266
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
267
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
268
+ # This is for performance reason: we can fuse add + layer_norm.
269
+ if config.fused_add_norm:
270
+ if layer_norm_fn is None or rms_norm_fn is None:
271
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
272
+
273
+ self.layers = nn.ModuleList(
274
+ [
275
+ create_block(
276
+ config.d_model,
277
+ d_intermediate=config.d_intermediate,
278
+ ssm_cfg=config.ssm_cfg,
279
+ norm_epsilon=config.norm_epsilon,
280
+ rms_norm=config.rms_norm,
281
+ residual_in_fp32=config.residual_in_fp32,
282
+ fused_add_norm=config.fused_add_norm,
283
+ layer_idx=i,
284
+ bidirectional=config.bidirectional,
285
+ bidirectional_strategy=config.bidirectional_strategy,
286
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
287
+ rcps=config.rcps,
288
+ **factory_kwargs,
289
+ )
290
+ for i in range(config.n_layer)
291
+ ]
292
+ )
293
+
294
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
295
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
296
+ )
297
+ self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)
298
+
299
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
300
+ """Mixer forward."""
301
+ all_hidden_states = []
302
+ if inputs_embeds is not None:
303
+ hidden_states = inputs_embeds
304
+ else:
305
+ hidden_states = self.embeddings(input_ids)
306
+
307
+ residual = None
308
+ for layer in self.layers:
309
+ if output_hidden_states:
310
+ all_hidden_states.append(hidden_states)
311
+ # TODO: Add support for gradient checkpointing
312
+ hidden_states, residual = layer(
313
+ hidden_states, residual, inference_params=None
314
+ )
315
+
316
+ if not self.fused_add_norm:
317
+ if self.rcps:
318
+ # Set prenorm=False here since we don't need the residual
319
+ hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
320
+ else:
321
+ residual = (hidden_states + residual) if residual is not None else hidden_states
322
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
323
+ else:
324
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
325
+ if self.rcps:
326
+ # Set prenorm=False here since we don't need the residual
327
+ hidden_states_fwd = fused_add_norm_fn(
328
+ hidden_states[..., :hidden_states.shape[-1] // 2],
329
+ self.norm_f.weight,
330
+ self.norm_f.bias,
331
+ eps=self.norm_f.eps,
332
+ residual=residual[..., :hidden_states.shape[-1] // 2],
333
+ prenorm=False,
334
+ residual_in_fp32=self.residual_in_fp32,
335
+ )
336
+ hidden_states_rc = fused_add_norm_fn(
337
+ hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
338
+ self.norm_f.weight,
339
+ self.norm_f.bias,
340
+ eps=self.norm_f.eps,
341
+ residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
342
+ prenorm=False,
343
+ residual_in_fp32=self.residual_in_fp32,
344
+ )
345
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
346
+ else:
347
+ # Set prenorm=False here since we don't need the residual
348
+ hidden_states = fused_add_norm_fn(
349
+ hidden_states,
350
+ self.norm_f.weight,
351
+ self.norm_f.bias,
352
+ eps=self.norm_f.eps,
353
+ residual=residual,
354
+ prenorm=False,
355
+ residual_in_fp32=self.residual_in_fp32,
356
+ )
357
+ if output_hidden_states:
358
+ all_hidden_states.append(hidden_states)
359
+ return hidden_states, all_hidden_states
360
+
361
+
362
+ def cross_entropy(logits, y, ignore_index=-100):
363
+ """Cross entropy loss."""
364
+ logits = logits.view(-1, logits.shape[-1])
365
+ y = y.view(-1)
366
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
367
+
368
+
369
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
370
+ """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
371
+ logits = logits.view(-1, logits.shape[-1])
372
+ y = y.view(-1)
373
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
374
+ loss_weights = loss_weights.view(-1)
375
+ loss_weights[y == ignore_index] = 0.0
376
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
377
+ return (ce * (loss_weights / loss_weights.sum())).sum()
378
+
379
+
380
+ class CaduceusPreTrainedModel(PreTrainedModel):
381
+ """PreTrainedModel wrapper for Caduceus backbone."""
382
+ config_class = CaduceusConfig
383
+ base_model_prefix = "caduceus"
384
+ supports_gradient_checkpointing = False
385
+ _no_split_modules = ["BiMambaWrapper"]
386
+
387
+ def _init_weights(
388
+ self,
389
+ module,
390
+ initializer_range=0.02, # Now only used for embedding layer.
391
+ **kwargs,
392
+ ):
393
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
394
+
395
+ n_layer = self.config.n_layer
396
+ initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
397
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
398
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
399
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
400
+
401
+ if isinstance(module, nn.Linear):
402
+ if module.bias is not None:
403
+ if not getattr(module.bias, "_no_reinit", False):
404
+ nn.init.zeros_(module.bias)
405
+ elif isinstance(module, nn.Embedding):
406
+ nn.init.normal_(module.weight, std=initializer_range)
407
+
408
+ if rescale_prenorm_residual:
409
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
410
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
411
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
412
+ # residual layers.
413
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
414
+ #
415
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
416
+ for name, p in module.named_parameters():
417
+ if name in ["out_proj.weight", "fc2.weight"]:
418
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
419
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
420
+ # We need to reinit p since this code could be called multiple times
421
+ # Having just p *= scale would repeatedly scale it down
422
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
423
+ with torch.no_grad():
424
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
425
+
426
+
427
+ class Caduceus(CaduceusPreTrainedModel):
428
+ """Caduceus model that can be instantiated using HF patterns."""
429
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
430
+ super().__init__(config)
431
+
432
+ if config.rcps:
433
+ assert config.complement_map is not None, "Complement map must be provided for RCPS."
434
+
435
+ # Adjust vocab size and complement maps if vocab padding is set.
436
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
437
+ config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
438
+ if config.complement_map is not None and config.vocab_size > len(config.complement_map):
439
+ for i in range(len(config.complement_map), config.vocab_size):
440
+ config.complement_map[i] = i
441
+
442
+ self.config = config
443
+ factory_kwargs = {"device": device, "dtype": dtype}
444
+ self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
445
+
446
+ def forward(
447
+ self,
448
+ input_ids: torch.LongTensor = None,
449
+ inputs_embeds: Optional[torch.FloatTensor] = None,
450
+ output_hidden_states: Optional[bool] = None,
451
+ return_dict: Optional[bool] = None,
452
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
453
+ """HF-compatible forward method."""
454
+ output_hidden_states = (
455
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
456
+ )
457
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
458
+
459
+ hidden_states, all_hidden_states = self.backbone(
460
+ input_ids,
461
+ inputs_embeds=inputs_embeds,
462
+ output_hidden_states=output_hidden_states
463
+ )
464
+ if return_dict:
465
+ return BaseModelOutputWithNoAttention(
466
+ last_hidden_state=hidden_states,
467
+ hidden_states=all_hidden_states if output_hidden_states else None
468
+ )
469
+ elif output_hidden_states:
470
+ return hidden_states, all_hidden_states
471
+ else:
472
+ return hidden_states
473
+
474
+
475
+ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
476
+ """HF-compatible Caduceus model for masked language modeling."""
477
+
478
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
479
+ super().__init__(config, **kwargs)
480
+ factory_kwargs = {"device": device, "dtype": dtype}
481
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
482
+ if config.rcps:
483
+ self.lm_head = RCPSLMHead(
484
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
485
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
486
+ true_dim=config.d_model,
487
+ dtype=dtype
488
+ )
489
+ else:
490
+ self.lm_head = nn.Linear(
491
+ config.d_model,
492
+ self.config.vocab_size, # Use caduceus config as it might have been updated
493
+ bias=False,
494
+ **factory_kwargs
495
+ )
496
+
497
+ # Initialize weights and apply final processing
498
+ self.post_init()
499
+
500
+ def get_input_embeddings(self):
501
+ return self.caduceus.backbone.embeddings.word_embeddings
502
+
503
+ def set_input_embeddings(self, value):
504
+ if self.config.rcps:
505
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
506
+ self.caduceus.backbone.embeddings.word_embeddings = value
507
+
508
+ def get_output_embeddings(self):
509
+ return self.lm_head
510
+
511
+ def set_output_embeddings(self, new_embeddings):
512
+ """Overrides output embeddings."""
513
+ if self.config.rcps:
514
+ raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
515
+ self.lm_head = new_embeddings
516
+
517
+ def tie_weights(self):
518
+ """Tie weights, accounting for RCPS."""
519
+ if self.config.rcps:
520
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
521
+ else:
522
+ super().tie_weights()
523
+
524
+ def get_decoder(self):
525
+ """Get decoder (backbone) for the model."""
526
+ return self.caduceus
527
+
528
+ def set_decoder(self, decoder):
529
+ """Set decoder (backbone) for the model."""
530
+ self.caduceus = decoder
531
+
532
+ def forward(
533
+ self,
534
+ input_ids: torch.LongTensor = None,
535
+ inputs_embeds: Optional[torch.FloatTensor] = None,
536
+ labels: Optional[torch.LongTensor] = None,
537
+ loss_weights: Optional[torch.FloatTensor] = None,
538
+ output_hidden_states: Optional[bool] = None,
539
+ return_dict: Optional[bool] = None,
540
+ ) -> Union[Tuple, MaskedLMOutput]:
541
+ """HF-compatible forward method."""
542
+
543
+ output_hidden_states = (
544
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
545
+ )
546
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
547
+
548
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
549
+ outputs = self.caduceus(
550
+ input_ids=input_ids,
551
+ inputs_embeds=inputs_embeds,
552
+ output_hidden_states=output_hidden_states,
553
+ return_dict=return_dict,
554
+ )
555
+
556
+ hidden_states = outputs[0]
557
+ logits = self.lm_head(hidden_states)
558
+ logits = logits.float()
559
+
560
+ loss = None
561
+ if labels is not None:
562
+ if loss_weights is not None:
563
+ loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
564
+ else:
565
+ loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
566
+
567
+ if not return_dict:
568
+ output = (logits,) + outputs[1:]
569
+ return (loss,) + output if loss is not None else output
570
+
571
+ return MaskedLMOutput(
572
+ loss=loss,
573
+ logits=logits,
574
+ hidden_states=outputs.hidden_states,
575
+ )
576
+
577
+
578
+ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
579
+ def __init__(
580
+ self,
581
+ config: CaduceusConfig,
582
+ pooling_strategy: str = "mean",
583
+ conjoin_train: bool = False,
584
+ conjoin_eval: bool = False,
585
+ device=None,
586
+ dtype=None,
587
+ **kwargs):
588
+ super().__init__(config, **kwargs)
589
+ if pooling_strategy not in ["mean", "max", "first", "last"]:
590
+ raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
591
+ self.pooling_strategy = pooling_strategy
592
+ factory_kwargs = {"device": device, "dtype": dtype}
593
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
594
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
595
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
596
+
597
+ self.conjoin_train = conjoin_train
598
+ self.conjoin_eval = conjoin_eval
599
+
600
+ # Initialize weights and apply final processing
601
+ self.post_init()
602
+ self.init_scorer()
603
+
604
+ def init_scorer(self, initializer_range=0.02):
605
+ initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
606
+ if self.config.initializer_cfg is not None else initializer_range
607
+ self.score.weight.data.normal_(std=initializer_range)
608
+
609
+ def get_input_embeddings(self):
610
+ return self.caduceus.backbone.embeddings.word_embeddings
611
+
612
+ def set_input_embeddings(self, value):
613
+ if self.config.rcps:
614
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
615
+ self.caduceus.backbone.embeddings.word_embeddings = value
616
+
617
+ def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
618
+ """Pools hidden states along sequence length dimension."""
619
+ if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
620
+ return hidden_states.mean(dim=sequence_length_dim)
621
+ if self.pooling_strategy == "max": # Max pooling along sequence length dimension
622
+ return hidden_states.max(dim=sequence_length_dim).values
623
+ if self.pooling_strategy == "last": # Use embedding of last token in the sequence
624
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
625
+ if self.pooling_strategy == "first": # Use embedding of first token in the sequence
626
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
627
+
628
+ def forward(
629
+ self,
630
+ input_ids: torch.LongTensor = None,
631
+ inputs_embeds: Optional[torch.FloatTensor] = None,
632
+ labels: Optional[torch.LongTensor] = None,
633
+ output_hidden_states: Optional[bool] = None,
634
+ return_dict: Optional[bool] = None,
635
+ ) -> Union[Tuple, SequenceClassifierOutput]:
636
+ r"""
637
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
638
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
639
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
640
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
641
+ """
642
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
643
+
644
+ # Get hidden representations from the backbone
645
+ if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
646
+ transformer_outputs = self.caduceus(
647
+ input_ids,
648
+ inputs_embeds=inputs_embeds,
649
+ output_hidden_states=output_hidden_states,
650
+ return_dict=return_dict,
651
+ )
652
+ hidden_states = torch.stack(
653
+ [
654
+ transformer_outputs[0][..., :self.config.d_model],
655
+ torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
656
+ ],
657
+ dim=-1
658
+ )
659
+ elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
660
+ assert input_ids is not None, "`input_ids` must be provided for conjoining."
661
+ assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
662
+ transformer_outputs = self.caduceus(
663
+ input_ids[..., 0],
664
+ inputs_embeds=None,
665
+ output_hidden_states=output_hidden_states,
666
+ return_dict=return_dict,
667
+ )
668
+ transformer_outputs_rc = self.caduceus(
669
+ input_ids[..., 1],
670
+ inputs_embeds=None,
671
+ output_hidden_states=output_hidden_states,
672
+ return_dict=return_dict,
673
+ )
674
+ # Stack along channel dimension (dim=-1)
675
+ hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
676
+ else:
677
+ transformer_outputs = self.caduceus(
678
+ input_ids,
679
+ inputs_embeds=None,
680
+ output_hidden_states=output_hidden_states,
681
+ return_dict=return_dict,
682
+ )
683
+ hidden_states = transformer_outputs[0]
684
+
685
+ # Pool and get logits
686
+ pooled_hidden_states = self.pool_hidden_states(hidden_states)
687
+ # Potentially run `score` twice (with parameters shared) for conjoining
688
+ if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
689
+ logits_fwd = self.score(pooled_hidden_states[..., 0])
690
+ logits_rc = self.score(pooled_hidden_states[..., 1])
691
+ logits = (logits_fwd + logits_rc) / 2
692
+ else:
693
+ logits = self.score(pooled_hidden_states)
694
+
695
+ loss = None
696
+ if labels is not None:
697
+ labels = labels.to(logits.device)
698
+ if self.config.problem_type is None:
699
+ if self.num_labels == 1:
700
+ self.config.problem_type = "regression"
701
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
702
+ self.config.problem_type = "single_label_classification"
703
+ else:
704
+ self.config.problem_type = "multi_label_classification"
705
+
706
+ if self.config.problem_type == "regression":
707
+ if self.num_labels == 1:
708
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
709
+ else:
710
+ loss = F.mse_loss(logits, labels)
711
+ elif self.config.problem_type == "single_label_classification":
712
+ loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
713
+ elif self.config.problem_type == "multi_label_classification":
714
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
715
+ if not return_dict:
716
+ output = (logits,) + transformer_outputs[1:]
717
+ return ((loss,) + output) if loss is not None else output
718
+
719
+ return SequenceClassifierOutput(
720
+ loss=loss,
721
+ logits=logits,
722
+ hidden_states=transformer_outputs.hidden_states,
723
+ )
modeling_rcps.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reverse-complement equivariant modules.
2
+
3
+ """
4
+ from collections import OrderedDict
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ # try:
13
+ # from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
14
+ # except ImportError:
15
+ # RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
16
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
17
+
18
+
19
+ class RCPSEmbedding(nn.Module):
20
+ """Embedding layer that supports reverse-complement equivariance."""
21
+ def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
22
+ """
23
+ Args:
24
+ vocab_size: Size of vocabulary.
25
+ d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
26
+ complement_map: Dictionary mapping each token id to its complement.
27
+ """
28
+ super().__init__()
29
+ self.register_buffer(
30
+ "complement_map",
31
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
32
+ )
33
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
34
+
35
+ @property
36
+ def weight(self):
37
+ """Embedding weights."""
38
+ return self.embedding.weight
39
+
40
+ def set_weight(self, value):
41
+ """Set embedding weights."""
42
+ self.embedding.weight = value
43
+
44
+ def rc(self, x):
45
+ """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
46
+ return torch.gather(
47
+ self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
48
+ dim=1,
49
+ index=torch.flip(x, dims=[-1])
50
+ )
51
+
52
+ def forward(self, input_ids):
53
+ """Reverse-complement equivariant forward pass.
54
+
55
+ This embedding module doubles the output dimensionality to support reverse-complement equivariance.
56
+
57
+ Args:
58
+ input_ids: Input tensor of shape (batch_size, seq_len)
59
+ Returns:
60
+ Embedding tensor of shape (batch_size, seq_len, d_model * 2)
61
+ """
62
+ fwd_out = self.embedding(input_ids)
63
+ rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
64
+
65
+ return torch.cat([fwd_out, rc_out], dim=-1)
66
+
67
+
68
+ class RCPSWrapper(nn.Module):
69
+ """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
70
+
71
+ See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
72
+ Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
73
+ """
74
+ def __init__(self, submodule: nn.Module):
75
+ super().__init__()
76
+ self.submodule = submodule
77
+
78
+ @staticmethod
79
+ def rc(x):
80
+ """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
81
+ return torch.flip(x, dims=[-2, -1])
82
+
83
+ def forward(self, x, **kwargs):
84
+ """Reverse-complement equivariant forward pass.
85
+
86
+ Args:
87
+ x: Input tensor of shape (batch_size, seq_len, channels)
88
+ Returns:
89
+ Output tensor of shape (batch_size, seq_len, channels * 2)
90
+ """
91
+ n_channels = x.shape[-1]
92
+ # Run submodule along sequence
93
+ fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
94
+ # Run submodule along rc-sequence
95
+ rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
96
+ # Concatenate along channel dimension (dim=-1)
97
+ return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
98
+
99
+
100
+ class RCPSAddNormWrapper(RCPSWrapper):
101
+ """RC equivariant AddNorm layer."""
102
+ def __init__(self, submodule: nn.Module):
103
+ super().__init__(submodule)
104
+
105
+ def forward(self, x, residual=None, prenorm=False):
106
+ """
107
+ Args:
108
+ x: Input tensor of shape (batch_size, seq_len, channels)
109
+ residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
110
+ prenorm: Whether to return residual.
111
+ """
112
+ n_channels = x.shape[-1]
113
+ if residual is None:
114
+ residual = x
115
+ x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
116
+ x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
117
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
118
+ else:
119
+ residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
120
+ x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
121
+
122
+ residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
123
+ x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
124
+
125
+ residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
126
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
127
+
128
+ return x if not prenorm else (x, residual)
129
+
130
+
131
+ class RCPSMambaBlock(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim,
135
+ mixer_cls,
136
+ mlp_cls,
137
+ norm_cls=nn.LayerNorm,
138
+ fused_add_norm=False,
139
+ residual_in_fp32=False,
140
+ device=None, # Keep for consistency with original Mamba Block
141
+ dtype=None, # Keep for consistency with original Mamba Block
142
+ ):
143
+ """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
144
+
145
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
146
+ """
147
+ super().__init__()
148
+ self.residual_in_fp32 = residual_in_fp32
149
+ self.fused_add_norm = fused_add_norm
150
+ self.mixer = RCPSWrapper(mixer_cls(dim))
151
+ norm_f = norm_cls(dim)
152
+ self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
153
+ if mlp_cls is not nn.Identity:
154
+ self.norm2 = norm_cls(dim)
155
+ self.mlp = mlp_cls(dim)
156
+ else:
157
+ self.mlp = None
158
+ if self.fused_add_norm:
159
+ assert RMSNorm is not None, "RMSNorm import fails"
160
+ assert isinstance(
161
+ self.norm, (nn.LayerNorm, RMSNorm)
162
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
163
+
164
+ def forward(
165
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
166
+ ):
167
+ r"""Pass the input through the encoder layer.
168
+
169
+ Args:
170
+ hidden_states: the sequence to the encoder layer (required).
171
+ residual: hidden_states = Mixer(LN(residual)).
172
+ inference_params: inference parameters for mixer.
173
+ """
174
+ if not self.fused_add_norm:
175
+ hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
176
+ if self.residual_in_fp32:
177
+ residual = residual.to(torch.float32)
178
+ else:
179
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
180
+
181
+ hidden_states_fwd, residual_fwd = fused_add_norm_fn(
182
+ hidden_states[..., hidden_states.shape[-1] // 2:],
183
+ self.norm.weight,
184
+ self.norm.bias,
185
+ residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
186
+ prenorm=True,
187
+ residual_in_fp32=self.residual_in_fp32,
188
+ eps=self.norm.eps,
189
+ )
190
+
191
+ hidden_states_rc, residual_rc = fused_add_norm_fn(
192
+ hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
193
+ self.norm.weight,
194
+ self.norm.bias,
195
+ residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
196
+ prenorm=True,
197
+ residual_in_fp32=self.residual_in_fp32,
198
+ eps=self.norm.eps,
199
+ )
200
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
201
+ residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
202
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
203
+
204
+ if self.mlp is not None:
205
+ if not self.fused_add_norm:
206
+ residual = hidden_states + residual
207
+ residual = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
208
+ if self.residual_in_fp32:
209
+ residual = residual.to(torch.float32)
210
+ else:
211
+ hidden_states, residual = layer_norm_fn(
212
+ hidden_states,
213
+ self.norm2.weight,
214
+ self.norm2.bias,
215
+ residual=residual,
216
+ prenorm=True,
217
+ residual_in_fp32=self.residual_in_fp32,
218
+ eps=self.norm2.eps,
219
+ is_rms_norm=isinstance(self.norm2, RMSNorm)
220
+ )
221
+ hidden_states = self.mlp(hidden_states)
222
+ return hidden_states, residual
223
+
224
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
225
+ """Allocate inference cache for mixer.
226
+
227
+ Keep for compatibility with original Mamba Block.
228
+ """
229
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
230
+
231
+
232
+ class RCPSLMHead(nn.Module):
233
+ """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
234
+ def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
235
+ """
236
+ `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
237
+ equivariant, i.e. 0.5 times the actual input dim.
238
+ """
239
+ super().__init__()
240
+ self.register_buffer(
241
+ "complement_map",
242
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
243
+ )
244
+ self.true_dim = true_dim
245
+ self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
246
+
247
+ @property
248
+ def weight(self):
249
+ """LM head weights."""
250
+ return self.lm_head.weight
251
+
252
+ def set_weight(self, value):
253
+ """Set LM head weights."""
254
+ self.lm_head.weight = value
255
+
256
+ def forward(self, x):
257
+ """
258
+ Args:
259
+ x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
260
+ """
261
+ n_channels = x.shape[-1]
262
+ assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
263
+ fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
264
+ rc_logits = F.linear(
265
+ torch.flip(x[..., n_channels // 2:], dims=[-1]),
266
+ self.weight[self.complement_map, :],
267
+ bias=self.lm_head.bias
268
+ )
269
+ return fwd_logits + rc_logits