File size: 5,505 Bytes
c2f1451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import numpy as np
from numpy.fft import rfft
from numpy.lib.stride_tricks import as_strided
from scipy.signal import get_window

def stft_multi(
    x,
    fs: float,
    win_s: float = 0.032,
    hop_s: float = 0.010,
    nfft: int | None = None,
    window: str | tuple | np.ndarray = "hann",
    center: bool = True,
    pad_mode: str = "reflect",
    out_dtype = np.complex64,
):
    """
    Multichannel STFT (vectorized).
    Args
    ----
    x       : np.ndarray, shape (N, C)  time-domain signal
    fs      : float, sampling rate (Hz)
    win_s   : float, window length in seconds (default 32 ms)
    hop_s   : float, hop length in seconds (default 10 ms)
    nfft    : int or None. If None, uses next power of two >= frame_len
    window  : str/tuple/array for scipy.signal.get_window or a length-L array
    center  : if True, pad by L//2 on both sides (librosa-style)
    pad_mode: np.pad mode (e.g., "reflect", "constant")
    out_dtype: dtype for STFT output (complex64 recommended)

    Returns
    -------
    X   : np.ndarray, shape (T, C, F) complex STFT
    freqs: np.ndarray, shape (F,) frequency bins in Hz
    times: np.ndarray, shape (T,) frame center times in seconds
    """
    x = np.asarray(x)
    if x.ndim == 1:
        x = x[:, None]  # (N,1)
    assert x.ndim == 2, "x must be (samples, channels)"
    N, C = x.shape

    # Window & hop in samples
    frame_len = int(round(win_s * fs))
    hop = int(round(hop_s * fs))
    if frame_len <= 0 or hop <= 0:
        raise ValueError("win_s and hop_s must be > 0")

    # FFT size
    def _next_pow2(n):
        return 1 << (int(n - 1).bit_length())
    nfft = _next_pow2(frame_len) if nfft is None else int(nfft)
    if nfft < frame_len:
        raise ValueError("nfft must be >= frame_len")

    # Window vector
    if isinstance(window, np.ndarray):
        w = window.astype(float, copy=False)
    else:
        w = get_window(window, frame_len, fftbins=True).astype(float)
    if w.shape[0] != frame_len:
        raise ValueError("Provided window length != frame_len")

    # Optional centering (pad by L//2 on both sides)
    pad = frame_len // 2 if center else 0
    if pad > 0:
        x_pad = np.pad(x, ((pad, pad), (0, 0)), mode=pad_mode)
    else:
        x_pad = x

    Np = x_pad.shape[0]
    if Np < frame_len:
        # ensure at least one frame
        x_pad = np.pad(x_pad, ((0, frame_len - Np), (0, 0)), mode=pad_mode)
        Np = x_pad.shape[0]

    # Number of frames
    T = 1 + (Np - frame_len) // hop
    if T <= 0:
        raise ValueError("Signal too short for given window/hop")

    # Stride-trick framing: (T, frame_len, C) view into x_pad
    s_t, s_c = x_pad.strides  # bytes per step in time/channel
    frames = as_strided(
        x_pad,
        shape=(T, frame_len, C),
        strides=(hop * s_t, s_t, s_c),
        writeable=False,
    )
    # Reorder to (T, C, frame_len) to apply window & FFT along the last axis
    frames = np.transpose(frames, (0, 2, 1))  # (T, C, L)

    # Apply window (broadcast over T and C)
    frames = frames * w[None, None, :]

    # Batched real FFT along last axis -> (T, C, F)
    X = rfft(frames, n=nfft, axis=-1).astype(out_dtype, copy=False)

    # Frequency and time vectors
    F = X.shape[-1]
    freqs = (fs / nfft) * np.arange(F)
    # Frame centers relative to original signal
    if center:
        # centers at sample indices: t*hop  (librosa convention)
        times = (np.arange(T) * hop) / fs
    else:
        # window centered at (frame_len/2) + t*hop
        times = (np.arange(T) * hop + frame_len / 2.0) / fs

    return X, freqs, times



def _wrap_to_2pi(x: np.ndarray) -> np.ndarray:
    """Wrap angles to [0, 2π)."""
    return np.mod(x, 2.0 * np.pi)

def compute_mag_phase(
    X: np.ndarray,
    dtype=np.float32,
):
    """
    Per-channel magnitude and absolute phase (wrapped to [0, 2π)).

    Args
    ----
    X    : np.ndarray, shape (T, C, F), complex STFT
    dtype: output dtype

    Returns
    -------
    mag   : np.ndarray, shape (T, C, F) = |X|
    phase : np.ndarray, shape (T, C, F) = angle(X) in [0, 2π)
    """
    assert X.ndim == 3, "X must be (T, C, F)"
    mag = np.abs(X).astype(dtype, copy=False)
    phase = _wrap_to_2pi(np.angle(X)).astype(dtype, copy=False)
    return mag, phase

def compute_mag_phase_cos_sin(
    X: np.ndarray,
    dtype=np.float32,
):
    """
    Concatenate per-channel magnitude, cos(phase), sin(phase).

    Args
    ----
    X    : np.ndarray, shape (T, C, F), complex STFT
    dtype: output dtype

    Returns
    -------
    feats : np.ndarray, shape (T, 3*C, F)
        Layout = [mag (C), cos(phase) (C), sin(phase) (C)]
        where phase is angle(X) wrapped to [0, 2π).
    """
    mag, phase = compute_mag_phase(X, dtype=dtype)
    cos_phase = np.cos(phase).astype(dtype, copy=False)
    sin_phase = np.sin(phase).astype(dtype, copy=False)
    feats = np.concatenate([mag, cos_phase, sin_phase], axis=1)
    return feats

def compute_real_imag_features(
    X: np.ndarray,
    dtype=np.float32,
):
    """
    Concatenate per-channel real and imaginary parts.

    Args
    ----
    X    : np.ndarray, shape (T, C, F), complex STFT
    dtype: output dtype

    Returns
    -------
    feats : np.ndarray, shape (T, 2*C, F)
        Layout = [Re (C), Im (C)]
    """
    assert X.ndim == 3, "X must be (T, C, F)"
    real = X.real.astype(dtype, copy=False)
    imag = X.imag.astype(dtype, copy=False)
    feats = np.concatenate([real, imag], axis=1)
    return feats