Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.autograd import Function | |
| def tile(x, count, dim=0): | |
| """ | |
| Tiles x on dimension dim count times. | |
| """ | |
| perm = list(range(len(x.size()))) | |
| if dim != 0: | |
| perm[0], perm[dim] = perm[dim], perm[0] | |
| x = x.permute(perm).contiguous() | |
| out_size = list(x.size()) | |
| out_size[0] *= count | |
| batch = x.size(0) | |
| x = x.view(batch, -1) \ | |
| .transpose(0, 1) \ | |
| .repeat(count, 1) \ | |
| .transpose(0, 1) \ | |
| .contiguous() \ | |
| .view(*out_size) | |
| if dim != 0: | |
| x = x.permute(perm).contiguous() | |
| return x | |
| class Linear(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): | |
| super(Linear, self).__init__() | |
| self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.linear_layer.weight, | |
| gain=torch.nn.init.calculate_gain(w_init_gain)) | |
| def forward(self, x): | |
| return self.linear_layer(x) | |
| class Conv1d(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, | |
| padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): | |
| super(Conv1d, self).__init__() | |
| if padding is None: | |
| assert(kernel_size % 2 == 1) | |
| padding = int(dilation * (kernel_size - 1)/2) | |
| self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, | |
| bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) | |
| def forward(self, x): | |
| # x: BxDxT | |
| return self.conv(x) | |
| def tile(x, count, dim=0): | |
| """ | |
| Tiles x on dimension dim count times. | |
| """ | |
| perm = list(range(len(x.size()))) | |
| if dim != 0: | |
| perm[0], perm[dim] = perm[dim], perm[0] | |
| x = x.permute(perm).contiguous() | |
| out_size = list(x.size()) | |
| out_size[0] *= count | |
| batch = x.size(0) | |
| x = x.view(batch, -1) \ | |
| .transpose(0, 1) \ | |
| .repeat(count, 1) \ | |
| .transpose(0, 1) \ | |
| .contiguous() \ | |
| .view(*out_size) | |
| if dim != 0: | |
| x = x.permute(perm).contiguous() | |
| return x | |