import math import os import sys from pathlib import Path from typing import Optional from typing import Union import librosa import numpy as np import rich import soundfile as sf import torch from audiotools import AudioSignal from audiotools.core.util import random_state from flatten_dict import flatten from flatten_dict import unflatten ################################################################################ # General utilities ################################################################################ def count_parameters(m: torch.nn.Module, trainable: bool = False): if trainable: return sum([p.shape.numel() for p in m.parameters() if p.requires_grad]) else: return sum([p.shape.numel() for p in m.parameters()]) def exists(val): return val is not None def print(*args, **kwargs): local_rank = int(os.environ.get("LOCAL_RANK", 0)) if not local_rank: rich.print(*args, **kwargs, file=sys.stderr) def ensure_dir(directory: Union[str, Path]): directory = str(directory) if len(directory) > 0 and not os.path.exists(directory): os.makedirs(directory, exist_ok=True) def ensure_dir_for_filename(filename: str): ensure_dir(os.path.dirname(filename)) def collate(list_of_dicts: list, n_splits: int = None): """ Collates a list of dictionaries (e.g. as returned by a dataloader) into a dictionary with batched values. This function takes `n_splits` to enable splitting a batch into multiple sub-batches for the purposes of gradient accumulation, etc. Adapted from `audiotools.core.util.collate`. Parameters ---------- list_of_dicts : list List of dictionaries to be collated. n_splits : int Number of splits to make when creating the batches (split into sub- batches). Useful for things like gradient accumulation. Returns ------- dict Dictionary containing batched data. """ batches = [] list_len = len(list_of_dicts) return_list = False if n_splits is None else True n_splits = 1 if n_splits is None else n_splits n_items = int(math.ceil(list_len / n_splits)) for i in range(0, list_len, n_items): list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] dict_of_lists = { k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] } batch = {} for k, v in dict_of_lists.items(): if not isinstance(v, list): continue # AudioSignal → pad & batch if all(isinstance(s, AudioSignal) for s in v): batch[k] = AudioSignal.batch(v, pad_signals=True) # Strings / Paths → keep as list elif all(isinstance(s, (str, Path)) for s in v): batch[k] = v # All None → keep as list elif all(s is None for s in v): batch[k] = v else: # Fallback to torch default collate (tensors, numbers, mappings, etc.) try: batch[k] = torch.utils.data._utils.collate.default_collate(v) except TypeError: # Last-resort: keep raw list batch[k] = v batches.append(unflatten(batch)) return batches[0] if not return_list else batches def get_info(path: Union[str, Path]): info = sf.info(str(path)) return float(info.duration), int(info.samplerate) def load_audio( path: Union[str, Path], offset: float, duration: float, file_sample_rate: Optional[int] = None, ): """ SoundFile windowed loading seems to outperform `librosa.load` (used throughout `AudioSignal`) in limiting memory consumption; this helps avert crashes when training with large `num_workers`. """ if file_sample_rate is None: _duration, sample_rate = get_info(path) start = int(offset * sample_rate) n_samples = int(duration * sample_rate) with sf.SoundFile(str(path), "r") as f: f.seek(start) x = f.read( n_samples, dtype="float32", always_2d=True ).T # (n_channels, n_samples) x = torch.from_numpy(x)[None, :, :] # (n_batch==1, n_channels, n_samples) return AudioSignal(x, sample_rate=sample_rate) def rms_salience( path: str, duration: float, cutoff_db: float = -40.0, num_tries: int = 3, state: Optional[int] = None, file_duration: Optional[float] = None, file_sample_rate: Optional[int] = None, ) -> float: if file_duration is None or file_sample_rate is None: _duration, sample_rate = get_info(path) else: _duration, sample_rate = file_duration, file_sample_rate if not np.isfinite(_duration) or _duration <= 0 or _duration <= duration: return 0.0 state = random_state(state) max_offset = _duration - duration n_samples = int(duration * sample_rate) tries = max(1, int(num_tries)) best_db = -np.inf best_offset = None with sf.SoundFile(str(path), "r") as f: for _ in range(tries): offset = float(state.rand() * max_offset) start = int(offset * sample_rate) try: f.seek(start) y = f.read( n_samples, dtype="float32", always_2d=True ) # (n_samples, n_channels) y = y.mean(axis=1, dtype=np.float32) # (n_samples,) rms = float(np.sqrt(np.mean(y * y) + 1e-12)) db = 20.0 * np.log10(max(rms, 1e-12)) except Exception: continue if db >= cutoff_db: return offset if db > best_db: best_db, best_offset = db, offset return float(best_offset if best_offset is not None else state.rand() * max_offset)