OliBomby commited on
Commit
17137e3
·
verified ·
1 Parent(s): c488d69

Add CM3P model

Browse files
Files changed (2) hide show
  1. configuration_cm3p.py +28 -6
  2. modeling_cm3p.py +83 -61
configuration_cm3p.py CHANGED
@@ -13,7 +13,7 @@ class CM3PMetadataConfig(PretrainedConfig):
13
 
14
  def __init__(
15
  self,
16
- cls_embed=False,
17
 
18
  projection_dim=512,
19
  initializer_factor=1.0,
@@ -177,6 +177,7 @@ class CM3PAudioConfig(PretrainedConfig):
177
 
178
  class CM3PBeatmapConfig(PretrainedConfig):
179
  model_type = "CM3PBeatmap"
 
180
  base_config_key = "beatmap_config"
181
  sub_configs = {"audio_config": CM3PAudioConfig}
182
 
@@ -186,7 +187,7 @@ class CM3PBeatmapConfig(PretrainedConfig):
186
  audio_sos_token_id=3164,
187
  audio_eos_token_id=3165,
188
  audio_token_id=3166,
189
- cls_embed=False,
190
 
191
  projection_dim=512,
192
  initializer_factor=1.0,
@@ -222,12 +223,15 @@ class CM3PBeatmapConfig(PretrainedConfig):
222
  sparse_pred_ignore_index=-100,
223
  reference_compile=None,
224
  repad_logits_with_grad=False,
 
 
225
  **kwargs,
226
  ):
227
  super().__init__(
228
  pad_token_id=pad_token_id,
229
  bos_token_id=bos_token_id,
230
  eos_token_id=eos_token_id,
 
231
  **kwargs,
232
  )
233
 
@@ -235,7 +239,11 @@ class CM3PBeatmapConfig(PretrainedConfig):
235
  audio_config = {}
236
  logger.info("`audio_config` is `None`. Initializing the `CM3PAudioConfig` with default values.")
237
 
238
- self.audio_config = CM3PAudioConfig(**audio_config)
 
 
 
 
239
  self.audio_sos_token_id = audio_sos_token_id
240
  self.audio_eos_token_id = audio_eos_token_id
241
  self.audio_token_id = audio_token_id
@@ -280,6 +288,7 @@ class CM3PBeatmapConfig(PretrainedConfig):
280
 
281
  class CM3PConfig(PretrainedConfig):
282
  model_type = "CM3P"
 
283
  sub_configs = {"metadata_config": CM3PMetadataConfig, "beatmap_config": CM3PBeatmapConfig}
284
 
285
  def __init__(
@@ -291,9 +300,15 @@ class CM3PConfig(PretrainedConfig):
291
  initializer_factor=1.0,
292
  initializer_range=0.02,
293
  loss_type=None,
 
 
 
294
  **kwargs
295
  ):
296
- super().__init__(**kwargs)
 
 
 
297
 
298
  if metadata_config is None:
299
  metadata_config = {}
@@ -303,14 +318,21 @@ class CM3PConfig(PretrainedConfig):
303
  beatmap_config = {}
304
  logger.debug("`beatmap_config` is `None`. initializing the `CM3PBeatmapConfig` with default values.")
305
 
306
- self.metadata_config = CM3PMetadataConfig(**metadata_config)
307
- self.beatmap_config = CM3PBeatmapConfig(**beatmap_config)
 
 
 
 
 
 
308
 
309
  self.projection_dim = projection_dim
310
  self.logit_scale_init_value = logit_scale_init_value
311
  self.initializer_factor = initializer_factor
312
  self.initializer_range = initializer_range
313
  self.loss_type = loss_type
 
314
 
315
 
316
  AutoConfig.register("CM3PMetadata", CM3PMetadataConfig)
 
13
 
