import threading import torch import numpy as np from types import SimpleNamespace from src.spectttra.feature import FeatureExtractor from src.spectttra.spectttra import ( SpecTTTra, build_spectttra_from_cfg, load_frozen_spectttra, ) # Shared variables for the model and setup, loaded only once and reused (cache) _PREDICTOR_LOCK = threading.Lock() _FEAT_EXT = None _MODEL = None _CFG = None _DEVICE = None def build_spectttra(cfg, device): """ Wrapper that builds SpecTTTra + FeatureExtractor and loads frozen checkpoint. """ feat_ext, model = build_spectttra_from_cfg(cfg, device) model = load_frozen_spectttra( model, "models/spectttra/spectttra_frozen.pth", device ) return feat_ext, model def _init_predictor_once(): """ Initialize and cache FeatureExtractor and SpecTTTra once per process. Ensures thread-safe, one-time initialization of the feature extractor and transformer model, including moving them to the appropriate device. This function also sets default configurations for audio, mel-spectrogram extraction, and model architecture. """ global _FEAT_EXT, _MODEL, _CFG, _DEVICE if _MODEL is not None and _FEAT_EXT is not None: return with _PREDICTOR_LOCK: if _MODEL is not None and _FEAT_EXT is not None: return # Configurations of best performing variant for 120s cfg = SimpleNamespace( audio=SimpleNamespace(sample_rate=16000, max_time=120, max_len=16000 * 120), melspec=SimpleNamespace( n_fft=2048, hop_length=512, win_length=2048, n_mels=128, f_min=20, f_max=8000, power=2, top_db=80, norm="mean_std", ), model=SimpleNamespace( embed_dim=384, num_heads=6, num_layers=12, t_clip=3, f_clip=1, pre_norm=True, pe_learnable=True, pos_drop_rate=0.1, attn_drop_rate=0.1, proj_drop_rate=0.0, mlp_ratio=2.67, ), ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") feat_ext, model = build_spectttra(cfg, device) feat_ext.to(device) # Move model to device (GPU if available) and allow faster inference with mixed precision model.to(device).eval() # Cache _FEAT_EXT, _MODEL, _CFG, _DEVICE = feat_ext, model, cfg, device def spectttra_predict(audio_tensor): """ Run single-input inference with SpecTTTra. Args: audio_tensor (torch.Tensor): Input waveform of shape (1, num_samples). Must already be preprocessed including resampled to the target sampling rate (16 kHz). Returns: np.ndarray: 1D embedding vector of shape (embed_dim,). The embedding is obtained by mean-pooling the transformer token outputs. """ global _FEAT_EXT, _MODEL, _CFG, _DEVICE _init_predictor_once() device = _DEVICE feat_ext = _FEAT_EXT model = _MODEL cfg = _CFG # Move waveform to device but keep float for mel extraction waveform = audio_tensor.to(device, dtype=torch.float32) with torch.no_grad(): # Extract mel-spectrogram melspec = feat_ext(waveform) # Ensure melspec shape matches model's expectation --- expected_frames = model.input_temp_dim # expected_frames is 3744 if melspec.shape[2] > expected_frames: melspec = melspec[:, :, :expected_frames] elif melspec.shape[2] < expected_frames: padding = expected_frames - melspec.shape[2] melspec = torch.nn.functional.pad(melspec, (0, padding)) if device.type == "cuda": with torch.amp.autocast("cuda", enabled=True): tokens = model(melspec) pooled = tokens.mean(dim=1) else: tokens = model(melspec) pooled = tokens.mean(dim=1) out = pooled.squeeze(0).cpu().numpy() return out def spectttra_train(audio_tensors): """ Run batch input training with SpecTTTra. Args: audio_tensors (list[torch.Tensor]): List of input waveforms. Each element should be shaped either (num_samples,) or (1, num_samples). Each waveform is processed independently and its pooled embedding is collected. Returns: np.ndarray: 2D array of shape (batch_size, embed_dim), where each row corresponds to the pooled embedding for one input waveform. """ global _FEAT_EXT, _MODEL, _CFG, _DEVICE _init_predictor_once() if not audio_tensors: return np.empty((0, _CFG.model.embed_dim)) feat_ext = _FEAT_EXT model = _MODEL device = _DEVICE # Chunk processing: Process in smaller batches chunk_size = 50 all_embeddings = [] for i in range(0, len(audio_tensors), chunk_size): chunk = audio_tensors[i : i + chunk_size] print( f"[INFO] Processing chunk {i//chunk_size + 1}/{(len(audio_tensors)-1)//chunk_size + 1} ({len(chunk)} samples)" ) try: waveforms_batch = torch.cat(chunk, dim=0).to(device).float() except Exception as e: print( f"[INFO] Error during tensor concatenation, falling back to loop. Error: {e}" ) batch_list = [spectttra_predict(w) for w in chunk] all_embeddings.extend(batch_list) continue with torch.no_grad(): melspec = feat_ext(waveforms_batch) # Ensure melspec shape matches model's expectation expected_frames = model.input_temp_dim if melspec.shape[2] > expected_frames: melspec = melspec[:, :, :expected_frames] elif melspec.shape[2] < expected_frames: padding = expected_frames - melspec.shape[2] melspec = torch.nn.functional.pad(melspec, (0, padding)) if device.type == "cuda": with torch.cuda.amp.autocast(enabled=True): tokens = model(melspec) pooled = tokens.mean(dim=1) else: tokens = model(melspec) pooled = tokens.mean(dim=1) chunk_embeddings = pooled.cpu().numpy() all_embeddings.append(chunk_embeddings) # Clear GPU cache after each chunk if device.type == "cuda": torch.cuda.empty_cache() return np.vstack(all_embeddings)