OliBomby commited on
Commit
9bbf03d
·
verified ·
1 Parent(s): b99f9cc

Add CM3P model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +185 -0
  3. configuration_cm3p.py +323 -0
  4. model.safetensors +3 -0
  5. modeling_cm3p.py +1389 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CM3PModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_cm3p.CM3PConfig",
7
+ "AutoModel": "modeling_cm3p.CM3PModel"
8
+ },
9
+ "beatmap_config": {
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "audio_config": {
13
+ "_name_or_path": "",
14
+ "add_cross_attention": false,
15
+ "architectures": null,
16
+ "attention_bias": false,
17
+ "attention_dropout": 0.0,
18
+ "bad_words_ids": null,
19
+ "begin_suppress_tokens": null,
20
+ "bos_token_id": null,
21
+ "chunk_size_feed_forward": 0,
22
+ "cross_attention_hidden_size": null,
23
+ "decoder_bias": true,
24
+ "decoder_start_token_id": null,
25
+ "deterministic_flash_attn": false,
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "early_stopping": false,
29
+ "embedding_dropout": 0.0,
30
+ "encoder_no_repeat_ngram_size": 0,
31
+ "eos_token_id": null,
32
+ "exponential_decay_length_penalty": null,
33
+ "f_max": 8000,
34
+ "f_min": 0,
35
+ "finetuning_task": null,
36
+ "forced_bos_token_id": null,
37
+ "forced_eos_token_id": null,
38
+ "global_attn_every_n_layers": 3,
39
+ "global_rope_theta": 160000.0,
40
+ "hidden_activation": "gelu",
41
+ "hidden_size": 512,
42
+ "hop_length": 128,
43
+ "id2label": {
44
+ "0": "LABEL_0",
45
+ "1": "LABEL_1"
46
+ },
47
+ "initializer_cutoff_factor": 2.0,
48
+ "initializer_range": 0.02,
49
+ "intermediate_size": 1024,
50
+ "is_decoder": false,
51
+ "is_encoder_decoder": false,
52
+ "label2id": {
53
+ "LABEL_0": 0,
54
+ "LABEL_1": 1
55
+ },
56
+ "length_penalty": 1.0,
57
+ "local_attention": 128,
58
+ "local_rope_theta": 10000.0,
59
+ "max_length": 20,
60
+ "max_position_embeddings": 4096,
61
+ "min_length": 0,
62
+ "mlp_bias": false,
63
+ "mlp_dropout": 0.0,
64
+ "model_type": "CM3PAudio",
65
+ "n_ftt": 2048,
66
+ "n_mels": 80,
67
+ "no_repeat_ngram_size": 0,
68
+ "norm_bias": false,
69
+ "norm_eps": 1e-05,
70
+ "num_attention_heads": 8,
71
+ "num_beam_groups": 1,
72
+ "num_beams": 1,
73
+ "num_hidden_layers": 6,
74
+ "num_return_sequences": 1,
75
+ "output_attentions": false,
76
+ "output_hidden_states": false,
77
+ "output_scores": false,
78
+ "pad_mode": "constant",
79
+ "pad_token_id": null,
80
+ "prefix": null,
81
+ "problem_type": null,
82
+ "projector_dim": 768,
83
+ "projector_hidden_act": "gelu",
84
+ "projector_intermediate_size": 2048,
85
+ "pruned_heads": {},
86
+ "remove_invalid_values": false,
87
+ "repetition_penalty": 1.0,
88
+ "return_dict": true,
89
+ "return_dict_in_generate": false,
90
+ "sample_rate": 16000,
91
+ "sep_token_id": null,
92
+ "suppress_tokens": null,
93
+ "task_specific_params": null,
94
+ "temperature": 1.0,
95
+ "tf_legacy_loss": false,
96
+ "tie_encoder_decoder": false,
97
+ "tie_word_embeddings": true,
98
+ "tokenizer_class": null,
99
+ "top_k": 50,
100
+ "top_p": 1.0,
101
+ "torch_dtype": null,
102
+ "torchscript": false,
103
+ "typical_p": 1.0,
104
+ "use_bfloat16": false,
105
+ "vocab_size": 1
106
+ },
107
+ "audio_eos_token_id": 3966,
108
+ "audio_sos_token_id": null,
109
+ "audio_token_id": 3967,
110
+ "bos_token_id": 3958,
111
+ "classifier_activation": "gelu",
112
+ "classifier_bias": false,
113
+ "cls_embed": true,
114
+ "decoder_bias": true,
115
+ "deterministic_flash_attn": false,
116
+ "embedding_dropout": 0.0,
117
+ "eos_token_id": 3959,
118
+ "global_attn_every_n_layers": 3,
119
+ "global_rope_theta": 160000.0,
120
+ "hidden_activation": "gelu",
121
+ "hidden_size": 768,
122
+ "initializer_cutoff_factor": 2.0,
123
+ "initializer_factor": 1.0,
124
+ "initializer_range": 0.02,
125
+ "intermediate_size": 1152,
126
+ "local_attention": 128,
127
+ "local_rope_theta": 10000.0,
128
+ "max_position_embeddings": 8192,
129
+ "mlp_bias": false,
130
+ "mlp_dropout": 0.0,
131
+ "model_type": "CM3PBeatmap",
132
+ "norm_bias": false,
133
+ "norm_eps": 1e-05,
134
+ "num_attention_heads": 12,
135
+ "num_hidden_layers": 22,
136
+ "pad_token_id": 3962,
137
+ "projection_dim": 512,
138
+ "repad_logits_with_grad": false,
139
+ "sparse_pred_ignore_index": -100,
140
+ "sparse_prediction": false,
141
+ "torch_dtype": "bfloat16",
142
+ "vocab_size": 3968
143
+ },
144
+ "has_decoder_head": true,
145
+ "initializer_factor": 1.0,
146
+ "initializer_range": 0.02,
147
+ "logit_scale_init_value": 2.6592,
148
+ "loss_type": "ForMaskedLM",
149
+ "metadata_config": {
150
+ "attention_bias": false,
151
+ "attention_dropout": 0.0,
152
+ "bos_token_id": 23127,
153
+ "cls_embed": true,
154
+ "decoder_bias": true,
155
+ "deterministic_flash_attn": false,
156
+ "embedding_dropout": 0.0,
157
+ "eos_token_id": 23128,
158
+ "global_attn_every_n_layers": 1,
159
+ "global_rope_theta": 10000.0,
160
+ "hidden_activation": "gelu",
161
+ "hidden_size": 256,
162
+ "initializer_cutoff_factor": 2.0,
163
+ "initializer_factor": 1.0,
164
+ "initializer_range": 0.02,
165
+ "intermediate_size": 512,
166
+ "local_attention": 128,
167
+ "local_rope_theta": 10000.0,
168
+ "max_position_embeddings": 128,
169
+ "mlp_bias": false,
170
+ "mlp_dropout": 0.0,
171
+ "model_type": "CM3PMetadata",
172
+ "norm_bias": false,
173
+ "norm_eps": 1e-05,
174
+ "num_attention_heads": 4,
175
+ "num_hidden_layers": 6,
176
+ "pad_token_id": 23129,
177
+ "projection_dim": 512,
178
+ "torch_dtype": "bfloat16",
179
+ "vocab_size": 23145
180
+ },
181
+ "model_type": "CM3P",
182
+ "projection_dim": 512,
183
+ "torch_dtype": "bfloat16",
184
+ "transformers_version": "4.55.0"
185
+ }
configuration_cm3p.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CM3P model configuration"""
2
+ from transformers import AutoConfig
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class CM3PMetadataConfig(PretrainedConfig):
11
+ model_type = "CM3PMetadata"
12
+ base_config_key = "metadata_config"
13
+
14
+ def __init__(
15
+ self,
16
+ cls_embed=False,
17
+
18
+ projection_dim=512,
19
+ initializer_factor=1.0,
20
+
21
+ vocab_size=1000,
22
+ hidden_size=256,
23
+ intermediate_size=512,
24
+ num_hidden_layers=6,
25
+ num_attention_heads=4,
26
+ hidden_activation="gelu",
27
+ max_position_embeddings=128,
28
+ initializer_range=0.02,
29
+ initializer_cutoff_factor=2.0,
30
+ norm_eps=1e-5,
31
+ norm_bias=False,
32
+ pad_token_id=0,
33
+ bos_token_id=1,
34
+ eos_token_id=2,
35
+ global_rope_theta=10000.0,
36
+ attention_bias=False,
37
+ attention_dropout=0.0,
38
+ global_attn_every_n_layers=1,
39
+ local_attention=128,
40
+ local_rope_theta=10000.0,
41
+ embedding_dropout=0.0,
42
+ mlp_bias=False,
43
+ mlp_dropout=0.0,
44
+ decoder_bias=True,
45
+ deterministic_flash_attn=False,
46
+ reference_compile=None,
47
+ **kwargs,
48
+ ):
49
+ super().__init__(
50
+ pad_token_id=pad_token_id,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ **kwargs,
54
+ )
55
+
56
+ self.cls_embed = cls_embed
57
+
58
+ self.projection_dim = projection_dim
59
+ self.initializer_range = initializer_range
60
+ self.initializer_factor = initializer_factor
61
+ self.attention_dropout = attention_dropout
62
+
63
+ self.vocab_size = vocab_size
64
+ self.max_position_embeddings = max_position_embeddings
65
+ self.hidden_size = hidden_size
66
+ self.intermediate_size = intermediate_size
67
+ self.num_hidden_layers = num_hidden_layers
68
+ self.num_attention_heads = num_attention_heads
69
+ self.initializer_range = initializer_range
70
+ self.initializer_cutoff_factor = initializer_cutoff_factor
71
+ self.norm_eps = norm_eps
72
+ self.norm_bias = norm_bias
73
+ self.global_rope_theta = global_rope_theta
74
+ self.attention_bias = attention_bias
75
+ self.attention_dropout = attention_dropout
76
+ self.hidden_activation = hidden_activation
77
+ self.global_attn_every_n_layers = global_attn_every_n_layers
78
+ self.local_attention = local_attention
79
+ self.local_rope_theta = local_rope_theta
80
+ self.embedding_dropout = embedding_dropout
81
+ self.mlp_bias = mlp_bias
82
+ self.mlp_dropout = mlp_dropout
83
+ self.decoder_bias = decoder_bias
84
+ self.deterministic_flash_attn = deterministic_flash_attn
85
+ self.reference_compile = reference_compile
86
+
87
+ def to_dict(self):
88
+ output = super().to_dict()
89
+ output.pop("reference_compile", None)
90
+ return output
91
+
92
+
93
+ class CM3PAudioConfig(PretrainedConfig):
94
+ model_type = "CM3PAudio"
95
+ base_config_key = "audio_config"
96
+
97
+ def __init__(
98
+ self,
99
+ hidden_size=512,
100
+ intermediate_size=1024,
101
+ num_hidden_layers=6,
102
+ num_attention_heads=8,
103
+ hidden_activation="gelu",
104
+ max_position_embeddings=4096,
105
+ initializer_range=0.02,
106
+ initializer_cutoff_factor=2.0,
107
+ norm_eps=1e-5,
108
+ norm_bias=False,
109
+ global_rope_theta=160000.0,
110
+ attention_bias=False,
111
+ attention_dropout=0.0,
112
+ global_attn_every_n_layers=3,
113
+ local_attention=128,
114
+ local_rope_theta=10000.0,
115
+ embedding_dropout=0.0,
116
+ mlp_bias=False,
117
+ mlp_dropout=0.0,
118
+ decoder_bias=True,
119
+ deterministic_flash_attn=False,
120
+ reference_compile=None,
121
+
122
+ projector_intermediate_size=2048, # 4 * hidden_size for a 4x reduction in tokens
123
+ projector_dim=768,
124
+ projector_hidden_act="gelu",
125
+
126
+ sample_rate: int = 16000,
127
+ n_ftt: int = 2048,
128
+ n_mels: int = 80,
129
+ hop_length: int = 128,
130
+ f_min: int = 0,
131
+ f_max: int = 8000,
132
+ pad_mode: str = "constant",
133
+ **kwargs,
134
+ ):
135
+ super().__init__(**kwargs)
136
+ self.vocab_size = 1
137
+ self.max_position_embeddings = max_position_embeddings
138
+ self.hidden_size = hidden_size
139
+ self.intermediate_size = intermediate_size
140
+ self.num_hidden_layers = num_hidden_layers
141
+ self.num_attention_heads = num_attention_heads
142
+ self.initializer_range = initializer_range
143
+ self.initializer_cutoff_factor = initializer_cutoff_factor
144
+ self.norm_eps = norm_eps
145
+ self.norm_bias = norm_bias
146
+ self.global_rope_theta = global_rope_theta
147
+ self.attention_bias = attention_bias
148
+ self.attention_dropout = attention_dropout
149
+ self.hidden_activation = hidden_activation
150
+ self.global_attn_every_n_layers = global_attn_every_n_layers
151
+ self.local_attention = local_attention
152
+ self.local_rope_theta = local_rope_theta
153
+ self.embedding_dropout = embedding_dropout
154
+ self.mlp_bias = mlp_bias
155
+ self.mlp_dropout = mlp_dropout
156
+ self.decoder_bias = decoder_bias
157
+ self.deterministic_flash_attn = deterministic_flash_attn
158
+ self.reference_compile = reference_compile
159
+
160
+ self.projector_intermediate_size = projector_intermediate_size
161
+ self.projector_dim = projector_dim
162
+ self.projector_hidden_act = projector_hidden_act
163
+
164
+ self.sample_rate = sample_rate
165
+ self.n_ftt = n_ftt
166
+ self.n_mels = n_mels
167
+ self.hop_length = hop_length
168
+ self.f_min = f_min
169
+ self.f_max = f_max
170
+ self.pad_mode = pad_mode
171
+
172
+ def to_dict(self):
173
+ output = super().to_dict()
174
+ output.pop("reference_compile", None)
175
+ return output
176
+
177
+
178
+ class CM3PBeatmapConfig(PretrainedConfig):
179
+ model_type = "CM3PBeatmap"
180
+ base_config_key = "beatmap_config"
181
+ sub_configs = {"audio_config": CM3PAudioConfig}
182
+
183
+ def __init__(
184
+ self,
185
+ audio_config: dict = None,
186
+ audio_sos_token_id=3164,
187
+ audio_eos_token_id=3165,
188
+ audio_token_id=3166,
189
+ cls_embed=False,
190
+
191
+ projection_dim=512,
192
+ initializer_factor=1.0,
193
+
194
+ vocab_size=3167,
195
+ hidden_size=768,
196
+ intermediate_size=1152,
197
+ num_hidden_layers=22,
198
+ num_attention_heads=12,
199
+ hidden_activation="gelu",
200
+ max_position_embeddings=8192,
201
+ initializer_range=0.02,
202
+ initializer_cutoff_factor=2.0,
203
+ norm_eps=1e-5,
204
+ norm_bias=False,
205
+ pad_token_id=0,
206
+ bos_token_id=1,
207
+ eos_token_id=2,
208
+ global_rope_theta=160000.0,
209
+ attention_bias=False,
210
+ attention_dropout=0.0,
211
+ global_attn_every_n_layers=3,
212
+ local_attention=128,
213
+ local_rope_theta=10000.0,
214
+ embedding_dropout=0.0,
215
+ mlp_bias=False,
216
+ mlp_dropout=0.0,
217
+ decoder_bias=True,
218
+ classifier_bias=False,
219
+ classifier_activation="gelu",
220
+ deterministic_flash_attn=False,
221
+ sparse_prediction=False,
222
+ sparse_pred_ignore_index=-100,
223
+ reference_compile=None,
224
+ repad_logits_with_grad=False,
225
+ **kwargs,
226
+ ):
227
+ super().__init__(
228
+ pad_token_id=pad_token_id,
229
+ bos_token_id=bos_token_id,
230
+ eos_token_id=eos_token_id,
231
+ **kwargs,
232
+ )
233
+
234
+ if audio_config is None:
235
+ audio_config = {}
236
+ logger.info("`audio_config` is `None`. Initializing the `CM3PAudioConfig` with default values.")
237
+
238
+ self.audio_config = CM3PAudioConfig(**audio_config)
239
+ self.audio_sos_token_id = audio_sos_token_id
240
+ self.audio_eos_token_id = audio_eos_token_id
241
+ self.audio_token_id = audio_token_id
242
+ self.cls_embed = cls_embed
243
+
244
+ self.projection_dim = projection_dim
245
+ self.initializer_factor = initializer_factor
246
+ self.vocab_size = vocab_size
247
+ self.max_position_embeddings = max_position_embeddings
248
+ self.hidden_size = hidden_size
249
+ self.intermediate_size = intermediate_size
250
+ self.num_hidden_layers = num_hidden_layers
251
+ self.num_attention_heads = num_attention_heads
252
+ self.initializer_range = initializer_range
253
+ self.initializer_cutoff_factor = initializer_cutoff_factor
254
+ self.norm_eps = norm_eps
255
+ self.norm_bias = norm_bias
256
+ self.global_rope_theta = global_rope_theta
257
+ self.attention_bias = attention_bias
258
+ self.attention_dropout = attention_dropout
259
+ self.hidden_activation = hidden_activation
260
+ self.global_attn_every_n_layers = global_attn_every_n_layers
261
+ self.local_attention = local_attention
262
+ self.local_rope_theta = local_rope_theta
263
+ self.embedding_dropout = embedding_dropout
264
+ self.mlp_bias = mlp_bias
265
+ self.mlp_dropout = mlp_dropout
266
+ self.decoder_bias = decoder_bias
267
+ self.classifier_bias = classifier_bias
268
+ self.classifier_activation = classifier_activation
269
+ self.deterministic_flash_attn = deterministic_flash_attn
270
+ self.sparse_prediction = sparse_prediction
271
+ self.sparse_pred_ignore_index = sparse_pred_ignore_index
272
+ self.reference_compile = reference_compile
273
+ self.repad_logits_with_grad = repad_logits_with_grad
274
+
275
+ def to_dict(self):
276
+ output = super().to_dict()
277
+ output.pop("reference_compile", None)
278
+ return output
279
+
280
+
281
+ class CM3PConfig(PretrainedConfig):
282
+ model_type = "CM3P"
283
+ sub_configs = {"metadata_config": CM3PMetadataConfig, "beatmap_config": CM3PBeatmapConfig}
284
+
285
+ def __init__(
286
+ self,
287
+ metadata_config=None,
288
+ beatmap_config=None,
289
+ projection_dim=512,
290
+ logit_scale_init_value=2.6592,
291
+ initializer_factor=1.0,
292
+ initializer_range=0.02,
293
+ loss_type=None,
294
+ has_decoder_head=False,
295
+ **kwargs
296
+ ):
297
+ super().__init__(**kwargs)
298
+
299
+ if metadata_config is None:
300
+ metadata_config = {}
301
+ logger.debug("`metadata_config` is `None`. Initializing the `CM3PMetadataConfig` with default values.")
302
+
303
+ if beatmap_config is None:
304
+ beatmap_config = {}
305
+ logger.debug("`beatmap_config` is `None`. initializing the `CM3PBeatmapConfig` with default values.")
306
+
307
+ self.metadata_config = CM3PMetadataConfig(**metadata_config)
308
+ self.beatmap_config = CM3PBeatmapConfig(**beatmap_config)
309
+
310
+ self.projection_dim = projection_dim
311
+ self.logit_scale_init_value = logit_scale_init_value
312
+ self.initializer_factor = initializer_factor
313
+ self.initializer_range = initializer_range
314
+ self.loss_type = loss_type
315
+ self.has_decoder_head = has_decoder_head
316
+
317
+
318
+ AutoConfig.register("CM3PMetadata", CM3PMetadataConfig)
319
+ AutoConfig.register("CM3PAudio", CM3PAudioConfig)
320
+ AutoConfig.register("CM3PBeatmap", CM3PBeatmapConfig)
321
+ AutoConfig.register("CM3P", CM3PConfig)
322
+
323
+ __all__ = ["CM3PConfig", "CM3PMetadataConfig", "CM3PAudioConfig", "CM3PBeatmapConfig"]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14f34fe88d4a181864cf4ed638093857e8be5ab6033a69f11dd9e56f6fa44cc9
3
+ size 292456522
modeling_cm3p.py ADDED
@@ -0,0 +1,1389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch CM3P model."""
2
+ from contextlib import nullcontext
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from transformers import ModernBertModel, AutoModel, AutoModelForSequenceClassification, AutoModelForMaskedLM
11
+ from transformers.activations import ACT2FN
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutput,
14
+ BaseModelOutputWithPooling, MaskedLMOutput,
15
+ )
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import ModelOutput, auto_docstring, can_return_tuple, logging
18
+
19
+ from .configuration_cm3p import CM3PConfig, CM3PMetadataConfig, CM3PBeatmapConfig, CM3PAudioConfig
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ # contrastive loss function, adapted from
26
+ # https://sachinruk.github.io/blog/2021-03-07-clip.html
27
+ def contrastive_loss(logits: torch.Tensor, target: torch.Tensor = None) -> torch.Tensor:
28
+ target = target if target is not None else torch.arange(len(logits), device=logits.device)
29
+ return nn.functional.cross_entropy(logits, target)
30
+
31
+
32
+ # CM3P loss function, adapted from CLIP
33
+ def cm3p_loss(similarity: torch.Tensor, metadata_variation_classes: torch.LongTensor = None) -> torch.Tensor:
34
+ if similarity.dim() == 3: # (metadata_batch_size, variations, beatmap_batch_size)
35
+ metadata_batch_size = similarity.size(0)
36
+ num_variations = similarity.size(1)
37
+ beatmap_batch_size = similarity.size(2)
38
+ assert metadata_batch_size == beatmap_batch_size
39
+
40
+ true_metadata_indices = (metadata_variation_classes == 0).int().argmax(dim=1)
41
+ metadata_loss = contrastive_loss(similarity[torch.arange(metadata_batch_size), true_metadata_indices]) # only use original metadata for loss
42
+
43
+ beatmap_similarity = similarity.permute(2, 0, 1) # (beatmap_batch_size, metadata_batch_size, variations)
44
+ beatmap_similarity = beatmap_similarity.reshape(beatmap_batch_size, -1) # (beatmap_batch_size, metadata_batch_size * variations)
45
+ target = torch.arange(0, beatmap_similarity.size(1), num_variations, device=similarity.device) # (metadata_batch_size,)
46
+ target += true_metadata_indices
47
+ beatmap_loss = contrastive_loss(beatmap_similarity, target=target)
48
+ else:
49
+ metadata_loss = contrastive_loss(similarity)
50
+ beatmap_loss = contrastive_loss(similarity.t())
51
+ return (metadata_loss + beatmap_loss) / 2.0
52
+
53
+
54
+ def _get_vector_norm(tensor: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ This method is equivalent to tensor.norm(p=2, dim=-1, keepdim=True) and used to make
57
+ model `executorch` exportable. See issue https://github.com/pytorch/executorch/issues/3566
58
+ """
59
+ square_tensor = torch.pow(tensor, 2)
60
+ sum_tensor = torch.sum(square_tensor, dim=-1, keepdim=True)
61
+ normed_tensor = torch.pow(sum_tensor, 0.5)
62
+ return normed_tensor
63
+
64
+
65
+ def _unpad_cm3p_input(
66
+ inputs: torch.Tensor,
67
+ attention_mask: torch.Tensor,
68
+ position_ids: Optional[torch.Tensor] = None,
69
+ labels: Optional[torch.Tensor] = None,
70
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
71
+ """
72
+ Remove padding from input sequences.
73
+
74
+ Args:
75
+ inputs: (batch, seqlen, ...) or (batch, seqlen)
76
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
77
+ position_ids: (batch, seqlen), int, position ids
78
+ labels: (batch, seqlen), int, labels
79
+
80
+ Returns:
81
+ unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
82
+ indices: (total_nnz)
83
+ cu_seqlens: (batch + 1), the cumulative sequence lengths
84
+ max_seqlen_in_batch: int
85
+ unpadded_position_ids: (total_nnz) or None
86
+ unpadded_labels: (total_nnz) or None
87
+ """
88
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
89
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
90
+ max_seqlen_in_batch = int(seqlens_in_batch.max().item())
91
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
92
+
93
+ if inputs.dim() == 2:
94
+ unpadded_inputs = inputs.flatten()[indices]
95
+ else:
96
+ batch, seqlen, *rest = inputs.shape
97
+ shape = batch * seqlen
98
+ unpadded_inputs = inputs.view(shape, *rest)[indices]
99
+
100
+ unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
101
+ unpadded_labels = labels.flatten()[indices] if labels is not None else None
102
+
103
+ return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
104
+
105
+
106
+ def _pad_cm3p_output(
107
+ inputs: torch.Tensor,
108
+ indices: torch.Tensor,
109
+ batch: int,
110
+ seqlen: int,
111
+ ) -> torch.Tensor:
112
+ """
113
+ Add padding to sequences.
114
+
115
+ Args:
116
+ inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
117
+ indices: (total_nnz)
118
+ batch: int, batch size
119
+ seqlen: int, max sequence length
120
+
121
+ Returns:
122
+ padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
123
+ """
124
+ if inputs.dim() == 1:
125
+ output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
126
+ output[indices] = inputs
127
+ padded_inputs = output.view(batch, seqlen)
128
+ else:
129
+ _, *rest = inputs.shape
130
+ output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
131
+ output[indices] = inputs
132
+ padded_inputs = output.view(batch, seqlen, *rest)
133
+
134
+ return padded_inputs
135
+
136
+
137
+ @dataclass
138
+ class BeatmapClassifierOutput(ModelOutput):
139
+ """
140
+ Base class for outputs of beatmap classification models.
141
+
142
+ Args:
143
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
144
+ Classification (or regression if config.num_labels==1) loss.
145
+ logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
146
+ Classification (or regression if config.num_labels==1) scores (before SoftMax).
147
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
148
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
149
+ one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
150
+ (also called feature maps) of the model at the output of each stage.
151
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
152
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
153
+ sequence_length)`.
154
+
155
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
156
+ heads.
157
+ """
158
+
159
+ loss: Optional[torch.FloatTensor] = None
160
+ logits: Optional[torch.FloatTensor] = None
161
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
162
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
163
+
164
+
165
+ @dataclass
166
+ @auto_docstring(
167
+ custom_intro="""
168
+ Base class for audio model's outputs that also contains a pooling of the last hidden states.
169
+ """
170
+ )
171
+ class CM3PAudioModelOutput(BaseModelOutput):
172
+ r"""
173
+ audio_embeds (`torch.FloatTensor` of shape `(batch_size * sequence_length, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
174
+ The audio embeddings obtained by applying the projection layer to the last hidden state.
175
+ """
176
+
177
+ audio_embeds: Optional[torch.FloatTensor] = None
178
+
179
+
180
+ @dataclass
181
+ @auto_docstring(
182
+ custom_intro="""
183
+ Base class for beatmap model's outputs that also contains beatmap embeddings of the pooling of the last hidden states.
184
+ """
185
+ )
186
+ class CM3PBeatmapModelOutput(BaseModelOutputWithPooling):
187
+ r"""
188
+ audio_model_output (`BaseModelOutput`):
189
+ The output of the audio model, which contains the last hidden state, hidden states, and attentions.
190
+ beatmap_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
191
+ The beatmap embeddings obtained by applying the projection layer to the pooler_output.
192
+ """
193
+
194
+ beatmap_embeds: Optional[torch.FloatTensor] = None
195
+ audio_model_output: Optional[CM3PAudioModelOutput] = None
196
+
197
+
198
+ @dataclass
199
+ @auto_docstring(
200
+ custom_intro="""
201
+ Base class for metadata model's outputs that also contains a pooling of the last hidden states.
202
+ """
203
+ )
204
+ class CM3PMetadataModelOutput(BaseModelOutput):
205
+ r"""
206
+ metadata_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
207
+ The metadata embeddings obtained by applying the projection layer to the pooler_output.
208
+ """
209
+
210
+ metadata_embeds: Optional[torch.FloatTensor] = None
211
+
212
+
213
+ @dataclass
214
+ @auto_docstring
215
+ class CM3POutput(ModelOutput):
216
+ r"""
217
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
218
+ Contrastive loss for beatmap-metadata similarity.
219
+ logits_per_beatmap (`torch.FloatTensor` of shape `(beatmap_batch_size, metadata_batch_size)`):
220
+ The scaled dot product scores between `beatmap_embeds` and `metadata_embeds`. This represents the beatmap-metadata
221
+ similarity scores.
222
+ logits_per_metadata (`torch.FloatTensor` of shape `(metadata_batch_size, beatmap_batch_size)`):
223
+ The scaled dot product scores between `metadata_embeds` and `beatmap_embeds`. This represents the metadata-beatmap
224
+ similarity scores.
225
+ metadata_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
226
+ The metadata embeddings obtained by applying the projection layer to the pooled output of [`CM3PMetadataModel`].
227
+ beatmap_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
228
+ The beatmap embeddings obtained by applying the projection layer to the pooled output of [`CM3PBeatmapModel`].
229
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, vocab_size)`, *optional*, returned when `labels` is provided):
230
+ Prediction scores of the masked language modeling head. Only computed if `labels` is provided.
231
+ metadata_model_output (`BaseModelOutputWithPooling`):
232
+ The output of the [`CM3PMetadataModel`].
233
+ beatmap_model_output (`BaseModelOutputWithPooling`):
234
+ The output of the [`CM3PBeatmapModel`].
235
+ """
236
+
237
+ loss: Optional[torch.FloatTensor] = None
238
+ logits_per_beatmap: Optional[torch.Tensor] = None
239
+ logits_per_metadata: Optional[torch.Tensor] = None
240
+ metadata_embeds: Optional[torch.FloatTensor] = None
241
+ beatmap_embeds: Optional[torch.FloatTensor] = None
242
+ logits: Optional[torch.FloatTensor] = None
243
+ metadata_model_output: BaseModelOutputWithPooling = None
244
+ beatmap_model_output: BaseModelOutputWithPooling = None
245
+
246
+ def to_tuple(self) -> tuple[Any]:
247
+ return tuple(
248
+ self[k] if k not in ["metadata_model_output", "beatmap_model_output"] else getattr(self, k).to_tuple()
249
+ for k in self.keys()
250
+ )
251
+
252
+
253
+ @auto_docstring
254
+ class CM3PPreTrainedModel(PreTrainedModel):
255
+ config_class = CM3PConfig
256
+ base_model_prefix = "cm3p"
257
+ supports_gradient_checkpointing = True
258
+ _supports_flash_attn_2 = True
259
+ _supports_sdpa = True
260
+ _supports_flex_attn = False
261
+
262
+ def _init_weights(self, module):
263
+ """Initialize the weights"""
264
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
265
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
266
+ if module.bias is not None:
267
+ module.bias.data.zero_()
268
+ elif isinstance(module, nn.LayerNorm):
269
+ module.weight.data.fill_(1.0)
270
+ if module.bias is not None:
271
+ module.bias.data.zero_()
272
+ elif isinstance(module, ModernBertModel):
273
+ module.initialize_weights()
274
+ elif isinstance(module, CM3PModel):
275
+ nn.init.normal_(
276
+ module.metadata_projection.weight,
277
+ std=module.metadata_embed_dim**-0.5 * self.config.initializer_factor,
278
+ )
279
+ nn.init.normal_(
280
+ module.beatmap_projection.weight,
281
+ std=module.beatmap_embed_dim**-0.5 * self.config.initializer_factor,
282
+ )
283
+ elif isinstance(module, CM3PBeatmapModelWithProjection):
284
+ nn.init.normal_(
285
+ module.beatmap_projection.weight,
286
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
287
+ )
288
+ elif isinstance(module, CM3PMetadataModelWithProjection):
289
+ nn.init.normal_(
290
+ module.metadata_projection.weight,
291
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
292
+ )
293
+ elif isinstance(module, CM3PForBeatmapClassification):
294
+ nn.init.normal_(
295
+ module.classifier.weight,
296
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
297
+ )
298
+
299
+
300
+ class CM3PMetadataTransformer(nn.Module):
301
+ def __init__(self, config: CM3PMetadataConfig):
302
+ super().__init__()
303
+ self.config = config
304
+ # noinspection PyTypeChecker
305
+ self.encoder = ModernBertModel(config)
306
+
307
+ def get_input_embeddings(self):
308
+ return self.encoder.get_input_embeddings()
309
+
310
+ def set_input_embeddings(self, value):
311
+ self.encoder.set_input_embeddings(value)
312
+
313
+ @can_return_tuple
314
+ @auto_docstring
315
+ def forward(
316
+ self,
317
+ input_ids: Optional[torch.Tensor] = None,
318
+ attention_mask: Optional[torch.Tensor] = None,
319
+ indices: Optional[torch.Tensor] = None,
320
+ cu_seqlens: Optional[torch.Tensor] = None,
321
+ max_seqlen: Optional[int] = None,
322
+ batch_size: Optional[int] = None,
323
+ seq_len: Optional[int] = None,
324
+ output_attentions: Optional[bool] = None,
325
+ output_hidden_states: Optional[bool] = None,
326
+ output_pooler: bool = True,
327
+ ) -> BaseModelOutputWithPooling:
328
+ r"""
329
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
330
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
331
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
332
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
333
+ max_seqlen (`int`, *optional*):
334
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
335
+ batch_size (`int`, *optional*):
336
+ Batch size of the input sequences. Used to pad the output tensors.
337
+ seq_len (`int`, *optional*):
338
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
339
+ output_pooler (`bool`, *optional*, defaults to `True`):
340
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
341
+ the first token (CLS) or the mean of the token representations.
342
+ """
343
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
344
+ output_hidden_states = (
345
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
346
+ )
347
+
348
+ if input_ids is None:
349
+ raise ValueError("You have to specify input_ids")
350
+
351
+ is_3d = input_ids.dim() == 3
352
+ batch_size_3d = input_ids.size(0)
353
+ if is_3d:
354
+ # flatten to 2D batch if multiple metadata variations are provided
355
+ input_ids = input_ids.view(-1, input_ids.size(-1))
356
+ if attention_mask is not None:
357
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1))
358
+
359
+ encoder_outputs: BaseModelOutput = self.encoder(
360
+ input_ids=input_ids,
361
+ attention_mask=attention_mask,
362
+ indices=indices,
363
+ cu_seqlens=cu_seqlens,
364
+ max_seqlen=max_seqlen,
365
+ batch_size=batch_size,
366
+ seq_len=seq_len,
367
+ output_attentions=output_attentions,
368
+ output_hidden_states=output_hidden_states,
369
+ )
370
+
371
+ last_hidden_state = encoder_outputs.last_hidden_state
372
+ pooled_output = None
373
+
374
+ if is_3d:
375
+ # un-flatten back to 3D batch (batch_size, variations, seq_length, hidden_size)
376
+ last_hidden_state = last_hidden_state.view(
377
+ batch_size_3d, -1, last_hidden_state.size(-2), last_hidden_state.size(-1)
378
+ )
379
+ if attention_mask is not None:
380
+ attention_mask = attention_mask.view(batch_size_3d, -1, attention_mask.size(-1))
381
+
382
+ if output_pooler:
383
+ if indices is not None:
384
+ raise NotImplementedError("Pooling with unpadded input is not implemented yet.")
385
+ if self.config.cls_embed:
386
+ pooled_output = last_hidden_state[..., 0, :]
387
+ elif attention_mask is not None:
388
+ # Use the attention mask to exclude padding tokens
389
+ expanded_attention_mask = attention_mask.unsqueeze(-1).float()
390
+ masked_hidden_states = last_hidden_state * expanded_attention_mask
391
+ sum_hidden_states = torch.sum(masked_hidden_states, dim=-2)
392
+ sum_attention_mask = torch.sum(expanded_attention_mask, dim=-2)
393
+ pooled_output = sum_hidden_states / torch.clamp(sum_attention_mask, min=1e-9)
394
+ pooled_output = pooled_output.to(dtype=last_hidden_state.dtype)
395
+ else:
396
+ pooled_output = torch.mean(last_hidden_state, dim=-2)
397
+
398
+ return BaseModelOutputWithPooling(
399
+ last_hidden_state=last_hidden_state,
400
+ pooler_output=pooled_output,
401
+ hidden_states=encoder_outputs.hidden_states,
402
+ attentions=encoder_outputs.attentions,
403
+ )
404
+
405
+
406
+ @auto_docstring(
407
+ custom_intro="""
408
+ The metadata model from CM3P without any head or projection on top.
409
+ """
410
+ )
411
+ class CM3PMetadataModel(CM3PPreTrainedModel):
412
+ config_class = CM3PMetadataConfig
413
+
414
+ def __init__(self, config: CM3PMetadataConfig):
415
+ super().__init__(config)
416
+ self.metadata_model = CM3PMetadataTransformer(config)
417
+ # Initialize weights and apply final processing
418
+ self.post_init()
419
+
420
+ def get_input_embeddings(self) -> nn.Module:
421
+ return self.metadata_model.encoder.embeddings.tok_embeddings
422
+
423
+ def set_input_embeddings(self, value):
424
+ self.metadata_model.encoder.embeddings.tok_embeddings = value
425
+
426
+ @can_return_tuple
427
+ @auto_docstring
428
+ def forward(
429
+ self,
430
+ input_ids: Optional[torch.Tensor] = None,
431
+ attention_mask: Optional[torch.Tensor] = None,
432
+ indices: Optional[torch.Tensor] = None,
433
+ cu_seqlens: Optional[torch.Tensor] = None,
434
+ max_seqlen: Optional[int] = None,
435
+ batch_size: Optional[int] = None,
436
+ seq_len: Optional[int] = None,
437
+ output_attentions: Optional[bool] = None,
438
+ output_hidden_states: Optional[bool] = None,
439
+ output_pooler: bool = True,
440
+ ) -> BaseModelOutputWithPooling:
441
+ r"""
442
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
443
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
444
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
445
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
446
+ max_seqlen (`int`, *optional*):
447
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
448
+ batch_size (`int`, *optional*):
449
+ Batch size of the input sequences. Used to pad the output tensors.
450
+ seq_len (`int`, *optional*):
451
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
452
+ output_pooler (`bool`, *optional*, defaults to `True`):
453
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
454
+ the first token (CLS) or the mean of the token representations.
455
+ """
456
+ return self.metadata_model(
457
+ input_ids=input_ids,
458
+ attention_mask=attention_mask,
459
+ indices=indices,
460
+ cu_seqlens=cu_seqlens,
461
+ max_seqlen=max_seqlen,
462
+ batch_size=batch_size,
463
+ seq_len=seq_len,
464
+ output_attentions=output_attentions,
465
+ output_hidden_states=output_hidden_states,
466
+ output_pooler=output_pooler,
467
+ )
468
+
469
+
470
+ class CM3PMultiModalProjector(nn.Module):
471
+ def __init__(self, config: CM3PAudioConfig):
472
+ super().__init__()
473
+ self.linear_1 = nn.Linear(config.projector_intermediate_size, config.projector_dim, bias=False)
474
+ self.act = ACT2FN[config.projector_hidden_act]
475
+ self.linear_2 = nn.Linear(config.projector_dim, config.projector_dim, bias=False)
476
+
477
+ def forward(self, audio_features):
478
+ hidden_states = self.linear_1(audio_features)
479
+ hidden_states = self.act(hidden_states)
480
+ hidden_states = self.linear_2(hidden_states)
481
+ return hidden_states
482
+
483
+
484
+ class CM3PAudioEncoder(nn.Module):
485
+ def __init__(self, config: CM3PAudioConfig):
486
+ super().__init__()
487
+ self.config = config
488
+ self.conv1 = nn.Conv1d(config.n_mels, config.hidden_size, kernel_size=3, padding=1)
489
+ self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
490
+ # noinspection PyTypeChecker
491
+ self.encoder = ModernBertModel(config)
492
+ self.multi_modal_projector = CM3PMultiModalProjector(config)
493
+
494
+ def forward(
495
+ self,
496
+ input_features: torch.FloatTensor,
497
+ output_attentions: Optional[bool] = None,
498
+ output_hidden_states: Optional[bool] = None,
499
+ ) -> CM3PAudioModelOutput:
500
+ # Conv layers from Whisper followed by an modern Bert encoder
501
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features))
502
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
503
+
504
+ inputs_embeds = inputs_embeds.permute(0, 2, 1).contiguous()
505
+
506
+ position_ids = torch.arange(inputs_embeds.size(1), device=inputs_embeds.device).unsqueeze(0).repeat(
507
+ inputs_embeds.size(0), 1)
508
+
509
+ encoder_outputs: BaseModelOutput = self.encoder(
510
+ inputs_embeds=inputs_embeds,
511
+ position_ids=position_ids,
512
+ output_attentions=output_attentions,
513
+ output_hidden_states=output_hidden_states,
514
+ )
515
+
516
+ # Reduce the sequence length and project to the beatmap hidden size
517
+ audio_hidden_states = encoder_outputs.last_hidden_state
518
+ audio_hidden_states = audio_hidden_states.reshape(-1, self.config.projector_intermediate_size)
519
+ audio_embeds = self.multi_modal_projector(audio_hidden_states)
520
+
521
+ audio_outputs = CM3PAudioModelOutput(
522
+ audio_embeds=audio_embeds,
523
+ last_hidden_state=encoder_outputs.last_hidden_state,
524
+ hidden_states=encoder_outputs.hidden_states,
525
+ attentions=encoder_outputs.attentions,
526
+ )
527
+
528
+ return audio_outputs
529
+
530
+
531
+ class CM3PBeatmapTransformer(nn.Module):
532
+ def __init__(self, config: CM3PBeatmapConfig):
533
+ super().__init__()
534
+ self.config = config
535
+ self.audio_encoder = CM3PAudioEncoder(config.audio_config)
536
+ # noinspection PyTypeChecker
537
+ self.encoder = ModernBertModel(config)
538
+
539
+ def get_input_embeddings(self):
540
+ return self.encoder.get_input_embeddings()
541
+
542
+ def set_input_embeddings(self, value):
543
+ self.encoder.set_input_embeddings(value)
544
+
545
+ @can_return_tuple
546
+ @auto_docstring
547
+ def forward(
548
+ self,
549
+ input_ids: Optional[torch.LongTensor] = None,
550
+ input_features: Optional[torch.FloatTensor] = None,
551
+ attention_mask: Optional[torch.FloatTensor] = None,
552
+ sliding_window_mask: Optional[torch.FloatTensor] = None,
553
+ position_ids: Optional[torch.LongTensor] = None,
554
+ inputs_embeds: Optional[torch.FloatTensor] = None,
555
+ indices: Optional[torch.Tensor] = None,
556
+ cu_seqlens: Optional[torch.Tensor] = None,
557
+ max_seqlen: Optional[int] = None,
558
+ batch_size: Optional[int] = None,
559
+ seq_len: Optional[int] = None,
560
+ output_attentions: Optional[bool] = None,
561
+ output_hidden_states: Optional[bool] = None,
562
+ output_pooler: bool = True,
563
+ ) -> CM3PBeatmapModelOutput:
564
+ r"""
565
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
566
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
567
+ compute the beatmap embeddings.
568
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
569
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
570
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
571
+ far-away tokens in the local attention layers when not using Flash Attention.
572
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
573
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
574
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
575
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
576
+ max_seqlen (`int`, *optional*):
577
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
578
+ batch_size (`int`, *optional*):
579
+ Batch size of the input sequences. Used to pad the output tensors.
580
+ seq_len (`int`, *optional*):
581
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
582
+ output_pooler (`bool`, *optional*, defaults to `True`):
583
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
584
+ the first token (CLS) or the mean of the token representations.
585
+ """
586
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
587
+ output_hidden_states = (
588
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
589
+ )
590
+
591
+ if inputs_embeds is None:
592
+ inputs_embeds = self.get_input_embeddings()(input_ids)
593
+
594
+ audio_model_outputs = None
595
+ if input_features is not None:
596
+ audio_model_outputs = self.audio_encoder(
597
+ input_features=input_features,
598
+ output_attentions=output_attentions,
599
+ output_hidden_states=output_hidden_states,
600
+ )
601
+
602
+ # replace text-audio token placeholders with audio embeddings
603
+ audio_embeds = audio_model_outputs.audio_embeds.to(dtype=inputs_embeds.dtype)
604
+ audio_token_mask = input_ids == self.config.audio_token_id
605
+ inputs_embeds[audio_token_mask] = audio_embeds
606
+
607
+ encoder_outputs: BaseModelOutput = self.encoder(
608
+ inputs_embeds=inputs_embeds,
609
+ attention_mask=attention_mask,
610
+ sliding_window_mask=sliding_window_mask,
611
+ position_ids=position_ids,
612
+ indices=indices,
613
+ cu_seqlens=cu_seqlens,
614
+ max_seqlen=max_seqlen,
615
+ batch_size=batch_size,
616
+ seq_len=seq_len,
617
+ output_attentions=output_attentions,
618
+ output_hidden_states=output_hidden_states,
619
+ )
620
+
621
+ last_hidden_state = encoder_outputs.last_hidden_state
622
+ pooled_output = None
623
+
624
+ if output_pooler:
625
+ if indices is not None:
626
+ if self.config.cls_embed:
627
+ pooled_output = last_hidden_state[cu_seqlens[:-1]]
628
+ else:
629
+ raise NotImplementedError("Pooling with unpadded input is not implemented yet.")
630
+ else:
631
+ if self.config.cls_embed:
632
+ pooled_output = last_hidden_state[:, 0]
633
+ elif attention_mask is not None:
634
+ # Use the attention mask to exclude padding tokens
635
+ expanded_attention_mask = attention_mask.unsqueeze(-1).float()
636
+ masked_hidden_states = last_hidden_state * expanded_attention_mask
637
+ sum_hidden_states = torch.sum(masked_hidden_states, dim=1)
638
+ sum_attention_mask = torch.sum(expanded_attention_mask, dim=1)
639
+ pooled_output = sum_hidden_states / torch.clamp(sum_attention_mask, min=1e-9)
640
+ pooled_output = pooled_output.to(dtype=last_hidden_state.dtype)
641
+ else:
642
+ pooled_output = torch.mean(last_hidden_state, dim=1)
643
+
644
+ return CM3PBeatmapModelOutput(
645
+ last_hidden_state=last_hidden_state,
646
+ pooler_output=pooled_output,
647
+ hidden_states=encoder_outputs.hidden_states,
648
+ attentions=encoder_outputs.attentions,
649
+ audio_model_output=audio_model_outputs,
650
+ )
651
+
652
+
653
+ @auto_docstring(
654
+ custom_intro="""
655
+ The beatmap model from CM3P without any head or projection on top.
656
+ """
657
+ )
658
+ class CM3PBeatmapModel(CM3PPreTrainedModel):
659
+ config_class = CM3PBeatmapConfig
660
+ main_input_name = "input_ids"
661
+
662
+ def __init__(self, config: CM3PBeatmapConfig):
663
+ super().__init__(config)
664
+ self.beatmap_model = CM3PBeatmapTransformer(config)
665
+ # Initialize weights and apply final processing
666
+ self.post_init()
667
+
668
+ def get_input_embeddings(self) -> nn.Module:
669
+ return self.beatmap_model.encoder.embeddings.tok_embeddings
670
+
671
+ def set_input_embeddings(self, value):
672
+ self.beatmap_model.encoder.embeddings.tok_embeddings = value
673
+
674
+ @can_return_tuple
675
+ @auto_docstring
676
+ def forward(
677
+ self,
678
+ input_ids: Optional[torch.LongTensor] = None,
679
+ input_features: Optional[torch.FloatTensor] = None,
680
+ attention_mask: Optional[torch.FloatTensor] = None,
681
+ position_ids: Optional[torch.LongTensor] = None,
682
+ inputs_embeds: Optional[torch.FloatTensor] = None,
683
+ indices: Optional[torch.Tensor] = None,
684
+ cu_seqlens: Optional[torch.Tensor] = None,
685
+ max_seqlen: Optional[int] = None,
686
+ batch_size: Optional[int] = None,
687
+ seq_len: Optional[int] = None,
688
+ output_attentions: Optional[bool] = None,
689
+ output_hidden_states: Optional[bool] = None,
690
+ output_pooler: bool = True,
691
+ ) -> CM3PBeatmapModelOutput:
692
+ r"""
693
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
694
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
695
+ compute the beatmap embeddings.
696
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
697
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
698
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
699
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
700
+ max_seqlen (`int`, *optional*):
701
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
702
+ batch_size (`int`, *optional*):
703
+ Batch size of the input sequences. Used to pad the output tensors.
704
+ seq_len (`int`, *optional*):
705
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
706
+ output_pooler (`bool`, *optional*, defaults to `True`):
707
+ Whether to return the pooled output of the model. The pooled output is usually the representation of
708
+ the first token (CLS) or the mean of the token representations.
709
+ """
710
+
711
+ return self.beatmap_model(
712
+ input_ids=input_ids,
713
+ input_features=input_features,
714
+ attention_mask=attention_mask,
715
+ position_ids=position_ids,
716
+ inputs_embeds=inputs_embeds,
717
+ indices=indices,
718
+ cu_seqlens=cu_seqlens,
719
+ max_seqlen=max_seqlen,
720
+ batch_size=batch_size,
721
+ seq_len=seq_len,
722
+ output_attentions=output_attentions,
723
+ output_hidden_states=output_hidden_states,
724
+ output_pooler=output_pooler,
725
+ )
726
+
727
+
728
+ @auto_docstring
729
+ class CM3PModel(CM3PPreTrainedModel):
730
+ config_class = CM3PConfig
731
+
732
+ def __init__(self, config: CM3PConfig):
733
+ super().__init__(config)
734
+
735
+ if not isinstance(config.metadata_config, CM3PMetadataConfig):
736
+ raise TypeError(
737
+ "config.metadata_config is expected to be of type CM3PMetadataConfig but is of type"
738
+ f" {type(config.metadata_config)}."
739
+ )
740
+
741
+ if not isinstance(config.beatmap_config, CM3PBeatmapConfig):
742
+ raise TypeError(
743
+ "config.beatmap_config is expected to be of type CM3PBeatmapConfig but is of type"
744
+ f" {type(config.beatmap_config)}."
745
+ )
746
+
747
+ metadata_config = config.metadata_config
748
+ beatmap_config = config.beatmap_config
749
+
750
+ self.projection_dim: int = config.projection_dim
751
+ self.metadata_embed_dim: int = metadata_config.hidden_size
752
+ self.beatmap_embed_dim: int = beatmap_config.hidden_size
753
+ self.loss_type = config.loss_type
754
+
755
+ metadata_model = CM3PMetadataModel._from_config(metadata_config)
756
+ self.metadata_model = metadata_model.metadata_model
757
+
758
+ beatmap_model = CM3PBeatmapModel._from_config(beatmap_config)
759
+ self.beatmap_model = beatmap_model.beatmap_model
760
+
761
+ self.beatmap_projection = nn.Linear(self.beatmap_embed_dim, self.projection_dim, bias=False)
762
+ self.metadata_projection = nn.Linear(self.metadata_embed_dim, self.projection_dim, bias=False)
763
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
764
+
765
+ if config.has_decoder_head:
766
+ self.head = CM3PPredictionHead(beatmap_config)
767
+ self.decoder = nn.Linear(beatmap_config.hidden_size, beatmap_config.vocab_size, bias=beatmap_config.decoder_bias)
768
+
769
+ # Initialize weights and apply final processing
770
+ self.post_init()
771
+
772
+ @auto_docstring
773
+ def get_metadata_features(
774
+ self,
775
+ input_ids: Optional[torch.LongTensor] = None,
776
+ output_attentions: Optional[bool] = None,
777
+ output_hidden_states: Optional[bool] = None,
778
+ ) -> torch.FloatTensor:
779
+ r"""
780
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
781
+ The input IDs for the metadata model. The model will use these IDs to compute the metadata embeddings.
782
+ Returns:
783
+ metadata_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The metadata embeddings obtained by
784
+ applying the projection layer to the pooled output of [`CM3PMetadataModel`].
785
+ """
786
+ # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
787
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
788
+ output_hidden_states = (
789
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
790
+ )
791
+
792
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
793
+ input_ids=input_ids,
794
+ output_attentions=output_attentions,
795
+ output_hidden_states=output_hidden_states,
796
+ )
797
+
798
+ pooled_output = metadata_outputs.pooler_output
799
+ metadata_features = self.metadata_projection(pooled_output)
800
+
801
+ return metadata_features
802
+
803
+ @auto_docstring
804
+ def get_beatmap_features(
805
+ self,
806
+ input_ids: Optional[torch.LongTensor] = None,
807
+ input_features: Optional[torch.FloatTensor] = None,
808
+ attention_mask: Optional[torch.Tensor] = None,
809
+ position_ids: Optional[torch.LongTensor] = None,
810
+ inputs_embeds: Optional[torch.FloatTensor] = None,
811
+ output_attentions: Optional[bool] = None,
812
+ output_hidden_states: Optional[bool] = None,
813
+ ) -> torch.FloatTensor:
814
+ r"""
815
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
816
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
817
+ compute the beatmap embeddings.
818
+ Returns:
819
+ beatmap_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The beatmap embeddings obtained by
820
+ applying the projection layer to the pooled output of [`CM3PBeatmapModel`].
821
+ """
822
+ # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
823
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
824
+ output_hidden_states = (
825
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
826
+ )
827
+
828
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
829
+ input_ids=input_ids,
830
+ input_features=input_features,
831
+ attention_mask=attention_mask,
832
+ position_ids=position_ids,
833
+ inputs_embeds=inputs_embeds,
834
+ output_attentions=output_attentions,
835
+ output_hidden_states=output_hidden_states,
836
+ )
837
+
838
+ pooled_output = beatmap_outputs.pooler_output
839
+ beatmap_features = self.beatmap_projection(pooled_output)
840
+
841
+ return beatmap_features
842
+
843
+ @torch.compile(dynamic=True)
844
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
845
+ return self.decoder(self.head(output))
846
+
847
+ @can_return_tuple
848
+ @auto_docstring
849
+ def forward(
850
+ self,
851
+ input_ids: Optional[torch.LongTensor] = None,
852
+ input_features: Optional[torch.FloatTensor] = None,
853
+ metadata_ids: Optional[torch.LongTensor] = None,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ metadata_attention_mask: Optional[torch.Tensor] = None,
856
+ position_ids: Optional[torch.LongTensor] = None,
857
+ inputs_embeds: Optional[torch.FloatTensor] = None,
858
+ metadata_variation_classes: Optional[torch.LongTensor] = None,
859
+ labels: Optional[torch.Tensor] = None,
860
+ indices: Optional[torch.Tensor] = None,
861
+ cu_seqlens: Optional[torch.Tensor] = None,
862
+ max_seqlen: Optional[int] = None,
863
+ batch_size: Optional[int] = None,
864
+ seq_len: Optional[int] = None,
865
+ return_loss: Optional[bool] = True,
866
+ output_attentions: Optional[bool] = None,
867
+ output_hidden_states: Optional[bool] = None,
868
+ output_logits: Optional[bool] = None,
869
+ **kwargs,
870
+ ) -> CM3POutput:
871
+ r"""
872
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
873
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
874
+ compute the beatmap embeddings.
875
+ metadata_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)` or `(batch_size, variations, sequence_length)`):
876
+ The input IDs for the metadata model. The model will use these IDs to compute the metadata embeddings.
877
+ metadata_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)` or `(batch_size, variations, sequence_length)`, *optional*):
878
+ The attention mask for the metadata model. If provided, the model will not attend to the padded tokens.
879
+ metadata_variation_classes (`torch.LongTensor` of shape `(batch_size, variations)`, *optional*):
880
+ Tells the model what kind of variation each metadata sequence is.
881
+ 0 indicates the original metadata, -1 indicates paddidng, and any positive integer indicates a specific variation class.
882
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
883
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
884
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
885
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
886
+ max_seqlen (`int`, *optional*):
887
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
888
+ batch_size (`int`, *optional*):
889
+ Batch size of the input sequences. Used to pad the output tensors.
890
+ seq_len (`int`, *optional*):
891
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
892
+ return_loss (`bool`, *optional*):
893
+ Whether to return the contrastive loss.
894
+ output_logits (`bool`, *optional*):
895
+ Whether to return the logits from the decoder head.
896
+ """
897
+ # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
898
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
899
+ output_hidden_states = (
900
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
901
+ )
902
+ output_logits = output_logits if output_logits is not None else self.config.has_decoder_head
903
+
904
+ if metadata_ids.dim() == 3 and return_loss and metadata_variation_classes is None:
905
+ raise ValueError("When providing multiple metadata variations, metadata_variation_classes must be provided in order to compute loss correctly.")
906
+
907
+ if output_logits and not self.config.has_decoder_head:
908
+ raise ValueError("Cannot return logits when the model is not configured with a decoder head.")
909
+
910
+ # noinspection PyProtectedMember
911
+ if self.config._attn_implementation == "flash_attention_2":
912
+ if indices is None and cu_seqlens is None and max_seqlen is None:
913
+ if batch_size is None and seq_len is None:
914
+ if inputs_embeds is not None:
915
+ batch_size, seq_len = inputs_embeds.shape[:2]
916
+ else:
917
+ batch_size, seq_len = input_ids.shape[:2]
918
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
919
+
920
+ if attention_mask is None:
921
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
922
+
923
+ if inputs_embeds is None:
924
+ with torch.no_grad():
925
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
926
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
927
+ )
928
+ else:
929
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
930
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
931
+ )
932
+
933
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
934
+ input_ids=input_ids,
935
+ input_features=input_features,
936
+ attention_mask=attention_mask,
937
+ position_ids=position_ids,
938
+ inputs_embeds=inputs_embeds,
939
+ indices=indices,
940
+ cu_seqlens=cu_seqlens,
941
+ max_seqlen=max_seqlen,
942
+ batch_size=batch_size,
943
+ seq_len=seq_len,
944
+ output_attentions=output_attentions,
945
+ output_hidden_states=output_hidden_states,
946
+ )
947
+
948
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
949
+ input_ids=metadata_ids,
950
+ attention_mask=metadata_attention_mask,
951
+ output_attentions=output_attentions,
952
+ output_hidden_states=output_hidden_states,
953
+ )
954
+
955
+ beatmap_embeds = beatmap_outputs.pooler_output
956
+ beatmap_embeds = self.beatmap_projection(beatmap_embeds)
957
+
958
+ metadata_embeds = metadata_outputs.pooler_output
959
+ metadata_embeds = self.metadata_projection(metadata_embeds)
960
+
961
+ # normalized features
962
+ beatmap_embeds = beatmap_embeds / _get_vector_norm(beatmap_embeds)
963
+ metadata_embeds = metadata_embeds / _get_vector_norm(metadata_embeds)
964
+
965
+ # cosine similarity as logits
966
+ logits_per_metadata = torch.matmul(metadata_embeds, beatmap_embeds.t().to(metadata_embeds.device))
967
+ logits_per_metadata = logits_per_metadata * self.logit_scale.exp().to(metadata_embeds.device)
968
+
969
+ if logits_per_metadata.dim() == 3:
970
+ logits_per_beatmap = logits_per_metadata.permute(2, 0, 1)
971
+ else:
972
+ logits_per_beatmap = logits_per_metadata.t()
973
+
974
+ loss = None
975
+ if return_loss:
976
+ loss = cm3p_loss(logits_per_metadata, metadata_variation_classes)
977
+
978
+ logits = None
979
+ if output_logits:
980
+ logits = (
981
+ self.compiled_head(beatmap_outputs.last_hidden_state)
982
+ if self.config.beatmap_config.reference_compile
983
+ else self.decoder(self.head(beatmap_outputs.last_hidden_state))
984
+ )
985
+
986
+ if labels is not None and return_loss:
987
+ mlm_loss = self.loss_function(logits, labels, vocab_size=self.config.beatmap_config.vocab_size, **kwargs)
988
+ loss += 0.5 * mlm_loss
989
+
990
+ # noinspection PyProtectedMember
991
+ if self.config._attn_implementation == "flash_attention_2":
992
+ with nullcontext() if self.config.beatmap_config.repad_logits_with_grad or labels is None else torch.no_grad():
993
+ logits = _pad_cm3p_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
994
+
995
+ return CM3POutput(
996
+ loss=loss,
997
+ logits_per_beatmap=logits_per_beatmap,
998
+ logits_per_metadata=logits_per_metadata,
999
+ metadata_embeds=metadata_embeds,
1000
+ beatmap_embeds=beatmap_embeds,
1001
+ logits=logits,
1002
+ metadata_model_output=metadata_outputs,
1003
+ beatmap_model_output=beatmap_outputs,
1004
+ )
1005
+
1006
+
1007
+ @auto_docstring
1008
+ class CM3PMetadataModelWithProjection(CM3PPreTrainedModel):
1009
+ config_class = CM3PMetadataConfig
1010
+
1011
+ def __init__(self, config: CM3PMetadataConfig):
1012
+ super().__init__(config)
1013
+
1014
+ metadata_model = CM3PMetadataModel._from_config(config)
1015
+ self.metadata_model = metadata_model.metadata_model
1016
+
1017
+ self.metadata_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1018
+
1019
+ # Initialize weights and apply final processing
1020
+ self.post_init()
1021
+
1022
+ def get_input_embeddings(self) -> nn.Module:
1023
+ return self.metadata_model.get_input_embeddings()
1024
+
1025
+ def set_input_embeddings(self, value):
1026
+ self.metadata_model.set_input_embeddings(value)
1027
+
1028
+ @can_return_tuple
1029
+ @auto_docstring
1030
+ def forward(
1031
+ self,
1032
+ input_ids: Optional[torch.Tensor] = None,
1033
+ attention_mask: Optional[torch.Tensor] = None,
1034
+ output_attentions: Optional[bool] = None,
1035
+ output_hidden_states: Optional[bool] = None,
1036
+ ) -> CM3PMetadataModelOutput:
1037
+ r"""
1038
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1039
+ The input IDs for the metadata model. The model will use these IDs to compute the metadata embeddings.
1040
+ Returns:
1041
+ metadata_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The metadata embeddings obtained by
1042
+ applying the projection layer to the pooled output of [`CM3PMetadataModel`].
1043
+ """
1044
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
1045
+ input_ids=input_ids,
1046
+ attention_mask=attention_mask,
1047
+ output_attentions=output_attentions,
1048
+ output_hidden_states=output_hidden_states,
1049
+ )
1050
+ pooled_output = metadata_outputs.pooler_output
1051
+ metadata_embeds = self.metadata_projection(pooled_output)
1052
+
1053
+ return CM3PMetadataModelOutput(
1054
+ metadata_embeds=metadata_embeds,
1055
+ last_hidden_state=metadata_outputs.last_hidden_state,
1056
+ hidden_states=metadata_outputs.hidden_states,
1057
+ attentions=metadata_outputs.attentions,
1058
+ )
1059
+
1060
+
1061
+ @auto_docstring
1062
+ class CM3PBeatmapModelWithProjection(CM3PPreTrainedModel):
1063
+ config_class = CM3PBeatmapConfig
1064
+
1065
+ def __init__(self, config: CM3PBeatmapConfig):
1066
+ super().__init__(config)
1067
+
1068
+ beatmap_model = CM3PBeatmapModel._from_config(config)
1069
+ self.beatmap_model = beatmap_model.beatmap_model
1070
+
1071
+ self.beatmap_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)
1072
+
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def get_input_embeddings(self) -> nn.Module:
1077
+ return self.beatmap_model.get_input_embeddings()
1078
+
1079
+ def set_input_embeddings(self, value):
1080
+ self.beatmap_model.set_input_embeddings(value)
1081
+
1082
+ @can_return_tuple
1083
+ @auto_docstring
1084
+ def forward(
1085
+ self,
1086
+ input_ids: Optional[torch.LongTensor] = None,
1087
+ input_features: Optional[torch.FloatTensor] = None,
1088
+ attention_mask: Optional[torch.Tensor] = None,
1089
+ position_ids: Optional[torch.LongTensor] = None,
1090
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1091
+ output_attentions: Optional[bool] = None,
1092
+ output_hidden_states: Optional[bool] = None,
1093
+ ) -> CM3PBeatmapModelOutput:
1094
+ r"""
1095
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
1096
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
1097
+ compute the beatmap embeddings.
1098
+ Returns:
1099
+ beatmap_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The beatmap embeddings obtained by
1100
+ applying the projection layer to the pooled output of [`CM3PBeatmapModel`].
1101
+ """
1102
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
1103
+ input_ids=input_ids,
1104
+ input_features=input_features,
1105
+ attention_mask=attention_mask,
1106
+ position_ids=position_ids,
1107
+ inputs_embeds=inputs_embeds,
1108
+ output_attentions=output_attentions,
1109
+ output_hidden_states=output_hidden_states,
1110
+ )
1111
+ pooled_output = beatmap_outputs.pooler_output
1112
+ beatmap_embeds = self.beatmap_projection(pooled_output)
1113
+
1114
+ return CM3PBeatmapModelOutput(
1115
+ beatmap_embeds=beatmap_embeds,
1116
+ pooler_output=pooled_output,
1117
+ last_hidden_state=beatmap_outputs.last_hidden_state,
1118
+ hidden_states=beatmap_outputs.hidden_states,
1119
+ attentions=beatmap_outputs.attentions,
1120
+ )
1121
+
1122
+
1123
+ @auto_docstring(
1124
+ custom_intro="""
1125
+ CM3P beatmap encoder with an beatmap classification head on top (a linear layer on top of the pooled final hidden states of
1126
+ the beatmap embeddings) e.g. for BeatmapNet.
1127
+ """
1128
+ )
1129
+ class CM3PForBeatmapClassification(CM3PPreTrainedModel):
1130
+ config_class = CM3PBeatmapConfig
1131
+ base_model_prefix = "beatmap_model"
1132
+
1133
+ def __init__(self, config: CM3PBeatmapConfig) -> None:
1134
+ super().__init__(config)
1135
+
1136
+ self.num_labels = config.num_labels
1137
+ beatmap_model = CM3PBeatmapModel._from_config(config)
1138
+ self.beatmap_model = beatmap_model.beatmap_model
1139
+
1140
+ # Classifier head
1141
+ self.classifier = (
1142
+ nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1143
+ )
1144
+
1145
+ # Initialize weights and apply final processing
1146
+ self.post_init()
1147
+
1148
+ @can_return_tuple
1149
+ @auto_docstring
1150
+ def forward(
1151
+ self,
1152
+ input_ids: Optional[torch.LongTensor] = None,
1153
+ input_features: Optional[torch.FloatTensor] = None,
1154
+ attention_mask: Optional[torch.Tensor] = None,
1155
+ position_ids: Optional[torch.LongTensor] = None,
1156
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1157
+ labels: Optional[torch.Tensor] = None,
1158
+ output_attentions: Optional[bool] = None,
1159
+ output_hidden_states: Optional[bool] = None,
1160
+ ) -> BeatmapClassifierOutput:
1161
+ r"""
1162
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
1163
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
1164
+ compute the beatmap embeddings.
1165
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1166
+ Labels for computing the beatmap classification/regression loss. Indices should be in `[0, ...,
1167
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1168
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1169
+ """
1170
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1171
+ output_hidden_states = (
1172
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1173
+ )
1174
+
1175
+ outputs: BaseModelOutputWithPooling = self.beatmap_model(
1176
+ input_ids=input_ids,
1177
+ input_features=input_features,
1178
+ attention_mask=attention_mask,
1179
+ position_ids=position_ids,
1180
+ inputs_embeds=inputs_embeds,
1181
+ output_attentions=output_attentions,
1182
+ output_hidden_states=output_hidden_states,
1183
+ )
1184
+
1185
+ pooled_output = outputs.pooler_output
1186
+ logits = self.classifier(pooled_output)
1187
+
1188
+ loss = None
1189
+ if labels is not None:
1190
+ # move labels to correct device to enable model parallelism
1191
+ labels = labels.to(logits.device)
1192
+ if self.config.problem_type is None:
1193
+ if self.num_labels == 1:
1194
+ self.config.problem_type = "regression"
1195
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1196
+ self.config.problem_type = "single_label_classification"
1197
+ else:
1198
+ self.config.problem_type = "multi_label_classification"
1199
+
1200
+ if self.config.problem_type == "regression":
1201
+ loss_fct = MSELoss()
1202
+ if self.num_labels == 1:
1203
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1204
+ else:
1205
+ loss = loss_fct(logits, labels)
1206
+ elif self.config.problem_type == "single_label_classification":
1207
+ loss_fct = CrossEntropyLoss()
1208
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1209
+ elif self.config.problem_type == "multi_label_classification":
1210
+ loss_fct = BCEWithLogitsLoss()
1211
+ loss = loss_fct(logits, labels)
1212
+
1213
+ return BeatmapClassifierOutput(
1214
+ loss=loss,
1215
+ logits=logits,
1216
+ hidden_states=outputs.hidden_states,
1217
+ attentions=outputs.attentions,
1218
+ )
1219
+
1220
+
1221
+ class CM3PPredictionHead(nn.Module):
1222
+ def __init__(self, config: CM3PBeatmapConfig):
1223
+ super().__init__()
1224
+ self.config = config
1225
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias)
1226
+ self.act = ACT2FN[config.classifier_activation]
1227
+ self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
1228
+
1229
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1230
+ return self.norm(self.act(self.dense(hidden_states)))
1231
+
1232
+
1233
+ class CM3PForMaskedLM(CM3PPreTrainedModel):
1234
+ config_class = CM3PBeatmapConfig
1235
+ base_model_prefix = "beatmap_model"
1236
+ _tied_weights_keys = ["decoder.weight"]
1237
+
1238
+ def __init__(self, config: CM3PBeatmapConfig):
1239
+ super().__init__(config)
1240
+ self.config = config
1241
+ beatmap_model = CM3PBeatmapModel._from_config(config)
1242
+ self.beatmap_model = beatmap_model.beatmap_model
1243
+ self.head = CM3PPredictionHead(config)
1244
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=config.decoder_bias)
1245
+
1246
+ self.sparse_prediction = self.config.sparse_prediction
1247
+ self.sparse_pred_ignore_index = self.config.sparse_pred_ignore_index
1248
+
1249
+ # Initialize weights and apply final processing
1250
+ self.post_init()
1251
+
1252
+ def get_output_embeddings(self):
1253
+ return self.decoder
1254
+
1255
+ def set_output_embeddings(self, new_embeddings: nn.Linear):
1256
+ self.decoder = new_embeddings
1257
+
1258
+ @torch.compile(dynamic=True)
1259
+ def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
1260
+ return self.decoder(self.head(output))
1261
+
1262
+ @auto_docstring
1263
+ def forward(
1264
+ self,
1265
+ input_ids: Optional[torch.LongTensor] = None,
1266
+ input_features: Optional[torch.FloatTensor] = None,
1267
+ attention_mask: Optional[torch.Tensor] = None,
1268
+ sliding_window_mask: Optional[torch.Tensor] = None,
1269
+ position_ids: Optional[torch.Tensor] = None,
1270
+ inputs_embeds: Optional[torch.Tensor] = None,
1271
+ labels: Optional[torch.Tensor] = None,
1272
+ indices: Optional[torch.Tensor] = None,
1273
+ cu_seqlens: Optional[torch.Tensor] = None,
1274
+ max_seqlen: Optional[int] = None,
1275
+ batch_size: Optional[int] = None,
1276
+ seq_len: Optional[int] = None,
1277
+ output_attentions: Optional[bool] = None,
1278
+ output_hidden_states: Optional[bool] = None,
1279
+ **kwargs,
1280
+ ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
1281
+ r"""
1282
+ input_features (`torch.FloatTensor` of shape `(batch_size, num_frames, num_mels)`, *optional*):
1283
+ The audio frames to be processed by the audio encoder. If provided, the model will use these frames to
1284
+ compute the beatmap embeddings.
1285
+ sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1286
+ Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
1287
+ perform global attention, while the rest perform local attention. This mask is used to avoid attending to
1288
+ far-away tokens in the local attention layers when not using Flash Attention.
1289
+ indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
1290
+ Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
1291
+ cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
1292
+ Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
1293
+ max_seqlen (`int`, *optional*):
1294
+ Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
1295
+ batch_size (`int`, *optional*):
1296
+ Batch size of the input sequences. Used to pad the output tensors.
1297
+ seq_len (`int`, *optional*):
1298
+ Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
1299
+ """
1300
+ # noinspection PyProtectedMember
1301
+ if self.config._attn_implementation == "flash_attention_2":
1302
+ if indices is None and cu_seqlens is None and max_seqlen is None:
1303
+ if batch_size is None and seq_len is None:
1304
+ if inputs_embeds is not None:
1305
+ batch_size, seq_len = inputs_embeds.shape[:2]
1306
+ else:
1307
+ batch_size, seq_len = input_ids.shape[:2]
1308
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1309
+
1310
+ if attention_mask is None:
1311
+ attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
1312
+
1313
+ if inputs_embeds is None:
1314
+ with torch.no_grad():
1315
+ input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
1316
+ inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1317
+ )
1318
+ else:
1319
+ inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_cm3p_input(
1320
+ inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
1321
+ )
1322
+
1323
+ outputs = self.beatmap_model(
1324
+ input_ids=input_ids,
1325
+ input_features=input_features,
1326
+ attention_mask=attention_mask,
1327
+ sliding_window_mask=sliding_window_mask,
1328
+ position_ids=position_ids,
1329
+ inputs_embeds=inputs_embeds,
1330
+ indices=indices,
1331
+ cu_seqlens=cu_seqlens,
1332
+ max_seqlen=max_seqlen,
1333
+ batch_size=batch_size,
1334
+ seq_len=seq_len,
1335
+ output_attentions=output_attentions,
1336
+ output_hidden_states=output_hidden_states,
1337
+ output_pooler=False,
1338
+ )
1339
+ last_hidden_state = outputs.last_hidden_state
1340
+
1341
+ if self.sparse_prediction and labels is not None:
1342
+ # flatten labels and output first
1343
+ labels = labels.view(-1)
1344
+ last_hidden_state = last_hidden_state.view(labels.shape[0], -1)
1345
+
1346
+ # then filter out the non-masked tokens
1347
+ mask_tokens = labels != self.sparse_pred_ignore_index
1348
+ last_hidden_state = last_hidden_state[mask_tokens]
1349
+ labels = labels[mask_tokens]
1350
+
1351
+ logits = (
1352
+ self.compiled_head(last_hidden_state)
1353
+ if self.config.reference_compile
1354
+ else self.decoder(self.head(last_hidden_state))
1355
+ )
1356
+
1357
+ loss = None
1358
+ if labels is not None:
1359
+ loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
1360
+
1361
+ # noinspection PyProtectedMember
1362
+ if self.config._attn_implementation == "flash_attention_2":
1363
+ with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
1364
+ logits = _pad_cm3p_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1365
+
1366
+ return MaskedLMOutput(
1367
+ loss=loss,
1368
+ logits=logits,
1369
+ hidden_states=outputs.hidden_states,
1370
+ attentions=outputs.attentions,
1371
+ )
1372
+
1373
+
1374
+ AutoModel.register(CM3PMetadataConfig, CM3PMetadataModel)
1375
+ AutoModel.register(CM3PBeatmapConfig, CM3PBeatmapModel)
1376
+ AutoModel.register(CM3PConfig, CM3PModel)
1377
+ AutoModelForSequenceClassification.register(CM3PBeatmapConfig, CM3PForBeatmapClassification)
1378
+ AutoModelForMaskedLM.register(CM3PBeatmapConfig, CM3PForMaskedLM)
1379
+
1380
+ __all__ = [
1381
+ "CM3PModel",
1382
+ "CM3PPreTrainedModel",
1383
+ "CM3PMetadataModel",
1384
+ "CM3PMetadataModelWithProjection",
1385
+ "CM3PBeatmapModel",
1386
+ "CM3PBeatmapModelWithProjection",
1387
+ "CM3PForBeatmapClassification",
1388
+ "CM3PForMaskedLM",
1389
+ ]