Spaces:
Sleeping
Sleeping
| import random | |
| import torch | |
| from src.simulation.effect import Effect | |
| ################################################################################ | |
| # Random time-domain dropout | |
| ################################################################################ | |
| class Dropout(Effect): | |
| def __init__(self, compute_grad: bool = True, rate: any = None): | |
| super().__init__(compute_grad) | |
| self.min_rate, self.max_rate = self.parse_range( | |
| rate, | |
| float, | |
| f'Invalid signal dropout rate {rate}' | |
| ) | |
| # store waveform mask as buffer to allow device movement | |
| self.register_buffer("mask", torch.zeros(1, dtype=torch.float32)) | |
| self.sample_params() | |
| def forward(self, x: torch.Tensor): | |
| return self.mask.clone().to(x) * x | |
| def sample_params(self): | |
| """ | |
| Sample dropout rate uniformly and apply random dropout | |
| """ | |
| rate = random.uniform(self.min_rate, self.max_rate) | |
| idx = torch.randperm(self.signal_length | |
| )[:round(rate * self.signal_length)] | |
| self.mask = torch.ones(self.signal_length).to(self.mask) | |
| self.mask[..., idx] = 0 | |