14
  def __init__(
15
  self,
16
+ cls_embed=True,
17
 
18
  projection_dim=512,
19
  initializer_factor=1.0,
 
177
 
178
  class CM3PBeatmapConfig(PretrainedConfig):
179
  model_type = "CM3PBeatmap"
180
+ is_composition = True
181
  base_config_key = "beatmap_config"
182
  sub_configs = {"audio_config": CM3PAudioConfig}
183
 
 
187
  audio_sos_token_id=3164,
188
  audio_eos_token_id=3165,
189
  audio_token_id=3166,
190
+ cls_embed=True,
191
 
192
  projection_dim=512,
193
  initializer_factor=1.0,
 
223
  sparse_pred_ignore_index=-100,
224
  reference_compile=None,
225
  repad_logits_with_grad=False,
226
+
227
+ attn_implementation: str = None,
228
  **kwargs,
229
  ):
230
  super().__init__(
231
  pad_token_id=pad_token_id,
232
  bos_token_id=bos_token_id,
233
  eos_token_id=eos_token_id,
234
+ attn_implementation=attn_implementation,
235
  **kwargs,
236
  )
237
 
 
239
  audio_config = {}
240
  logger.info("`audio_config` is `None`. Initializing the `CM3PAudioConfig` with default values.")
241
 
242
+ self.audio_config = CM3PAudioConfig(
243
+ attn_implementation=attn_implementation,
244
+ **audio_config
245
+ )
246
+
247
  self.audio_sos_token_id = audio_sos_token_id
248
  self.audio_eos_token_id = audio_eos_token_id
249
  self.audio_token_id = audio_token_id
 
288
 
289
  class CM3PConfig(PretrainedConfig):
290
  model_type = "CM3P"
291
+ is_composition = True
292
  sub_configs = {"metadata_config": CM3PMetadataConfig, "beatmap_config": CM3PBeatmapConfig}
293
 
294
  def __init__(
 
300
  initializer_factor=1.0,
301
  initializer_range=0.02,
302
  loss_type=None,
303
+ has_decoder_head=False,
304
+
305
+ attn_implementation: str = None,
306
  **kwargs
307
  ):
308
+ super().__init__(
309
+ attn_implementation=attn_implementation,
310
+ **kwargs
311
+ )
312
 
313
  if metadata_config is None:
314
  metadata_config = {}
 
318
  beatmap_config = {}
319
  logger.debug("`beatmap_config` is `None`. initializing the `CM3PBeatmapConfig` with default values.")
320
 
321
+ self.metadata_config = CM3PMetadataConfig(
322
+ attn_implementation=attn_implementation,
323
+ **metadata_config
324
+ )
325
+ self.beatmap_config = CM3PBeatmapConfig(
326
+ attn_implementation=attn_implementation,
327
+ **beatmap_config
328
+ )
329
 
330
  self.projection_dim = projection_dim
331
  self.logit_scale_init_value = logit_scale_init_value
332
  self.initializer_factor = initializer_factor
333
  self.initializer_range = initializer_range
334
  self.loss_type = loss_type
335
+ self.has_decoder_head = has_decoder_head
336
 
337
 
338
  AutoConfig.register("CM3PMetadata", CM3PMetadataConfig)
modeling_cm3p.py CHANGED
@@ -24,7 +24,7 @@ logger = logging.get_logger(__name__)
24
 
25
  # contrastive loss function, adapted from
26
  # https://sachinruk.github.io/blog/2021-03-07-clip.html
27
- def contrastive_loss(logits: torch.Tensor, target: torch.LongTensor = None) -> torch.Tensor:
28
  target = target if target is not None else torch.arange(len(logits), device=logits.device)
29
  return nn.functional.cross_entropy(logits, target)
30
 
@@ -192,7 +192,7 @@ class CM3PBeatmapModelOutput(BaseModelOutputWithPooling):
192
  """
193
 
194
  beatmap_embeds: Optional[torch.FloatTensor] = None
195
- audio_model_output: CM3PAudioModelOutput = None
196
 
197
 
198
  @dataclass
@@ -235,8 +235,8 @@ class CM3POutput(ModelOutput):
235
  """
