Add CM3P model
Browse files- processing_cm3p.py +3 -2
processing_cm3p.py
CHANGED
|
@@ -135,6 +135,7 @@ class CM3PTokenizerKwargs(TypedDict, total=False):
|
|
| 135 |
class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
|
| 136 |
window_length_sec: float
|
| 137 |
window_stride_sec: float
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
class CM3PAudioKwargs(AudioKwargs, total=False):
|
|
@@ -563,7 +564,7 @@ class CM3PProcessor(ProcessorMixin):
|
|
| 563 |
**beatmap_kwargs,
|
| 564 |
)
|
| 565 |
|
| 566 |
-
if
|
| 567 |
data = dict(beatmap_encoding)
|
| 568 |
data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
|
| 569 |
beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
|
|
@@ -577,7 +578,7 @@ class CM3PProcessor(ProcessorMixin):
|
|
| 577 |
},
|
| 578 |
tensor_type=return_tensors,
|
| 579 |
)
|
| 580 |
-
if
|
| 581 |
data = dict(beatmap_encoding)
|
| 582 |
data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
|
| 583 |
beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
|
|
|
|
| 135 |
class CM3PBeatmapKwargs(CM3PTokenizerKwargs, total=False):
|
| 136 |
window_length_sec: float
|
| 137 |
window_stride_sec: float
|
| 138 |
+
min_window_length_sec: float
|
| 139 |
|
| 140 |
|
| 141 |
class CM3PAudioKwargs(AudioKwargs, total=False):
|
|
|
|
| 564 |
**beatmap_kwargs,
|
| 565 |
)
|
| 566 |
|
| 567 |
+
if all(a is not None for a in audio):
|
| 568 |
data = dict(beatmap_encoding)
|
| 569 |
data["input_features"] = self._retrieve_input_features(batch_audio, **audio_kwargs)
|
| 570 |
beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
|
|
|
|
| 578 |
},
|
| 579 |
tensor_type=return_tensors,
|
| 580 |
)
|
| 581 |
+
if all(a is not None for a in audio):
|
| 582 |
data = dict(beatmap_encoding)
|
| 583 |
data["input_features"] = torch.zeros((0, self.audio_feature_extractor.feature_size, max_source_positions), dtype=torch.float) if return_tensors == "pt" else []
|
| 584 |
beatmap_encoding = BatchFeature(data, tensor_type=return_tensors)
|