| | import numpy as np |
| | import torch as t |
| | import models.utils.dist_adapter as dist |
| | import soundfile |
| | import librosa |
| | from models.utils.dist_utils import print_once |
| |
|
| | class DefaultSTFTValues: |
| | def __init__(self, hps): |
| | self.sr = hps.sr |
| | self.n_fft = 2048 |
| | self.hop_length = 256 |
| | self.window_size = 6 * self.hop_length |
| |
|
| | class STFTValues: |
| | def __init__(self, hps, n_fft, hop_length, window_size): |
| | self.sr = hps.sr |
| | self.n_fft = n_fft |
| | self.hop_length = hop_length |
| | self.window_size = window_size |
| |
|
| | def calculate_bandwidth(dataset, hps, duration=600): |
| | hps = DefaultSTFTValues(hps) |
| | n_samples = int(dataset.sr * duration) |
| | l1, total, total_sq, n_seen, idx = 0.0, 0.0, 0.0, 0.0, dist.get_rank() |
| | spec_norm_total, spec_nelem = 0.0, 0.0 |
| | while n_seen < n_samples: |
| | x = dataset[idx] |
| | if isinstance(x, (tuple, list)): |
| | x, y = x |
| | samples = x.astype(np.float64) |
| | stft = librosa.core.stft(np.mean(samples, axis=1), hps.n_fft, hop_length=hps.hop_length, win_length=hps.window_size) |
| | spec = np.absolute(stft) |
| | spec_norm_total += np.linalg.norm(spec) |
| | spec_nelem += 1 |
| | n_seen += int(np.prod(samples.shape)) |
| | l1 += np.sum(np.abs(samples)) |
| | total += np.sum(samples) |
| | total_sq += np.sum(samples ** 2) |
| | idx += max(16, dist.get_world_size()) |
| |
|
| | if dist.is_available(): |
| | from jukebox.utils.dist_utils import allreduce |
| | n_seen = allreduce(n_seen) |
| | total = allreduce(total) |
| | total_sq = allreduce(total_sq) |
| | l1 = allreduce(l1) |
| | spec_nelem = allreduce(spec_nelem) |
| | spec_norm_total = allreduce(spec_norm_total) |
| |
|
| | mean = total / n_seen |
| | bandwidth = dict(l2 = total_sq / n_seen - mean ** 2, |
| | l1 = l1 / n_seen, |
| | spec = spec_norm_total / spec_nelem) |
| | print_once(bandwidth) |
| | return bandwidth |
| |
|
| | def audio_preprocess(x, hps): |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | return x |
| |
|
| | def audio_postprocess(x, hps): |
| | return x |
| |
|
| | def stft(sig, hps): |
| | return t.stft(sig, hps.n_fft, hps.hop_length, win_length=hps.window_size, window=t.hann_window(hps.window_size, device=sig.device)) |
| |
|
| | def spec(x, hps): |
| | return t.norm(stft(x, hps), p=2, dim=-1) |
| |
|
| | def norm(x): |
| | return (x.view(x.shape[0], -1) ** 2).sum(dim=-1).sqrt() |
| |
|
| | def squeeze(x): |
| | if len(x.shape) == 3: |
| | assert x.shape[-1] in [1,2] |
| | x = t.mean(x, -1) |
| | if len(x.shape) != 2: |
| | raise ValueError(f'Unknown input shape {x.shape}') |
| | return x |
| |
|
| | def spectral_loss(x_in, x_out, hps): |
| | hps = DefaultSTFTValues(hps) |
| | spec_in = spec(squeeze(x_in.float()), hps) |
| | spec_out = spec(squeeze(x_out.float()), hps) |
| | return norm(spec_in - spec_out) |
| |
|
| | def multispectral_loss(x_in, x_out, hps): |
| | losses = [] |
| | assert len(hps.multispec_loss_n_fft) == len(hps.multispec_loss_hop_length) == len(hps.multispec_loss_window_size) |
| | args = [hps.multispec_loss_n_fft, |
| | hps.multispec_loss_hop_length, |
| | hps.multispec_loss_window_size] |
| | for n_fft, hop_length, window_size in zip(*args): |
| | hps = STFTValues(hps, n_fft, hop_length, window_size) |
| | spec_in = spec(squeeze(x_in.float()), hps) |
| | spec_out = spec(squeeze(x_out.float()), hps) |
| | losses.append(norm(spec_in - spec_out)) |
| | return sum(losses) / len(losses) |
| |
|
| | def spectral_convergence(x_in, x_out, hps, epsilon=2e-3): |
| | hps = DefaultSTFTValues(hps) |
| | spec_in = spec(squeeze(x_in.float()), hps) |
| | spec_out = spec(squeeze(x_out.float()), hps) |
| |
|
| | gt_norm = norm(spec_in) |
| | residual_norm = norm(spec_in - spec_out) |
| | mask = (gt_norm > epsilon).float() |
| | return (residual_norm * mask) / t.clamp(gt_norm, min=epsilon) |
| |
|
| | def log_magnitude_loss(x_in, x_out, hps, epsilon=1e-4): |
| | hps = DefaultSTFTValues(hps) |
| | spec_in = t.log(spec(squeeze(x_in.float()), hps) + epsilon) |
| | spec_out = t.log(spec(squeeze(x_out.float()), hps) + epsilon) |
| | return t.mean(t.abs(spec_in - spec_out)) |
| |
|
| | def load_audio(file, sr, offset, duration, mono=False): |
| | |
| | x, _ = librosa.load(file, sr=sr, mono=mono, offset=offset/sr, duration=duration/sr) |
| | if len(x.shape) == 1: |
| | x = x.reshape((1, -1)) |
| | return x |
| |
|
| |
|
| | def save_wav(fname, aud, sr): |
| | |
| | aud = t.clamp(aud, -1, 1).cpu().numpy() |
| | for i in list(range(aud.shape[0])): |
| | soundfile.write(f'{fname}/item_{i}.wav', aud[i], samplerate=sr, format='wav') |
| |
|
| |
|
| |
|