236
 
237
  loss: Optional[torch.FloatTensor] = None
238
- logits_per_beatmap: Optional[torch.FloatTensor] = None
239
- logits_per_metadata: Optional[torch.FloatTensor] = None
240
  metadata_embeds: Optional[torch.FloatTensor] = None
241
  beatmap_embeds: Optional[torch.FloatTensor] = None
242
  logits: Optional[torch.FloatTensor] = None
@@ -301,6 +301,7 @@ class CM3PMetadataTransformer(nn.Module):
301
  def __init__(self, config: CM3PMetadataConfig):
302
  super().__init__()
303
  self.config = config
 
304
  self.encoder = ModernBertModel(config)
305
 
306
  def get_input_embeddings(self):
@@ -486,6 +487,7 @@ class CM3PAudioEncoder(nn.Module):
486
  self.config = config
487
  self.conv1 = nn.Conv1d(config.n_mels, config.hidden_size, kernel_size=3, padding=1)
488
  self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
 
489
  self.encoder = ModernBertModel(config)
490
  self.multi_modal_projector = CM3PMultiModalProjector(config)
491
 
@@ -531,6 +533,7 @@ class CM3PBeatmapTransformer(nn.Module):
531
  super().__init__()
532
  self.config = config
533
  self.audio_encoder = CM3PAudioEncoder(config.audio_config)
 
534
  self.encoder = ModernBertModel(config)
535
 
536
  def get_input_embeddings(self):
@@ -590,7 +593,7 @@ class CM3PBeatmapTransformer(nn.Module):
590
 
591
  audio_model_outputs = None
592
  if input_features is not None:
593
- audio_model_outputs: CM3PAudioModelOutput = self.audio_encoder(
594
  input_features=input_features,
595
  output_attentions=output_attentions,
596
  output_hidden_states=output_hidden_states,
@@ -744,9 +747,9 @@ class CM3PModel(CM3PPreTrainedModel):
744
  metadata_config = config.metadata_config
745
  beatmap_config = config.beatmap_config
746
 
747
- self.projection_dim = config.projection_dim
748
- self.metadata_embed_dim = metadata_config.hidden_size
749
- self.beatmap_embed_dim = beatmap_config.hidden_size
750
  self.loss_type = config.loss_type
751
 
752
  metadata_model = CM3PMetadataModel._from_config(metadata_config)
@@ -759,8 +762,9 @@ class CM3PModel(CM3PPreTrainedModel):
759
  self.metadata_projection = nn.Linear(self.metadata_embed_dim, self.projection_dim, bias=False)
760
  self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
761
 
762
- self.head = CM3PPredictionHead(beatmap_config)
763
- self.decoder = nn.Linear(beatmap_config.hidden_size, beatmap_config.vocab_size, bias=beatmap_config.decoder_bias)
 
764
 
765
  # Initialize weights and apply final processing
766
  self.post_init()
@@ -861,6 +865,7 @@ class CM3PModel(CM3PPreTrainedModel):
861
  return_loss: Optional[bool] = True,
862
  output_attentions: Optional[bool] = None,
863
  output_hidden_states: Optional[bool] = None,
 
864
  **kwargs,
865
  ) -> CM3POutput:
866
  r"""
@@ -886,16 +891,22 @@ class CM3PModel(CM3PPreTrainedModel):
886
  Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
887
  return_loss (`bool`, *optional*):
888
  Whether to return the contrastive loss.
 
 
889
  """
890
  # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
891
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
  output_hidden_states = (
893
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
  )
 
895
 
896
- if metadata_ids.dim() == 3 and return_loss and metadata_variation_classes is None:
897
  raise ValueError("When providing multiple metadata variations, metadata_variation_classes must be provided in order to compute loss correctly.")
