File size: 3,635 Bytes
3e21dc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# gtw.py  – ZeroBAS‑faithful GTW, batch‑vectorised
import torch, math
from torch import Tensor
import torch.nn.functional as F

def _lagrange_weights(d: Tensor, taps: int = 8) -> Tensor:
    """Return (B, taps) weights for 0 ≤ d < 1."""
    n = torch.arange(taps, device=d.device, dtype=d.dtype)   # 0..7
    w = torch.ones(d.shape + (taps,), dtype=d.dtype, device=d.device)
    for k in range(taps):
        others = torch.cat([n[:k], n[k+1:]])
        w[..., k] = torch.prod((d.unsqueeze(-1) - others) / (n[k] - others), dim=-1)
    return w                                                # (B, taps)

def gtw_shift(x: Tensor, delay: Tensor) -> Tensor:
    """
    ZeroBAS‑style GTW: constant ITD per clip.
      x:     (B, T)
      delay: (B,) or any constant‑valued (B,T)
    """
    if delay.dim() == 0:
        delay = delay.unsqueeze(0)
    if delay.dim() == 2:                              # squeeze if constant
        if not torch.allclose(delay, delay[:, :1].expand_as(delay)):
            raise ValueError("delay must be constant per item")
        delay = delay[:, 0]

    taps, pad = 8, 4
    total   = -delay                        # ① Positive Δ ⇒ phase‑advance
    d_int   = torch.floor(total).to(torch.int64)
    d_frac  = (total - d_int).float()       # 0 ≤ d_frac < 1

    kernel = _lagrange_weights(d_frac, taps).flip(-1).unsqueeze(1)
    y = torch.nn.functional.conv1d(
        x.unsqueeze(1), kernel, padding=pad, groups=x.size(0)
    ).squeeze(1)

    y = y.roll(-pad, dims=1)[..., : x.size(1)] 

    for b in range(x.size(0)):
        if d_int[b] != 0:
            y[b] = torch.roll(y[b], int(-d_int[b]), 0)
    return y


def _linear_weights(d: torch.Tensor) -> torch.Tensor:
    # (B,) -> (B,2)
    return torch.stack([1.0 - d, d], dim=-1)

import torch

def gtw_shift_linear(x: torch.Tensor,
                     delay: torch.Tensor,
                     *, debug: bool = False) -> torch.Tensor:
    """
    Linear-interpolation fractional delay.

    • Positive delay  → advance (earlier), just like ZeroBAS / the tests
    • Negative delay  → retard (later)
    • When `delay` is an *exact integer*, the output is a pure cyclic roll,
      matching the reference tests.

    Shapes
    ------
    x      : (B, T)
    delay  : (B,)
    """
    B, T         = x.shape
    dtype, dev   = x.dtype, x.device

    delay        = delay.to(dtype)                     # ensure same dtype/device
    int_part     = delay.round().to(torch.int64)       # nearest integer
    is_integer   = torch.isclose(delay, int_part.to(dtype), atol=1e-7)

    # ── Common path: direct gather-style interpolation ───────────────────
    n            = torch.arange(T, device=dev, dtype=dtype).unsqueeze(0)  # (1,T)
    src          = n + delay.unsqueeze(1)                                 # (B,T)
    src_clamped  = torch.clamp(src, 0, T - 1)

    i0           = src_clamped.floor().to(torch.long)                     # (B,T)
    frac         = (src_clamped - i0.to(dtype))
    i1           = torch.clamp(i0 + 1, max=T - 1)

    y            = (1.0 - frac) * x.gather(1, i0) + frac * x.gather(1, i1)

    # ── Overwrite rows whose delay is an exact integer with a cyclic roll ─
    for b in range(B):
        if is_integer[b]:
            shift = -int(int_part[b].item())          # advance ⇔ negative roll
            if shift:
                y[b] = torch.roll(x[b], shifts=shift, dims=0)

    if debug:
        print("delay      :", delay)
        print("is_integer :", is_integer)
        print("int_part   :", int_part)
    return y