Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def mask_zero(mask, value): | |
| return torch.where(mask, value, torch.zeros_like(value)) | |
| def clampped_one_hot(x, num_classes): | |
| mask = (x >= 0) & (x < num_classes) # (N, L) | |
| x = x.clamp(min=0, max=num_classes-1) | |
| y = F.one_hot(x, num_classes) * mask[...,None] # (N, L, C) | |
| return y | |
| class DistanceToBins(nn.Module): | |
| def __init__(self, dist_min=0.0, dist_max=20.0, num_bins=64, use_onehot=False): | |
| super().__init__() | |
| self.dist_min = dist_min | |
| self.dist_max = dist_max | |
| self.num_bins = num_bins | |
| self.use_onehot = use_onehot | |
| if use_onehot: | |
| offset = torch.linspace(dist_min, dist_max, self.num_bins) | |
| else: | |
| offset = torch.linspace(dist_min, dist_max, self.num_bins-1) # 1 overflow flag | |
| self.coeff = -0.5 / ((offset[1] - offset[0]) * 0.2).item() ** 2 # `*0.2`: makes it not too blurred | |
| self.register_buffer('offset', offset) | |
| def out_channels(self): | |
| return self.num_bins | |
| def forward(self, dist, dim, normalize=True): | |
| """ | |
| Args: | |
| dist: (N, *, 1, *) | |
| Returns: | |
| (N, *, num_bins, *) | |
| """ | |
| assert dist.size()[dim] == 1 | |
| offset_shape = [1] * len(dist.size()) | |
| offset_shape[dim] = -1 | |
| if self.use_onehot: | |
| diff = torch.abs(dist - self.offset.view(*offset_shape)) # (N, *, num_bins, *) | |
| bin_idx = torch.argmin(diff, dim=dim, keepdim=True) # (N, *, 1, *) | |
| y = torch.zeros_like(diff).scatter_(dim=dim, index=bin_idx, value=1.0) | |
| else: | |
| overflow_symb = (dist >= self.dist_max).float() # (N, *, 1, *) | |
| y = dist - self.offset.view(*offset_shape) # (N, *, num_bins-1, *) | |
| y = torch.exp(self.coeff * torch.pow(y, 2)) # (N, *, num_bins-1, *) | |
| y = torch.cat([y, overflow_symb], dim=dim) # (N, *, num_bins, *) | |
| if normalize: | |
| y = y / y.sum(dim=dim, keepdim=True) | |
| return y | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, num_funcs=6): | |
| super().__init__() | |
| self.num_funcs = num_funcs | |
| self.register_buffer('freq_bands', 2.0 ** torch.linspace(0.0, num_funcs-1, num_funcs)) | |
| def get_out_dim(self, in_dim): | |
| return in_dim * (2 * self.num_funcs + 1) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (..., d). | |
| """ | |
| shape = list(x.shape[:-1]) + [-1] | |
| x = x.unsqueeze(-1) # (..., d, 1) | |
| code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) | |
| code = code.reshape(shape) | |
| return code | |
| class AngularEncoding(nn.Module): | |
| def __init__(self, num_funcs=3): | |
| super().__init__() | |
| self.num_funcs = num_funcs | |
| self.register_buffer('freq_bands', torch.FloatTensor( | |
| [i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)] | |
| )) | |
| def get_out_dim(self, in_dim): | |
| return in_dim * (1 + 2*2*self.num_funcs) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (..., d). | |
| """ | |
| shape = list(x.shape[:-1]) + [-1] | |
| x = x.unsqueeze(-1) # (..., d, 1) | |
| code = torch.cat([x, torch.sin(x * self.freq_bands), torch.cos(x * self.freq_bands)], dim=-1) # (..., d, 2f+1) | |
| code = code.reshape(shape) | |
| return code | |
| class LayerNorm(nn.Module): | |
| def __init__(self, | |
| normal_shape, | |
| gamma=True, | |
| beta=True, | |
| epsilon=1e-10): | |
| """Layer normalization layer | |
| See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) | |
| :param normal_shape: The shape of the input tensor or the last dimension of the input tensor. | |
| :param gamma: Add a scale parameter if it is True. | |
| :param beta: Add an offset parameter if it is True. | |
| :param epsilon: Epsilon for calculating variance. | |
| """ | |
| super().__init__() | |
| if isinstance(normal_shape, int): | |
| normal_shape = (normal_shape,) | |
| else: | |
| normal_shape = (normal_shape[-1],) | |
| self.normal_shape = torch.Size(normal_shape) | |
| self.epsilon = epsilon | |
| if gamma: | |
| self.gamma = nn.Parameter(torch.Tensor(*normal_shape)) | |
| else: | |
| self.register_parameter('gamma', None) | |
| if beta: | |
| self.beta = nn.Parameter(torch.Tensor(*normal_shape)) | |
| else: | |
| self.register_parameter('beta', None) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| if self.gamma is not None: | |
| self.gamma.data.fill_(1) | |
| if self.beta is not None: | |
| self.beta.data.zero_() | |
| def forward(self, x): | |
| mean = x.mean(dim=-1, keepdim=True) | |
| var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) | |
| std = (var + self.epsilon).sqrt() | |
| y = (x - mean) / std | |
| if self.gamma is not None: | |
| y *= self.gamma | |
| if self.beta is not None: | |
| y += self.beta | |
| return y | |
| def extra_repr(self): | |
| return 'normal_shape={}, gamma={}, beta={}, epsilon={}'.format( | |
| self.normal_shape, self.gamma is not None, self.beta is not None, self.epsilon, | |
| ) | |