File size: 1,900 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from typing import List
import numpy as np
from fairseq.data.audio.dataset_transforms import (
AudioDatasetTransform,
register_audio_dataset_transform,
)
_DEFAULTS = {"rate": 0.25, "max_tokens": 3000, "attempts": 5}
@register_audio_dataset_transform("concataugment")
class ConcatAugment(AudioDatasetTransform):
@classmethod
def from_config_dict(cls, config=None):
_config = {} if config is None else config
return ConcatAugment(
_config.get("rate", _DEFAULTS["rate"]),
_config.get("max_tokens", _DEFAULTS["max_tokens"]),
_config.get("attempts", _DEFAULTS["attempts"]),
)
def __init__(
self,
rate=_DEFAULTS["rate"],
max_tokens=_DEFAULTS["max_tokens"],
attempts=_DEFAULTS["attempts"],
):
self.rate, self.max_tokens, self.attempts = rate, max_tokens, attempts
def __repr__(self):
return (
self.__class__.__name__
+ "("
+ ", ".join(
[
f"rate={self.rate}",
f"max_tokens={self.max_tokens}",
f"attempts={self.attempts}",
]
)
+ ")"
)
def find_indices(self, index: int, n_frames: List[int], n_samples: int):
# skip conditions: application rate, max_tokens limit exceeded
if np.random.random() > self.rate:
return [index]
if self.max_tokens and n_frames[index] > self.max_tokens:
return [index]
# pick second sample to concatenate
for _ in range(self.attempts):
index2 = np.random.randint(0, n_samples)
if index2 != index and (
not self.max_tokens
or n_frames[index] + n_frames[index2] < self.max_tokens
):
return [index, index2]
return [index]
|