898
 
 
 
 
899
  # noinspection PyProtectedMember
900
  if self.config._attn_implementation == "flash_attention_2":
901
  if indices is None and cu_seqlens is None and max_seqlen is None:
@@ -919,65 +930,75 @@ class CM3PModel(CM3PPreTrainedModel):
919
  inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
920
  )
921
 
922
- beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
923
- input_ids=input_ids,
924
- input_features=input_features,
925
- attention_mask=attention_mask,
926
- position_ids=position_ids,
927
- inputs_embeds=inputs_embeds,
928
- indices=indices,
929
- cu_seqlens=cu_seqlens,
930
- max_seqlen=max_seqlen,
931
- batch_size=batch_size,
932
- seq_len=seq_len,
933
- output_attentions=output_attentions,
934
- output_hidden_states=output_hidden_states,
935
- )
936
-
937
- metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
938
- input_ids=metadata_ids,
939
- attention_mask=metadata_attention_mask,
940
- output_attentions=output_attentions,
941
- output_hidden_states=output_hidden_states,
942
- )
 
 
 
943
 
944
- beatmap_embeds = beatmap_outputs.pooler_output
945
- beatmap_embeds = self.beatmap_projection(beatmap_embeds)
 
946
 
947
- metadata_embeds = metadata_outputs.pooler_output
948
- metadata_embeds = self.metadata_projection(metadata_embeds)
 
 
 
 
 
949
 
950
- # normalized features
951
- beatmap_embeds = beatmap_embeds / _get_vector_norm(beatmap_embeds)
952
- metadata_embeds = metadata_embeds / _get_vector_norm(metadata_embeds)
953
 
954
- # cosine similarity as logits
955
- logits_per_metadata = torch.matmul(metadata_embeds, beatmap_embeds.t().to(metadata_embeds.device))
956
- logits_per_metadata = logits_per_metadata * self.logit_scale.exp().to(metadata_embeds.device)
 
957
 
958
- if logits_per_metadata.dim() == 3:
959
- logits_per_beatmap = logits_per_metadata.permute(2, 0, 1)
960
- else:
961
- logits_per_beatmap = logits_per_metadata.t()
962
 
963
- loss = None
964
- if return_loss:
965
- loss = cm3p_loss(logits_per_metadata, metadata_variation_classes)
966
 
967
- logits = (
968
- self.compiled_head(beatmap_outputs.last_hidden_state)
969
- if self.config.beatmap_config.reference_compile
970
- else self.decoder(self.head(beatmap_outputs.last_hidden_state))
971
- )
 
972
 
973
- if labels is not None and return_loss:
974
- mlm_loss = self.loss_function(logits, labels, vocab_size=self.config.beatmap_config.vocab_size, **kwargs)
975
- loss += 0.5 * mlm_loss
976
 
977
- # noinspection PyProtectedMember
978
- if self.config._attn_implementation == "flash_attention_2":
979
- with nullcontext() if self.config.beatmap_config.repad_logits_with_grad or labels is None else torch.no_grad():
980
- logits = _pad_cm3p_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
981
 
