# gtw.py – ZeroBAS‑faithful GTW, batch‑vectorised import torch, math from torch import Tensor import torch.nn.functional as F def _lagrange_weights(d: Tensor, taps: int = 8) -> Tensor: """Return (B, taps) weights for 0 ≤ d < 1.""" n = torch.arange(taps, device=d.device, dtype=d.dtype) # 0..7 w = torch.ones(d.shape + (taps,), dtype=d.dtype, device=d.device) for k in range(taps): others = torch.cat([n[:k], n[k+1:]]) w[..., k] = torch.prod((d.unsqueeze(-1) - others) / (n[k] - others), dim=-1) return w # (B, taps) def gtw_shift(x: Tensor, delay: Tensor) -> Tensor: """ ZeroBAS‑style GTW: constant ITD per clip. x: (B, T) delay: (B,) or any constant‑valued (B,T) """ if delay.dim() == 0: delay = delay.unsqueeze(0) if delay.dim() == 2: # squeeze if constant if not torch.allclose(delay, delay[:, :1].expand_as(delay)): raise ValueError("delay must be constant per item") delay = delay[:, 0] taps, pad = 8, 4 total = -delay # ① Positive Δ ⇒ phase‑advance d_int = torch.floor(total).to(torch.int64) d_frac = (total - d_int).float() # 0 ≤ d_frac < 1 kernel = _lagrange_weights(d_frac, taps).flip(-1).unsqueeze(1) y = torch.nn.functional.conv1d( x.unsqueeze(1), kernel, padding=pad, groups=x.size(0) ).squeeze(1) y = y.roll(-pad, dims=1)[..., : x.size(1)] for b in range(x.size(0)): if d_int[b] != 0: y[b] = torch.roll(y[b], int(-d_int[b]), 0) return y def _linear_weights(d: torch.Tensor) -> torch.Tensor: # (B,) -> (B,2) return torch.stack([1.0 - d, d], dim=-1) import torch def gtw_shift_linear(x: torch.Tensor, delay: torch.Tensor, *, debug: bool = False) -> torch.Tensor: """ Linear-interpolation fractional delay. • Positive delay → advance (earlier), just like ZeroBAS / the tests • Negative delay → retard (later) • When `delay` is an *exact integer*, the output is a pure cyclic roll, matching the reference tests. Shapes ------ x : (B, T) delay : (B,) """ B, T = x.shape dtype, dev = x.dtype, x.device delay = delay.to(dtype) # ensure same dtype/device int_part = delay.round().to(torch.int64) # nearest integer is_integer = torch.isclose(delay, int_part.to(dtype), atol=1e-7) # ── Common path: direct gather-style interpolation ─────────────────── n = torch.arange(T, device=dev, dtype=dtype).unsqueeze(0) # (1,T) src = n + delay.unsqueeze(1) # (B,T) src_clamped = torch.clamp(src, 0, T - 1) i0 = src_clamped.floor().to(torch.long) # (B,T) frac = (src_clamped - i0.to(dtype)) i1 = torch.clamp(i0 + 1, max=T - 1) y = (1.0 - frac) * x.gather(1, i0) + frac * x.gather(1, i1) # ── Overwrite rows whose delay is an exact integer with a cyclic roll ─ for b in range(B): if is_integer[b]: shift = -int(int_part[b].item()) # advance ⇔ negative roll if shift: y[b] = torch.roll(x[b], shifts=shift, dims=0) if debug: print("delay :", delay) print("is_integer :", is_integer) print("int_part :", int_part) return y