richiejp's picture
Initial upload: LocalVQE demo Space
6e0a6e4 verified
import math
import torch
import torch.nn as nn
class AlignBlock(nn.Module):
"""Cross-attention soft delay alignment.
Estimates a delay distribution D(t) over [0, dmax) for each frame,
then applies it to align the reference signal to the microphone signal.
Fixed bugs from Xiaobin-Rong implementation:
- Line 57 used K.shape[1] (hidden_channels) instead of x_ref.shape[1]
(in_channels) for the weighted sum reshape.
"""
def __init__(self, in_channels, hidden_channels, dmax=32, temperature=1.0):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.dmax = dmax
# Registered as a buffer so the trained value persists in state_dict.
# Mutate via .fill_(), never re-assign.
self.register_buffer("temperature", torch.tensor(float(temperature)))
# Pointwise projections for Q and K
self.pconv_mic = nn.Conv2d(in_channels, hidden_channels, 1)
self.pconv_ref = nn.Conv2d(in_channels, hidden_channels, 1)
# Causal unfold: pad dmax-1 frames on top (past), zero on bottom
self.unfold_k = nn.Sequential(
nn.ZeroPad2d([0, 0, dmax - 1, 0]),
nn.Unfold((dmax, 1)),
)
# Smoothing conv over the similarity scores
self.conv = nn.Sequential(
nn.ZeroPad2d([1, 1, 4, 0]), # causal
nn.Conv2d(hidden_channels, 1, (5, 3)),
)
def fold_temperature(self, temperature=None):
"""Bake the AlignBlock softmax temperature into the smoothing-conv
weights. Reads `self.temperature` (a buffer carried by the
checkpoint) when called with no argument; after folding the buffer
is reset to 1.0 so subsequent calls are no-ops.
"""
t = float(temperature if temperature is not None else self.temperature)
if t == 1.0:
return
with torch.no_grad():
self.conv[1].weight.div_(t)
self.conv[1].bias.div_(t)
self.temperature.fill_(1.0)
def forward(self, x_mic, x_ref, return_delay=False):
"""
x_mic: (B, C, T, F) β€” microphone encoder features
x_ref: (B, C, T, F) β€” far-end encoder features
Returns:
aligned: (B, C, T, F) β€” aligned far-end features
delay_dist: (B, T, dmax) β€” delay distribution (if return_delay=True)
"""
B, C, T, F = x_ref.shape
# Compute Q and K
Q = self.pconv_mic(x_mic) # (B, H, T, F)
K = self.pconv_ref(x_ref) # (B, H, T, F)
# Unfold K along time: creates dmax delayed copies
Ku = self.unfold_k(K) # (B, H*dmax, T*F)
Ku = Ku.view(B, self.hidden_channels, self.dmax, T, F)
Ku = Ku.permute(0, 1, 3, 2, 4).contiguous() # (B, H, T, dmax, F)
# Cross-attention similarity: sum over frequency, scaled by sqrt(F)
V = torch.sum(Q.unsqueeze(-2) * Ku, dim=-1) # (B, H, T, dmax)
V = V / math.sqrt(F)
# Smooth and reduce to single-head attention
V = self.conv(V) # (B, 1, T, dmax)
# Softmax over delay dimension (temperature < 1 sharpens the distribution)
A = torch.softmax(V / self.temperature, dim=-1) # (B, 1, T, dmax)
# Unfold x_ref (full channels) for the weighted sum
# BUG FIX: use in_channels (C), not hidden_channels
unfold_ref = nn.functional.pad(x_ref, [0, 0, self.dmax - 1, 0])
# Manual unfold along time dimension
# unfold_ref: (B, C, T+dmax-1, F)
ref_unfolded = unfold_ref.unfold(2, self.dmax, 1) # (B, C, T, F, dmax)
ref_unfolded = ref_unfolded.permute(0, 1, 2, 4, 3).contiguous() # (B, C, T, dmax, F)
# Weighted sum using attention weights
A_expanded = A[:, :, :, :, None] # (B, 1, T, dmax, 1)
aligned = torch.sum(ref_unfolded * A_expanded, dim=-2) # (B, C, T, F)
if return_delay:
return aligned, A.squeeze(1) # (B, T, dmax)
return aligned