import torch import numpy as np import soundfile as sf from src.audio import load_audio, get_melspec from src.config import SR from src.utils import get_idx, to_square # https://www.kaggle.com/code/tarunpaparaju/birdcall-identification-spectrogram-loader def to_imagenet(X, mean=None, std=None, norm_max=None, norm_min=None, eps=1e-6): mean = mean or X.mean() X = X - mean std = std or X.std() Xstd = X / (std + eps) _min, _max = Xstd.min(), Xstd.max() norm_max = norm_max or _max norm_min = norm_min or _min if (_max - _min) > eps: # Normalize to [0, 255] V = Xstd V[V < norm_min] = norm_min V[V > norm_max] = norm_max V = (V - norm_min) / (norm_max - norm_min) else: # Just zero V = np.zeros_like(Xstd, dtype=np.uint8) return V #np.stack([V]*3, axis=-1) def extract_melspec_as_imgarr(fp, n_secs=8, random_chunk=True, convert_to_int8=False): info = sf.info(fp) y, _ = load_audio(fp, SR) #, offset=start, duration=n_secs while True: start, end = get_idx(info.duration, n_secs, random_chunk=random_chunk) y2 = y[start:end] if len(y2): y = y2 break mel_dB = to_square(get_melspec(y, SR)) try: normalised_db = to_imagenet(mel_dB) # replaced minmax_scale except: normalised_db = torch.zeros_like(torch.as_tensor(mel_dB)) db_array = np.asarray(normalised_db)*255 if convert_to_int8: db_array = db_array.astype(np.uint8) return db_array[::-1].astype(float) def generate_test_images(fp, n=10): arrs = [] for _ in range(n): arrs.append(extract_melspec_as_imgarr(fp)) return torch.as_tensor(np.array(arrs)).unsqueeze(1)