Spaces:
Running
Running
| import math | |
| import torch | |
| import torch.nn as nn | |
| class AlignBlock(nn.Module): | |
| """Cross-attention soft delay alignment. | |
| Estimates a delay distribution D(t) over [0, dmax) for each frame, | |
| then applies it to align the reference signal to the microphone signal. | |
| Fixed bugs from Xiaobin-Rong implementation: | |
| - Line 57 used K.shape[1] (hidden_channels) instead of x_ref.shape[1] | |
| (in_channels) for the weighted sum reshape. | |
| """ | |
| def __init__(self, in_channels, hidden_channels, dmax=32, temperature=1.0): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.hidden_channels = hidden_channels | |
| self.dmax = dmax | |
| # Registered as a buffer so the trained value persists in state_dict. | |
| # Mutate via .fill_(), never re-assign. | |
| self.register_buffer("temperature", torch.tensor(float(temperature))) | |
| # Pointwise projections for Q and K | |
| self.pconv_mic = nn.Conv2d(in_channels, hidden_channels, 1) | |
| self.pconv_ref = nn.Conv2d(in_channels, hidden_channels, 1) | |
| # Causal unfold: pad dmax-1 frames on top (past), zero on bottom | |
| self.unfold_k = nn.Sequential( | |
| nn.ZeroPad2d([0, 0, dmax - 1, 0]), | |
| nn.Unfold((dmax, 1)), | |
| ) | |
| # Smoothing conv over the similarity scores | |
| self.conv = nn.Sequential( | |
| nn.ZeroPad2d([1, 1, 4, 0]), # causal | |
| nn.Conv2d(hidden_channels, 1, (5, 3)), | |
| ) | |
| def fold_temperature(self, temperature=None): | |
| """Bake the AlignBlock softmax temperature into the smoothing-conv | |
| weights. Reads `self.temperature` (a buffer carried by the | |
| checkpoint) when called with no argument; after folding the buffer | |
| is reset to 1.0 so subsequent calls are no-ops. | |
| """ | |
| t = float(temperature if temperature is not None else self.temperature) | |
| if t == 1.0: | |
| return | |
| with torch.no_grad(): | |
| self.conv[1].weight.div_(t) | |
| self.conv[1].bias.div_(t) | |
| self.temperature.fill_(1.0) | |
| def forward(self, x_mic, x_ref, return_delay=False): | |
| """ | |
| x_mic: (B, C, T, F) β microphone encoder features | |
| x_ref: (B, C, T, F) β far-end encoder features | |
| Returns: | |
| aligned: (B, C, T, F) β aligned far-end features | |
| delay_dist: (B, T, dmax) β delay distribution (if return_delay=True) | |
| """ | |
| B, C, T, F = x_ref.shape | |
| # Compute Q and K | |
| Q = self.pconv_mic(x_mic) # (B, H, T, F) | |
| K = self.pconv_ref(x_ref) # (B, H, T, F) | |
| # Unfold K along time: creates dmax delayed copies | |
| Ku = self.unfold_k(K) # (B, H*dmax, T*F) | |
| Ku = Ku.view(B, self.hidden_channels, self.dmax, T, F) | |
| Ku = Ku.permute(0, 1, 3, 2, 4).contiguous() # (B, H, T, dmax, F) | |
| # Cross-attention similarity: sum over frequency, scaled by sqrt(F) | |
| V = torch.sum(Q.unsqueeze(-2) * Ku, dim=-1) # (B, H, T, dmax) | |
| V = V / math.sqrt(F) | |
| # Smooth and reduce to single-head attention | |
| V = self.conv(V) # (B, 1, T, dmax) | |
| # Softmax over delay dimension (temperature < 1 sharpens the distribution) | |
| A = torch.softmax(V / self.temperature, dim=-1) # (B, 1, T, dmax) | |
| # Unfold x_ref (full channels) for the weighted sum | |
| # BUG FIX: use in_channels (C), not hidden_channels | |
| unfold_ref = nn.functional.pad(x_ref, [0, 0, self.dmax - 1, 0]) | |
| # Manual unfold along time dimension | |
| # unfold_ref: (B, C, T+dmax-1, F) | |
| ref_unfolded = unfold_ref.unfold(2, self.dmax, 1) # (B, C, T, F, dmax) | |
| ref_unfolded = ref_unfolded.permute(0, 1, 2, 4, 3).contiguous() # (B, C, T, dmax, F) | |
| # Weighted sum using attention weights | |
| A_expanded = A[:, :, :, :, None] # (B, 1, T, dmax, 1) | |
| aligned = torch.sum(ref_unfolded * A_expanded, dim=-2) # (B, C, T, F) | |
| if return_delay: | |
| return aligned, A.squeeze(1) # (B, T, dmax) | |
| return aligned | |