| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from einops import rearrange
|
|
|
| def set_sigma_for_DCLS(model, s):
|
| for name, module in model.named_modules():
|
| if module.__class__.__name__ == 'DelayConv':
|
| if hasattr(module, 'sigma'):
|
| module.sigma = s
|
| print('Set sigma to ',s)
|
|
|
| class DropoutNd(nn.Module):
|
| def __init__(self, p: float = 0.5, tie=True, transposed=True):
|
| """
|
| tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
|
| """
|
| super().__init__()
|
| if p < 0 or p >= 1:
|
| raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
|
| self.p = p
|
| self.tie = tie
|
| self.transposed = transposed
|
| self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
|
|
|
| def forward(self, X):
|
| """X: (batch, dim, lengths...)."""
|
| if self.training:
|
| if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
|
|
|
| mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
|
|
|
| mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p
|
| X = X * mask * (1.0 / (1 - self.p))
|
| if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
|
| return X
|
| return X
|
|
|
| class DelayConv(nn.Module):
|
| def __init__(
|
| self,
|
| in_c,
|
| k,
|
| dropout=0.0,
|
| n_delay=1,
|
| dilation=1,
|
| kernel_type='triangle_r_temp'
|
| ):
|
| super().__init__()
|
| self.C = in_c
|
| self.win_len = k
|
| self.dilation = dilation
|
| self.n_delay = n_delay
|
| self.kernel_type = kernel_type
|
|
|
| self.t = torch.arange(self.win_len).float().unsqueeze(0)
|
| self.sigma = self.win_len // 2
|
|
|
| self.delay_kernel = None
|
| self.bump = None
|
|
|
|
|
| d = torch.rand(self.C, self.C, self.n_delay)
|
| with torch.no_grad():
|
| for co in range(self.C):
|
| for ci in range(self.C):
|
| d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1
|
| self.register("d", d, lr=1e-2)
|
|
|
|
|
| weight = torch.ones([self.C, self.C, k])
|
| with torch.no_grad():
|
| for co in range(self.C):
|
| for ci in range(self.C):
|
| for i in range(k - 2, -1, -1):
|
| weight[co, ci, i] = weight[co, ci, i + 1] / 2
|
|
|
| self.weight = nn.Parameter(weight)
|
|
|
| self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity()
|
|
|
| def register(self, name, tensor, lr=None):
|
| """注册可训练或固定参数"""
|
| if lr == 0.0:
|
| self.register_buffer(name, tensor)
|
| else:
|
| self.register_parameter(name, nn.Parameter(tensor))
|
| optim = {"weight_decay": 0}
|
| if lr is not None:
|
| optim["lr"] = lr
|
| setattr(getattr(self, name), "_optim", optim)
|
|
|
| def update_kernel(self, device):
|
| """
|
| 输出 delay kernel: shape [C_out, C_in, k]
|
| """
|
| t = self.t.to(device).view(1, 1, 1, -1)
|
| d = self.d.to(device)
|
|
|
|
|
| if self.kernel_type == 'gauss':
|
| bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2)
|
| bump = (bump - 1e-3).relu() + 1e-3
|
| bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7)
|
|
|
| elif self.kernel_type == 'triangle':
|
| bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma))
|
| bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
|
|
| elif self.kernel_type == 'triangle_r':
|
| d_int = (d.round() - d).detach() + d
|
| bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
|
| bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
|
|
| elif self.kernel_type == 'triangle_r_temp':
|
| scale = min(1.0, 1.0 / self.sigma)
|
| d_int = (d.round() - d).detach() * scale + d
|
| bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
|
| bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
|
|
| if not self.training:
|
| max_idx = bump.argmax(dim=-1, keepdim=True)
|
| hard_mask = torch.zeros_like(bump)
|
| hard_mask.scatter_(-1, max_idx, 1.0)
|
| bump = bump * hard_mask
|
|
|
| else:
|
| raise ValueError(f"Unknown kernel_type: {self.kernel_type}")
|
|
|
|
|
| self.bump = bump.detach().clone().to(device)
|
|
|
|
|
| bump_sum = bump.sum(dim=2)
|
|
|
|
|
|
|
| self.delay_kernel = (self.weight * bump_sum).to(device)
|
|
|
| def forward(self, x):
|
| """
|
| x: (T, B, N, C)
|
| return: (T*B, C, N)
|
| """
|
|
|
| x = x.permute(0, 1, 3, 2).contiguous()
|
| T, B, N, C = x.shape
|
| assert C == self.C, f"Input channel mismatch: {C} vs {self.C}"
|
| x = x.permute(1, 2, 3, 0).contiguous()
|
|
|
|
|
| x_reshaped = x.view(B * N, C, T)
|
| device = x.device
|
|
|
|
|
| self.update_kernel(device)
|
| kernel = self.delay_kernel
|
|
|
|
|
| pad_left = (self.win_len - 1) * self.dilation
|
| x_padded = F.pad(x_reshaped, (pad_left, 0))
|
|
|
|
|
| y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1)
|
|
|
|
|
| y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N)
|
|
|
| return self.dropout(y) |