File size: 533 Bytes
affcd23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from typing import List, Callable
from torch import Tensor
import random
from hw_asr.augmentations.base import AugmentationBase


class SequentialRandomApply(AugmentationBase):
    def __init__(self, augmentation_list: List[Callable], p: float = 0.5):
        self.augmentation_list = augmentation_list
        self.p = p

    def __call__(self, data: Tensor) -> Tensor:
        x = data
        for augmentation in self.augmentation_list:
            if random.random() < self.p:
                x = augmentation(x)
        return x