982
  return CM3POutput(
983
  loss=loss,
@@ -1372,4 +1393,5 @@ __all__ = [
1372
  "CM3PBeatmapModel",
1373
  "CM3PBeatmapModelWithProjection",
1374
  "CM3PForBeatmapClassification",
 
1375
  ]
 
24
 
25
  # contrastive loss function, adapted from
26
  # https://sachinruk.github.io/blog/2021-03-07-clip.html
27
+ def contrastive_loss(logits: torch.Tensor, target: torch.Tensor = None) -> torch.Tensor:
28
  target = target if target is not None else torch.arange(len(logits), device=logits.device)
29
  return nn.functional.cross_entropy(logits, target)
30
 
 
192
  """
193
 
194
  beatmap_embeds: Optional[torch.FloatTensor] = None
195
+ audio_model_output: Optional[CM3PAudioModelOutput] = None
196
 
197
 
198
  @dataclass
 
235
  """
236
 
237
  loss: Optional[torch.FloatTensor] = None
238
+ logits_per_beatmap: Optional[torch.Tensor] = None
239
+ logits_per_metadata: Optional[torch.Tensor] = None
240
  metadata_embeds: Optional[torch.FloatTensor] = None
241
  beatmap_embeds: Optional[torch.FloatTensor] = None
242
  logits: Optional[torch.FloatTensor] = None
 
301
  def __init__(self, config: CM3PMetadataConfig):
302
  super().__init__()
303
  self.config = config
304
+ # noinspection PyTypeChecker
305
  self.encoder = ModernBertModel(config)
306
 
307
  def get_input_embeddings(self):
 
487
  self.config = config
488
  self.conv1 = nn.Conv1d(config.n_mels, config.hidden_size, kernel_size=3, padding=1)
489
  self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
490
+ # noinspection PyTypeChecker
491
  self.encoder = ModernBertModel(config)
492
  self.multi_modal_projector = CM3PMultiModalProjector(config)
493
 
 
533
  super().__init__()
534
  self.config = config
535
  self.audio_encoder = CM3PAudioEncoder(config.audio_config)
536
+ # noinspection PyTypeChecker
537
  self.encoder = ModernBertModel(config)
538
 
539
  def get_input_embeddings(self):
 
593
 
594
  audio_model_outputs = None
595
  if input_features is not None:
596
+ audio_model_outputs = self.audio_encoder(
597
  input_features=input_features,
598
  output_attentions=output_attentions,
599
  output_hidden_states=output_hidden_states,
 
747
  metadata_config = config.metadata_config
748
  beatmap_config = config.beatmap_config
749
 
750
+ self.projection_dim: int = config.projection_dim
751
+ self.metadata_embed_dim: int = metadata_config.hidden_size
752
+ self.beatmap_embed_dim: int = beatmap_config.hidden_size
753
  self.loss_type = config.loss_type
754
 
755
  metadata_model = CM3PMetadataModel._from_config(metadata_config)
 
762
  self.metadata_projection = nn.Linear(self.metadata_embed_dim, self.projection_dim, bias=False)
763
  self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
764
 
765
+ if config.has_decoder_head:
766
+ self.head = CM3PPredictionHead(beatmap_config)
767
+ self.decoder = nn.Linear(beatmap_config.hidden_size, beatmap_config.vocab_size, bias=beatmap_config.decoder_bias)
768
 
769
  # Initialize weights and apply final processing
770
  self.post_init()
 
865
  return_loss: Optional[bool] = True,
866
  output_attentions: Optional[bool] = None,
867
  output_hidden_states: Optional[bool] = None,
868
+ output_logits: Optional[bool] = None,
869
  **kwargs,
870
  ) -> CM3POutput:
871
  r"""
 
891
  Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
892
  return_loss (`bool`, *optional*):
893
  Whether to return the contrastive loss.
894
+ output_logits (`bool`, *optional*):
895
+ Whether to return the logits from the decoder head.
896
  """
897
  # Use CM3P model's config for some fields (if specified) instead of those of beatmap & metadata components.
898
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
899
  output_hidden_states = (
900
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
901
  )
902
+ output_logits = output_logits if output_logits is not None else self.config.has_decoder_head
903
 
904
+ if metadata_ids is not None and metadata_ids.dim() == 3 and return_loss and metadata_variation_classes is None:
905
  raise ValueError("When providing multiple metadata variations, metadata_variation_classes must be provided in order to compute loss correctly.")
906
 
907
+ if output_logits and not self.config.has_decoder_head:
908
+ raise ValueError("Cannot return logits when the model is not configured with a decoder head.")
909
+
910
  # noinspection PyProtectedMember
911
  if self.config._attn_implementation == "flash_attention_2":
912
  if indices is None and cu_seqlens is None and max_seqlen is None:
 
930
  inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
931
  )
