OliBomby commited on
Commit
7b6f0a8
·
verified ·
1 Parent(s): 88986bc

Add CM3P model

Browse files
Files changed (1) hide show
  1. processing_cm3p.py +3 -2
processing_cm3p.py CHANGED
@@ -135,6 +135,7 @@ class CM3PTokenizerKwargs(TypedDict, total=False):
135
  class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
136
  window_length_sec: float
137
  window_stride_sec: float
 
138
 
139
 
140
  class CM3PAudioKwargs(AudioKwargs, total=False):
@@ -563,7 +564,7 @@ class CM3PProcessor(ProcessorMixin):
563
  **beatmap_kwargs,
564
  )
565
 
566
- if audio is not None:
567
  data = dict(beatmap_encoding)
568
  data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
569
  beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
@@ -577,7 +578,7 @@ class CM3PProcessor(ProcessorMixin):
577
  },
578
  tensor_type=return_tensors,
579
  )
580
- if audio is not None:
581
  data = dict(beatmap_encoding)
582
  data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
583
  beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
 
135
  class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
136
  window_length_sec: float
137
  window_stride_sec: float
138
+ min_window_length_sec: float
139
 
140
 
141
  class CM3PAudioKwargs(AudioKwargs, total=False):
 
564
  **beatmap_kwargs,
565
  )
566
 
567
+ if all(a is not None for a in audio):
568
  data = dict(beatmap_encoding)
569
  data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
570
  beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
 
578
  },
579
  tensor_type=return_tensors,
580
  )
581
+ if all(a is not None for a in audio):
582
  data = dict(beatmap_encoding)
583
  data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
584
  beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)