| | import contextlib |
| | import importlib |
| | from huggingface_hub import hf_hub_download |
| | import numpy as np |
| | import torch |
| |
|
| | from inspect import isfunction |
| | import os |
| | import subprocess |
| | import tempfile |
| | import json |
| | import soundfile as sf |
| | import time |
| | import wave |
| | import torchaudio |
| | import progressbar |
| | from librosa.filters import mel as librosa_mel_fn |
| | from audiosr.lowpass import lowpass |
| |
|
| | hann_window = {} |
| | mel_basis = {} |
| |
|
| |
|
| | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): |
| | return torch.log(torch.clamp(x, min=clip_val) * C) |
| |
|
| |
|
| | def dynamic_range_decompression_torch(x, C=1): |
| | return torch.exp(x) / C |
| |
|
| |
|
| | def spectral_normalize_torch(magnitudes): |
| | output = dynamic_range_compression_torch(magnitudes) |
| | return output |
| |
|
| |
|
| | def spectral_de_normalize_torch(magnitudes): |
| | output = dynamic_range_decompression_torch(magnitudes) |
| | return output |
| |
|
| |
|
| | def _locate_cutoff_freq(stft, percentile=0.97): |
| | def _find_cutoff(x, percentile=0.95): |
| | percentile = x[-1] * percentile |
| | for i in range(1, x.shape[0]): |
| | if x[-i] < percentile: |
| | return x.shape[0] - i |
| | return 0 |
| |
|
| | magnitude = torch.abs(stft) |
| | energy = torch.cumsum(torch.sum(magnitude, dim=0), dim=0) |
| | return _find_cutoff(energy, percentile) |
| |
|
| |
|
| | def pad_wav(waveform, target_length): |
| | waveform_length = waveform.shape[-1] |
| | assert waveform_length > 100, "Waveform is too short, %s" % waveform_length |
| |
|
| | if waveform_length == target_length: |
| | return waveform |
| |
|
| | |
| | temp_wav = np.zeros((1, target_length), dtype=np.float32) |
| | rand_start = 0 |
| |
|
| | temp_wav[:, rand_start : rand_start + waveform_length] = waveform |
| | return temp_wav |
| |
|
| |
|
| | def lowpass_filtering_prepare_inference(dl_output): |
| | waveform = dl_output["waveform"] |
| | sampling_rate = dl_output["sampling_rate"] |
| |
|
| | cutoff_freq = ( |
| | _locate_cutoff_freq(dl_output["stft"], percentile=0.985) / 1024 |
| | ) * 24000 |
| | |
| | |
| | if(cutoff_freq < 1000): |
| | cutoff_freq = 24000 |
| |
|
| | order = 8 |
| | ftype = np.random.choice(["butter", "cheby1", "ellip", "bessel"]) |
| | filtered_audio = lowpass( |
| | waveform.numpy().squeeze(), |
| | highcut=cutoff_freq, |
| | fs=sampling_rate, |
| | order=order, |
| | _type=ftype, |
| | ) |
| |
|
| | filtered_audio = torch.FloatTensor(filtered_audio.copy()).unsqueeze(0) |
| |
|
| | if waveform.size(-1) <= filtered_audio.size(-1): |
| | filtered_audio = filtered_audio[..., : waveform.size(-1)] |
| | else: |
| | filtered_audio = torch.functional.pad( |
| | filtered_audio, (0, waveform.size(-1) - filtered_audio.size(-1)) |
| | ) |
| |
|
| | return {"waveform_lowpass": filtered_audio} |
| |
|
| |
|
| | def mel_spectrogram_train(y): |
| | global mel_basis, hann_window |
| |
|
| | sampling_rate = 48000 |
| | filter_length = 2048 |
| | hop_length = 480 |
| | win_length = 2048 |
| | n_mel = 256 |
| | mel_fmin = 20 |
| | mel_fmax = 24000 |
| |
|
| | if 24000 not in mel_basis: |
| | mel = librosa_mel_fn(sr=sampling_rate, n_fft=filter_length, n_mels=n_mel, fmin=mel_fmin, fmax=mel_fmax) |
| | mel_basis[str(mel_fmax) + "_" + str(y.device)] = ( |
| | torch.from_numpy(mel).float().to(y.device) |
| | ) |
| | hann_window[str(y.device)] = torch.hann_window(win_length).to(y.device) |
| |
|
| | y = torch.nn.functional.pad( |
| | y.unsqueeze(1), |
| | (int((filter_length - hop_length) / 2), int((filter_length - hop_length) / 2)), |
| | mode="reflect", |
| | ) |
| |
|
| | y = y.squeeze(1) |
| |
|
| | stft_spec = torch.stft( |
| | y, |
| | filter_length, |
| | hop_length=hop_length, |
| | win_length=win_length, |
| | window=hann_window[str(y.device)], |
| | center=False, |
| | pad_mode="reflect", |
| | normalized=False, |
| | onesided=True, |
| | return_complex=True, |
| | ) |
| |
|
| | stft_spec = torch.abs(stft_spec) |
| |
|
| | mel = spectral_normalize_torch( |
| | torch.matmul(mel_basis[str(mel_fmax) + "_" + str(y.device)], stft_spec) |
| | ) |
| |
|
| | return mel[0], stft_spec[0] |
| |
|
| |
|
| | def pad_spec(log_mel_spec, target_frame): |
| | n_frames = log_mel_spec.shape[0] |
| | p = target_frame - n_frames |
| | |
| | if p > 0: |
| | m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
| | log_mel_spec = m(log_mel_spec) |
| | elif p < 0: |
| | log_mel_spec = log_mel_spec[0:target_frame, :] |
| |
|
| | if log_mel_spec.size(-1) % 2 != 0: |
| | log_mel_spec = log_mel_spec[..., :-1] |
| |
|
| | return log_mel_spec |
| |
|
| |
|
| | def wav_feature_extraction(waveform, target_frame): |
| | waveform = waveform[0, ...] |
| | waveform = torch.FloatTensor(waveform) |
| |
|
| | log_mel_spec, stft = mel_spectrogram_train(waveform.unsqueeze(0)) |
| |
|
| | log_mel_spec = torch.FloatTensor(log_mel_spec.T) |
| | stft = torch.FloatTensor(stft.T) |
| |
|
| | log_mel_spec, stft = pad_spec(log_mel_spec, target_frame), pad_spec( |
| | stft, target_frame |
| | ) |
| | return log_mel_spec, stft |
| |
|
| |
|
| | def normalize_wav(waveform): |
| | waveform = waveform - np.mean(waveform) |
| | waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) |
| | return waveform * 0.5 |
| |
|
| | def read_wav_file(filename): |
| | waveform, sr = torchaudio.load(filename) |
| | duration = waveform.size(-1) / sr |
| |
|
| | if(duration > 10.24): |
| | print("\033[93m {}\033[00m" .format("Warning: audio is longer than 10.24 seconds, may degrade the model performance. It's recommand to truncate your audio to 5.12 seconds before input to AudioSR to get the best performance.")) |
| |
|
| | if(duration % 5.12 != 0): |
| | pad_duration = duration + (5.12 - duration % 5.12) |
| | else: |
| | pad_duration = duration |
| |
|
| | target_frame = int(pad_duration * 100) |
| |
|
| | waveform = torchaudio.functional.resample(waveform, sr, 48000) |
| |
|
| | waveform = waveform.numpy()[0, ...] |
| |
|
| | waveform = normalize_wav( |
| | waveform |
| | ) |
| |
|
| | waveform = waveform[None, ...] |
| | waveform = pad_wav(waveform, target_length=int(48000 * pad_duration)) |
| | return waveform, target_frame, pad_duration |
| |
|
| | def read_audio_file(filename): |
| | waveform, target_frame, duration = read_wav_file(filename) |
| | log_mel_spec, stft = wav_feature_extraction(waveform, target_frame) |
| | return log_mel_spec, stft, waveform, duration, target_frame |
| |
|
| |
|
| | def read_list(fname): |
| | result = [] |
| | with open(fname, "r", encoding="utf-8") as f: |
| | for each in f.readlines(): |
| | each = each.strip("\n") |
| | result.append(each) |
| | return result |
| |
|
| |
|
| | def get_duration(fname): |
| | with contextlib.closing(wave.open(fname, "r")) as f: |
| | frames = f.getnframes() |
| | rate = f.getframerate() |
| | return frames / float(rate) |
| |
|
| |
|
| | def get_bit_depth(fname): |
| | with contextlib.closing(wave.open(fname, "r")) as f: |
| | bit_depth = f.getsampwidth() * 8 |
| | return bit_depth |
| |
|
| |
|
| | def get_time(): |
| | t = time.localtime() |
| | return time.strftime("%d_%m_%Y_%H_%M_%S", t) |
| |
|
| |
|
| | def seed_everything(seed): |
| | import random, os |
| | import numpy as np |
| | import torch |
| |
|
| | random.seed(seed) |
| | os.environ["PYTHONHASHSEED"] = str(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = True |
| |
|
| |
|
| |
|
| | def strip_silence(orignal_path, input_path, output_path): |
| | get_dur = subprocess.run([ |
| | 'ffprobe', |
| | '-v', 'error', |
| | '-select_streams', 'a:0', |
| | '-show_entries', 'format=duration', |
| | '-sexagesimal', |
| | '-of', 'json', |
| | orignal_path |
| | ], stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| |
|
| | duration = json.loads(get_dur.stdout)['format']['duration'] |
| | |
| | subprocess.run([ |
| | 'ffmpeg', |
| | '-y', |
| | '-ss', '00:00:00', |
| | '-i', input_path, |
| | '-t', duration, |
| | '-c', 'copy', |
| | output_path |
| | ]) |
| | os.remove(input_path) |
| |
|
| |
|
| |
|
| | def save_wave(waveform, inputpath, savepath, name="outwav", samplerate=16000): |
| | if type(name) is not list: |
| | name = [name] * waveform.shape[0] |
| |
|
| | for i in range(waveform.shape[0]): |
| | if waveform.shape[0] > 1: |
| | fname = "%s_%s.wav" % ( |
| | os.path.basename(name[i]) |
| | if (not ".wav" in name[i]) |
| | else os.path.basename(name[i]).split(".")[0], |
| | i, |
| | ) |
| | else: |
| | fname = ( |
| | "%s.wav" % os.path.basename(name[i]) |
| | if (not ".wav" in name[i]) |
| | else os.path.basename(name[i]).split(".")[0] |
| | ) |
| | |
| | if len(fname) > 255: |
| | fname = f"{hex(hash(fname))}.wav" |
| |
|
| | save_path = os.path.join(savepath, fname) |
| | temp_path = os.path.join(tempfile.gettempdir(), fname) |
| | print("\033[98m {}\033[00m" .format("Don't forget to try different seeds by setting --seed <int> so that AudioSR can have optimal performance on your hardware.")) |
| | print("Save audio to %s." % save_path) |
| | sf.write(temp_path, waveform[i, 0], samplerate=samplerate) |
| | strip_silence(inputpath, temp_path, save_path) |
| |
|
| |
|
| | def exists(x): |
| | return x is not None |
| |
|
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if isfunction(d) else d |
| |
|
| |
|
| | def count_params(model, verbose=False): |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | if verbose: |
| | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") |
| | return total_params |
| |
|
| |
|
| | def get_obj_from_str(string, reload=False): |
| | module, cls = string.rsplit(".", 1) |
| | if reload: |
| | module_imp = importlib.import_module(module) |
| | importlib.reload(module_imp) |
| | return getattr(importlib.import_module(module, package=None), cls) |
| |
|
| |
|
| | def instantiate_from_config(config): |
| | if not "target" in config: |
| | if config == "__is_first_stage__": |
| | return None |
| | elif config == "__is_unconditional__": |
| | return None |
| | raise KeyError("Expected key `target` to instantiate.") |
| | try: |
| | return get_obj_from_str(config["target"])(**config.get("params", dict())) |
| | except: |
| | import ipdb |
| |
|
| | ipdb.set_trace() |
| |
|
| |
|
| | def default_audioldm_config(model_name="basic"): |
| | basic_config = get_basic_config() |
| | return basic_config |
| |
|
| |
|
| | class MyProgressBar: |
| | def __init__(self): |
| | self.pbar = None |
| |
|
| | def __call__(self, block_num, block_size, total_size): |
| | if not self.pbar: |
| | self.pbar = progressbar.ProgressBar(maxval=total_size) |
| | self.pbar.start() |
| |
|
| | downloaded = block_num * block_size |
| | if downloaded < total_size: |
| | self.pbar.update(downloaded) |
| | else: |
| | self.pbar.finish() |
| |
|
| |
|
| | def download_checkpoint(checkpoint_name="basic"): |
| | if checkpoint_name == "basic": |
| | model_id = "haoheliu/audiosr_basic" |
| |
|
| | checkpoint_path = hf_hub_download( |
| | repo_id=model_id, filename="pytorch_model.bin" |
| | ) |
| | elif checkpoint_name == "speech": |
| | model_id = "haoheliu/audiosr_speech" |
| |
|
| | checkpoint_path = hf_hub_download( |
| | repo_id=model_id, filename="pytorch_model.bin" |
| | ) |
| | else: |
| | raise ValueError("Invalid Model Name %s" % checkpoint_name) |
| | return checkpoint_path |
| |
|
| |
|
| | def get_basic_config(): |
| | return { |
| | "preprocessing": { |
| | "audio": { |
| | "sampling_rate": 48000, |
| | "max_wav_value": 32768, |
| | "duration": 10.24, |
| | }, |
| | "stft": {"filter_length": 2048, "hop_length": 480, "win_length": 2048}, |
| | "mel": {"n_mel_channels": 256, "mel_fmin": 20, "mel_fmax": 24000}, |
| | }, |
| | "augmentation": {"mixup": 0.5}, |
| | "model": { |
| | "target": "audiosr.latent_diffusion.models.ddpm.LatentDiffusion", |
| | "params": { |
| | "first_stage_config": { |
| | "base_learning_rate": 0.000008, |
| | "target": "audiosr.latent_encoder.autoencoder.AutoencoderKL", |
| | "params": { |
| | "reload_from_ckpt": "/mnt/bn/lqhaoheliu/project/audio_generation_diffusion/log/vae/vae_48k_256/ds_8_kl_1/checkpoints/ckpt-checkpoint-484999.ckpt", |
| | "sampling_rate": 48000, |
| | "batchsize": 4, |
| | "monitor": "val/rec_loss", |
| | "image_key": "fbank", |
| | "subband": 1, |
| | "embed_dim": 16, |
| | "time_shuffle": 1, |
| | "ddconfig": { |
| | "double_z": True, |
| | "mel_bins": 256, |
| | "z_channels": 16, |
| | "resolution": 256, |
| | "downsample_time": False, |
| | "in_channels": 1, |
| | "out_ch": 1, |
| | "ch": 128, |
| | "ch_mult": [1, 2, 4, 8], |
| | "num_res_blocks": 2, |
| | "attn_resolutions": [], |
| | "dropout": 0.1, |
| | }, |
| | }, |
| | }, |
| | "base_learning_rate": 0.0001, |
| | "warmup_steps": 5000, |
| | "optimize_ddpm_parameter": True, |
| | "sampling_rate": 48000, |
| | "batchsize": 16, |
| | "beta_schedule": "cosine", |
| | "linear_start": 0.0015, |
| | "linear_end": 0.0195, |
| | "num_timesteps_cond": 1, |
| | "log_every_t": 200, |
| | "timesteps": 1000, |
| | "unconditional_prob_cfg": 0.1, |
| | "parameterization": "v", |
| | "first_stage_key": "fbank", |
| | "latent_t_size": 128, |
| | "latent_f_size": 32, |
| | "channels": 16, |
| | "monitor": "val/loss_simple_ema", |
| | "scale_by_std": True, |
| | "unet_config": { |
| | "target": "audiosr.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel", |
| | "params": { |
| | "image_size": 64, |
| | "in_channels": 32, |
| | "out_channels": 16, |
| | "model_channels": 128, |
| | "attention_resolutions": [8, 4, 2], |
| | "num_res_blocks": 2, |
| | "channel_mult": [1, 2, 3, 5], |
| | "num_head_channels": 32, |
| | "extra_sa_layer": True, |
| | "use_spatial_transformer": True, |
| | "transformer_depth": 1, |
| | }, |
| | }, |
| | "evaluation_params": { |
| | "unconditional_guidance_scale": 3.5, |
| | "ddim_sampling_steps": 200, |
| | "n_candidates_per_samples": 1, |
| | }, |
| | "cond_stage_config": { |
| | "concat_lowpass_cond": { |
| | "cond_stage_key": "lowpass_mel", |
| | "conditioning_key": "concat", |
| | "target": "audiosr.latent_diffusion.modules.encoders.modules.VAEFeatureExtract", |
| | "params": { |
| | "first_stage_config": { |
| | "base_learning_rate": 0.000008, |
| | "target": "audiosr.latent_encoder.autoencoder.AutoencoderKL", |
| | "params": { |
| | "sampling_rate": 48000, |
| | "batchsize": 4, |
| | "monitor": "val/rec_loss", |
| | "image_key": "fbank", |
| | "subband": 1, |
| | "embed_dim": 16, |
| | "time_shuffle": 1, |
| | "ddconfig": { |
| | "double_z": True, |
| | "mel_bins": 256, |
| | "z_channels": 16, |
| | "resolution": 256, |
| | "downsample_time": False, |
| | "in_channels": 1, |
| | "out_ch": 1, |
| | "ch": 128, |
| | "ch_mult": [1, 2, 4, 8], |
| | "num_res_blocks": 2, |
| | "attn_resolutions": [], |
| | "dropout": 0.1, |
| | }, |
| | }, |
| | } |
| | }, |
| | } |
| | }, |
| | }, |
| | }, |
| | } |