gnikesh commited on
Commit
18c9a0d
·
verified ·
1 Parent(s): 312c181

Final trained model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json CHANGED
@@ -1,12 +1,12 @@
1
  {
2
- "_name_or_path": "kuleshov-group/PlantCaduceus_l20",
3
  "architectures": [
4
  "CaduceusForMaskedLM"
5
  ],
6
  "auto_map": {
7
- "AutoConfig": "yairschiff/caduceus_base--configuration_caduceus.CaduceusConfig",
8
  "AutoModel": "yairschiff/caduceus_base--modeling_caduceus.Caduceus",
9
- "AutoModelForMaskedLM": "yairschiff/caduceus_base--modeling_caduceus.CaduceusForMaskedLM",
10
  "AutoModelForSequenceClassification": "yairschiff/caduceus_base--modeling_caduceus.CaduceusForSequenceClassification"
11
  },
12
  "bidirectional": true,
 
1
  {
2
+ "_name_or_path": "/home/g/gnikesh/projects/PlantCaduceus/FunDLM-20M_WLCF_M0_512bp",
3
  "architectures": [
4
  "CaduceusForMaskedLM"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "configuration_caduceus.CaduceusConfig",
8
  "AutoModel": "yairschiff/caduceus_base--modeling_caduceus.Caduceus",
9
+ "AutoModelForMaskedLM": "modeling_caduceus.CaduceusForMaskedLM",
10
  "AutoModelForSequenceClassification": "yairschiff/caduceus_base--modeling_caduceus.CaduceusForSequenceClassification"
11
  },
12
  "bidirectional": true,
configuration_caduceus.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class CaduceusConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
12
+ model_type = "caduceus"
13
+
14
+ def __init__(
15
+ self,
16
+ # From original MambaConfig
17
+ d_model: int = 2560,
18
+ n_layer: int = 64,
19
+ vocab_size: int = 50277,
20
+ ssm_cfg: Optional[dict] = None,
21
+ rms_norm: bool = True,
22
+ residual_in_fp32: bool = True,
23
+ fused_add_norm: bool = True,
24
+ pad_vocab_size_multiple: int = 8,
25
+
26
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
27
+ norm_epsilon: float = 1e-5,
28
+
29
+ # Used in init_weights
30
+ initializer_cfg: Optional[dict] = None,
31
+
32
+ # Caduceus-specific params
33
+ bidirectional: bool = True,
34
+ bidirectional_strategy: Union[str, None] = "add",
35
+ bidirectional_weight_tie: bool = True,
36
+ rcps: bool = False,
37
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
38
+ **kwargs,
39
+ ):
40
+ super().__init__(**kwargs)
41
+ self.d_model = d_model
42
+ self.n_layer = n_layer
43
+ self.vocab_size = vocab_size
44
+ self.ssm_cfg = ssm_cfg
45
+ self.rms_norm = rms_norm
46
+ self.residual_in_fp32 = residual_in_fp32
47
+ self.fused_add_norm = fused_add_norm
48
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
49
+ self.norm_epsilon = norm_epsilon
50
+ self.initializer_cfg = initializer_cfg
51
+ self.bidirectional = bidirectional
52
+ self.bidirectional_strategy = bidirectional_strategy
53
+ self.bidirectional_weight_tie = bidirectional_weight_tie
54
+ self.rcps = rcps
55
+ self.complement_map = complement_map
modeling_caduceus.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus model for Hugging Face.
2
+
3
+ """
4
+
5
+ import inspect
6
+ import math
7
+ from functools import partial
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ from mamba_ssm.modules.mamba_simple import Mamba
12
+ try:
13
+ from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
14
+ except ImportError:
15
+ from mamba_ssm.modules.block import Block # mambav2 file structure
16
+ from torch import nn
17
+ from torch.nn import functional as F
18
+ from transformers import PreTrainedModel
19
+ from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
20
+
21
+ try:
22
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
23
+ except ImportError:
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
+
29
+ from .configuration_caduceus import CaduceusConfig
30
+ from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
31
+
32
+
33
+ def create_block(
34
+ d_model,
35
+ ssm_cfg=None,
36
+ norm_epsilon=1e-5,
37
+ rms_norm=False,
38
+ residual_in_fp32=False,
39
+ fused_add_norm=False,
40
+ layer_idx=None,
41
+ bidirectional=True,
42
+ bidirectional_strategy="add",
43
+ bidirectional_weight_tie=True,
44
+ rcps=False,
45
+ device=None,
46
+ dtype=None,
47
+ ):
48
+ """Create Caduceus block.
49
+
50
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
51
+ """
52
+ if ssm_cfg is None:
53
+ ssm_cfg = {}
54
+ factory_kwargs = {"device": device, "dtype": dtype}
55
+ bidirectional_kwargs = {
56
+ "bidirectional": bidirectional,
57
+ "bidirectional_strategy": bidirectional_strategy,
58
+ "bidirectional_weight_tie": bidirectional_weight_tie,
59
+ }
60
+ mixer_cls = partial(BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs)
61
+ norm_cls = partial(
62
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
63
+ )
64
+ block_cls = RCPSMambaBlock if rcps else Block
65
+ # mambav2 compatibility
66
+ if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
67
+ block = block_cls(
68
+ d_model,
69
+ mixer_cls,
70
+ mlp_cls=nn.Identity,
71
+ norm_cls=norm_cls,
72
+ fused_add_norm=fused_add_norm,
73
+ residual_in_fp32=residual_in_fp32,
74
+ )
75
+ else:
76
+ block = block_cls(
77
+ d_model,
78
+ mixer_cls,
79
+ norm_cls=norm_cls,
80
+ fused_add_norm=fused_add_norm,
81
+ residual_in_fp32=residual_in_fp32,
82
+ )
83
+ block.layer_idx = layer_idx
84
+ return block
85
+
86
+
87
+ class BiMambaWrapper(nn.Module):
88
+ """Thin wrapper around Mamba to support bi-directionality."""
89
+
90
+ def __init__(
91
+ self,
92
+ d_model: int,
93
+ bidirectional: bool = True,
94
+ bidirectional_strategy: Optional[str] = "add",
95
+ bidirectional_weight_tie: bool = True,
96
+ **mamba_kwargs,
97
+ ):
98
+ super().__init__()
99
+ if bidirectional and bidirectional_strategy is None:
100
+ bidirectional_strategy = "add" # Default strategy: `add`
101
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
102
+ raise NotImplementedError(f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!")
103
+ self.bidirectional = bidirectional
104
+ self.bidirectional_strategy = bidirectional_strategy
105
+ self.mamba_fwd = Mamba(
106
+ d_model=d_model,
107
+ **mamba_kwargs
108
+ )
109
+ if bidirectional:
110
+ self.mamba_rev = Mamba(
111
+ d_model=d_model,
112
+ **mamba_kwargs
113
+ )
114
+ if bidirectional_weight_tie: # Tie in and out projections (where most of param count lies)
115
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
116
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
117
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
118
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
119
+ else:
120
+ self.mamba_rev = None
121
+
122
+ def forward(self, hidden_states, inference_params=None):
123
+ """Bidirectional-enabled forward pass
124
+
125
+ hidden_states: (B, L, D)
126
+ Returns: same shape as hidden_states
127
+ """
128
+ out = self.mamba_fwd(hidden_states, inference_params=inference_params)
129
+ if self.bidirectional:
130
+ out_rev = self.mamba_rev(
131
+ hidden_states.flip(dims=(1,)), # Flip along the sequence length dimension
132
+ inference_params=inference_params
133
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
134
+ if self.bidirectional_strategy == "add":
135
+ out = out + out_rev
136
+ elif self.bidirectional_strategy == "ew_multiply":
137
+ out = out * out_rev
138
+ else:
139
+ raise NotImplementedError(f"`{self.bidirectional_strategy}` for bi-directionality not implemented!")
140
+ return out
141
+
142
+
143
+ class CaduceusEmbeddings(nn.Module):
144
+ def __init__(
145
+ self,
146
+ config: CaduceusConfig,
147
+ device=None,
148
+ dtype=None,
149
+ ):
150
+ super().__init__()
151
+ factory_kwargs = {"device": device, "dtype": dtype}
152
+ if config.rcps:
153
+ self.word_embeddings = RCPSEmbedding(
154
+ config.vocab_size, config.d_model, config.complement_map, **factory_kwargs
155
+ )
156
+ else:
157
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, **factory_kwargs)
158
+
159
+ def forward(self, input_ids):
160
+ """
161
+ input_ids: (batch, seqlen)
162
+ """
163
+ return self.word_embeddings(input_ids)
164
+
165
+
166
+ class CaduceusMixerModel(nn.Module):
167
+ def __init__(
168
+ self,
169
+ config: CaduceusConfig,
170
+ device=None,
171
+ dtype=None,
172
+ ) -> None:
173
+ super().__init__()
174
+ factory_kwargs = {"device": device, "dtype": dtype}
175
+
176
+ self.fused_add_norm = config.fused_add_norm
177
+ self.rcps = config.rcps
178
+ self.residual_in_fp32 = config.residual_in_fp32
179
+
180
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
181
+
182
+ # Mamba changes the order of residual and layer norm:
183
+ # Instead of LN -> Attn / MLP -> Add, we do:
184
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
185
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
186
+ # This is for performance reason: we can fuse add + layer_norm.
187
+ if config.fused_add_norm:
188
+ if layer_norm_fn is None or rms_norm_fn is None:
189
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
190
+
191
+ self.layers = nn.ModuleList(
192
+ [
193
+ create_block(
194
+ config.d_model,
195
+ ssm_cfg=config.ssm_cfg,
196
+ norm_epsilon=config.norm_epsilon,
197
+ rms_norm=config.rms_norm,
198
+ residual_in_fp32=config.residual_in_fp32,
199
+ fused_add_norm=config.fused_add_norm,
200
+ layer_idx=i,
201
+ bidirectional=config.bidirectional,
202
+ bidirectional_strategy=config.bidirectional_strategy,
203
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
204
+ rcps=config.rcps,
205
+ **factory_kwargs,
206
+ )
207
+ for i in range(config.n_layer)
208
+ ]
209
+ )
210
+
211
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
212
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
213
+ )
214
+ self.norm_f = norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f)
215
+
216
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
217
+ """Mixer forward."""
218
+ all_hidden_states = []
219
+ if inputs_embeds is not None:
220
+ hidden_states = inputs_embeds
221
+ else:
222
+ hidden_states = self.embeddings(input_ids)
223
+
224
+ residual = None
225
+ for layer in self.layers:
226
+ if output_hidden_states:
227
+ all_hidden_states.append(hidden_states)
228
+ # TODO: Add support for gradient checkpointing
229
+ hidden_states, residual = layer(
230
+ hidden_states, residual, inference_params=None
231
+ )
232
+
233
+ if not self.fused_add_norm:
234
+ if self.rcps:
235
+ # Set prenorm=False here since we don't need the residual
236
+ hidden_states = self.norm_f(hidden_states, residual=residual, prenorm=False)
237
+ else:
238
+ residual = (hidden_states + residual) if residual is not None else hidden_states
239
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
240
+ else:
241
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
242
+ if self.rcps:
243
+ # Set prenorm=False here since we don't need the residual
244
+ hidden_states_fwd = fused_add_norm_fn(
245
+ hidden_states[..., :hidden_states.shape[-1] // 2],
246
+ self.norm_f.weight,
247
+ self.norm_f.bias,
248
+ eps=self.norm_f.eps,
249
+ residual=residual[..., :hidden_states.shape[-1] // 2],
250
+ prenorm=False,
251
+ residual_in_fp32=self.residual_in_fp32,
252
+ )
253
+ hidden_states_rc = fused_add_norm_fn(
254
+ hidden_states[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
255
+ self.norm_f.weight,
256
+ self.norm_f.bias,
257
+ eps=self.norm_f.eps,
258
+ residual=residual[..., hidden_states.shape[-1] // 2:].flip(dims=[-2, -1]),
259
+ prenorm=False,
260
+ residual_in_fp32=self.residual_in_fp32,
261
+ )
262
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
263
+ else:
264
+ # Set prenorm=False here since we don't need the residual
265
+ hidden_states = fused_add_norm_fn(
266
+ hidden_states,
267
+ self.norm_f.weight,
268
+ self.norm_f.bias,
269
+ eps=self.norm_f.eps,
270
+ residual=residual,
271
+ prenorm=False,
272
+ residual_in_fp32=self.residual_in_fp32,
273
+ )
274
+ if output_hidden_states:
275
+ all_hidden_states.append(hidden_states)
276
+ return hidden_states, all_hidden_states
277
+
278
+
279
+ def cross_entropy(logits, y, ignore_index=-100):
280
+ """Cross entropy loss."""
281
+ logits = logits.view(-1, logits.shape[-1])
282
+ y = y.view(-1)
283
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
284
+
285
+
286
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
287
+ """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
288
+ logits = logits.view(-1, logits.shape[-1])
289
+ y = y.view(-1)
290
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
291
+ loss_weights = loss_weights.view(-1)
292
+ loss_weights[y == ignore_index] = 0.0
293
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
294
+ return (ce * (loss_weights / loss_weights.sum())).sum()
295
+
296
+
297
+ class CaduceusPreTrainedModel(PreTrainedModel):
298
+ """PreTrainedModel wrapper for Caduceus backbone."""
299
+ config_class = CaduceusConfig
300
+ base_model_prefix = "caduceus"
301
+ supports_gradient_checkpointing = False
302
+ _no_split_modules = ["BiMambaWrapper"]
303
+
304
+ def _init_weights(
305
+ self,
306
+ module,
307
+ initializer_range=0.02, # Now only used for embedding layer.
308
+ **kwargs,
309
+ ):
310
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
311
+
312
+ n_layer = self.config.n_layer
313
+ initialized_cfg = self.config.initializer_cfg if self.config.initializer_cfg is not None else {}
314
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
315
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
316
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
317
+
318
+ if isinstance(module, nn.Linear):
319
+ if module.bias is not None:
320
+ if not getattr(module.bias, "_no_reinit", False):
321
+ nn.init.zeros_(module.bias)
322
+ elif isinstance(module, nn.Embedding):
323
+ nn.init.normal_(module.weight, std=initializer_range)
324
+
325
+ if rescale_prenorm_residual:
326
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
327
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
328
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
329
+ # residual layers.
330
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
331
+ #
332
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
333
+ for name, p in module.named_parameters():
334
+ if name in ["out_proj.weight", "fc2.weight"]:
335
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
336
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
337
+ # We need to reinit p since this code could be called multiple times
338
+ # Having just p *= scale would repeatedly scale it down
339
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
340
+ with torch.no_grad():
341
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
342
+
343
+
344
+ class Caduceus(CaduceusPreTrainedModel):
345
+ """Caduceus model that can be instantiated using HF patterns."""
346
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
347
+ super().__init__(config)
348
+
349
+ if config.rcps:
350
+ assert config.complement_map is not None, "Complement map must be provided for RCPS."
351
+
352
+ # Adjust vocab size and complement maps if vocab padding is set.
353
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
354
+ config.vocab_size += config.pad_vocab_size_multiple - (config.vocab_size % config.pad_vocab_size_multiple)
355
+ if config.complement_map is not None and config.vocab_size > len(config.complement_map):
356
+ for i in range(len(config.complement_map), config.vocab_size):
357
+ config.complement_map[i] = i
358
+
359
+ self.config = config
360
+ factory_kwargs = {"device": device, "dtype": dtype}
361
+ self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
362
+
363
+ def maybe_weight_tie_mamba(self):
364
+ if getattr(self.config, 'bidirectional', False) and getattr(self.config, 'bidirectional_weight_tie', False):
365
+ if getattr(self.config, 'rcps', False):
366
+ for layer in self.backbone.layers:
367
+ layer.mixer.submodule.mamba_rev.in_proj.weight = layer.mixer.submodule.mamba_fwd.in_proj.weight
368
+ layer.mixer.submodule.mamba_rev.in_proj.bias = layer.mixer.submodule.mamba_fwd.in_proj.bias
369
+ layer.mixer.submodule.mamba_rev.out_proj.weight = layer.mixer.submodule.mamba_fwd.out_proj.weight
370
+ layer.mixer.submodule.mamba_rev.out_proj.bias = layer.mixer.submodule.mamba_fwd.out_proj.bias
371
+ else:
372
+ for layer in self.backbone.layers:
373
+ layer.mixer.mamba_rev.in_proj.weight = layer.mixer.mamba_fwd.in_proj.weight
374
+ layer.mixer.mamba_rev.in_proj.bias = layer.mixer.mamba_fwd.in_proj.bias
375
+ layer.mixer.mamba_rev.out_proj.weight = layer.mixer.mamba_fwd.out_proj.weight
376
+ layer.mixer.mamba_rev.out_proj.bias = layer.mixer.mamba_fwd.out_proj.bias
377
+
378
+ def tie_weights(self):
379
+ self.maybe_weight_tie_mamba()
380
+
381
+ def forward(
382
+ self,
383
+ input_ids: torch.LongTensor = None,
384
+ inputs_embeds: Optional[torch.FloatTensor] = None,
385
+ output_hidden_states: Optional[bool] = None,
386
+ return_dict: Optional[bool] = None,
387
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
388
+ """HF-compatible forward method."""
389
+ output_hidden_states = (
390
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
391
+ )
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+
394
+ hidden_states, all_hidden_states = self.backbone(
395
+ input_ids,
396
+ inputs_embeds=inputs_embeds,
397
+ output_hidden_states=output_hidden_states
398
+ )
399
+ if return_dict:
400
+ return BaseModelOutputWithNoAttention(
401
+ last_hidden_state=hidden_states,
402
+ hidden_states=all_hidden_states if output_hidden_states else None
403
+ )
404
+ elif output_hidden_states:
405
+ return hidden_states, all_hidden_states
406
+ else:
407
+ return hidden_states
408
+
409
+
410
+ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
411
+ """HF-compatible Caduceus model for masked language modeling."""
412
+
413
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
414
+ super().__init__(config, **kwargs)
415
+ factory_kwargs = {"device": device, "dtype": dtype}
416
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
417
+ if config.rcps:
418
+ self.lm_head = RCPSLMHead(
419
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
420
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
421
+ true_dim=config.d_model,
422
+ dtype=dtype
423
+ )
424
+ else:
425
+ self.lm_head = nn.Linear(
426
+ config.d_model,
427
+ self.config.vocab_size, # Use caduceus config as it might have been updated
428
+ bias=False,
429
+ **factory_kwargs
430
+ )
431
+
432
+ # Initialize weights and apply final processing
433
+ self.post_init()
434
+
435
+ def get_input_embeddings(self):
436
+ return self.caduceus.backbone.embeddings.word_embeddings
437
+
438
+ def set_input_embeddings(self, value):
439
+ if self.config.rcps:
440
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
441
+ self.caduceus.backbone.embeddings.word_embeddings = value
442
+
443
+ def get_output_embeddings(self):
444
+ return self.lm_head
445
+
446
+ def set_output_embeddings(self, new_embeddings):
447
+ """Overrides output embeddings."""
448
+ if self.config.rcps:
449
+ raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
450
+ self.lm_head = new_embeddings
451
+
452
+ def maybe_weight_tie_mamba(self):
453
+ self.caduceus.maybe_weight_tie_mamba()
454
+
455
+ def tie_weights(self):
456
+ """Tie weights, accounting for RCPS."""
457
+ self.maybe_weight_tie_mamba()
458
+ if self.config.rcps:
459
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
460
+ else:
461
+ super().tie_weights()
462
+
463
+ def get_decoder(self):
464
+ """Get decoder (backbone) for the model."""
465
+ return self.caduceus
466
+
467
+ def set_decoder(self, decoder):
468
+ """Set decoder (backbone) for the model."""
469
+ self.caduceus = decoder
470
+
471
+ def forward(
472
+ self,
473
+ input_ids: torch.LongTensor = None,
474
+ inputs_embeds: Optional[torch.FloatTensor] = None,
475
+ labels: Optional[torch.LongTensor] = None,
476
+ loss_weights: Optional[torch.FloatTensor] = None,
477
+ output_hidden_states: Optional[bool] = None,
478
+ return_dict: Optional[bool] = None,
479
+ ) -> Union[Tuple, MaskedLMOutput]:
480
+ """HF-compatible forward method."""
481
+
482
+ output_hidden_states = (
483
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
484
+ )
485
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486
+
487
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
488
+ outputs = self.caduceus(
489
+ input_ids=input_ids,
490
+ inputs_embeds=inputs_embeds,
491
+ output_hidden_states=output_hidden_states,
492
+ return_dict=return_dict,
493
+ )
494
+
495
+ hidden_states = outputs[0]
496
+ logits = self.lm_head(hidden_states)
497
+ logits = logits.float()
498
+
499
+ loss = None
500
+ if labels is not None:
501
+ if loss_weights is not None:
502
+ loss = weighted_cross_entropy(logits, labels, loss_weights, ignore_index=self.config.pad_token_id)
503
+ else:
504
+ loss = cross_entropy(logits, labels, ignore_index=self.config.pad_token_id)
505
+
506
+ if not return_dict:
507
+ output = (logits,) + outputs[1:]
508
+ return (loss,) + output if loss is not None else output
509
+
510
+ return MaskedLMOutput(
511
+ loss=loss,
512
+ logits=logits,
513
+ hidden_states=outputs.hidden_states,
514
+ )
515
+
516
+
517
+ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
518
+ def __init__(
519
+ self,
520
+ config: CaduceusConfig,
521
+ pooling_strategy: str = "mean",
522
+ conjoin_train: bool = False,
523
+ conjoin_eval: bool = False,
524
+ device=None,
525
+ dtype=None,
526
+ **kwargs):
527
+ super().__init__(config, **kwargs)
528
+ if pooling_strategy not in ["mean", "max", "first", "last"]:
529
+ raise NotImplementedError(f"Pooling strategy `{pooling_strategy}` not implemented.")
530
+ self.pooling_strategy = pooling_strategy
531
+ factory_kwargs = {"device": device, "dtype": dtype}
532
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
533
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
534
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
535
+
536
+ self.conjoin_train = conjoin_train
537
+ self.conjoin_eval = conjoin_eval
538
+
539
+ # Initialize weights and apply final processing
540
+ self.post_init()
541
+
542
+ def get_input_embeddings(self):
543
+ return self.caduceus.backbone.embeddings.word_embeddings
544
+
545
+ def set_input_embeddings(self, value):
546
+ if self.config.rcps:
547
+ raise NotImplementedError("Setting input embeddings for RCPS LM is not supported.")
548
+ self.caduceus.backbone.embeddings.word_embeddings = value
549
+
550
+ def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
551
+ """Pools hidden states along sequence length dimension."""
552
+ if self.pooling_strategy == "mean": # Mean pooling along sequence length dimension
553
+ return hidden_states.mean(dim=sequence_length_dim)
554
+ if self.pooling_strategy == "max": # Max pooling along sequence length dimension
555
+ return hidden_states.max(dim=sequence_length_dim).values
556
+ if self.pooling_strategy == "last": # Use embedding of last token in the sequence
557
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[-1, ...]
558
+ if self.pooling_strategy == "first": # Use embedding of first token in the sequence
559
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
560
+
561
+ def maybe_weight_tie_mamba(self):
562
+ self.caduceus.maybe_weight_tie_mamba()
563
+
564
+ def tie_weights(self):
565
+ self.maybe_weight_tie_mamba()
566
+ super().tie_weights()
567
+
568
+ def forward(
569
+ self,
570
+ input_ids: torch.LongTensor = None,
571
+ inputs_embeds: Optional[torch.FloatTensor] = None,
572
+ labels: Optional[torch.LongTensor] = None,
573
+ output_hidden_states: Optional[bool] = None,
574
+ return_dict: Optional[bool] = None,
575
+ **kwargs,
576
+ ) -> Union[Tuple, SequenceClassifierOutput]:
577
+ r"""
578
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
579
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
580
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
581
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
582
+ """
583
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
584
+
585
+ # Get hidden representations from the backbone
586
+ if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
587
+ transformer_outputs = self.caduceus(
588
+ input_ids,
589
+ inputs_embeds=inputs_embeds,
590
+ output_hidden_states=output_hidden_states,
591
+ return_dict=return_dict,
592
+ )
593
+ hidden_states = torch.stack(
594
+ [
595
+ transformer_outputs[0][..., :self.config.d_model],
596
+ torch.flip(transformer_outputs[0][..., self.config.d_model:], dims=[1, 2])
597
+ ],
598
+ dim=-1
599
+ )
600
+ elif self.conjoin_train or (self.conjoin_eval and not self.training): # For conjoining / post-hoc conjoining
601
+ assert input_ids is not None, "`input_ids` must be provided for conjoining."
602
+ assert input_ids.ndim == 3, "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
603
+ transformer_outputs = self.caduceus(
604
+ input_ids[..., 0],
605
+ inputs_embeds=None,
606
+ output_hidden_states=output_hidden_states,
607
+ return_dict=return_dict,
608
+ )
609
+ transformer_outputs_rc = self.caduceus(
610
+ input_ids[..., 1],
611
+ inputs_embeds=None,
612
+ output_hidden_states=output_hidden_states,
613
+ return_dict=return_dict,
614
+ )
615
+ # Stack along channel dimension (dim=-1)
616
+ hidden_states = torch.stack([transformer_outputs[0], transformer_outputs_rc[0]], dim=-1)
617
+ else:
618
+ transformer_outputs = self.caduceus(
619
+ input_ids,
620
+ inputs_embeds=None,
621
+ output_hidden_states=output_hidden_states,
622
+ return_dict=return_dict,
623
+ )
624
+ hidden_states = transformer_outputs[0]
625
+
626
+ # Pool and get logits
627
+ pooled_hidden_states = self.pool_hidden_states(hidden_states)
628
+ # Potentially run `score` twice (with parameters shared) for conjoining
629
+ if hidden_states.ndim == 4: # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
630
+ logits_fwd = self.score(pooled_hidden_states[..., 0])
631
+ logits_rc = self.score(pooled_hidden_states[..., 1])
632
+ logits = (logits_fwd + logits_rc) / 2
633
+ else:
634
+ logits = self.score(pooled_hidden_states)
635
+
636
+ loss = None
637
+ if labels is not None:
638
+ labels = labels.to(logits.device)
639
+ if self.config.problem_type is None:
640
+ if self.num_labels == 1:
641
+ self.config.problem_type = "regression"
642
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
643
+ self.config.problem_type = "single_label_classification"
644
+ else:
645
+ self.config.problem_type = "multi_label_classification"
646
+
647
+ if self.config.problem_type == "regression":
648
+ if self.num_labels == 1:
649
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
650
+ else:
651
+ loss = F.mse_loss(logits, labels)
652
+ elif self.config.problem_type == "single_label_classification":
653
+ loss = F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
654
+ elif self.config.problem_type == "multi_label_classification":
655
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
656
+ if not return_dict:
657
+ output = (logits,) + transformer_outputs[1:]
658
+ return ((loss,) + output) if loss is not None else output
659
+
660
+ return SequenceClassifierOutput(
661
+ loss=loss,
662
+ logits=logits,
663
+ hidden_states=transformer_outputs.hidden_states,
664
+ )
modeling_rcps.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reverse-complement equivariant modules.
2
+
3
+ """
4
+ from collections import OrderedDict
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ try:
13
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
14
+ except ImportError:
15
+ try:
16
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
17
+ except ImportError:
18
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
19
+
20
+
21
+ class RCPSEmbedding(nn.Module):
22
+ """Embedding layer that supports reverse-complement equivariance."""
23
+ def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
24
+ """
25
+ Args:
26
+ vocab_size: Size of vocabulary.
27
+ d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
28
+ complement_map: Dictionary mapping each token id to its complement.
29
+ """
30
+ super().__init__()
31
+ self.register_buffer(
32
+ "complement_map",
33
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
34
+ )
35
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
36
+
37
+ @property
38
+ def weight(self):
39
+ """Embedding weights."""
40
+ return self.embedding.weight
41
+
42
+ def set_weight(self, value):
43
+ """Set embedding weights."""
44
+ self.embedding.weight = value
45
+
46
+ def rc(self, x):
47
+ """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
48
+ return torch.gather(
49
+ self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
50
+ dim=1,
51
+ index=torch.flip(x, dims=[-1])
52
+ )
53
+
54
+ def forward(self, input_ids):
55
+ """Reverse-complement equivariant forward pass.
56
+
57
+ This embedding module doubles the output dimensionality to support reverse-complement equivariance.
58
+
59
+ Args:
60
+ input_ids: Input tensor of shape (batch_size, seq_len)
61
+ Returns:
62
+ Embedding tensor of shape (batch_size, seq_len, d_model * 2)
63
+ """
64
+ fwd_out = self.embedding(input_ids)
65
+ rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
66
+
67
+ return torch.cat([fwd_out, rc_out], dim=-1)
68
+
69
+
70
+ class RCPSWrapper(nn.Module):
71
+ """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
72
+
73
+ See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
74
+ Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
75
+ """
76
+ def __init__(self, submodule: nn.Module):
77
+ super().__init__()
78
+ self.submodule = submodule
79
+
80
+ @staticmethod
81
+ def rc(x):
82
+ """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
83
+ return torch.flip(x, dims=[-2, -1])
84
+
85
+ def forward(self, x, **kwargs):
86
+ """Reverse-complement equivariant forward pass.
87
+
88
+ Args:
89
+ x: Input tensor of shape (batch_size, seq_len, channels)
90
+ Returns:
91
+ Output tensor of shape (batch_size, seq_len, channels * 2)
92
+ """
93
+ n_channels = x.shape[-1]
94
+ # Run submodule along sequence
95
+ fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
96
+ # Run submodule along rc-sequence
97
+ rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
98
+ # Concatenate along channel dimension (dim=-1)
99
+ return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
100
+
101
+
102
+ class RCPSAddNormWrapper(RCPSWrapper):
103
+ """RC equivariant AddNorm layer."""
104
+ def __init__(self, submodule: nn.Module):
105
+ super().__init__(submodule)
106
+
107
+ def forward(self, x, residual=None, prenorm=False):
108
+ """
109
+ Args:
110
+ x: Input tensor of shape (batch_size, seq_len, channels)
111
+ residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
112
+ prenorm: Whether to return residual.
113
+ """
114
+ n_channels = x.shape[-1]
115
+ if residual is None:
116
+ residual = x
117
+ x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
118
+ x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
119
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
120
+ else:
121
+ residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
122
+ x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
123
+
124
+ residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
125
+ x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
126
+
127
+ residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
128
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
129
+
130
+ return x if not prenorm else (x, residual)
131
+
132
+
133
+ class RCPSMambaBlock(nn.Module):
134
+ def __init__(
135
+ self,
136
+ dim,
137
+ mixer_cls,
138
+ norm_cls=nn.LayerNorm,
139
+ fused_add_norm=False,
140
+ residual_in_fp32=False,
141
+ device=None, # Keep for consistency with original Mamba Block
142
+ dtype=None, # Keep for consistency with original Mamba Block
143
+ ):
144
+ """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
145
+
146
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
147
+ """
148
+ super().__init__()
149
+ self.residual_in_fp32 = residual_in_fp32
150
+ self.fused_add_norm = fused_add_norm
151
+ self.mixer = RCPSWrapper(mixer_cls(dim))
152
+ norm_f = norm_cls(dim)
153
+ self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
154
+ if self.fused_add_norm:
155
+ assert RMSNorm is not None, "RMSNorm import fails"
156
+ assert isinstance(
157
+ self.norm, (nn.LayerNorm, RMSNorm)
158
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
159
+
160
+ def forward(
161
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
162
+ ):
163
+ r"""Pass the input through the encoder layer.
164
+
165
+ Args:
166
+ hidden_states: the sequence to the encoder layer (required).
167
+ residual: hidden_states = Mixer(LN(residual)).
168
+ inference_params: inference parameters for mixer.
169
+ """
170
+ if not self.fused_add_norm:
171
+ hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
172
+ if self.residual_in_fp32:
173
+ residual = residual.to(torch.float32)
174
+ else:
175
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
176
+
177
+ hidden_states_fwd, residual_fwd = fused_add_norm_fn(
178
+ hidden_states[..., hidden_states.shape[-1] // 2:],
179
+ self.norm.weight,
180
+ self.norm.bias,
181
+ residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
182
+ prenorm=True,
183
+ residual_in_fp32=self.residual_in_fp32,
184
+ eps=self.norm.eps,
185
+ )
186
+
187
+ hidden_states_rc, residual_rc = fused_add_norm_fn(
188
+ hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
189
+ self.norm.weight,
190
+ self.norm.bias,
191
+ residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
192
+ prenorm=True,
193
+ residual_in_fp32=self.residual_in_fp32,
194
+ eps=self.norm.eps,
195
+ )
196
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
197
+ residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
198
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
199
+ return hidden_states, residual
200
+
201
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
202
+ """Allocate inference cache for mixer.
203
+
204
+ Keep for compatibility with original Mamba Block.
205
+ """
206
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
207
+
208
+
209
+ class RCPSLMHead(nn.Module):
210
+ """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
211
+ def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
212
+ """
213
+ `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
214
+ equivariant, i.e. 0.5 times the actual input dim.
215
+ """
216
+ super().__init__()
217
+ self.register_buffer(
218
+ "complement_map",
219
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
220
+ )
221
+ self.true_dim = true_dim
222
+ self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
223
+
224
+ @property
225
+ def weight(self):
226
+ """LM head weights."""
227
+ return self.lm_head.weight
228
+
229
+ def set_weight(self, value):
230
+ """Set LM head weights."""
231
+ self.lm_head.weight = value
232
+
233
+ def forward(self, x):
234
+ """
235
+ Args:
236
+ x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
237
+ """
238
+ n_channels = x.shape[-1]
239
+ assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
240
+ fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
241
+ rc_logits = F.linear(
242
+ torch.flip(x[..., n_channels // 2:], dims=[-1]),
243
+ self.weight[self.complement_map, :],
244
+ bias=self.lm_head.bias
245
+ )
246
+ return fwd_logits + rc_logits
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:88157621e26dbcd2fc152c7aefb798cb9b1593f8e904836bfadf06888a3d2b44
3
- size 83610666
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:885e2831e9f268f01fb7f075333629c0b0e973990cf55e868dcbb1155d12e328
3
+ size 83605290