roombox / gtw.py
ak36's picture
Upload folder using huggingface_hub
3e21dc5 verified
# gtw.py – ZeroBAS‑faithful GTW, batch‑vectorised
import torch, math
from torch import Tensor
import torch.nn.functional as F
def _lagrange_weights(d: Tensor, taps: int = 8) -> Tensor:
"""Return (B, taps) weights for 0 ≤ d < 1."""
n = torch.arange(taps, device=d.device, dtype=d.dtype) # 0..7
w = torch.ones(d.shape + (taps,), dtype=d.dtype, device=d.device)
for k in range(taps):
others = torch.cat([n[:k], n[k+1:]])
w[..., k] = torch.prod((d.unsqueeze(-1) - others) / (n[k] - others), dim=-1)
return w # (B, taps)
def gtw_shift(x: Tensor, delay: Tensor) -> Tensor:
"""
ZeroBAS‑style GTW: constant ITD per clip.
x: (B, T)
delay: (B,) or any constant‑valued (B,T)
"""
if delay.dim() == 0:
delay = delay.unsqueeze(0)
if delay.dim() == 2: # squeeze if constant
if not torch.allclose(delay, delay[:, :1].expand_as(delay)):
raise ValueError("delay must be constant per item")
delay = delay[:, 0]
taps, pad = 8, 4
total = -delay # ① Positive Δ ⇒ phase‑advance
d_int = torch.floor(total).to(torch.int64)
d_frac = (total - d_int).float() # 0 ≤ d_frac < 1
kernel = _lagrange_weights(d_frac, taps).flip(-1).unsqueeze(1)
y = torch.nn.functional.conv1d(
x.unsqueeze(1), kernel, padding=pad, groups=x.size(0)
).squeeze(1)
y = y.roll(-pad, dims=1)[..., : x.size(1)]
for b in range(x.size(0)):
if d_int[b] != 0:
y[b] = torch.roll(y[b], int(-d_int[b]), 0)
return y
def _linear_weights(d: torch.Tensor) -> torch.Tensor:
# (B,) -> (B,2)
return torch.stack([1.0 - d, d], dim=-1)
import torch
def gtw_shift_linear(x: torch.Tensor,
delay: torch.Tensor,
*, debug: bool = False) -> torch.Tensor:
"""
Linear-interpolation fractional delay.
• Positive delay → advance (earlier), just like ZeroBAS / the tests
• Negative delay → retard (later)
• When `delay` is an *exact integer*, the output is a pure cyclic roll,
matching the reference tests.
Shapes
------
x : (B, T)
delay : (B,)
"""
B, T = x.shape
dtype, dev = x.dtype, x.device
delay = delay.to(dtype) # ensure same dtype/device
int_part = delay.round().to(torch.int64) # nearest integer
is_integer = torch.isclose(delay, int_part.to(dtype), atol=1e-7)
# ── Common path: direct gather-style interpolation ───────────────────
n = torch.arange(T, device=dev, dtype=dtype).unsqueeze(0) # (1,T)
src = n + delay.unsqueeze(1) # (B,T)
src_clamped = torch.clamp(src, 0, T - 1)
i0 = src_clamped.floor().to(torch.long) # (B,T)
frac = (src_clamped - i0.to(dtype))
i1 = torch.clamp(i0 + 1, max=T - 1)
y = (1.0 - frac) * x.gather(1, i0) + frac * x.gather(1, i1)
# ── Overwrite rows whose delay is an exact integer with a cyclic roll ─
for b in range(B):
if is_integer[b]:
shift = -int(int_part[b].item()) # advance ⇔ negative roll
if shift:
y[b] = torch.roll(x[b], shifts=shift, dims=0)
if debug:
print("delay :", delay)
print("is_integer :", is_integer)
print("int_part :", int_part)
return y