| | |
| | import torch, math |
| | from torch import Tensor |
| | import torch.nn.functional as F |
| |
|
| | def _lagrange_weights(d: Tensor, taps: int = 8) -> Tensor: |
| | """Return (B, taps) weights for 0 ≤ d < 1.""" |
| | n = torch.arange(taps, device=d.device, dtype=d.dtype) |
| | w = torch.ones(d.shape + (taps,), dtype=d.dtype, device=d.device) |
| | for k in range(taps): |
| | others = torch.cat([n[:k], n[k+1:]]) |
| | w[..., k] = torch.prod((d.unsqueeze(-1) - others) / (n[k] - others), dim=-1) |
| | return w |
| |
|
| | def gtw_shift(x: Tensor, delay: Tensor) -> Tensor: |
| | """ |
| | ZeroBAS‑style GTW: constant ITD per clip. |
| | x: (B, T) |
| | delay: (B,) or any constant‑valued (B,T) |
| | """ |
| | if delay.dim() == 0: |
| | delay = delay.unsqueeze(0) |
| | if delay.dim() == 2: |
| | if not torch.allclose(delay, delay[:, :1].expand_as(delay)): |
| | raise ValueError("delay must be constant per item") |
| | delay = delay[:, 0] |
| |
|
| | taps, pad = 8, 4 |
| | total = -delay |
| | d_int = torch.floor(total).to(torch.int64) |
| | d_frac = (total - d_int).float() |
| |
|
| | kernel = _lagrange_weights(d_frac, taps).flip(-1).unsqueeze(1) |
| | y = torch.nn.functional.conv1d( |
| | x.unsqueeze(1), kernel, padding=pad, groups=x.size(0) |
| | ).squeeze(1) |
| |
|
| | y = y.roll(-pad, dims=1)[..., : x.size(1)] |
| |
|
| | for b in range(x.size(0)): |
| | if d_int[b] != 0: |
| | y[b] = torch.roll(y[b], int(-d_int[b]), 0) |
| | return y |
| |
|
| |
|
| | def _linear_weights(d: torch.Tensor) -> torch.Tensor: |
| | |
| | return torch.stack([1.0 - d, d], dim=-1) |
| |
|
| | import torch |
| |
|
| | def gtw_shift_linear(x: torch.Tensor, |
| | delay: torch.Tensor, |
| | *, debug: bool = False) -> torch.Tensor: |
| | """ |
| | Linear-interpolation fractional delay. |
| | |
| | • Positive delay → advance (earlier), just like ZeroBAS / the tests |
| | • Negative delay → retard (later) |
| | • When `delay` is an *exact integer*, the output is a pure cyclic roll, |
| | matching the reference tests. |
| | |
| | Shapes |
| | ------ |
| | x : (B, T) |
| | delay : (B,) |
| | """ |
| | B, T = x.shape |
| | dtype, dev = x.dtype, x.device |
| |
|
| | delay = delay.to(dtype) |
| | int_part = delay.round().to(torch.int64) |
| | is_integer = torch.isclose(delay, int_part.to(dtype), atol=1e-7) |
| |
|
| | |
| | n = torch.arange(T, device=dev, dtype=dtype).unsqueeze(0) |
| | src = n + delay.unsqueeze(1) |
| | src_clamped = torch.clamp(src, 0, T - 1) |
| |
|
| | i0 = src_clamped.floor().to(torch.long) |
| | frac = (src_clamped - i0.to(dtype)) |
| | i1 = torch.clamp(i0 + 1, max=T - 1) |
| |
|
| | y = (1.0 - frac) * x.gather(1, i0) + frac * x.gather(1, i1) |
| |
|
| | |
| | for b in range(B): |
| | if is_integer[b]: |
| | shift = -int(int_part[b].item()) |
| | if shift: |
| | y[b] = torch.roll(x[b], shifts=shift, dims=0) |
| |
|
| | if debug: |
| | print("delay :", delay) |
| | print("is_integer :", is_integer) |
| | print("int_part :", int_part) |
| | return y |
| |
|