Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from torch import nn | |
| from torch.nn import functional as F | |
| class Conv1d(nn.Conv1d): | |
| """Extended nn.Conv1d for incremental dilated convolutions""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.clear_buffer() | |
| self._linearized_weight = None | |
| self.register_backward_hook(self._clear_linearized_weight) | |
| def incremental_forward(self, input): | |
| # input (B, T, C) | |
| # run forward pre hooks | |
| for hook in self._forward_pre_hooks.values(): | |
| hook(self, input) | |
| # reshape weight | |
| weight = self._get_linearized_weight() | |
| kw = self.kernel_size[0] | |
| dilation = self.dilation[0] | |
| bsz = input.size(0) | |
| if kw > 1: | |
| input = input.data | |
| if self.input_buffer is None: | |
| self.input_buffer = input.new( | |
| bsz, kw + (kw - 1) * (dilation - 1), input.size(2) | |
| ) | |
| self.input_buffer.zero_() | |
| else: | |
| # shift buffer | |
| self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() | |
| # append next input | |
| self.input_buffer[:, -1, :] = input[:, -1, :] | |
| input = self.input_buffer | |
| if dilation > 1: | |
| input = input[:, 0::dilation, :].contiguous() | |
| output = F.linear(input.view(bsz, -1), weight, self.bias) | |
| return output.view(bsz, 1, -1) | |
| def clear_buffer(self): | |
| self.input_buffer = None | |
| def _get_linearized_weight(self): | |
| if self._linearized_weight is None: | |
| kw = self.kernel_size[0] | |
| # nn.Conv1d | |
| if self.weight.size() == (self.out_channels, self.in_channels, kw): | |
| weight = self.weight.transpose(1, 2).contiguous() | |
| else: | |
| # fairseq.modules.conv_tbc.ConvTBC | |
| weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() | |
| assert weight.size() == (self.out_channels, kw, self.in_channels) | |
| self._linearized_weight = weight.view(self.out_channels, -1) | |
| return self._linearized_weight | |
| def _clear_linearized_weight(self, *args): | |
| self._linearized_weight = None | |