Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py | |
| """ | |
| import math | |
| import torch | |
| def overlap_and_add(signal: torch.Tensor, frame_step: int): | |
| """ | |
| Reconstructs a signal from a framed representation. | |
| Adds potentially overlapping frames of a signal with shape | |
| `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. | |
| The resulting tensor has shape `[..., output_size]` where | |
| output_size = (frames - 1) * frame_step + frame_length | |
| Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py | |
| :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2. | |
| :param frame_step: int, overlap offsets. Must be less than or equal to frame_length. | |
| :return: Tensor, shape: [..., output_size]. | |
| containing the overlap-added frames of signal's inner-most two dimensions. | |
| output_size = (frames - 1) * frame_step + frame_length | |
| """ | |
| outer_dimensions = signal.size()[:-2] | |
| frames, frame_length = signal.size()[-2:] | |
| subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor | |
| subframe_step = frame_step // subframe_length | |
| subframes_per_frame = frame_length // subframe_length | |
| output_size = frame_step * (frames - 1) + frame_length | |
| output_subframes = output_size // subframe_length | |
| subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) | |
| frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) | |
| frame = frame.clone().detach() | |
| frame = frame.to(signal.device) | |
| frame = frame.long() | |
| frame = frame.contiguous().view(-1) | |
| result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) | |
| result.index_add_(-2, frame, subframe_signal) | |
| result = result.view(*outer_dimensions, -1) | |
| return result | |
| if __name__ == "__main__": | |
| pass | |