932
 
933
+ beatmap_embeds = None
934
+ beatmap_outputs = None
935
+ metadata_embeds = None
936
+ metadata_outputs = None
937
+ logits_per_beatmap = None
938
+ logits_per_metadata = None
939
+ loss = 0 if return_loss else None
940
+ logits = None
941
+
942
+ if input_ids is not None:
943
+ beatmap_outputs: BaseModelOutputWithPooling = self.beatmap_model(
944
+ input_ids=input_ids,
945
+ input_features=input_features,
946
+ attention_mask=attention_mask,
947
+ position_ids=position_ids,
948
+ inputs_embeds=inputs_embeds,
949
+ indices=indices,
950
+ cu_seqlens=cu_seqlens,
951
+ max_seqlen=max_seqlen,
952
+ batch_size=batch_size,
953
+ seq_len=seq_len,
954
+ output_attentions=output_attentions,
955
+ output_hidden_states=output_hidden_states,
956
+ )
957
 
958
+ beatmap_embeds = beatmap_outputs.pooler_output
959
+ beatmap_embeds = self.beatmap_projection(beatmap_embeds)
960
+ beatmap_embeds = beatmap_embeds / _get_vector_norm(beatmap_embeds)
961
 
962
+ if metadata_ids is not None:
963
+ metadata_outputs: BaseModelOutputWithPooling = self.metadata_model(
964
+ input_ids=metadata_ids,
965
+ attention_mask=metadata_attention_mask,
966
+ output_attentions=output_attentions,
967
+ output_hidden_states=output_hidden_states,
968
+ )
969
 
970
+ metadata_embeds = metadata_outputs.pooler_output
971
+ metadata_embeds = self.metadata_projection(metadata_embeds)
972
+ metadata_embeds = metadata_embeds / _get_vector_norm(metadata_embeds)
973
 
974
+ if metadata_embeds is not None and beatmap_embeds is not None:
975
+ # cosine similarity as logits
976
+ logits_per_metadata = torch.matmul(metadata_embeds, beatmap_embeds.t().to(metadata_embeds.device))
977
+ logits_per_metadata = logits_per_metadata * self.logit_scale.exp().to(metadata_embeds.device)
978
 
979
+ if logits_per_metadata.dim() == 3:
980
+ logits_per_beatmap = logits_per_metadata.permute(2, 0, 1)
981
+ else:
982
+ logits_per_beatmap = logits_per_metadata.t()
983
 
984
+ if return_loss:
985
+ loss = cm3p_loss(logits_per_metadata, metadata_variation_classes)
 
986
 
987
+ if output_logits:
988
+ logits = (
989
+ self.compiled_head(beatmap_outputs.last_hidden_state)
990
+ if self.config.beatmap_config.reference_compile
991
+ else self.decoder(self.head(beatmap_outputs.last_hidden_state))
992
+ )
993
 
994
+ if labels is not None and return_loss:
995
+ mlm_loss = self.loss_function(logits, labels, vocab_size=self.config.beatmap_config.vocab_size, **kwargs)
996
+ loss += 0.5 * mlm_loss
997
 
998
+ # noinspection PyProtectedMember
999
+ if self.config._attn_implementation == "flash_attention_2":
1000
+ with nullcontext() if self.config.beatmap_config.repad_logits_with_grad or labels is None else torch.no_grad():
1001
+ logits = _pad_cm3p_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1002
 
1003
  return CM3POutput(
1004
  loss=loss,
 
1393
  "CM3PBeatmapModel",
1394
  "CM3PBeatmapModelWithProjection",
1395
  "CM3PForBeatmapClassification",
1396
+ "CM3PForMaskedLM",
1397
  ]