Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FsmnLayer(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| out_dim, | |
| hidden_dim, | |
| left_frame=1, | |
| right_frame=1, | |
| left_dilation=1, | |
| right_dilation=1, | |
| ): | |
| super(FsmnLayer, self).__init__() | |
| self.input_dim = input_dim | |
| self.out_dim = out_dim | |
| self.hidden_dim = hidden_dim | |
| self.left_frame = left_frame | |
| self.right_frame = right_frame | |
| self.left_dilation = left_dilation | |
| self.right_dilation = right_dilation | |
| self.conv_in = nn.Conv1d(input_dim, hidden_dim, kernel_size=1) | |
| if left_frame > 0: | |
| self.pad_left = nn.ConstantPad1d([left_dilation * left_frame, 0], 0.0) | |
| self.conv_left = nn.Conv1d( | |
| hidden_dim, | |
| hidden_dim, | |
| kernel_size=left_frame + 1, | |
| dilation=left_dilation, | |
| bias=False, | |
| groups=hidden_dim, | |
| ) | |
| if right_frame > 0: | |
| self.pad_right = nn.ConstantPad1d([-right_dilation, right_dilation * right_frame], 0.0) | |
| self.conv_right = nn.Conv1d( | |
| hidden_dim, | |
| hidden_dim, | |
| kernel_size=right_frame, | |
| dilation=right_dilation, | |
| bias=False, | |
| groups=hidden_dim, | |
| ) | |
| self.conv_out = nn.Conv1d(hidden_dim, out_dim, kernel_size=1) | |
| # cache = 1, self.hidden_dim, left_frame * left_dilation + right_frame * right_dilation | |
| self.cache_size = left_frame * left_dilation + right_frame * right_dilation | |
| self.buffer_size = self.hidden_dim * self.cache_size | |
| self.p_in_raw_chache_size = self.right_frame * self.right_dilation | |
| self.p_in_raw_buffer_size = self.hidden_dim * self.p_in_raw_chache_size | |
| self.hidden_chache_size = self.right_frame * self.right_dilation | |
| self.hidden_buffer_size = self.hidden_dim * self.hidden_chache_size | |
| def forward(self, x, hidden=None): | |
| x_data = x.transpose(1, 2) | |
| p_in = self.conv_in(x_data) | |
| if self.left_frame > 0: | |
| p_left = self.pad_left(p_in) | |
| p_left = self.conv_left(p_left) | |
| else: | |
| p_left = 0 | |
| if self.right_frame > 0: | |
| p_right = self.pad_right(p_in) | |
| p_right = self.conv_right(p_right) | |
| else: | |
| p_right = 0 | |
| p_out = p_in + p_right + p_left | |
| if hidden is not None: | |
| p_out = hidden + p_out | |
| out = F.relu(self.conv_out(p_out)) | |
| out = out.transpose(1, 2) | |
| return out, p_out | |
| def infer(self, x, buffer, buffer_index, buffer_out, hidden=None): | |
| # type: (Tensor, Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor] | |
| p_in_raw = self.conv_in(x) | |
| cnn_buffer = buffer[buffer_index : buffer_index + self.buffer_size].reshape( | |
| [1, self.hidden_dim, self.cache_size] | |
| ) | |
| p_in = torch.cat([cnn_buffer, p_in_raw], dim=2) | |
| # buffer[buffer_index: buffer_index + self.buffer_size] = p_in[:, :, -self.cache_size:].reshape(-1) | |
| buffer_out.append(p_in[:, :, -self.cache_size :].reshape(-1)) | |
| buffer_index = buffer_index + self.buffer_size | |
| if self.left_frame > 0: | |
| if self.right_frame > 0: | |
| p_left = p_in[:, :, : -self.right_frame * self.right_dilation] | |
| else: | |
| p_left = p_in[:, :] | |
| p_left_out = self.conv_left(p_left) | |
| else: | |
| p_left_out = torch.tensor([0]) | |
| if self.right_frame > 0: | |
| p_right = p_in[:, :, self.left_frame * self.left_dilation + 1 :] | |
| p_right_out = self.conv_right(p_right) | |
| else: | |
| p_right_out = torch.tensor([0]) | |
| if self.right_frame > 0: | |
| p_in_raw_cnn_buffer = buffer[ | |
| buffer_index : buffer_index + self.p_in_raw_buffer_size | |
| ].reshape([1, self.hidden_dim, self.p_in_raw_chache_size]) | |
| p_in_raw = torch.cat([p_in_raw_cnn_buffer, p_in_raw], dim=2) | |
| # buffer[buffer_index: buffer_index + self.p_in_raw_buffer_size] = p_in_raw[:, :, -self.p_in_raw_chache_size:].reshape(-1) | |
| buffer_out.append(p_in_raw[:, :, -self.p_in_raw_chache_size :].reshape(-1)) | |
| buffer_index = buffer_index + self.p_in_raw_buffer_size | |
| p_in_raw = p_in_raw[:, :, : -self.p_in_raw_chache_size] | |
| p_out = p_in_raw + p_left_out + p_right_out | |
| if hidden is not None: | |
| if self.right_frame > 0: | |
| hidden_cnn_buffer = buffer[ | |
| buffer_index : buffer_index + self.hidden_buffer_size | |
| ].reshape([1, self.hidden_dim, self.hidden_chache_size]) | |
| hidden = torch.cat([hidden_cnn_buffer, hidden], dim=2) | |
| # buffer[buffer_index: buffer_index + self.hidden_buffer_size] = hidden[:, :, -self.hidden_chache_size:].reshape(-1) | |
| buffer_out.append(hidden[:, :, -self.hidden_chache_size :].reshape(-1)) | |
| buffer_index = buffer_index + self.hidden_buffer_size | |
| hidden = hidden[:, :, : -self.hidden_chache_size] | |
| p_out = hidden + p_out | |
| out = F.relu(self.conv_out(p_out)) | |
| return out, buffer, buffer_index, buffer_out, p_out | |