File size: 6,738 Bytes
fc7b4a9
 
 
 
 
 
e26dafd
 
 
 
 
fc7b4a9
 
 
 
 
 
 
 
 
 
 
75d43d2
fc7b4a9
75d43d2
e26dafd
 
 
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d43d2
fc7b4a9
 
75d43d2
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d43d2
fc7b4a9
 
 
 
 
 
 
 
 
 
e26dafd
fc7b4a9
 
 
75d43d2
 
 
 
 
 
 
 
 
fc7b4a9
 
c51ad28
75d43d2
 
fc7b4a9
 
 
 
75d43d2
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61f21af
 
 
 
 
 
e26dafd
61f21af
e26dafd
75d43d2
61f21af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d43d2
61f21af
 
fc7b4a9
61f21af
75d43d2
61f21af
fc7b4a9
61f21af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import threading
import torch
import numpy as np
from types import SimpleNamespace

from src.spectttra.feature import FeatureExtractor
from src.spectttra.spectttra import (
    SpecTTTra,
    build_spectttra_from_cfg,
    load_frozen_spectttra,
)

# Shared variables for the model and setup, loaded only once and reused (cache)
_PREDICTOR_LOCK = threading.Lock()
_FEAT_EXT = None
_MODEL = None
_CFG = None
_DEVICE = None


def build_spectttra(cfg, device):
    """
    Wrapper that builds SpecTTTra + FeatureExtractor and loads frozen checkpoint.
    """
    feat_ext, model = build_spectttra_from_cfg(cfg, device)
    model = load_frozen_spectttra(
        model, "models/spectttra/spectttra_frozen.pth", device
    )
    return feat_ext, model


def _init_predictor_once():
    """
    Initialize and cache FeatureExtractor and SpecTTTra once per process.

    Ensures thread-safe, one-time initialization of the feature extractor and
    transformer model, including moving them to the appropriate device.

    This function also sets default configurations for audio,
    mel-spectrogram extraction, and model architecture.
    """

    global _FEAT_EXT, _MODEL, _CFG, _DEVICE

    if _MODEL is not None and _FEAT_EXT is not None:
        return

    with _PREDICTOR_LOCK:
        if _MODEL is not None and _FEAT_EXT is not None:
            return

        # Configurations of best performing variant for 120s
        cfg = SimpleNamespace(
            audio=SimpleNamespace(sample_rate=16000, max_time=120, max_len=16000 * 120),
            melspec=SimpleNamespace(
                n_fft=2048,
                hop_length=512,
                win_length=2048,
                n_mels=128,
                f_min=20,
                f_max=8000,
                power=2,
                top_db=80,
                norm="mean_std",
            ),
            model=SimpleNamespace(
                embed_dim=384,
                num_heads=6,
                num_layers=12,
                t_clip=3,
                f_clip=1,
                pre_norm=True,
                pe_learnable=True,
                pos_drop_rate=0.1,
                attn_drop_rate=0.1,
                proj_drop_rate=0.0,
                mlp_ratio=2.67,
            ),
        )

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        feat_ext, model = build_spectttra(cfg, device)
        feat_ext.to(device)

        # Move model to device (GPU if available) and allow faster inference with mixed precision
        model.to(device).eval()

        # Cache
        _FEAT_EXT, _MODEL, _CFG, _DEVICE = feat_ext, model, cfg, device


def spectttra_predict(audio_tensor):
    """
    Run single-input inference with SpecTTTra.

    Args:
        audio_tensor (torch.Tensor): Input waveform of shape (1, num_samples).
            Must already be preprocessed including resampled to the target sampling rate (16 kHz).

    Returns:
        np.ndarray:
            1D embedding vector of shape (embed_dim,). The embedding is obtained
            by mean-pooling the transformer token outputs.
    """

    global _FEAT_EXT, _MODEL, _CFG, _DEVICE

    _init_predictor_once()

    device = _DEVICE
    feat_ext = _FEAT_EXT
    model = _MODEL
    cfg = _CFG

    # Move waveform to device but keep float for mel extraction
    waveform = audio_tensor.to(device, dtype=torch.float32)

    with torch.no_grad():
        # Extract mel-spectrogram
        melspec = feat_ext(waveform)

        # Ensure melspec shape matches model's expectation ---
        expected_frames = model.input_temp_dim  # expected_frames is 3744
        if melspec.shape[2] > expected_frames:
            melspec = melspec[:, :, :expected_frames]
        elif melspec.shape[2] < expected_frames:
            padding = expected_frames - melspec.shape[2]
            melspec = torch.nn.functional.pad(melspec, (0, padding))

        if device.type == "cuda":
            with torch.amp.autocast("cuda", enabled=True):
                tokens = model(melspec)
                pooled = tokens.mean(dim=1)
        else:
            tokens = model(melspec)
            pooled = tokens.mean(dim=1)

    out = pooled.squeeze(0).cpu().numpy()
    return out


def spectttra_train(audio_tensors):
    """
    Run batch input training with SpecTTTra.

    Args:
        audio_tensors (list[torch.Tensor]):
            List of input waveforms. Each element should be shaped either
            (num_samples,) or (1, num_samples). Each waveform is processed
            independently and its pooled embedding is collected.

    Returns:
        np.ndarray:
            2D array of shape (batch_size, embed_dim), where each row
            corresponds to the pooled embedding for one input waveform.
    """

    global _FEAT_EXT, _MODEL, _CFG, _DEVICE

    _init_predictor_once()

    if not audio_tensors:
        return np.empty((0, _CFG.model.embed_dim))

    feat_ext = _FEAT_EXT
    model = _MODEL
    device = _DEVICE

    # Chunk processing: Process in smaller batches
    chunk_size = 50
    all_embeddings = []

    for i in range(0, len(audio_tensors), chunk_size):
        chunk = audio_tensors[i : i + chunk_size]
        print(
            f"[INFO] Processing chunk {i//chunk_size + 1}/{(len(audio_tensors)-1)//chunk_size + 1} ({len(chunk)} samples)"
        )

        try:
            waveforms_batch = torch.cat(chunk, dim=0).to(device).float()
        except Exception as e:
            print(
                f"[INFO] Error during tensor concatenation, falling back to loop. Error: {e}"
            )
            batch_list = [spectttra_predict(w) for w in chunk]
            all_embeddings.extend(batch_list)
            continue

        with torch.no_grad():
            melspec = feat_ext(waveforms_batch)

            # Ensure melspec shape matches model's expectation
            expected_frames = model.input_temp_dim
            if melspec.shape[2] > expected_frames:
                melspec = melspec[:, :, :expected_frames]
            elif melspec.shape[2] < expected_frames:
                padding = expected_frames - melspec.shape[2]
                melspec = torch.nn.functional.pad(melspec, (0, padding))

            if device.type == "cuda":
                with torch.cuda.amp.autocast(enabled=True):
                    tokens = model(melspec)
                    pooled = tokens.mean(dim=1)
            else:
                tokens = model(melspec)
                pooled = tokens.mean(dim=1)

        chunk_embeddings = pooled.cpu().numpy()
        all_embeddings.append(chunk_embeddings)

        # Clear GPU cache after each chunk
        if device.type == "cuda":
            torch.cuda.empty_cache()

    return np.vstack(all_embeddings)