diff --git a/app.py b/app.py index 968e840de7b8aa9fdc7215515d6e2ed7f069b6de..12f311ca41450bd880f76c7a475e672681d2ea05 100644 --- a/app.py +++ b/app.py @@ -16,14 +16,12 @@ from download import download_model # 下载模型 APP_DIR = op.dirname(op.abspath(__file__)) download_model(APP_DIR) -base_full_path = op.join(APP_DIR, "ckpt", "songgeneration-base-full") -os.makedirs(base_full_path, exist_ok=True) -download_model(base_full_path, repo_id="lglg666/SongGeneration-base-full", revision="19ebdb6") +download_model(op.join(APP_DIR, "ckpt"), repo_id="waytan22/SongGeneration-v1.5-beta", revision="db10f47") print("Successful downloaded model.") # 模型初始化 from levo_inference import LeVoInference -MODEL = LeVoInference(base_full_path) +MODEL = LeVoInference(op.join(APP_DIR, "ckpt", "SongGeneration-v1.5-beta")) EXAMPLE_LYRICS = """ [intro-medium] @@ -225,7 +223,7 @@ lyrics minimum=0.1, maximum=2.0, step=0.1, - value=0.75, + value=0.8, interactive=True, elem_id="temperature", ) @@ -268,12 +266,12 @@ lyrics # 生成按钮点击事件 generate_btn.click( fn=generate_song, - inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(-1)], + inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50)], outputs=[output_audio, output_json] ) generate_bgm_btn.click( fn=generate_song, - inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(-1), gr.State("bgm")], + inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, gr.State(50), gr.State("bgm")], outputs=[output_audio, output_json] ) diff --git a/codeclm/models/builders.py b/codeclm/models/builders.py index 1b4f3f23bdd469a1b0653837b085dd189c849cfe..8e34e9c4f01d82c9a4eb532c575a3a8b8571f43b 100755 --- a/codeclm/models/builders.py +++ b/codeclm/models/builders.py @@ -52,7 +52,7 @@ def get_audio_tokenizer_model_cpu(checkpoint_path: str, cfg: omegaconf.DictConfi return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode, tango_device='cpu') -def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel: +def get_lm_model(cfg: omegaconf.DictConfig, version: str = 'v1.0'): #-> LMModel: """Instantiate a LM.""" lm_kwargs = dict_from_config(getattr(cfg, 'lm')) @@ -61,8 +61,8 @@ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel: q_modeling = lm_kwargs.pop('q_modeling', None) # conditioner - condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg) - + condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg, version=version) + # codebook pattern: delay codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') if codebooks_pattern_cfg.modeling is None: @@ -97,7 +97,7 @@ def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel: raise KeyError(f"Unexpected LM model {lm_type}") -def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider: +def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig, version: str = 'v1.0') -> ConditionerProvider: """Instantiate a conditioning model.""" cfg = getattr(cfg, 'conditioners') dict_cfg = {} if cfg is None else dict_from_config(cfg) @@ -115,6 +115,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond elif model_type == "QwTextTokenizer": conditioners[str(cond)] = QwTextConditioner( output_dim=output_dim, + version=version, **model_args ) elif model_type == "qt_embedding": diff --git a/codeclm/modules/conditioners.py b/codeclm/modules/conditioners.py index 7a66fe8d768da38167a7dec5e0325ddcbaf077d0..fdcf08c92398545fbd6df6fb1229f50f30858975 100755 --- a/codeclm/modules/conditioners.py +++ b/codeclm/modules/conditioners.py @@ -188,10 +188,13 @@ class QwTokenizerConditioner(TextConditioner): class QwTextConditioner(TextConditioner): def __init__(self, output_dim: int, token_path = "", - max_len = 300): #"" + max_len = 300, + version: str = 'v1.0'): #"" from transformers import Qwen2Tokenizer - self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) + self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) + if version == 'v1.5': + self.text_tokenizer.add_tokens(['[Musicality-very-high]', '[Musicality-high]', '[Musicality-medium]', '[Musicality-low]', '[Musicality-very-low]'], special_tokens=True) voc_size = len(self.text_tokenizer.get_vocab()) # here initialize a output_proj (nn.Embedding) layer super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) @@ -636,7 +639,14 @@ class ClassifierFreeGuidanceDropoutInference(ClassifierFreeGuidanceDropout): sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) else: if customized is None: - sample.text[condition] = None + if condition in ['type_info'] and sample.text[condition] is not None: + if "[Musicality-very-high]" in sample.text[condition]: + sample.text[condition] = "[Musicality-very-low], ." + print(f"cfg unconditioning: change sample.text[condition] to [Musicality-very-low]") + else: + sample.text[condition] = None + else: + sample.text[condition] = None else: text_cond = deepcopy(sample.text[condition]) if "structure" in customized: diff --git a/codeclm/tokenizer/Flow1dVAE/cal_token_stat.py b/codeclm/tokenizer/Flow1dVAE/cal_token_stat.py deleted file mode 100644 index e4c9f2ad6bcff850bdec1a53c7299bad63a63ab2..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/cal_token_stat.py +++ /dev/null @@ -1,19 +0,0 @@ -import kaldiio -from tqdm import tqdm -import torch - -if __name__ == "__main__": - bar = torch.zeros(1, 16384) - with open('token.scp', 'r') as f: - for item_idx, line in tqdm(enumerate(f)): - idx, pos = line.strip().split() - codes = kaldiio.load_mat(pos) - for i0 in range(codes.shape[-1]): - bar[0, codes[0, 0, i0]] += 1 - if(item_idx % 1000 == 0): - print("=========") - print(1 - (bar[0]==0).sum() / bar.shape[-1]) - print("=========") - print("=========") - print(1 - (bar[0]==0).sum() / bar.shape[-1]) - print("=========") \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/compare_model_weight.py b/codeclm/tokenizer/Flow1dVAE/compare_model_weight.py deleted file mode 100644 index af9c77e6c3b0d5adc6f56bb19d0ecea966acecfd..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/compare_model_weight.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -import sys -from safetensors.torch import load_file - -if __name__ == "__main__": - m0, m1 = sys.argv[1], sys.argv[2] - m0 = load_file(m0) - m1 = load_file(m1) - - ks = [k for k in m0.keys() if 'bestrq' in k] - for k in ks: - print(k, (m0[k] - m1[k]).abs().sum()) - \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py deleted file mode 100644 index 6922a42d8016477448e92890f9eec84a8eb8d9d9..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch,torchaudio -import os,sys,json -from tqdm import tqdm -import numpy as np - -#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango -from generate_septoken import Tango as Tango_sep -from generate_2rvq import Tango as Tango_1x2 -import kaldiio -from kaldiio import WriteHelper -from audio import AudioFile - -from demucs.models.pretrained import get_model_from_yaml -from filelock import FileLock - -# os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml") -class Separator: - def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: - if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): - self.device = torch.device(f"cuda:{gpu_id}") - else: - self.device = torch.device("cpu") - self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) - - def init_demucs_model(self, model_path, config_path): - model = get_model_from_yaml(config_path, model_path) - model.to(self.device) - model.eval() - return model - - def load_audio(self, f): - a, fs = torchaudio.load(f) - if (fs != 48000): - a = torchaudio.functional.resample(a, fs, 48000) - # if a.shape[-1] >= 48000*10: - # a = a[..., :48000*10] - # else: - # a = torch.cat([a, a], -1) - # return a[:, 0:48000*10] - return a - - def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"): - name, _ = os.path.splitext(os.path.split(audio_path)[-1]) - output_paths = [] - # lock_path = os.path.join(output_dir, f"{name}.lock") - # with FileLock(lock_path): # 加一个避免多卡访问时死锁 - for stem in self.demucs_model.sources: - output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") - if os.path.exists(output_path): - output_paths.append(output_path) - if len(output_paths) == 1: # 4 - # drums_path, bass_path, other_path, vocal_path = output_paths - vocal_path = output_paths[0] - else: - lock_path = os.path.join(output_dir, f"{name}_separate.lock") - with FileLock(lock_path): - drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) - full_audio = self.load_audio(audio_path) - vocal_audio = self.load_audio(vocal_path) - minlen = min(full_audio.shape[-1], vocal_audio.shape[-1]) - # bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen] - bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path) - for path in [drums_path, bass_path, other_path, vocal_path]: - os.remove(path) - return full_audio, vocal_audio, bgm_audio - -def read_wav(fname, sample_rate=48_000): - try: - orig_samples, fs = torchaudio.load(fname) - except: - af = AudioFile(fname) - orig_samples = af.read() - fs = af.samplerate() - orig_samples = orig_samples[0] - if(fs!=sample_rate): - orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate) - fs = sample_rate - if orig_samples.shape[0] == 1: - orig_samples = torch.cat([orig_samples, orig_samples], 0) - return orig_samples - -if __name__ == "__main__": - # Define Model - json_path = sys.argv[1] - - mus_infos = [] - with open(json_path) as f: - for line in f: - item = json.loads(line) - mus_infos.append(item) - - tango_sep = Tango_sep(model_path="./saved/model_septoken/model_2.safetensors") - tango_1x2 = Tango_1x2(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2) - separator = Separator() - - # Feature extraction loop - # for i in tqdm(range(2000)): - first_time = True - for item in tqdm(mus_infos): - if(os.path.exists(item['path'])): - full_path = item['path'] - else: - full_path = '/mnt/share/' + item['path'] - - full_tensor, vocal_tensor, bgm_tensor = separator.run(full_path) - - # full_tensor = read_wav(full_path) - # vocal_tensor = read_wav(vocal_path) - # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1]) - # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length] - # bgm_tensor = full_tensor - vocal_tensor - codes_1x2 = tango_1x2.sound2code(full_tensor) - codes_vocal, codes_bgm = tango_sep.sound2code(vocal_tensor, bgm_tensor) - codes = torch.cat([codes_1x2[:,[0],:], codes_vocal, codes_bgm], 1).cpu().numpy() - save_path = full_path.replace('.wav', '.1x1_and_sep.npy').replace('.mp3', '.1x1_and_sep.npy').replace('.flac', '.1x1_and_sep.npy').replace('.ogg', '.1x1_and_sep.npy') - assert save_path != full_path, (save_path, full_path) - np.save(save_path, codes) - - if(first_time): - first_time = False - print(codes_vocal.shape, codes_bgm.shape) diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py deleted file mode 100644 index b46d6afef8297764b7f3ca9b3b652202b12dfc7c..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch,torchaudio -import os,sys,json -from tqdm import tqdm - -#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango -from generate_septoken import Tango -import kaldiio -from kaldiio import WriteHelper -from audio import AudioFile - -def read_wav(fname, sample_rate=48_000): - try: - orig_samples, fs = torchaudio.load(fname) - except: - af = AudioFile(fname) - orig_samples = af.read() - fs = af.samplerate() - orig_samples = orig_samples[0] - if(fs!=sample_rate): - orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate) - fs = sample_rate - if orig_samples.shape[0] == 1: - orig_samples = torch.cat([orig_samples, orig_samples], 0) - return orig_samples - -if __name__ == "__main__": - # Define Model - json_path = sys.argv[1] - outdir = sys.argv[2] - - mus_infos = [] - with open(json_path) as f: - for line in f: - item = json.loads(line) - mus_infos.append(item) - - tango = Tango(model_path="./saved/model_septoken/model_2.safetensors") - - - # Feature extraction loop - # for i in tqdm(range(2000)): - first_time = True - with WriteHelper('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir), write_function="pickle") as writer_vocal, WriteHelper('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir), write_function="pickle") as writer_bgm: - print('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir)) - print('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir)) - for item in tqdm(mus_infos): - try: - # if True: - idx = item['idx'] - # print(idx) - if(os.path.exists(item['path'])): - full_path = item['path'] - else: - full_path = '/mnt/share/' + item['path'] - if(os.path.exists(item['vocal_path'])): - vocal_path = item['vocal_path'] - bgm_paths = item['bgm_path'] - else: - vocal_path = '/mnt/share/' + item['vocal_path'] - bgm_paths = ['/mnt/share/' + p for p in item['bgm_path']] - vocal_tensor = read_wav(vocal_path) - # full_tensor = read_wav(full_path) - # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1]) - # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length] - # bgm_tensor = full_tensor - vocal_tensor - bgm_tensor = sum([read_wav(p) for p in bgm_paths]) - codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor) - writer_vocal(str(idx), codes_vocal.cpu()) - writer_bgm(str(idx), codes_bgm.cpu()) - if(first_time): - first_time = False - print(codes_vocal.shape, codes_bgm.shape) - except: - print(item['vocal_path']) - print(item['bgm_path']) - continue - - # idx = item['idx'] - # # print(idx) - # full_path = item['path'] - # vocal_path = item['vocal_path'] - # bgm_paths = item['bgm_path'] - # full_tensor = read_wav(full_path) - # vocal_tensor = read_wav(vocal_path) - # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1]) - # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length] - # bgm_tensor = full_tensor - vocal_tensor - # codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor) - # writer_vocal(str(idx), codes_vocal.cpu()) - # writer_bgm(str(idx), codes_bgm.cpu()) - # if(first_time): - # first_time = False - # print(codes_vocal.shape, codes_bgm.shape) - diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py deleted file mode 100644 index 4069277c87dd58a05ea9fdf964af9713cfab205c..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch,torchaudio -import os,sys,json -from tqdm import tqdm - -#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango -from generate_2rvq import Tango -import kaldiio -from kaldiio import WriteHelper -import torch -import subprocess -import time -import sys - -def get_gpu_memory(): - _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1] - - ACCEPTABLE_AVAILABLE_MEMORY = 1024 - COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv" - memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:] - memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] - return memory_free_values - -if __name__ == "__main__": - # Define Model - json_path = sys.argv[1] - outdir = sys.argv[2] - - gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES']) - while True: - free_mem = get_gpu_memory() - free_mem = free_mem[gpu_idx] - if(free_mem > 25_000): - print("GPU memory {}, run matrix cal".format(free_mem)) - break - else: - print("GPU memory {}, sleep 1min".format(free_mem)) - time.sleep(60) - - mus_infos = [] - with open(json_path) as f: - for line in f: - item = json.loads(line) - mus_infos.append(item) - - tango = Tango(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2) - - - # Feature extraction loop - # for i in tqdm(range(2000)): - with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer: - print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir)) - for item in tqdm(mus_infos): - try: - # if True: - idx = item['idx'] - # print(idx) - with torch.autocast(device_type="cuda", dtype=torch.float16): - if(os.path.exists(item['path'])): - codes = tango.file2code(item['path']) - else: - codes = tango.file2code('/mnt/share/' + item['path']) - writer(str(idx), codes.cpu()) - except: - print(item['path']) - continue - # idx = item['idx'] - # # print(idx) - # with torch.autocast(device_type="cuda", dtype=torch.float16): - # codes = tango.file2code(item['path']) - # writer(str(idx), codes.cpu()) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py deleted file mode 100644 index 5116d4bc1e3bf5ef7c553d348992e0dfc119f303..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch,torchaudio -import os,sys,json -from tqdm import tqdm - -#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango -from generate_4rvq import Tango -import kaldiio -from kaldiio import WriteHelper - -if __name__ == "__main__": - # Define Model - json_path = sys.argv[1] - outdir = sys.argv[2] - - mus_infos = [] - with open(json_path) as f: - for line in f: - item = json.loads(line) - mus_infos.append(item) - - tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4) - - - # Feature extraction loop - # for i in tqdm(range(2000)): - with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer: - print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir)) - for item in tqdm(mus_infos): - try: - # if True: - idx = item['idx'] - # print(idx) - with torch.autocast(device_type="cuda", dtype=torch.float16): - if(os.path.exists(item['path'])): - codes = tango.file2code(item['path']) - else: - codes = tango.file2code('/mnt/share/' + item['path']) - writer(str(idx), codes.cpu()) - except: - print(item['path']) - continue - # idx = item['idx'] - # # print(idx) - # with torch.autocast(device_type="cuda", dtype=torch.float16): - # codes = tango.file2code(item['path']) - # writer(str(idx), codes.cpu()) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py deleted file mode 100644 index 8df3f06af60cfd9aa020dd8ac0e50a0d89898b88..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch,torchaudio -import os,sys,json -from tqdm import tqdm - -#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango -from generate_4rvq import Tango -import kaldiio -from kaldiio import WriteHelper -import torch -import subprocess -import time -import sys - -def get_gpu_memory(): - _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1] - - ACCEPTABLE_AVAILABLE_MEMORY = 1024 - COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv" - memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:] - memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] - return memory_free_values - -if __name__ == "__main__": - # Define Model - json_path = sys.argv[1] - outdir = sys.argv[2] - ds = int(sys.argv[3]) - - gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES']) - while True: - free_mem = get_gpu_memory() - free_mem = free_mem[gpu_idx] - if(free_mem > 25_000): - print("GPU memory {}, run matrix cal".format(free_mem)) - break - else: - print("GPU memory {}, sleep 1min".format(free_mem)) - time.sleep(60) - - mus_infos = [] - with open(json_path) as f: - for line in f: - item = json.loads(line) - mus_infos.append(item) - - tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4) - - - # Feature extraction loop - # for i in tqdm(range(2000)): - with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer: - print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir)) - bar = torch.zeros(4, 16384) - for item_idx, item in tqdm(enumerate(mus_infos)): - try: - # if True: - idx = item['idx'] - # print(idx) - with torch.autocast(device_type="cuda", dtype=torch.float16): - if(os.path.exists(item['path'])): - codes = tango.file2code_ds(item['path'], ds) - else: - codes = tango.file2code_ds('/mnt/share/' + item['path'], ds) - codes = codes.cpu() - writer(str(idx), codes) - for i0 in range(codes.shape[-1]): - bar[0, codes[0, 0, i0]] += 1 - bar[1, codes[0, 1, i0]] += 1 - bar[2, codes[0, 2, i0]] += 1 - bar[3, codes[0, 3, i0]] += 1 - except Exception as e: - print(item['path']) - # print(e.message, e.args) - # exit(1) - continue - - if(item_idx % 1000 == 0): - print("=========") - print(1 - (bar[0]==0).sum() / bar.shape[-1]) - print("=========") - - # idx = item['idx'] - # # print(idx) - # with torch.autocast(device_type="cuda", dtype=torch.float16): - # codes = tango.file2code(item['path']) - # writer(str(idx), codes.cpu()) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py b/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py index b491cdc66c7905a7a88bd71eb8ec374084318376..b866899f951dd61a3edad90dda82a86078e63400 100644 --- a/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py +++ b/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py @@ -8,7 +8,6 @@ import librosa import os import math import numpy as np -from tools.get_1dvae_large import get_model import tools.torch_tools as torch_tools from safetensors.torch import load_file @@ -24,9 +23,9 @@ class Tango: scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" self.device = device - self.vae = get_model(vae_config, vae_model) - self.vae = self.vae.to(device) - self.vae=self.vae.eval() + # self.vae = get_model(vae_config, vae_model) + # self.vae = self.vae.to(device) + # self.vae=self.vae.eval() self.layer_num = layer_num self.MAX_DURATION = 360 @@ -254,37 +253,9 @@ class Tango: # print(fname, wave.shape) return wave - @torch.no_grad() - def sound2sound_vae(self, sound, prompt=None, steps=50, disable_progress=False): - min_samples = int(40 * 25) # 40ms per frame - hop_samples = min_samples // 4 * 3 - ovlp_samples = min_samples - hop_samples - dur = 20 - - latent_list = [] - for i in range(0, sound.shape[-1], dur*48000): - if(i+dur*2*48000 > sound.shape[-1]): - latent = tango.vae.encode_audio(sound.cuda()[None,:,i:]) - break - else: - latent = tango.vae.encode_audio(sound.cuda()[None,:,i:i+dur*48000]) - latent_list.append(latent) - - output = None - for i in range(len(latent_list)): - print(i) - latent = latent_list[i] - cur_output = self.vae.decode_audio(latent)[0].detach().cpu() - if output is None: - output = cur_output - else: - output = torch.cat([output, cur_output], -1) - return output - def to(self, device=None, dtype=None, non_blocking=False): if device is not None: self.device = device self.model.device = device - self.vae = self.vae.to(device, dtype, non_blocking) self.model = self.model.to(device, dtype, non_blocking) return self diff --git a/codeclm/tokenizer/Flow1dVAE/generate_2rvq.py b/codeclm/tokenizer/Flow1dVAE/generate_2rvq.py deleted file mode 100644 index c02fd1ddc0f76db726045acbd6c26f9739fbc857..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/generate_2rvq.py +++ /dev/null @@ -1,293 +0,0 @@ -import json -import torch -from tqdm import tqdm -from model_2rvq import PromptCondAudioDiffusion -from diffusers import DDIMScheduler, DDPMScheduler -import torchaudio -import librosa -import os -import math -import numpy as np -# from tools.get_mulan import get_mulan -from tools.get_1dvae_large import get_model -import tools.torch_tools as torch_tools -from safetensors.torch import load_file -from audio import AudioFile -import kaldiio - -class Tango: - def __init__(self, \ - model_path, \ - layer_num=6, \ - rvq_num=1, \ - device="cuda:0"): - - self.sample_rate = 48000 - scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" - self.device = device - - self.vae = get_model() - self.vae = self.vae.to(device) - self.vae=self.vae.eval() - self.layer_num = layer_num - - self.MAX_DURATION = 360 - main_config = { - "num_channels":32, - "unet_model_name":None, - "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json", - "snr_gamma":None, - } - self.rvq_num = rvq_num - # print("rvq_num: ", self.rvq_num) - # exit() - self.model = PromptCondAudioDiffusion(**main_config).to(device) - if model_path.endswith(".safetensors"): - main_weights = load_file(model_path) - else: - main_weights = torch.load(model_path, map_location=device) - self.model.load_state_dict(main_weights, strict=False) - print ("Successfully loaded checkpoint from:", model_path) - - self.model.eval() - self.model.init_device_dtype(torch.device(device), torch.float32) - - # self.scheduler = DDIMScheduler.from_pretrained( \ - # scheduler_name, subfolder="scheduler") - # self.scheduler = DDPMScheduler.from_pretrained( \ - # scheduler_name, subfolder="scheduler") - print("Successfully loaded inference scheduler from {}".format(scheduler_name)) - - - - @torch.no_grad() - @torch.autocast(device_type="cuda", dtype=torch.float32) - def sound2code(self, orig_samples, batch_size=8): - if(orig_samples.ndim == 2): - audios = orig_samples.unsqueeze(0).to(self.device) - elif(orig_samples.ndim == 3): - audios = orig_samples.to(self.device) - else: - assert orig_samples.ndim in (2,3), orig_samples.shape - audios = self.preprocess_audio(audios) - audios = audios.squeeze(0) - orig_length = audios.shape[-1] - min_samples = int(40 * self.sample_rate) - # 40秒对应10个token - output_len = int(orig_length / float(self.sample_rate) * 25) + 1 - # print("output_len: ", output_len) - - while(audios.shape[-1] < min_samples): - audios = torch.cat([audios, audios], -1) - int_max_len=audios.shape[-1]//min_samples+1 - audios = torch.cat([audios, audios], -1) - audios=audios[:,:int(int_max_len*(min_samples))] - codes_list=[] - - audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) - - for audio_inx in range(0, audio_input.shape[0], batch_size): - # import pdb; pdb.set_trace() - codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num) - # print("codes",codes[0].shape) - - codes_list.append(torch.cat(codes, 1)) - # print("codes_list",codes_list[0].shape) - - codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T - codes=codes[:,:,:output_len] - - return codes - - @torch.no_grad() - @torch.autocast(device_type="cuda", dtype=torch.float32) - def sound2code_ds(self, orig_samples, ds, batch_size=8): - if(orig_samples.ndim == 2): - audios = orig_samples.unsqueeze(0).to(self.device) - elif(orig_samples.ndim == 3): - audios = orig_samples.to(self.device) - else: - assert orig_samples.ndim in (2,3), orig_samples.shape - audios = self.preprocess_audio(audios) - audios = audios.squeeze(0) - orig_length = audios.shape[-1] - min_samples = int(40 * self.sample_rate) - # 40秒对应10个token - output_len = int(orig_length / float(self.sample_rate) * 25) + 1 - # print("output_len: ", output_len) - - while(audios.shape[-1] < min_samples): - audios = torch.cat([audios, audios], -1) - int_max_len=audios.shape[-1]//min_samples+1 - audios = torch.cat([audios, audios], -1) - audios=audios[:,:int(int_max_len*(min_samples))] - codes_list=[] - - audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) - - for audio_inx in range(0, audio_input.shape[0], batch_size): - # import pdb; pdb.set_trace() - codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds) - # print("codes",codes[0].shape) - - codes_list.append(torch.cat(codes, 1)) - # print("codes_list",codes_list[0].shape) - - codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T - codes=codes[:,:,:output_len] - - return codes - - @torch.no_grad() - def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): - codes = codes.to(self.device) - - min_samples = duration * 25 # 40ms per frame - hop_samples = min_samples // 4 * 3 - ovlp_samples = min_samples - hop_samples - hop_frames = hop_samples - ovlp_frames = ovlp_samples - first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device) - first_latent_length = 0 - first_latent_codes_length = 0 - - if(isinstance(prompt, torch.Tensor)): - # prepare prompt - prompt = prompt.to(self.device) - if(prompt.ndim == 3): - assert prompt.shape[0] == 1, prompt.shape - prompt = prompt[0] - elif(prompt.ndim == 1): - prompt = prompt.unsqueeze(0).repeat(2,1) - elif(prompt.ndim == 2): - if(prompt.shape[0] == 1): - prompt = prompt.repeat(2,1) - - if(prompt.shape[-1] < int(30 * self.sample_rate)): - # if less than 30s, just choose the first 10s - prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24 - else: - # else choose from 20.48s which might includes verse or chorus - prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 - - true_latent = self.vae.encode_audio(prompt).permute(0,2,1) - # print("true_latent.shape", true_latent.shape) - # print("first_latent.shape", first_latent.shape) - #true_latent.shape torch.Size([1, 250, 64]) - # first_latent.shape torch.Size([1, 1000, 64]) - - first_latent[:,0:true_latent.shape[1],:] = true_latent - first_latent_length = true_latent.shape[1] - first_latent_codes = self.sound2code(prompt) - first_latent_codes_length = first_latent_codes.shape[-1] - codes = torch.cat([first_latent_codes, codes], -1) - - codes_len= codes.shape[-1] - target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate) - # target_len = int(codes_len / 100 * 4 * self.sample_rate) - # code repeat - if(codes_len < min_samples): - while(codes.shape[-1] < min_samples): - codes = torch.cat([codes, codes], -1) - codes = codes[:,:,0:min_samples] - codes_len = codes.shape[-1] - if((codes_len - ovlp_samples) % hop_samples > 0): - len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples - while(codes.shape[-1] < len_codes): - codes = torch.cat([codes, codes], -1) - codes = codes[:,:,0:len_codes] - latent_length = min_samples - latent_list = [] - spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device) - with torch.autocast(device_type="cuda", dtype=torch.float16): - for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples): - codes_input=[] - codes_input.append(codes[:,:,sinx:sinx+min_samples]) - if(sinx == 0): - # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) - incontext_length = first_latent_length - latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') - latent_list.append(latents) - else: - # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) - true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1) - print("true_latent.shape", true_latent.shape) - len_add_to_1000 = 1000 - true_latent.shape[-2] - # print("len_add_to_1000", len_add_to_1000) - # exit() - incontext_length = true_latent.shape[-2] - true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2) - latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') - latent_list.append(latents) - - latent_list = [l.float() for l in latent_list] - latent_list[0] = latent_list[0][:,:,first_latent_length:] - min_samples = int(min_samples * self.sample_rate // 1000 * 40) - hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) - ovlp_samples = min_samples - hop_samples - with torch.no_grad(): - output = None - for i in range(len(latent_list)): - latent = latent_list[i] - cur_output = self.vae.decode_audio(latent)[0].detach().cpu() - - if output is None: - output = cur_output - else: - ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :]) - ov_win = torch.cat([ov_win, 1 - ov_win], -1) - print("output.shape", output.shape) - print("ov_win.shape", ov_win.shape) - output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples] - output = torch.cat([output, cur_output[:, ovlp_samples:]], -1) - output = output[:, 0:target_len] - return output - - @torch.no_grad() - def preprocess_audio(self, input_audios, threshold=0.8): - assert len(input_audios.shape) == 3, input_audios.shape - nchan = input_audios.shape[1] - input_audios = input_audios.reshape(input_audios.shape[0], -1) - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1) - - @torch.no_grad() - def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False): - codes = self.sound2code(sound) - # print(codes.shape) - # exit() - wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress) - # print(fname, wave.shape) - return wave - - def file2code(self, fname): - try: - orig_samples, fs = torchaudio.load(fname) - except: - af = AudioFile(fname) - orig_samples = af.read() - fs = af.samplerate() - orig_samples = orig_samples[0] - if(fs!=self.sample_rate): - orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) - fs = self.sample_rate - if orig_samples.shape[0] == 1: - orig_samples = torch.cat([orig_samples, orig_samples], 0) - return self.sound2code(orig_samples) - - def file2code_ds(self, fname, ds): - try: - orig_samples, fs = torchaudio.load(fname) - except: - af = AudioFile(fname) - orig_samples = af.read() - fs = af.samplerate() - orig_samples = orig_samples[0] - if(fs!=self.sample_rate): - orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) - fs = self.sample_rate - if orig_samples.shape[0] == 1: - orig_samples = torch.cat([orig_samples, orig_samples], 0) - return self.sound2code_ds(orig_samples, ds) diff --git a/codeclm/tokenizer/Flow1dVAE/generate_4rvq.py b/codeclm/tokenizer/Flow1dVAE/generate_4rvq.py deleted file mode 100644 index d2502d5bb5cadce37c9aee63e399a1dd110f6689..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/generate_4rvq.py +++ /dev/null @@ -1,292 +0,0 @@ -import json -import torch -from tqdm import tqdm -from model_4rvq import PromptCondAudioDiffusion -from diffusers import DDIMScheduler, DDPMScheduler -import torchaudio -import librosa -import os -import math -import numpy as np -# from tools.get_mulan import get_mulan -from tools.get_1dvae_large import get_model -import tools.torch_tools as torch_tools -from safetensors.torch import load_file -from audio import AudioFile - -class Tango: - def __init__(self, \ - model_path, \ - layer_num=6, \ - rvq_num=1, \ - device="cuda:0"): - - self.sample_rate = 48000 - scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" - self.device = device - - self.vae = get_model() - self.vae = self.vae.to(device) - self.vae=self.vae.eval() - self.layer_num = layer_num - - self.MAX_DURATION = 360 - main_config = { - "num_channels":32, - "unet_model_name":None, - "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json", - "snr_gamma":None, - } - self.rvq_num = rvq_num - # print("rvq_num: ", self.rvq_num) - # exit() - self.model = PromptCondAudioDiffusion(**main_config).to(device) - if model_path.endswith(".safetensors"): - main_weights = load_file(model_path) - else: - main_weights = torch.load(model_path, map_location=device) - self.model.load_state_dict(main_weights, strict=False) - print ("Successfully loaded checkpoint from:", model_path) - - self.model.eval() - self.model.init_device_dtype(torch.device(device), torch.float32) - - # self.scheduler = DDIMScheduler.from_pretrained( \ - # scheduler_name, subfolder="scheduler") - # self.scheduler = DDPMScheduler.from_pretrained( \ - # scheduler_name, subfolder="scheduler") - print("Successfully loaded inference scheduler from {}".format(scheduler_name)) - - - - @torch.no_grad() - @torch.autocast(device_type="cuda", dtype=torch.float32) - def sound2code(self, orig_samples, batch_size=8): - if(orig_samples.ndim == 2): - audios = orig_samples.unsqueeze(0).to(self.device) - elif(orig_samples.ndim == 3): - audios = orig_samples.to(self.device) - else: - assert orig_samples.ndim in (2,3), orig_samples.shape - audios = self.preprocess_audio(audios) - audios = audios.squeeze(0) - orig_length = audios.shape[-1] - min_samples = int(40 * self.sample_rate) - # 40秒对应10个token - output_len = int(orig_length / float(self.sample_rate) * 25) + 1 - # print("output_len: ", output_len) - - while(audios.shape[-1] < min_samples): - audios = torch.cat([audios, audios], -1) - int_max_len=audios.shape[-1]//min_samples+1 - audios = torch.cat([audios, audios], -1) - audios=audios[:,:int(int_max_len*(min_samples))] - codes_list=[] - - audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) - - for audio_inx in range(0, audio_input.shape[0], batch_size): - # import pdb; pdb.set_trace() - codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num) - # print("codes",codes[0].shape) - - codes_list.append(torch.cat(codes, 1)) - # print("codes_list",codes_list[0].shape) - - codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T - codes=codes[:,:,:output_len] - - return codes - - @torch.no_grad() - @torch.autocast(device_type="cuda", dtype=torch.float32) - def sound2code_ds(self, orig_samples, ds, batch_size=6): - if(orig_samples.ndim == 2): - audios = orig_samples.unsqueeze(0).to(self.device) - elif(orig_samples.ndim == 3): - audios = orig_samples.to(self.device) - else: - assert orig_samples.ndim in (2,3), orig_samples.shape - audios = self.preprocess_audio(audios) - audios = audios.squeeze(0) - orig_length = audios.shape[-1] - min_samples = int(40 * self.sample_rate) - # 40秒对应10个token - output_len = int(orig_length / float(self.sample_rate) * 25) + 1 - # print("output_len: ", output_len) - - while(audios.shape[-1] < min_samples): - audios = torch.cat([audios, audios], -1) - int_max_len=audios.shape[-1]//min_samples+1 - audios = torch.cat([audios, audios], -1) - audios=audios[:,:int(int_max_len*(min_samples))] - codes_list=[] - - audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) - - for audio_inx in range(0, audio_input.shape[0], batch_size): - # import pdb; pdb.set_trace() - codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds) - # print("codes",codes[0].shape) - - codes_list.append(torch.cat(codes, 1)) - # print("codes_list",codes_list[0].shape) - - codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T - codes=codes[:,:,:output_len] - - return codes - - @torch.no_grad() - def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): - codes = codes.to(self.device) - - min_samples = duration * 25 # 40ms per frame - hop_samples = min_samples // 4 * 3 - ovlp_samples = min_samples - hop_samples - hop_frames = hop_samples - ovlp_frames = ovlp_samples - first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device) - first_latent_length = 0 - first_latent_codes_length = 0 - - if(isinstance(prompt, torch.Tensor)): - # prepare prompt - prompt = prompt.to(self.device) - if(prompt.ndim == 3): - assert prompt.shape[0] == 1, prompt.shape - prompt = prompt[0] - elif(prompt.ndim == 1): - prompt = prompt.unsqueeze(0).repeat(2,1) - elif(prompt.ndim == 2): - if(prompt.shape[0] == 1): - prompt = prompt.repeat(2,1) - - if(prompt.shape[-1] < int(30 * self.sample_rate)): - # if less than 30s, just choose the first 10s - prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24 - else: - # else choose from 20.48s which might includes verse or chorus - prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 - - true_latent = self.vae.encode_audio(prompt).permute(0,2,1) - # print("true_latent.shape", true_latent.shape) - # print("first_latent.shape", first_latent.shape) - #true_latent.shape torch.Size([1, 250, 64]) - # first_latent.shape torch.Size([1, 1000, 64]) - - first_latent[:,0:true_latent.shape[1],:] = true_latent - first_latent_length = true_latent.shape[1] - first_latent_codes = self.sound2code(prompt) - first_latent_codes_length = first_latent_codes.shape[-1] - codes = torch.cat([first_latent_codes, codes], -1) - - codes_len= codes.shape[-1] - target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate) - # target_len = int(codes_len / 100 * 4 * self.sample_rate) - # code repeat - if(codes_len < min_samples): - while(codes.shape[-1] < min_samples): - codes = torch.cat([codes, codes], -1) - codes = codes[:,:,0:min_samples] - codes_len = codes.shape[-1] - if((codes_len - ovlp_samples) % hop_samples > 0): - len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples - while(codes.shape[-1] < len_codes): - codes = torch.cat([codes, codes], -1) - codes = codes[:,:,0:len_codes] - latent_length = min_samples - latent_list = [] - spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device) - with torch.autocast(device_type="cuda", dtype=torch.float16): - for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples): - codes_input=[] - codes_input.append(codes[:,:,sinx:sinx+min_samples]) - if(sinx == 0): - # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) - incontext_length = first_latent_length - latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') - latent_list.append(latents) - else: - # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) - true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1) - print("true_latent.shape", true_latent.shape) - len_add_to_1000 = 1000 - true_latent.shape[-2] - # print("len_add_to_1000", len_add_to_1000) - # exit() - incontext_length = true_latent.shape[-2] - true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2) - latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') - latent_list.append(latents) - - latent_list = [l.float() for l in latent_list] - latent_list[0] = latent_list[0][:,:,first_latent_length:] - min_samples = int(min_samples * self.sample_rate // 1000 * 40) - hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) - ovlp_samples = min_samples - hop_samples - with torch.no_grad(): - output = None - for i in range(len(latent_list)): - latent = latent_list[i] - cur_output = self.vae.decode_audio(latent)[0].detach().cpu() - - if output is None: - output = cur_output - else: - ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :]) - ov_win = torch.cat([ov_win, 1 - ov_win], -1) - print("output.shape", output.shape) - print("ov_win.shape", ov_win.shape) - output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples] - output = torch.cat([output, cur_output[:, ovlp_samples:]], -1) - output = output[:, 0:target_len] - return output - - @torch.no_grad() - def preprocess_audio(self, input_audios, threshold=0.8): - assert len(input_audios.shape) == 3, input_audios.shape - nchan = input_audios.shape[1] - input_audios = input_audios.reshape(input_audios.shape[0], -1) - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1) - - @torch.no_grad() - def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False): - codes = self.sound2code(sound) - # print(codes.shape) - # exit() - wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress) - # print(fname, wave.shape) - return wave - - def file2code(self, fname): - try: - orig_samples, fs = torchaudio.load(fname) - except: - af = AudioFile(fname) - orig_samples = af.read() - fs = af.samplerate() - orig_samples = orig_samples[0] - if(fs!=self.sample_rate): - orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) - fs = self.sample_rate - if orig_samples.shape[0] == 1: - orig_samples = torch.cat([orig_samples, orig_samples], 0) - return self.sound2code(orig_samples) - - def file2code_ds(self, fname, ds): - try: - orig_samples, fs = torchaudio.load(fname) - except: - af = AudioFile(fname) - orig_samples = af.read() - fs = af.samplerate() - orig_samples = orig_samples[0] - if(fs!=self.sample_rate): - orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) - fs = self.sample_rate - if orig_samples.shape[0] == 1: - orig_samples = torch.cat([orig_samples, orig_samples], 0) - return self.sound2code_ds(orig_samples, ds) diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py deleted file mode 100644 index f0af158226ee61ac9b3f268cc402779d9ae1ea00..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py +++ /dev/null @@ -1,1278 +0,0 @@ -from torch.utils.data import Dataset -from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union -from beartype import beartype -from beartype.door import is_bearable -import random -import pandas as pd -import os -from torchaudio.functional import resample -import torch -import typing as tp -from pathlib import Path -import torchaudio as ta -import torch.nn.functional as F -import numpy as np -import json -import yaml -import torchaudio -import math -import re -from loguru import logger -import ffmpeg - -class Read_and_PadCrop_Normalized_T(torch.nn.Module): - def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): - - super().__init__() - - self.n_samples = n_samples - self.sample_rate = sample_rate - self.randomize = randomize - - def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: - if self.n_samples < 0: #means not clip - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = 1.0 - offset = 0 - else: - if(duration<(float(self.n_samples)/self.sample_rate+1)): - # print(duration,(float(self.n_samples)/self.sample_rate+1)) - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) - offset = 0 - # print('c1:',chunk.shape) - else: - offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - t_start = offset / float(cur_sample_rate) / duration - t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration - chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - # print('offset:',offset) - # print('c0:',chunk.shape) - # Pad with silence if necessary. - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - if(cur_sample_rate!=self.sample_rate): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) - # print('b:',self.sample_rate,chunk.shape) - - if self.n_samples > 0: - if chunk.shape[-1] < self.n_samples: - chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) - else: - chunk = chunk[:,0:self.n_samples] - seconds_start = math.floor(offset / cur_sample_rate) - seconds_total = math.floor(duration) - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total - ) - -class Read_and_PadCrop_Normalized_T_Avoid_Watermark(torch.nn.Module): - def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, w_start = 0, w_interval = 11.3): - - super().__init__() - - self.n_samples = n_samples - self.sample_rate = sample_rate - self.randomize = randomize - - self.w_start = w_start - self.w_interval = w_interval - - def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: - if self.n_samples < 0: #means not clip - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = 1.0 - offset = 0 - else: - if(duration<(float(self.n_samples)/self.sample_rate+1)): - # print(duration,(float(self.n_samples)/self.sample_rate+1)) - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) - offset = 0 - # print('c1:',chunk.shape) - else: - n_offset_option = (duration - self.w_start) // self.w_interval - if n_offset_option <= 1: - offset = 0 - else: - offset = int((random.randint(0,n_offset_option-1) * self.w_interval + self.w_start) * cur_sample_rate) - # offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - t_start = offset / float(cur_sample_rate) / duration - t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration - chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - # print('offset:',offset) - # print('c0:',chunk.shape) - # Pad with silence if necessary. - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - if(cur_sample_rate!=self.sample_rate): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) - # print('b:',self.sample_rate,chunk.shape) - - if self.n_samples > 0: - if chunk.shape[-1] < self.n_samples: - chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) - else: - chunk = chunk[:,0:self.n_samples] - seconds_start = math.floor(offset / cur_sample_rate) - seconds_total = math.floor(duration) - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total - ) - -USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替 -if USE_DUMMY_AUDIO: - logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!") - -class SafeAudioReader: - """ - This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data. - """ - def __init__(self, - duration: float, # 返回音频长度 - sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample - randomize: bool = True, - use_avoid_watermark_policy = False, - ): - self.n_samples = int(sample_rate * duration) - self.reader = ( - Read_and_PadCrop_Normalized_T_Avoid_Watermark if use_avoid_watermark_policy \ - else Read_and_PadCrop_Normalized_T - )(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize) - - #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数! - def __call__(self, - filepath: os.PathLike, # 音频路径 - origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取 - origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取 - ) -> torch.Tensor: - if USE_DUMMY_AUDIO: - wav = torch.zeros(self.n_samples, dtype=torch.float32) - return wav - try: - if origin_sample_rate is None or origin_duration is None: - # audio_info = torchaudio.info(filepath) - # origin_sample_rate = audio_info.sample_rate - # origin_duration = audio_info.num_frames / origin_sample_rate - info = ffmpeg.probe(filepath) - origin_duration = float(info['format']['duration']) - origin_sample_rate = int(info['streams'][0]['sample_rate']) - wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate) - wav = wav.squeeze_(0) - except Exception as e: - logger.error(f"Error reading {filepath}: {e}") - wav = torch.zeros(self.n_samples, dtype=torch.float32) - return wav - - -class PromptTemplate: - def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): - self.template_text = template_text - self.tag_map = tag_map - self.lang = lang - - @property - def tags(self): - return tuple(self.tag_map.keys()) - - def apply(self, **kwargs): - for tag in list(kwargs.keys()): - if kwargs[tag] == '': - kwargs.pop(tag) - for tag in self.tags: - if tag in kwargs: - kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') - else: - kwargs[tag] = '' - prompt = self.template_text.format(**kwargs) - - return self.beautify(prompt) - - def beautify(self, text): - if self.lang == 'en': - return self._beautify_en(text) - elif self.lang == 'zh': - return self._beautify_zh(text) - else: - raise ValueError(f'Unknown language {self.lang}') - - @staticmethod - def _beautify_en(text): - # no continuous commas without content between them - text = re.sub(r'[,\s]*,[,\s]*', r', ', text) - # no continuous whitespace - text = re.sub(r'\s+', ' ', text) - # the comma is NOT followed by whitespace, and should be followed by ONE whitespace - text = re.sub(r'\s+,', r',', text) - text = re.sub(r',\s+', r', ', text) - # no whitespace before the full stop - text = re.sub(r'\s+\.', r'.', text) - # strip whitespace, comma, and replace ',.' - text = text.strip(' ,') - text = text.replace(',.', '.') - return text - - @staticmethod - def _beautify_zh(text): - # no continuous commas without content between them - text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) - text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) - # assume there should be NO whitespace in Chinese - text = re.sub(r'\s+', r'', text) - # strip whitespace, comma, and replace ',。' - text = text.strip(', 、') - text = text.replace(',。', '。') - return text - - def __repr__(self): - return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' - - __str__ = __repr__ - -def parse_prompt_template(prompt_template_text, lang='en'): - span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) - tag_pattern = re.compile(r'{.+?}', re.DOTALL) - - template_text = prompt_template_text.strip() - span_texts = span_pattern.findall(prompt_template_text) - tag_map = {} - for span_text in span_texts: - tag = tag_pattern.findall(span_text)[0].strip('{}') - tag_map[tag] = span_text - template_text = template_text.replace(span_text, '{'+tag+'}') - - return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) - -def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: - with open(path, 'r') as f: - lines = f.readlines() - cnt = 0 - pts = [] - for line in lines: - pt = parse_prompt_template(line, lang=lang) - cnt += 1 - if len(pt.tags) < num: - logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') - pts.append(pt) - - return pts - - -def get_base_dir_file(key: os.PathLike): - base = os.path.basename(key) - dirname = os.path.basename(os.path.dirname(key)) - return os.path.join(dirname, base) - -def read_jsonlike(path: os.PathLike): - #json or jsonl - if str(path).endswith(".json"): - with open(path, 'r', encoding='utf8') as f: - data = json.load(f) - return data - elif str(path).endswith(".jsonl"): - with open(path, 'r', encoding='utf8') as f: - data = [json.loads(line) for line in f.readlines()] - return data - else: - raise ValueError("Unknown file format") - -dist_prob_map = { - 1: (1.0,), - 2: (0.5, 0.5), - 3: (0.3, 0.4, 0.3), - 4: (0.2, 0.3, 0.3, 0.2), - 5: (0.2, 0.2, 0.3, 0.2, 0.1), - 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), - 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), - 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), - 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), - 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) -} - -''' -#更加偏向短文本的方案 -dist_prob_map = { - 1: (1.0,), - 2: (0.7, 0.3), - 3: (0.7, 0.2, 0.1), - 4: (0.6, 0.2, 0.1, 0.1), - 5: (0.6, 0.2, 0.1, 0.05, 0.05), - 6: (0.6, 0.15, 0.1, 0.05, 0.05, 0.05), - 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), - 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), - 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), - 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) -} -''' - -#全部都用的方案 -# dist_prob_map = { -# 1: (1.0,), -# 2: (0, 1.0), -# 3: (0, 0, 1.0), -# 4: (0, 0, 0, 1.0), -# 5: (0, 0, 0, 0, 1.0), -# 6: (0, 0, 0, 0, 0, 1.0), -# 7: (0, 0, 0, 0, 0, 0, 1.0), -# 8: (0, 0, 0, 0, 0, 0, 0, 1.0), -# 9: (0, 0, 0, 0, 0, 0, 0, 0, 1.0), -# 10: (0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0) -# } - -dist_prob_map_low = { - 1: (1.0,), - 2: (0.8, 0.2), - 3: (0.8, 0.1, 0.1), - 4: (0.7, 0.1, 0.1, 0.1), - 5: (0.7, 0.1, 0.1, 0.05, 0.05), - 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), -} - -_bpm_range_rights = ( - (40, '20-40'), - (60, '40-60'), - (66, '60-66'), - (76, '66-76'), - (108, '76-108'), - (120, '108-120'), - (168, '120-168'), - (176, '168-176'), - (200, '176-200') -) -_bpm_desc_map = { - '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), - '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), - '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), - '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), - '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), - '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), - '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), - '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), - '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), - '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') -} -_bpm_desc_map_zh = { - '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), - '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), - '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), - '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), - '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), - '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), - '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), - '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), - '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), - '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') -} -def get_bpm_range(bpm): - bpm = int(bpm) - for right, tag in _bpm_range_rights: - if bpm <= right: - return tag - return '>200' - -def gen_bpm_descript(bpm, lang='en'): - bpm_range = get_bpm_range(bpm) - if lang == 'en': - return random.choice(_bpm_desc_map[bpm_range]) - elif lang == 'zh': - return random.choice(_bpm_desc_map_zh[bpm_range]) - else: - raise ValueError(f"Unknown language {lang}") - -def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]): - if translate is None: - return None - if isinstance(translate, str): - return read_jsonlike(translate) - return {k: read_jsonlike(path) for k, path in translate.items()} - - -def gen_plain_prompt(key_list, sep=', '): - if len(key_list) == 0: - return 'none' - - key_list = [k.strip() for k in key_list] - - if len(key_list) > 10: - random.shuffle(key_list) - key_list = key_list[:10] - - probs = dist_prob_map[len(key_list)] - - num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0] - - random.shuffle(key_list) - tags = key_list[:num_tags] - tags_str = sep.join(tags) - return tags_str - - -class MagnaTagATuneDataset(Dataset): - def __init__(self): - pass - - -def tags_to_desc(tag_list, sep=',') -> str: - if not isinstance(tag_list, Sequence): - return str(tag_list) - if isinstance(tag_list, str): - return tag_list - if len(tag_list) <= 0: - return '' - elif len(tag_list) <= 5: - probs = dist_prob_map[len(tag_list)] - tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - else: - probs = dist_prob_map[5] - tags_num = random.choices(range(1, 6), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - -def get_sr_and_duration_info(item): - return item.get('sample_rate', None), item.get('duration', None) - -class MtgJamendoDatasetFromJson(Dataset): - def __init__(self, - data_dir:str, - json_path:str, - duration:float=10, - sr:int = 0, - lang = 'en', - plain_rate = 0, - return_audio = True, - return_path = False, - prompt_template_path: os.PathLike = None, - tag_types = [], - translate:Optional[Dict[str, os.PathLike]] = None, - use_literal_none = True, - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self.data_dir = data_dir - self._load_metadata_json(json_path) - self.sr = sr - self.duration = duration - self.plain_rate = plain_rate - self.return_audio = return_audio - self.return_path = return_path - self.use_literal_none = use_literal_none - self.lang = lang - - self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0 - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types)) - self.tag_types = tag_types - - self.translate = read_translate(translate) - - #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示 - WEAK_TAG_LIST = ["title", "artist"] - - def _load_metadata_json(self, json_path): - with open(json_path) as fp: - self.data = json.load(fp) - - def convert_key_to_path(self, key): - return os.path.join(self.data_dir, get_base_dir_file(key)) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - path = self.convert_key_to_path(item['key']) - description = self.generate_description(item) - - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - else: - audio = None - - if self.return_path: - return audio, description, path - return audio, description - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def generate_description(self, item): - if random.random() > self.plain_rate: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - else: - # use plain prompt, i.e. tags sequence separated by comma - description = self.generate_description_plain(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag)) - exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag)) - - if len(exists_strong_tag) > 0: - probs = dist_prob_map[len(exists_strong_tag)] - tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0] - random.shuffle(exists_strong_tag) - tags = exists_strong_tag[:tags_num] - weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1] - weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0] - random.shuffle(exists_weak_tag) - weak_tags = exists_weak_tag[:weak_tags_num] - tags += weak_tags - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - prompt = prompt_template.apply(**tags_args) - else: - # no strong tags, use all weak tags instead - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag} - prompt = prompt_template.apply(**tags_args) - - if self.use_literal_none and len(tags_args) == 0: - return 'none' - - return prompt - - def generate_description_plain(self, item): - keywords = [] - for tag_t in self.tag_types: - this_key = item[tag_t] - if this_key is None: - continue - if isinstance(this_key, str): - this_key = [this_key] - if self.lang != 'en': - this_key = [self.get_translation(tag_t, k) for k in this_key] - keywords += this_key - return gen_plain_prompt(keywords, sep=self.keysep) - - def get_translation(self, tag_t, k): - k = k.strip() - if k in self.translate[tag_t]: - return self.translate[tag_t][k] - else: - return k - - @property - def keysep(self): - if self.lang == 'zh': - return ',' if random.random() > 0.5 else '、' - elif self.lang == 'en': - return ', ' - -class AudioStockDataset(Dataset): - def __init__(self, - metadata_path:str, - duration:float=10, - sr:int = 0, - plain_rate = 0, - return_path = False, - return_audio = True, - prompt_template_path: os.PathLike = None, - tag_types = [], - lang = 'en', - translate:Optional[Dict[str, os.PathLike]] = None, - use_literal_none = True, - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self._load_metadata(metadata_path) - self.sr = sr - self.duration = duration - self.plain_rate = plain_rate - self.return_path = return_path - self.return_audio = return_audio - self.use_literal_none = use_literal_none - - self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0 - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) - self.tag_types = tag_types - - self.lang = lang - self.translate = read_translate(translate) - - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - self.data.append(item) - self.is_info_recorded = bool('Tags' in self.data[0]) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - path:str = self.data[idx]["path"] - json_path = path[:path.rfind('.')] + ".json" - if self.is_info_recorded: - item = self.data[idx] - else: - try: - with open(json_path) as fp: - item:dict = json.load(fp) - except Exception as e: - print(f"Error loading json file {json_path} :\n{e}") - item = {} - description = self.generate_description(item) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path - return audio, description - - def generate_description(self, item): - if random.random() > self.plain_rate: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - else: - # use plain prompt, i.e. tags sequence separated by comma - description = self.generate_description_plain(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - - if len(exists_tag) > 0: - probs = dist_prob_map[len(exists_tag)] - tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] - random.shuffle(exists_tag) - tags = exists_tag[:tags_num] - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - tags_args = self.handle_BPM_tag(tags_args) - prompt = prompt_template.apply(**tags_args) - else: - return 'none' - - if self.use_literal_none and len(tags_args) == 0: - return 'none' - - return prompt - - def get_translation(self, tag_t, k): - k = k.strip() - if k in self.translate[tag_t]: - return self.translate[tag_t][k] - else: - return k - - def generate_description_plain(self, item): - keywords = [] - for tag_t in self.tag_types: - if tag_t == 'BPMDescript': - bpm = item['BPM'] - if bpm is None or bpm.strip() == '' or bpm.strip() == '0': - continue - this_key = gen_bpm_descript(bpm.strip(), lang=self.lang) - elif tag_t == 'BPM': - bpm = item['BPM'] - if bpm is None or bpm.strip() == '' or bpm.strip() == '0': - continue - this_key = f"{bpm.strip()} bpm" - else: - this_key = item[tag_t] - if this_key is None: - continue - if isinstance(this_key, str): - this_key = [this_key] - if self.lang != 'en': - this_key = [self.get_translation(tag_t, k) for k in this_key] - if this_key is None: - continue - if isinstance(this_key, str): - this_key = [this_key] - keywords += this_key - return gen_plain_prompt(keywords, sep=self.keysep) - - @property - def keysep(self): - if self.lang == 'zh': - return ',' if random.random() > 0.5 else '、' - elif self.lang == 'en': - return ', ' - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - if tag_type == 'BPM': - return tags_to_desc(tag_list, sep='、') - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def handle_BPM_tag(self, tags_args): - if "BPM" in tags_args and 'BPMDescript' in self.tag_types: - bpm = tags_args["BPM"] - del tags_args["BPM"] - tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) - for tag_type in tag_types_used: - tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) - return tags_args - -def mp3_path_to_id(mp3_path): - return int( - mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.')] - ) - -class TmeDataset(Dataset): - def __init__(self, - data_index:str, - music_info:str = None, - duration:float = 10, - sr:int = 0, - plain_rate = 0, - return_path = False, - return_audio = True, - return_ID = False, - prompt_format_path: os.PathLike = None, - tag_types = ['*'], - lang = 'zh', - translate: Optional[os.PathLike] = None, - prompt_dir: os.PathLike = None, #使用GPT生成的预有的prompt - ): - if plain_rate > 0: - print("Tme Dataset do not support plain rate > 0, use plain_rate = 0 instead.") - plain_rate = 0 - self.audio_reader = SafeAudioReader(duration, sr) - - self.sr = sr - self.duration = duration - self.plain_rate = plain_rate - self.return_path = return_path - self.return_audio = return_audio - self.return_ID = return_ID - self.lang = lang - - self.use_ready_prompt = prompt_dir is not None - - data_index = read_jsonlike(data_index) - self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index} - self.data_ids = list(self.data_index_dict.keys()) - - if not self.use_ready_prompt: - #读取音乐的信息文件 - music_info = read_jsonlike(music_info) - if 'music' in music_info: - music_info = music_info['music'] - self.music_info_dict = {d["歌曲ID"]:d for d in music_info} - self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict} - self.data_ids = list(self.data_index_dict.keys()) - - with open(prompt_format_path) as fp: - self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader) - - #加载tag types,并分成一般的tag_types和关键的key_tag_types - if '*' in tag_types: - self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag'] - else: - self.tag_types = tag_types - - self.key_tag_types = [] - if 'tag' in self.tag_types: - self.tag_types.remove('tag') - self.key_tag_types = list(self.prompt_formats['tag'].keys()) - - #加载translate翻译 - if translate is not None: - self.translator = read_jsonlike(translate) - else: - data_ids_set = set(self.data_ids) - self.prompts_dict = {} - for fname in os.listdir(prompt_dir): - items = read_jsonlike(os.path.join(prompt_dir, fname)) - for item in items: - if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']): - continue - if item['ID'] not in self.prompts_dict: - self.prompts_dict[item['ID']] = [] - self.prompts_dict[item['ID']].append(item['Text']) - self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict} - self.data_ids = list(self.data_index_dict.keys()) - - def tags_to_desc(self, tag_list) -> str: - if is_bearable(tag_list, int): - return str(tag_list) - if self.lang == 'zh': - return tags_to_desc(tag_list, sep=self.sep) - else: - translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ] - return tags_to_desc(translated_tag_list, sep=self.sep) - - def gen_desc_of_tag(self, formats, tags): - fmt = random.choice(formats) - return fmt.format(self.tags_to_desc(tags)) - - @staticmethod - def check_valid(value): - if isinstance(value, int) or isinstance(value, float): - return value > 0 - if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0): - return True - return False - - @staticmethod - def remove_repeat(data): - #若专辑名和歌曲名相同,则只使用后者 - album_name = data.get('专辑名', None) - if album_name is not None and album_name == data.get('歌曲名', None): - del data['专辑名'] - return data - - @property - def comma(self): - if self.lang == 'zh': - return ',' - elif self.lang == 'en': - return ', ' - - @property - def sep(self): - if self.lang == 'zh': - return '、' - elif self.lang == 'en': - return ', ' - - - def generate_description(self, item): - if random.random() > self.plain_rate: - # dynamically generate prompt from given prompt template - description = self.generate_description_dynamic(item) - else: - # use plain prompt, i.e. tags sequence separated by comma - description = self.generate_description_plain(item) - return description - - def generate_description_dynamic(self, data): - data = self.remove_repeat(data) - - weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低 - - key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个 - - prompts = [] - if len(weak_tags) > 0: - probs = dist_prob_map_low[len(weak_tags)] - if len(key_tags) > 0: - tags_num = random.choices(range(0, len(weak_tags)), probs)[0] - else: - tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0] - random.shuffle(weak_tags) - tags = weak_tags[:tags_num] - for tag_type in tags: - tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type]) - prompts.append(tag_desc) - - if len(key_tags) > 0: - probs = dist_prob_map[len(key_tags)] - tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0] - random.shuffle(key_tags) - tags = key_tags[:tags_num] - for tag_type in tags: - tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type]) - prompts.append(tag_desc) - - random.shuffle(prompts) - return self.comma.join(prompts) - - def generate_description_plain(self, item): - keywords = item['tag'] - if self.lang != 'en': - keywords = [self.translator[k.strip()] for k in keywords] - return gen_plain_prompt(keywords, sep=self.keysep) - - @property - def keysep(self): - if self.lang == 'zh': - return ',' if random.random() > 0.5 else '、' - elif self.lang == 'en': - return ', ' - - def is_valid_prompt_text(self, text): - for bad in ('抱歉','sorry', 'Sorry'): - if bad in text: - return False - return True - - def get_ready_prompt(self, path): - sid = mp3_path_to_id(path) - return random.choice(self.prompts_dict[sid]) - - def __len__(self): - return len(self.data_ids) - - def __getitem__(self, idx): - data_id = self.data_ids[idx] - item = self.data_index_dict[data_id] - path = item['path'] - if not self.use_ready_prompt: - info = self.music_info_dict[data_id] - description = self.generate_description(info) - else: - description = self.get_ready_prompt(path) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - if self.return_ID: - return audio, description, path, info['歌曲ID'] - return audio, description, path - if self.return_ID: - return audio, description, info['歌曲ID'] - return audio, description - - -class Pond5Dataset(Dataset): - MAX_PROMPT_LEN = 200 - def __init__(self, - metadata_path:str, - index_path:str, - duration:float=10, - sr:int = 0, - plain_rate = 0, - return_path = False, - return_audio = True, - lang = 'en', - translate:Optional[Dict[str, os.PathLike]] = None, - use_literal_none = True, - use_avoid_watermark_policy = None, - ): - - if use_avoid_watermark_policy is None: - raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type") - self.use_avoid_watermark_policy = use_avoid_watermark_policy - self.audio_reader = SafeAudioReader(duration, sr, use_avoid_watermark_policy=use_avoid_watermark_policy) - - self._load_metadata(metadata_path, index_path) - self.sr = sr - self.duration = duration - self.plain_rate = plain_rate - self.return_path = return_path - self.return_audio = return_audio - self.use_literal_none = use_literal_none - - self.lang = lang - self.translate = read_translate(translate) - - def _load_metadata(self, metadata_path, index_path): - data_index = read_jsonlike(index_path) - data_ids = set([item['id'] for item in data_index]) - - with open(metadata_path) as fp: - lines = fp.readlines() - - append_ids = set() - - self.data = [] - for line in lines: - item = json.loads(line) - if item['id'] in data_ids and item['id'] not in append_ids: - self.data.append(item) - append_ids.add(item['id']) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - path:str = item["path"] - description = self.generate_description(item) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path - return audio, description - - @property - def keysep(self): - if self.lang == 'zh': - return ',' if random.random() > 0.5 else '、' - elif self.lang == 'en': - return ', ' - - def generate_description(self, item): - if random.random() > self.plain_rate: - # dynamically generate prompt from given prompt template - description = self.generate_description_dynamic(item) - else: - # use plain prompt, i.e. tags sequence separated by comma - description = self.generate_description_plain(item) - return description - - def get_translation(self, k): - k = k.strip() - if k in self.translate: - return self.translate[k] - else: - return k - - def generate_description_plain(self, item): - keywords = item['keywords'] - if self.lang != 'en': - keywords = [self.get_translation(k) for k in keywords] - return gen_plain_prompt(keywords, sep=self.keysep) - - def generate_description_dynamic(self,item): - desc = item.get('desc', 'none') - if desc is None: - desc = 'none' - desc = desc.strip() - if len(desc) > self.MAX_PROMPT_LEN: - shorter_desc = desc[:self.MAX_PROMPT_LEN] - # find last stop - stop_idx = shorter_desc.rfind('.') - if stop_idx == -1: - stop_idx = shorter_desc.rfind('!') - if stop_idx == -1: - stop_idx = shorter_desc.rfind(',') - if stop_idx == -1: - stop_idx = self.MAX_PROMPT_LEN - 1 - desc = desc[:stop_idx+1] - return desc - -class SoundDataset(Dataset): - def __init__(self, - metadata_index: str, - duration:float = 10, - min_non_silent_duration:float = 3, - sr:int = 0, - return_path = False, - return_audio = True, - ): - self.data = read_jsonlike(metadata_index) - self.sr = sr - self.reader = SafeAudioReader(duration, sr) - self.duration = duration - self.min_non_silent_duration = min_non_silent_duration - self.return_audio = return_audio - self.return_path = return_path - - def __getitem__(self, index): - item = self.data[index] - if self.return_audio: - origin_duration = item['duration'] - if origin_duration < self.min_non_silent_duration: - audio = self.read_and_repeat_and_pad(item) - else: - audio = self.reader(item['path'], item['sample_rate'], origin_duration) - else: - audio = None - desc = item['caption'] - if self.return_path: - return audio, desc, item['path'] - else: - return audio, desc - - def __len__(self): - return len(self.data) - - def read_and_repeat_and_pad(self, item): - path = item['path'] - try: - # read - clip, sr = torchaudio.load(path) - if len(clip.shape) > 1: - clip = torch.mean(clip, dim=0, keepdim=True) - clip = resample(clip, sr, self.sr) - #repeat - n_repeats = math.ceil(self.min_non_silent_duration/item['duration']) - clip = torch.repeat_interleave(clip, n_repeats, dim=0).reshape(-1) - #pad - n_samples = int(self.duration * self.sr) - if clip.shape[0] >= n_samples: - audio = clip[:n_samples] - else: - audio = torch.zeros(int(self.duration * self.sr), dtype=clip.dtype) - start_pos = np.random.randint(0, max(0,(n_samples - clip.shape[0]))) - audio[start_pos:start_pos+clip.shape[0]] = clip - return audio - - except Exception as e: - logger.error(f"Error reading {path}: {e}") - wav = torch.zeros(int(self.duration * self.sr), dtype=torch.float32) - return wav - -class CombinedDataset(Dataset): - @beartype - def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]): - self.datasets = datasets - self.datasets_index = [] - - for i,dataset in enumerate(datasets): - if dataset is None: - continue - for dup in range(ratios[i]): - for j in range(len(dataset)): - self.datasets_index.append((i,j)) - - def __len__(self): - return len(self.datasets_index) - - def __getitem__(self, idx): - index = self.datasets_index[idx] - i,j = index - return self.datasets[i][j] - -class CombinedDataset_random(Dataset): - @beartype - def __init__(self, num_examples:int, datasets: Sequence[Dataset], ratios: Sequence[int]): - self.datasets = datasets - self.datasets_index = [] - - for i,dataset in enumerate(datasets): - if dataset is None: - continue - for dup in range(ratios[i]): - for j in range(len(dataset)): - self.datasets_index.append((i,j)) - - if num_examples > 0: - self.random_choose = True - self.dataset_len = num_examples - else: - self.random_choose = False - self.dataset_len = len(self.datasets_index) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, idx): - first_try = True - try_cnt = 0 - while True: - try: - if(self.random_choose or not first_try): - index2 = [] - index2.append(np.random.randint(0,len(self.datasets))) - index2.append(np.random.randint(0,len(self.datasets[index2[-1]]))) - else: - index2 = self.datasets_index[idx] - first_try = False - out = list(self.datasets[index2[0]][index2[1]]) - return out - except: - print("Error loadding ", index2) - try_cnt += 1 - if(try_cnt>10): - raise ValueError() - -class SoundMixedDataset(Dataset): - @staticmethod - def music_desc(desc): - return f'Music:<{desc}>' - @staticmethod - def sound_desc(desc): - return f'Effect:<{desc}>' - - def __init__(self, - music_dataset: Dataset, - sound_dataset: Dataset, - mixed_ratios: Tuple[float, float, float] = (0.3, 0.3, 0.4) # 只有音乐:只有音效:音乐音效混合 的比例 - ) -> None: - self.music_dataset = music_dataset - self.sound_dataset = sound_dataset - music_r, sound_r, mix_r = [r/sum(mixed_ratios) for r in mixed_ratios] #化为0-1间的比例 - #三个概率区间的左端点 - self.music_anchor = 0 - self.sound_anchor = music_r - self.mix_anchor = music_r + sound_r - - def __len__(self): - return len(self.music_dataset) - - def get_random_sound_data(self): - idx = random.randint(0, len(self.sound_dataset)-1) - return self.sound_dataset[idx] - - def __getitem__(self, idx): - p = random.random() - if p >= self.mix_anchor: - music, m_desc = self.music_dataset[idx] - sound, s_desc = self.get_random_sound_data() - audio = music + sound - if(audio.abs().max()>1.0): - music = music / audio.abs().max() * 0.95 - audio = audio / audio.abs().max() * 0.95 - desc = self.music_desc(m_desc) + self.sound_desc(s_desc) - return audio[None,:], music[None,:], desc - elif p >= self.sound_anchor: - audio, desc = self.get_random_sound_data() - return audio[None,:], torch.zeros_like(audio[None,:]), self.sound_desc(desc) - else: - audio, desc = self.music_dataset[idx] - return audio[None,:], audio[None,:], self.music_desc(desc) - - -class DecoTagDataset(Dataset): - '''这个类把普通的datatset包装成适用于标签解耦学习的dataset''' - - TAG_TYPES = ('genre', 'mood', 'insrument') - - def __init__(self, dataset_class: type, tag_map: Dict[str, str], *args, **kwargs): - self.datasets = [] - for i, tag_t in enumerate(self.TAG_TYPES): - kwargs['tag_types'] = [tag_map[tag_t]] - kwargs['return_audio'] = (i == 0) #只有第0个需要返回音频和文本,其余只需要返回文本 - self.datasets.append(dataset_class(*args, **kwargs)) - - def __len__(self): - return len(self.datasets[0]) - - def __getitem__(self, idx): - audio, text = self.datasets[0][idx] - texts = (text, self.datasets[1][idx][1], self.datasets[2][idx][1]) - return audio, texts - - -class DecoTagWrapper: - '''这是一个包装器,便于选择是否使用标签解耦学习''' - def __init__(self, dataset_class: Dataset, deco_tag_types: List[str] = list(), switch_on: bool = False): - self.dataset_class = dataset_class - self.tag_map = dict(zip(DecoTagDataset.TAG_TYPES, deco_tag_types)) - self.switch_on = switch_on - - def __call__(self, *args, **kwargs): - if self.switch_on: - return DecoTagDataset(self.dataset_class, self.tag_map, *args, **kwargs) - else: - return self.dataset_class(*args, **kwargs) diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py deleted file mode 100644 index 2ad93e9eaf529d30df7db3973d0c9822857df1a2..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py +++ /dev/null @@ -1,372 +0,0 @@ -import re -import sys -import json -from typing import List, Union - -from torch.utils.data import Dataset -import torchaudio -from torchaudio.functional import resample -import torch -import numpy as np - -from torch.nn.utils.rnn import pad_sequence - -PARAGRAPH_GAP = 6 -MIN_MUSIC_LEN = 3 - -def check_lryics(lyric): - _FILTER_STRING = [ - '作词', '作曲', '编曲', '【', '策划', - '录音', '混音', '母带', ':', '制作', - '版权', '校对', '演奏', '制作', '伴奏' - ] - for item in _FILTER_STRING: - if item in lyric: - return True - - return False - - - -def process_lyrics(lines): - lyric_part = [] - timestamp_part = [] - - timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') - - for i, line in enumerate(lines): - - # 删除前几行的特定信息 - if i<10 and check_lryics(line): - continue - - # 检查是否包含有效的时间戳和歌词内容 - if timestamp_pattern.match(line): - timestamp_end = line.rfind(']') - lyrics = line[timestamp_end + 1:].strip() - timestamps = line[:timestamp_end + 1] - - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - # if lyrics: # 确保歌词部分不是空的 - # lyric_part.append(lyrics) - # timestamp_part.append(timestamps) - # print(processed_lyrics) - return timestamp_part, lyric_part - -def get_timestamps(timestamp_part): - - # 转换为秒 - - timestamps = [] - - for line in timestamp_part: - match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) - if match: - minutes = int(match.group(1)) - seconds = float(match.group(2)) - millis = float(match.group(3)) if match.group(3) else 0 - total_seconds = minutes * 60 + seconds + millis - timestamps.append(total_seconds) - - - return timestamps - -def process_lyrics_lrc(lyrics): - timestamp_part, lyric_part = process_lyrics(lyrics) - # print(timestamp_part) - # print(lyric_part) - timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - - return output_list - - - -def process_lyrics_yrc(lyrics): - - timestamps, lyric_part = extract_lrc(lyrics) - - # timestamp_part, lyric_part = process_lyrics(lyrics) - # import pdb; pdb.set_trace() - # print(timestamp_part) - # print(lyric_part) - # timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - # import pdb; pdb.set_trace() - return output_list - -def extract_lrc(lyrics): - timestamp_part, lyric_part = [], [] - - for i, text in enumerate(lyrics): - # 提取中括号内的内容 - bracket_content = re.search(r'\[(.*?)\]', text).group(1) - bracket_content = bracket_content.split(',') - # 提取小括号内的内容 - parentheses_content = re.findall(r'\((.*?)\)', text) - # 提取其他内容 - other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() - - # 数据怎么处理? - if i<10 and check_lryics(other_content): - continue - timestamp_part.append(float(bracket_content[0])/1000) - lyric_part.append(other_content) - return timestamp_part, lyric_part - - - -class WYYSongDataset(Dataset): - def __init__(self, - metadata_path: Union[str, List[str]], - sr:int = 0, - use_lang = ['en', 'zh-cn'], - num_examples = -1, - max_dur = 20, - min_dur=0, - add_music=False, - pad_to_max= True, - ): - - self.sr = sr - self.use_lang = use_lang - self.data = [] - if type(metadata_path) == str: - metadata_path = [metadata_path] - for _meta in metadata_path: - self._load_metadata(_meta) - self.max_dur = max_dur - self.min_dur = min_dur - self.pad_to_max = pad_to_max - self.add_music = add_music - - # buffer - self.lyric_buffer = {} - - if(num_examples<=0): - self.dataset_len = len(self.data) - self.random_slc = False - else: - self.dataset_len = num_examples - self.random_slc = True - - - # 读取jsonl文件 - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - for line in lines: - item = json.loads(line) - if '伴奏' not in item['path']: - # if "lang_type" in item and item['lang_type'] == 'en': - if "lang_type" in item: - self.data.append(item) - - - def __len__(self): - return self.dataset_len - - - def __getitem__(self, idx): - try_cnt = 0 - while True: - if(self.random_slc): - idx = np.random.randint(0, len(self.data)) - yrc_lyrics = [] - lrc_lyrics = [] - try: - info = self.data[idx] - - # audio path - path = info["path"] - lang_type = info["lang_type"] - lyrics = info['lyrics'] # chinese - # lyrics = info['lyrics_phone'] - - # 随机选取一个lyric段落 - - parsed_lyrics = [] - # st_idx = np.random.randint(0, len(lyrics)) - for ly_id in range(len(lyrics)): - lyric = lyrics[ly_id].strip() - st, et, lyric = self.parse_lyric(lyric) - - if et - st >= self.max_dur: - continue #TODO 前后外沿 [MUSIC] - - if parsed_lyrics != []: - if st - parsed_lyrics[-1][1] >= PARAGRAPH_GAP: # 大gap - parsed_lyrics.append((parsed_lyrics[-1][1], st, '[GAP]')) - elif self.add_music and st - parsed_lyrics[-1][1] >= MIN_MUSIC_LEN: - parsed_lyrics.append((parsed_lyrics[-1][1], st, '[MUSIC]')) - - lyric = lyric.replace("\xa0", " ") - lyric = " ".join(lyric.split()) - parsed_lyrics.append((st, et, lyric)) - - assert parsed_lyrics != [] - # if parsed_lyrics[-1][1] - parsed_lyrics[0][0] > self.max_dur: - # print(f"{parsed_lyrics[0][0]}-{parsed_lyrics[-1][1]} {parsed_lyrics}", file=open('tmp.txt', 'a')) - - parsed_lyrics = [(0, parsed_lyrics[0][0], '[GAP]')] + parsed_lyrics - - possible_starts = [e for e,i in enumerate(parsed_lyrics) if i[2]=='[GAP]'] - st_idx = np.random.choice(possible_starts) - - paraphrase = [] - for i in parsed_lyrics[st_idx+1:]: - if i[2] == '[GAP]': - break - paraphrase.append(i) - # print(paraphrase, lyrics) - - while paraphrase[-1][1] - paraphrase[0][0] > self.max_dur: - if np.random.rand() > 0.2: - paraphrase.pop(-1) # 大概率从后面截断 - else: - paraphrase.pop(0) # 小概率截前面 - - st, et, lyric = paraphrase[0][0], paraphrase[-1][1], ', '.join([i[2] for i in paraphrase]) # [SEP] - # print(st, et, lyric) - # import pdb; pdb.set_trace() - assert self.min_dur < et - st < self.max_dur, f"{st}-{et} {lyric}" - # print(et-st, lyric) - # import pdb; pdb.set_trace() - - if info["lang_type"] == 'en': - # print(len(lyric.split())/(et-st)) - char_num = sum([len(lrc[-1].split()) for lrc in paraphrase]) - assert 6 > char_num / (et-st) > 1 - else: - # print(len(lyric.split())/(et-st)) - char_num = sum([len(lrc[-1]) for lrc in paraphrase]) - assert 6 > char_num / (et-st) > 1 - - # 读取音频文件 - cur_sample_rate = torchaudio.info(path).sample_rate - offset = int(cur_sample_rate*st) - num_frames = int(cur_sample_rate * (et -st)) - chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) - # chunk = torch.zeros(1, 48000*15) - if abs(chunk.shape[-1] - num_frames) > num_frames * 0.05: # 音频文件长度与歌词不一致 - print(f"fail to load {path} from {st} to {et} !") - raise FileNotFoundError - # 随机选取一个channel - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - - if(cur_sample_rate!=self.sr): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) - - if self.pad_to_max: - chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0) - - # print(self.sz_cnt) - return chunk, lyric, [st, et], path, lang_type - except (AssertionError, FileNotFoundError, RuntimeError) as e: # 其他Error不ok - # print("Error loadding ", info["path"]) - try_cnt += 1 - idx = np.random.randint(0, len(self.data)) - if(try_cnt>100): - raise e - - def parse_lyric(self, lyric): - pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' - match = re.search(pattern, lyric) - - start_time = float(match.group(1)) - end_time = float(match.group(2)) - content = match.group(3) - return start_time, end_time, content - - def pad_2d_tensor(self, x, max_len, pad_id): - # 获取输入 tensor 的形状 - batch_size, seq_len = x.size() - max_len = max(max_len, seq_len) - # 计算需要填充的长度 - pad_len = max_len - seq_len - - # 如果需要填充 - if pad_len > 0: - # 创建填充 tensor - pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) - - # 沿第二个维度(列)连接输入 tensor 和填充 tensor - padded_tensor = torch.cat([x, pad_tensor], dim=1) - else: - # 如果不需要填充,直接返回输入 tensor - padded_tensor = x - - return padded_tensor - -def collect_data(data_list): - audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) - lyrics = [data[1] for data in data_list] - st_et = [data[2] for data in data_list] - paths = [data[3] for data in data_list] - lang_types = [data[4] for data in data_list] - return audios, lyrics, st_et - # return audios, lyrics, st_et - - -def build_dataset(train_jsonl_list, val_jsonl_list, min_dur=0, max_dur=20, add_music=False): - print(min_dur,max_dur) - print(train_jsonl_list) - # ["exp/wyy3_20240418_v2f.jsonl", - # "exp/tme_lyric_baokuan.jsonl"] - train_dataset = WYYSongDataset( - metadata_path = train_jsonl_list, - sr = 48000, - use_lang = ['zh-cn', 'en'], - num_examples = 10*10000, - min_dur=min_dur, - max_dur=max_dur, - add_music=add_music - ) - - valid_dataset = WYYSongDataset( - metadata_path = val_jsonl_list, - sr = 48000, - use_lang = ['zh-cn', 'en'], - num_examples = 500, - min_dur=min_dur, - max_dur=max_dur, - add_music=add_music - ) - print(train_jsonl_list, "\t total_song = ", len(train_dataset.data)) - return train_dataset, valid_dataset diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py deleted file mode 100644 index a1ec74a70b8491e7c973ed1dff68d843049c044d..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py +++ /dev/null @@ -1,830 +0,0 @@ -from torch.utils.data import Dataset -from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List -from beartype import beartype -from beartype.door import is_bearable -import random -import pandas as pd -import os -from torchaudio.functional import resample -import torch -import typing as tp -from pathlib import Path -import torchaudio as ta -import torch.nn.functional as F -import numpy as np -import json -import yaml -import torchaudio -import math -import re -from loguru import logger - -class Read_and_PadCrop_Normalized_T(torch.nn.Module): - def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): - - super().__init__() - - self.n_samples = n_samples - self.sample_rate = sample_rate - self.randomize = randomize - - def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: - if(duration<(float(self.n_samples)/self.sample_rate+1)): - # print(duration,(float(self.n_samples)/self.sample_rate+1)) - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) - offset = 0 - # print('c1:',chunk.shape) - else: - offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - t_start = offset / float(cur_sample_rate) / duration - t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration - chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - # print('offset:',offset) - # print('c0:',chunk.shape) - # Pad with silence if necessary. - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - if(cur_sample_rate!=self.sample_rate): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) - # print('b:',self.sample_rate,chunk.shape) - if chunk.shape[-1] < self.n_samples: - chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) - else: - chunk = chunk[:,0:self.n_samples] - seconds_start = math.floor(offset / cur_sample_rate) - seconds_total = math.floor(duration) - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total - ) - - -USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替 -if USE_DUMMY_AUDIO: - logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!") - -class SafeAudioReader: - """ - This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data. - """ - def __init__(self, - duration: float, # 返回音频长度 - sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample - randomize: bool = True - ): - self.n_samples = int(sample_rate * max(duration, 0)) - self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize) - - #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数! - def __call__(self, - filepath: os.PathLike, # 音频路径 - origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取 - origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取 - ) -> torch.Tensor: - if USE_DUMMY_AUDIO: - wav = torch.zeros(self.n_samples, dtype=torch.float32) - return wav - try: - if origin_sample_rate is None or origin_duration is None: - audio_info = torchaudio.info(filepath) - origin_sample_rate = audio_info.sample_rate - origin_duration = audio_info.num_frames / origin_sample_rate - wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate) - except Exception as e: - logger.error(f"Error reading {filepath}: {e}") - wav = torch.zeros(self.n_samples, dtype=torch.float32) - return wav - - -class PromptTemplate: - def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): - self.template_text = template_text - self.tag_map = tag_map - self.lang = lang - - @property - def tags(self): - return tuple(self.tag_map.keys()) - - def apply(self, **kwargs): - for tag in list(kwargs.keys()): - if kwargs[tag] == '': - kwargs.pop(tag) - for tag in self.tags: - if tag in kwargs: - kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') - else: - kwargs[tag] = '' - prompt = self.template_text.format(**kwargs) - - return self.beautify(prompt) - - def beautify(self, text): - if self.lang == 'en': - return self._beautify_en(text) - elif self.lang == 'zh': - return self._beautify_zh(text) - else: - raise ValueError(f'Unknown language {self.lang}') - - @staticmethod - def _beautify_en(text): - # no continuous commas without content between them - text = re.sub(r'[,\s]*,[,\s]*', r', ', text) - # no continuous whitespace - text = re.sub(r'\s+', ' ', text) - # the comma is NOT followed by whitespace, and should be followed by ONE whitespace - text = re.sub(r'\s+,', r',', text) - text = re.sub(r',\s+', r', ', text) - # no whitespace before the full stop - text = re.sub(r'\s+\.', r'.', text) - # strip whitespace, comma, and replace ',.' - text = text.strip(' ,') - text = text.replace(',.', '.') - return text - - @staticmethod - def _beautify_zh(text): - # no continuous commas without content between them - text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) - text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) - # assume there should be NO whitespace in Chinese - text = re.sub(r'\s+', r'', text) - # strip whitespace, comma, and replace ',。' - text = text.strip(', 、') - text = text.replace(',。', '。') - return text - - def __repr__(self): - return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' - - __str__ = __repr__ - -def parse_prompt_template(prompt_template_text, lang='en'): - span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) - tag_pattern = re.compile(r'{.+?}', re.DOTALL) - - template_text = prompt_template_text.strip() - span_texts = span_pattern.findall(prompt_template_text) - tag_map = {} - for span_text in span_texts: - tag = tag_pattern.findall(span_text)[0].strip('{}') - tag_map[tag] = span_text - template_text = template_text.replace(span_text, '{'+tag+'}') - - return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) - -def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: - with open(path, 'r') as f: - lines = f.readlines() - cnt = 0 - pts = [] - for line in lines: - pt = parse_prompt_template(line, lang=lang) - cnt += 1 - if len(pt.tags) < num: - logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') - pts.append(pt) - - return pts - - -def get_base_dir_file(key: os.PathLike): - base = os.path.basename(key) - dirname = os.path.basename(os.path.dirname(key)) - return os.path.join(dirname, base) - -def read_jsonlike(path: os.PathLike): - #json or jsonl - if str(path).endswith(".json"): - with open(path, 'r', encoding='utf8') as f: - data = json.load(f) - return data - elif str(path).endswith(".jsonl"): - with open(path, 'r', encoding='utf8') as f: - data = [json.loads(line) for line in f.readlines()] - return data - else: - raise ValueError("Unknown file format") - -dist_prob_map = { - 1: (1.0,), - 2: (0.5, 0.5), - 3: (0.3, 0.4, 0.3), - 4: (0.2, 0.3, 0.3, 0.2), - 5: (0.2, 0.2, 0.3, 0.2, 0.1), - 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), - 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), - 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), - 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), - 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) -} - -dist_prob_map_low = { - 1: (1.0,), - 2: (0.8, 0.2), - 3: (0.8, 0.1, 0.1), - 4: (0.7, 0.1, 0.1, 0.1), - 5: (0.7, 0.1, 0.1, 0.05, 0.05), - 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), -} - -_bpm_range_rights = ( - (40, '20-40'), - (60, '40-60'), - (66, '60-66'), - (76, '66-76'), - (108, '76-108'), - (120, '108-120'), - (168, '120-168'), - (176, '168-176'), - (200, '176-200') -) -_bpm_desc_map = { - '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), - '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), - '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), - '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), - '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), - '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), - '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), - '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), - '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), - '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') -} -_bpm_desc_map_zh = { - '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), - '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), - '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), - '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), - '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), - '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), - '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), - '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), - '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), - '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') -} -def get_bpm_range(bpm): - bpm = int(bpm) - for right, tag in _bpm_range_rights: - if bpm <= right: - return tag - return '>200' - -def gen_bpm_descript(bpm, lang='en'): - bpm_range = get_bpm_range(bpm) - if lang == 'en': - return random.choice(_bpm_desc_map[bpm_range]) - elif lang == 'zh': - return random.choice(_bpm_desc_map_zh[bpm_range]) - else: - raise ValueError(f"Unknown language {lang}") - -def read_translate(translate: Optional[Dict[str, os.PathLike]]): - if translate is None: - return None - return {k: read_jsonlike(path) for k, path in translate.items()} - - -class MagnaTagATuneDataset(Dataset): - def __init__(self): - pass - - -def tags_to_desc(tag_list, sep=',') -> str: - if not isinstance(tag_list, Sequence): - return str(tag_list) - if isinstance(tag_list, str): - return tag_list - if len(tag_list) <= 0: - return '' - elif len(tag_list) <= 5: - probs = dist_prob_map[len(tag_list)] - tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - else: - probs = dist_prob_map[5] - tags_num = random.choices(range(1, 6), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - -def get_sr_and_duration_info(item): - return item.get('sample_rate', None), item.get('duration', None) - -class MtgJamendoDatasetFromJson(Dataset): - def __init__(self, - data_dir:str, - json_path:str, - duration:float=10, - sr:int = 0, - *, - lang = 'en', - return_path = False, - prompt_template_path: os.PathLike = None, - tag_types = [], - translate:Optional[Dict[str, os.PathLike]] = None, - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self.data_dir = data_dir - self._load_metadata_json(json_path) - self.sr = sr - self.duration = duration - self.return_path = return_path - self.lang = lang - - self.use_dynamic_prompt = prompt_template_path is not None - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types)) - self.tag_types = tag_types - - self.translate = read_translate(translate) - if not self.use_dynamic_prompt and self.lang != 'en': - raise NotImplementedError - - #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示 - WEAK_TAG_LIST = ["title", "artist"] - - def _load_metadata_json(self, json_path): - with open(json_path) as fp: - self.data = json.load(fp) - - def convert_key_to_path(self, key): - return os.path.join(self.data_dir, get_base_dir_file(key)) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - path = self.convert_key_to_path(item['key']) - description = self.generate_description(item) - - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - - if self.return_path: - return audio, description, path - return audio, description - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def generate_description(self, item): - if self.use_dynamic_prompt: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - - else: - # use ordinary static prompt instead - description = self.generate_description_ordinary(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag)) - exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag)) - - if len(exists_strong_tag) > 0: - probs = dist_prob_map[len(exists_strong_tag)] - tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0] - random.shuffle(exists_strong_tag) - tags = exists_strong_tag[:tags_num] - weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1] - weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0] - random.shuffle(exists_weak_tag) - weak_tags = exists_weak_tag[:weak_tags_num] - tags += weak_tags - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - prompt = prompt_template.apply(**tags_args) - else: - # no strong tags, use all weak tags instead - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag} - prompt = prompt_template.apply(**tags_args) - - return prompt - - def generate_description_ordinary(self, data, thresh = 0.3): - # Initialize the description with title and artist - description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}' - - # Add genre if available - if data["genre"] and random.random() > thresh: - genres = ', '.join(data["genre"]) - description += f', belonging to the {genres} genres' - - # Add moods if available - if data["moods"] and random.random() > thresh: - moods = ', '.join(data["moods"]) - description += f'. This track conveys a {moods} mood' - - # Add instruments if available - if data["instrument"] and random.random() > thresh: - instruments = ', '.join(data["instrument"]) - description += f', and primarily features the following instruments: {instruments}' - - # Add a period to end the description - description += '.' - - return description - -class AudioStockDataset(Dataset): - def __init__(self, - metadata_path:str, - duration:float=10, - sr:int = 0, - return_path = False, - return_audio = True, - prompt_template_path: os.PathLike = None, - tag_types = [], - lang = 'en', - translate:Optional[Dict[str, os.PathLike]] = None - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self._load_metadata(metadata_path) - self.sr = sr - self.duration = duration - self.return_path = return_path - self.return_audio = return_audio - - self.use_dynamic_prompt = prompt_template_path is not None - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) - self.tag_types = tag_types - - self.lang = lang - self.translate = read_translate(translate) - - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - self.data.append(item) - self.is_info_recorded = bool('Tags' in self.data[0]) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - path:str = self.data[idx]["path"] - json_path = path[:path.rfind('.')] + ".json" - if self.is_info_recorded: - item = self.data[idx] - else: - try: - with open(json_path) as fp: - item:dict = json.load(fp) - except Exception as e: - print(f"Error loading json file {json_path} :\n{e}") - item = {} - description = self.generate_description(item) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path - return audio, description - - def generate_description(self, item): - if self.use_dynamic_prompt: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - else: - # use ordinary static prompt instead - description = self.generate_description_ordinary(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - - if len(exists_tag) > 0: - probs = dist_prob_map[len(exists_tag)] - tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] - random.shuffle(exists_tag) - tags = exists_tag[:tags_num] - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - tags_args = self.handle_BPM_tag(tags_args) - prompt = prompt_template.apply(**tags_args) - else: - # no strong tags, use all weak tags instead - prompt = prompt_template.apply() - - return prompt - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - if tag_type == 'BPM': - return tags_to_desc(tag_list, sep='、') - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def handle_BPM_tag(self, tags_args): - if "BPM" in tags_args and 'BPMDescript' in self.tag_types: - bpm = tags_args["BPM"] - del tags_args["BPM"] - tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) - for tag_type in tag_types_used: - tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) - return tags_args - - def generate_description_ordinary(self, data, thresh = 0.3): - if self.lang != 'en': - raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') - description = f'a piece of music by {data["Artist"]}' - - # Add genre if available - if data["Genre"] and random.random() > thresh: - genres = ', '.join(data["Genre"]) - description += f', belonging to the {genres} genres' - - # Add moods if available - if data["Tags"] and random.random() > thresh: - tags = ', '.join(data["Tags"]) - description += f'. This track contains the tags:{tags}' - - # Add moods if available - if data["Mood"] and random.random() > thresh: - moods = ', '.join(data["Mood"]) - description += f'. This track conveys a {moods} mood.' - - # Add instruments if available - if data["Instrument"] and random.random() > thresh: - instruments = ', '.join(data["Instrument"]) - description += f'. and primarily features the following instruments: {instruments}' - - # Add a period to end the description - description += '.' - - return description - -def mp3_path_to_id(mp3_path): - return int( - mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')] - ) - -class TmeDataset(Dataset): - def __init__(self, - data_index:str, - music_info:str = None, - duration:float = 10, - sr:int = 0, - return_path = False, - return_audio = True, - prompt_format_path: os.PathLike = None, - tag_types = ['*'], - lang = 'zh', - translate: Optional[os.PathLike] = None, - prompt_dir: os.PathLike = None, - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self.sr = sr - self.duration = duration - self.return_path = return_path - self.return_audio = return_audio - self.lang = lang - - self.use_ready_prompt = prompt_dir is not None - - data_index = read_jsonlike(data_index) - self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index} - self.data_ids = list(self.data_index_dict.keys()) - - if not self.use_ready_prompt: - #读取音乐的信息文件 - music_info = read_jsonlike(music_info) - if 'music' in music_info: - music_info = music_info['music'] - self.music_info_dict = {d["歌曲ID"]:d for d in music_info} - self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict} - self.data_ids = list(self.data_index_dict.keys()) - - with open(prompt_format_path) as fp: - self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader) - - #加载tag types,并分成一般的tag_types和关键的key_tag_types - if '*' in tag_types: - self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag'] - else: - self.tag_types = tag_types - - self.key_tag_types = [] - if 'tag' in self.tag_types: - self.tag_types.remove('tag') - self.key_tag_types = list(self.prompt_formats['tag'].keys()) - - #加载translate翻译 - if translate is not None: - self.translator = read_jsonlike(translate) - else: - data_ids_set = set(self.data_ids) - self.prompts_dict = {} - for fname in os.listdir(prompt_dir): - items = read_jsonlike(os.path.join(prompt_dir, fname)) - for item in items: - if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']): - continue - if item['ID'] not in self.prompts_dict: - self.prompts_dict[item['ID']] = [] - self.prompts_dict[item['ID']].append(item['Text']) - self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict} - self.data_ids = list(self.data_index_dict.keys()) - - def tags_to_desc(self, tag_list) -> str: - if is_bearable(tag_list, int): - return str(tag_list) - if self.lang == 'zh': - return tags_to_desc(tag_list, sep=self.sep) - else: - translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ] - return tags_to_desc(translated_tag_list, sep=self.sep) - - def gen_desc_of_tag(self, formats, tags): - fmt = random.choice(formats) - return fmt.format(self.tags_to_desc(tags)) - - @staticmethod - def check_valid(value): - if isinstance(value, int) or isinstance(value, float): - return value > 0 - if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0): - return True - return False - - @staticmethod - def remove_repeat(data): - #若专辑名和歌曲名相同,则只使用后者 - album_name = data.get('专辑名', None) - if album_name is not None and album_name == data.get('歌曲名', None): - del data['专辑名'] - return data - - @property - def comma(self): - if self.lang == 'zh': - return ',' - elif self.lang == 'en': - return ', ' - - @property - def sep(self): - if self.lang == 'zh': - return '、' - elif self.lang == 'en': - return ', ' - - def generate_description(self, data): - data = self.remove_repeat(data) - weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低 - - key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个 - - prompts = [] - if len(weak_tags) > 0: - probs = dist_prob_map_low[len(weak_tags)] - if len(key_tags) > 0: - tags_num = random.choices(range(0, len(weak_tags)), probs)[0] - else: - tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0] - random.shuffle(weak_tags) - tags = weak_tags[:tags_num] - for tag_type in tags: - tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type]) - prompts.append(tag_desc) - - if len(key_tags) > 0: - probs = dist_prob_map[len(key_tags)] - tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0] - random.shuffle(key_tags) - tags = key_tags[:tags_num] - for tag_type in tags: - tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type]) - prompts.append(tag_desc) - - random.shuffle(prompts) - return self.comma.join(prompts) - - def is_valid_prompt_text(self, text): - for bad in ('抱歉','sorry', 'Sorry'): - if bad in text: - return False - return True - - def get_ready_prompt(self, path): - sid = mp3_path_to_id(path) - return random.choice(self.prompts_dict[sid]) - - def __len__(self): - return len(self.data_ids) - - def __getitem__(self, idx): - data_id = self.data_ids[idx] - item = self.data_index_dict[data_id] - path = item['path'] - if not self.use_ready_prompt: - info = self.music_info_dict[data_id] - description = self.generate_description(info) - else: - description = self.get_ready_prompt(path) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path - return audio, description - -class CombinedDataset(Dataset): - @beartype - def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]): - self.datasets = datasets - self.datasets_index = [] - - for i,dataset in enumerate(datasets): - if dataset is None: - continue - for dup in range(ratios[i]): - for j in range(len(dataset)): - self.datasets_index.append((i,j)) - - def __len__(self): - return len(self.datasets_index) - - def __getitem__(self, idx): - index = self.datasets_index[idx] - i,j = index - return self.datasets[i][j] - -class CombinedDataset_random(Dataset): - @beartype - def __init__(self, - num_examples:int, - datasets: Sequence[Dataset], ratios: Sequence[int] - ): - self.datasets = datasets - self.datasets_index = [] - - for i,dataset in enumerate(datasets): - if dataset is None: - continue - for dup in range(ratios[i]): - for j in range(len(dataset)): - self.datasets_index.append((i,j)) - if num_examples > 0: - self.random_choose = True - self.dataset_len = num_examples - else: - self.random_choose = False - self.dataset_len = len(self.datasets_index) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, idx): - first_try = True - try_cnt = 0 - while True: - try: - if(self.random_choose or not first_try): - index2 = [] - index2.append(np.random.randint(0,len(self.datasets))) - index2.append(np.random.randint(0,len(self.datasets[index2[-1]]))) - else: - index2 = self.datasets_index[idx] - first_try = False - out = self.datasets[index2[0]][index2[1]] - if(len(out[0].shape)==1):out[0]=out[0][None,:] - return out - except: - print("Error loadding ", index2) - try_cnt += 1 - if(try_cnt>10): - raise ValueError() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py deleted file mode 100644 index 39044e6b5a6c945b86ddd0091b6e76775e0573d9..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py +++ /dev/null @@ -1,994 +0,0 @@ -from torch.utils.data import Dataset -from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List -from beartype import beartype -from beartype.door import is_bearable -import random -import pandas as pd -import os -from torchaudio.functional import resample -import torch -import typing as tp -from pathlib import Path -import torchaudio as ta -import torch.nn.functional as F -import numpy as np -import json -import yaml -import torchaudio -import math -import re -from loguru import logger - -def gen_plain_prompt(key_list, sep=', '): - if len(key_list) == 0: - return 'none' - - key_list = [k.strip() for k in key_list] - - if len(key_list) > 10: - random.shuffle(key_list) - key_list = key_list[:10] - - probs = dist_prob_map[len(key_list)] - - num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0] - - random.shuffle(key_list) - tags = key_list[:num_tags] - tags_str = sep.join(tags) - return tags_str - -class Read_and_PadCrop_Normalized_T(torch.nn.Module): - - def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): - - super().__init__() - - self.n_samples = n_samples - self.sample_rate = sample_rate - self.randomize = randomize - self.prob = {"is_start":0.2, "is_end":0.9} - self.shift_secs = 5 - - def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: - if(duration<(float(self.n_samples)/self.sample_rate+1)): - raise ValueError(duration,float(self.n_samples),self.sample_rate) - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) - offset = 0 - is_start = True - is_end = True - else: - prob = random.uniform(0,1) - if(probself.prob['is_end']): - is_start = False - is_end = True - offset = int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate) - else: - is_start = False - is_end = False - offset = np.random.randint(self.shift_secs*cur_sample_rate, \ - int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)-self.shift_secs*cur_sample_rate) - t_start = offset / float(cur_sample_rate) / duration - t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration - chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - if(cur_sample_rate!=self.sample_rate): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) - # print('b:',self.sample_rate,chunk.shape) - if chunk.shape[-1] != self.n_samples: - raise ValueError(chunk.shape, self.n_samples, offset, int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - # if chunk.shape[-1] < self.n_samples: - # chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) - # else: - # chunk = chunk[:,0:self.n_samples] - seconds_start = math.floor(offset / cur_sample_rate) - seconds_total = math.floor(duration) - - # # In this dataset, we do not introduce zeros - # if(is_start): - # chunk = torch.cat([torch.zeros(1, self.shift_secs*self.sample_rate), chunk],1)[:,0:self.n_samples] - # elif(is_end): - # chunk = torch.cat([chunk, torch.zeros(1, self.shift_secs*self.sample_rate)],1)[:,self.shift_secs*self.sample_rate:] - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total, - is_start, - is_end, - ) - - -USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替 -if USE_DUMMY_AUDIO: - logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!") - -class SafeAudioReader: - """ - This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data. - """ - def __init__(self, - duration: float, # 返回音频长度 - sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample - randomize: bool = True - ): - self.n_samples = int(sample_rate * max(duration, 0)) - self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize) - - #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数! - def __call__(self, - filepath: os.PathLike, # 音频路径 - origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取 - origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取 - ) -> torch.Tensor: - if USE_DUMMY_AUDIO: - wav = torch.zeros(self.n_samples, dtype=torch.float32) - return wav - try: - # if origin_sample_rate is None or origin_duration is None: - # audio_info = torchaudio.info(filepath) - # origin_sample_rate = audio_info.sample_rate - # origin_duration = audio_info.num_frames / origin_sample_rate - audio_info = torchaudio.info(filepath) - origin_sample_rate = audio_info.sample_rate - origin_duration = audio_info.num_frames / origin_sample_rate - wav, *ignored, is_start, is_end = self.reader(filepath, origin_duration, origin_sample_rate) - except Exception as e: - logger.error(f"Error reading {filepath}: {e}") - raise FileNotFoundError(filepath) - return wav, is_start, is_end - - -class PromptTemplate: - def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): - self.template_text = template_text - self.tag_map = tag_map - self.lang = lang - - @property - def tags(self): - return tuple(self.tag_map.keys()) - - def apply(self, **kwargs): - for tag in list(kwargs.keys()): - if kwargs[tag] == '': - kwargs.pop(tag) - for tag in self.tags: - if tag in kwargs: - kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') - else: - kwargs[tag] = '' - prompt = self.template_text.format(**kwargs) - - return self.beautify(prompt) - - def beautify(self, text): - if self.lang == 'en': - return self._beautify_en(text) - elif self.lang == 'zh': - return self._beautify_zh(text) - else: - raise ValueError(f'Unknown language {self.lang}') - - @staticmethod - def _beautify_en(text): - # no continuous commas without content between them - text = re.sub(r'[,\s]*,[,\s]*', r', ', text) - # no continuous whitespace - text = re.sub(r'\s+', ' ', text) - # the comma is NOT followed by whitespace, and should be followed by ONE whitespace - text = re.sub(r'\s+,', r',', text) - text = re.sub(r',\s+', r', ', text) - # no whitespace before the full stop - text = re.sub(r'\s+\.', r'.', text) - # strip whitespace, comma, and replace ',.' - text = text.strip(' ,') - text = text.replace(',.', '.') - return text - - @staticmethod - def _beautify_zh(text): - # no continuous commas without content between them - text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) - text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) - # assume there should be NO whitespace in Chinese - text = re.sub(r'\s+', r'', text) - # strip whitespace, comma, and replace ',。' - text = text.strip(', 、') - text = text.replace(',。', '。') - return text - - def __repr__(self): - return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' - - __str__ = __repr__ - -def parse_prompt_template(prompt_template_text, lang='en'): - span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) - tag_pattern = re.compile(r'{.+?}', re.DOTALL) - - template_text = prompt_template_text.strip() - span_texts = span_pattern.findall(prompt_template_text) - tag_map = {} - for span_text in span_texts: - tag = tag_pattern.findall(span_text)[0].strip('{}') - tag_map[tag] = span_text - template_text = template_text.replace(span_text, '{'+tag+'}') - - return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) - -def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: - with open(path, 'r') as f: - lines = f.readlines() - cnt = 0 - pts = [] - for line in lines: - pt = parse_prompt_template(line, lang=lang) - cnt += 1 - if len(pt.tags) < num: - logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') - pts.append(pt) - - return pts - - -def get_base_dir_file(key: os.PathLike): - base = os.path.basename(key) - dirname = os.path.basename(os.path.dirname(key)) - return os.path.join(dirname, base) - -def read_jsonlike(path: os.PathLike): - #json or jsonl - if str(path).endswith(".json"): - with open(path, 'r', encoding='utf8') as f: - data = json.load(f) - return data - elif str(path).endswith(".jsonl"): - with open(path, 'r', encoding='utf8') as f: - data = [json.loads(line) for line in f.readlines()] - return data - else: - raise ValueError("Unknown file format") - -dist_prob_map = { - 1: (1.0,), - 2: (0.5, 0.5), - 3: (0.3, 0.4, 0.3), - 4: (0.2, 0.3, 0.3, 0.2), - 5: (0.2, 0.2, 0.3, 0.2, 0.1), - 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), - 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), - 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), - 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), - 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) -} - -dist_prob_map_low = { - 1: (1.0,), - 2: (0.8, 0.2), - 3: (0.8, 0.1, 0.1), - 4: (0.7, 0.1, 0.1, 0.1), - 5: (0.7, 0.1, 0.1, 0.05, 0.05), - 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), -} - -_bpm_range_rights = ( - (40, '20-40'), - (60, '40-60'), - (66, '60-66'), - (76, '66-76'), - (108, '76-108'), - (120, '108-120'), - (168, '120-168'), - (176, '168-176'), - (200, '176-200') -) -_bpm_desc_map = { - '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), - '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), - '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), - '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), - '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), - '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), - '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), - '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), - '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), - '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') -} -_bpm_desc_map_zh = { - '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), - '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), - '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), - '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), - '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), - '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), - '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), - '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), - '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), - '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') -} -def get_bpm_range(bpm): - bpm = int(bpm) - for right, tag in _bpm_range_rights: - if bpm <= right: - return tag - return '>200' - -def gen_bpm_descript(bpm, lang='en'): - bpm_range = get_bpm_range(bpm) - if lang == 'en': - return random.choice(_bpm_desc_map[bpm_range]) - elif lang == 'zh': - return random.choice(_bpm_desc_map_zh[bpm_range]) - else: - raise ValueError(f"Unknown language {lang}") - -def read_translate(translate: Optional[Dict[str, os.PathLike]]): - if translate is None: - return None - if isinstance(translate, str): - return read_jsonlike(translate) - return {k: read_jsonlike(path) for k, path in translate.items()} - - -class MagnaTagATuneDataset(Dataset): - def __init__(self): - pass - - -def tags_to_desc(tag_list, sep=',') -> str: - if not isinstance(tag_list, Sequence): - return str(tag_list) - if isinstance(tag_list, str): - return tag_list - if len(tag_list) <= 0: - return '' - elif len(tag_list) <= 5: - probs = dist_prob_map[len(tag_list)] - tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - else: - probs = dist_prob_map[5] - tags_num = random.choices(range(1, 6), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - -def get_sr_and_duration_info(item): - return item.get('sample_rate', None), item.get('duration', None) - -class MtgJamendoDatasetFromJson(Dataset): - def __init__(self, - data_dir:str, - json_path:str, - duration:float=10, - sr:int = 0, - *, - lang = 'en', - return_path = False, - prompt_template_path: os.PathLike = None, - tag_types = [], - translate:Optional[Dict[str, os.PathLike]] = None, - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self.data_dir = data_dir - self._load_metadata_json(json_path) - self.sr = sr - self.duration = duration - self.return_path = return_path - self.lang = lang - - self.use_dynamic_prompt = prompt_template_path is not None - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types)) - self.tag_types = tag_types - - self.translate = read_translate(translate) - if not self.use_dynamic_prompt and self.lang != 'en': - raise NotImplementedError - - #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示 - WEAK_TAG_LIST = ["title", "artist"] - - def _load_metadata_json(self, json_path): - with open(json_path) as fp: - self.data = json.load(fp) - - def convert_key_to_path(self, key): - return os.path.join(self.data_dir, get_base_dir_file(key)) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - path = self.convert_key_to_path(item['key']) - description = self.generate_description(item) - - sr, duration = get_sr_and_duration_info(item) - audio, is_start, is_end = self.audio_reader(path, sr, duration) - - if self.return_path: - return audio, description, path - return audio, description, is_start, is_end - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def generate_description(self, item): - if self.use_dynamic_prompt: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - - else: - # use ordinary static prompt instead - description = self.generate_description_ordinary(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag)) - exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag)) - - if len(exists_strong_tag) > 0: - probs = dist_prob_map[len(exists_strong_tag)] - tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0] - random.shuffle(exists_strong_tag) - tags = exists_strong_tag[:tags_num] - weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1] - weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0] - random.shuffle(exists_weak_tag) - weak_tags = exists_weak_tag[:weak_tags_num] - tags += weak_tags - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - prompt = prompt_template.apply(**tags_args) - else: - # no strong tags, use all weak tags instead - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag} - prompt = prompt_template.apply(**tags_args) - - return prompt - - def generate_description_ordinary(self, data, thresh = 0.3): - # Initialize the description with title and artist - description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}' - - # Add genre if available - if data["genre"] and random.random() > thresh: - genres = ', '.join(data["genre"]) - description += f', belonging to the {genres} genres' - - # Add moods if available - if data["moods"] and random.random() > thresh: - moods = ', '.join(data["moods"]) - description += f'. This track conveys a {moods} mood' - - # Add instruments if available - if data["instrument"] and random.random() > thresh: - instruments = ', '.join(data["instrument"]) - description += f', and primarily features the following instruments: {instruments}' - - # Add a period to end the description - description += '.' - - return description - -class AudioStockDataset(Dataset): - def __init__(self, - metadata_path:str, - duration:float=10, - sr:int = 0, - return_path = False, - return_audio = True, - prompt_template_path: os.PathLike = None, - tag_types = [], - lang = 'en', - translate:Optional[Dict[str, os.PathLike]] = None - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self.duration = duration - self._load_metadata(metadata_path) - self.sr = sr - self.return_path = return_path - self.return_audio = return_audio - - self.use_dynamic_prompt = prompt_template_path is not None - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) - self.tag_types = tag_types - - self.lang = lang - self.translate = read_translate(translate) - - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - if(item['duration']>self.duration+10): - self.data.append(item) - self.is_info_recorded = bool('Tags' in self.data[0]) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - path:str = self.data[idx]["path"] - json_path = path[:path.rfind('.')] + ".json" - if self.is_info_recorded: - item = self.data[idx] - else: - try: - with open(json_path) as fp: - item:dict = json.load(fp) - except Exception as e: - print(f"Error loading json file {json_path} :\n{e}") - item = {} - description = self.generate_description(item) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio, is_start, is_end = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path, is_start, is_end - else: - return audio, description, is_start, is_end - - def generate_description(self, item): - if self.use_dynamic_prompt: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - else: - # use ordinary static prompt instead - description = self.generate_description_ordinary(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - - if len(exists_tag) > 0: - probs = dist_prob_map[len(exists_tag)] - tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] - random.shuffle(exists_tag) - tags = exists_tag[:tags_num] - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - tags_args = self.handle_BPM_tag(tags_args) - prompt = prompt_template.apply(**tags_args) - else: - # no strong tags, use all weak tags instead - prompt = prompt_template.apply() - - return prompt - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - if tag_type == 'BPM': - return tags_to_desc(tag_list, sep='、') - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def handle_BPM_tag(self, tags_args): - if "BPM" in tags_args and 'BPMDescript' in self.tag_types: - bpm = tags_args["BPM"] - del tags_args["BPM"] - tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) - for tag_type in tag_types_used: - tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) - return tags_args - - def generate_description_ordinary(self, data, thresh = 0.3): - if self.lang != 'en': - raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') - description = f'a piece of music by {data["Artist"]}' - - # Add genre if available - if data["Genre"] and random.random() > thresh: - genres = ', '.join(data["Genre"]) - description += f', belonging to the {genres} genres' - - # Add moods if available - if data["Tags"] and random.random() > thresh: - tags = ', '.join(data["Tags"]) - description += f'. This track contains the tags:{tags}' - - # Add moods if available - if data["Mood"] and random.random() > thresh: - moods = ', '.join(data["Mood"]) - description += f'. This track conveys a {moods} mood.' - - # Add instruments if available - if data["Instrument"] and random.random() > thresh: - instruments = ', '.join(data["Instrument"]) - description += f'. and primarily features the following instruments: {instruments}' - - # Add a period to end the description - description += '.' - - return description - -def mp3_path_to_id(mp3_path): - return int( - mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')] - ) - -class TmeDataset(Dataset): - def __init__(self, - data_index:str, - music_info:str = None, - duration:float = 10, - sr:int = 0, - return_path = False, - return_audio = True, - prompt_format_path: os.PathLike = None, - tag_types = ['*'], - lang = 'zh', - translate: Optional[os.PathLike] = None, - prompt_dir: os.PathLike = None, - ): - self.audio_reader = SafeAudioReader(duration, sr) - - self.sr = sr - self.duration = duration - self.return_path = return_path - self.return_audio = return_audio - self.lang = lang - - self.use_ready_prompt = prompt_dir is not None - - data_index = read_jsonlike(data_index) - data_index = [d for d in data_index if d['duration']>self.duration+10] - self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index} - self.data_ids = list(self.data_index_dict.keys()) - - if not self.use_ready_prompt: - #读取音乐的信息文件 - music_info = read_jsonlike(music_info) - if 'music' in music_info: - music_info = music_info['music'] - self.music_info_dict = {d["歌曲ID"]:d for d in music_info} - self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict} - self.data_ids = list(self.data_index_dict.keys()) - - with open(prompt_format_path) as fp: - self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader) - - #加载tag types,并分成一般的tag_types和关键的key_tag_types - if '*' in tag_types: - self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag'] - else: - self.tag_types = tag_types - - self.key_tag_types = [] - if 'tag' in self.tag_types: - self.tag_types.remove('tag') - self.key_tag_types = list(self.prompt_formats['tag'].keys()) - - #加载translate翻译 - if translate is not None: - self.translator = read_jsonlike(translate) - else: - data_ids_set = set(self.data_ids) - self.prompts_dict = {} - for fname in os.listdir(prompt_dir): - items = read_jsonlike(os.path.join(prompt_dir, fname)) - for item in items: - if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']): - continue - if item['ID'] not in self.prompts_dict: - self.prompts_dict[item['ID']] = [] - self.prompts_dict[item['ID']].append(item['Text']) - self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict} - self.data_ids = list(self.data_index_dict.keys()) - - def tags_to_desc(self, tag_list) -> str: - if is_bearable(tag_list, int): - return str(tag_list) - if self.lang == 'zh': - return tags_to_desc(tag_list, sep=self.sep) - else: - translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ] - return tags_to_desc(translated_tag_list, sep=self.sep) - - def gen_desc_of_tag(self, formats, tags): - fmt = random.choice(formats) - return fmt.format(self.tags_to_desc(tags)) - - @staticmethod - def check_valid(value): - if isinstance(value, int) or isinstance(value, float): - return value > 0 - if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0): - return True - return False - - @staticmethod - def remove_repeat(data): - #若专辑名和歌曲名相同,则只使用后者 - album_name = data.get('专辑名', None) - if album_name is not None and album_name == data.get('歌曲名', None): - del data['专辑名'] - return data - - @property - def comma(self): - if self.lang == 'zh': - return ',' - elif self.lang == 'en': - return ', ' - - @property - def sep(self): - if self.lang == 'zh': - return '、' - elif self.lang == 'en': - return ', ' - - def generate_description(self, data): - data = self.remove_repeat(data) - weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低 - - key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个 - - prompts = [] - if len(weak_tags) > 0: - probs = dist_prob_map_low[len(weak_tags)] - if len(key_tags) > 0: - tags_num = random.choices(range(0, len(weak_tags)), probs)[0] - else: - tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0] - random.shuffle(weak_tags) - tags = weak_tags[:tags_num] - for tag_type in tags: - tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type]) - prompts.append(tag_desc) - - if len(key_tags) > 0: - probs = dist_prob_map[len(key_tags)] - tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0] - random.shuffle(key_tags) - tags = key_tags[:tags_num] - for tag_type in tags: - tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type]) - prompts.append(tag_desc) - - random.shuffle(prompts) - return self.comma.join(prompts) - - def is_valid_prompt_text(self, text): - for bad in ('抱歉','sorry', 'Sorry'): - if bad in text: - return False - return True - - def get_ready_prompt(self, path): - sid = mp3_path_to_id(path) - return random.choice(self.prompts_dict[sid]) - - def __len__(self): - return len(self.data_ids) - - def __getitem__(self, idx): - data_id = self.data_ids[idx] - item = self.data_index_dict[data_id] - path = item['path'] - if not self.use_ready_prompt: - info = self.music_info_dict[data_id] - description = self.generate_description(info) - else: - description = self.get_ready_prompt(path) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio, is_start, is_end = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path, is_start, is_end - else: - return audio, description, is_start, is_end - -class Pond5Dataset(Dataset): - MAX_PROMPT_LEN = 200 - def __init__(self, - metadata_path:str, - index_path:str, - duration:float=10, - sr:int = 0, - plain_rate = 0, - return_path = False, - return_audio = True, - lang = 'en', - translate:Optional[Dict[str, os.PathLike]] = None, - use_literal_none = True, - use_avoid_watermark_policy = None, - ): - - if use_avoid_watermark_policy is None: - raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type") - self.use_avoid_watermark_policy = use_avoid_watermark_policy - assert self.use_avoid_watermark_policy is False - self.audio_reader = SafeAudioReader(duration, sr) - - self.duration = duration - self._load_metadata(metadata_path, index_path) - self.sr = sr - self.plain_rate = plain_rate - self.return_path = return_path - self.return_audio = return_audio - self.use_literal_none = use_literal_none - - self.lang = lang - self.translate = read_translate(translate) - - def _load_metadata(self, metadata_path, index_path): - data_index = read_jsonlike(index_path) - data_ids = set([item['id'] for item in data_index]) - - with open(metadata_path) as fp: - lines = fp.readlines() - - append_ids = set() - - self.data = [] - for line in lines: - item = json.loads(line) - if item['id'] in data_ids and item['id'] not in append_ids and item["details"]["duration"] is not None and item["details"]["duration"]>self.duration+10: - self.data.append(item) - append_ids.add(item['id']) - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - item = self.data[idx] - path:str = item["path"] - description = self.generate_description(item) - if self.return_audio: - sr, duration = get_sr_and_duration_info(item) - audio, is_start, is_end = self.audio_reader(path, sr, duration) - else: - audio = None - if self.return_path: - return audio, description, path - return audio, description, is_start, is_end - - @property - def keysep(self): - if self.lang == 'zh': - return ',' if random.random() > 0.5 else '、' - elif self.lang == 'en': - return ', ' - - def generate_description(self, item): - if random.random() > self.plain_rate: - # dynamically generate prompt from given prompt template - description = self.generate_description_dynamic(item) - else: - # use plain prompt, i.e. tags sequence separated by comma - description = self.generate_description_plain(item) - return description - - def get_translation(self, k): - k = k.strip() - if k in self.translate: - return self.translate[k] - else: - return k - - def generate_description_plain(self, item): - keywords = item['keywords'] - if self.lang != 'en': - keywords = [self.get_translation(k) for k in keywords] - return gen_plain_prompt(keywords, sep=self.keysep) - - def generate_description_dynamic(self,item): - desc = item.get('desc', 'none') - if desc is None: - desc = 'none' - desc = desc.strip() - if len(desc) > self.MAX_PROMPT_LEN: - shorter_desc = desc[:self.MAX_PROMPT_LEN] - # find last stop - stop_idx = shorter_desc.rfind('.') - if stop_idx == -1: - stop_idx = shorter_desc.rfind('!') - if stop_idx == -1: - stop_idx = shorter_desc.rfind(',') - if stop_idx == -1: - stop_idx = self.MAX_PROMPT_LEN - 1 - desc = desc[:stop_idx+1] - return desc - -class CombinedDataset(Dataset): - @beartype - def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]): - self.datasets = datasets - self.datasets_index = [] - - for i,dataset in enumerate(datasets): - if dataset is None: - continue - for dup in range(ratios[i]): - for j in range(len(dataset)): - self.datasets_index.append((i,j)) - - def __len__(self): - return len(self.datasets_index) - - def __getitem__(self, idx): - index = self.datasets_index[idx] - i,j = index - return self.datasets[i][j] - -class CombinedDataset_random(Dataset): - @beartype - def __init__(self, - num_examples:int, - datasets: Sequence[Dataset], ratios: Sequence[int] - ): - self.datasets = datasets - self.datasets_index = [] - - for i,dataset in enumerate(datasets): - if dataset is None: - continue - for dup in range(ratios[i]): - for j in range(len(dataset)): - self.datasets_index.append((i,j)) - if num_examples > 0: - self.random_choose = True - self.dataset_len = num_examples - else: - self.random_choose = False - self.dataset_len = len(self.datasets_index) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, idx): - first_try = True - try_cnt = 0 - while True: - try: - if(self.random_choose or not first_try): - index2 = [] - index2.append(np.random.randint(0,len(self.datasets))) - index2.append(np.random.randint(0,len(self.datasets[index2[-1]]))) - else: - index2 = self.datasets_index[idx] - first_try = False - out = self.datasets[index2[0]][index2[1]] - if(len(out[0].shape)==1):out[0]=out[0][None,:] - return out - except: - print("Error loadding ", index2) - try_cnt += 1 - if(try_cnt>10): - raise FileNotFoundError() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py deleted file mode 100644 index 46d619c298718d4869dcdb54a420a1d080fac217..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py +++ /dev/null @@ -1,313 +0,0 @@ -import re -import sys -import json - -from torch.utils.data import Dataset -import torchaudio -from torchaudio.functional import resample -import torch -import numpy as np - -from torch.nn.utils.rnn import pad_sequence - - - -def check_lryics(lyric): - _FILTER_STRING = [ - '作词', '作曲', '编曲', '【', '策划', - '录音', '混音', '母带', ':', '制作', - '版权', '校对', '演奏', '制作', '伴奏' - ] - for item in _FILTER_STRING: - if item in lyric: - return True - - return False - - - -def process_lyrics(lines): - lyric_part = [] - timestamp_part = [] - - timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') - - for i, line in enumerate(lines): - - # 删除前几行的特定信息 - if i<10 and check_lryics(line): - continue - - # 检查是否包含有效的时间戳和歌词内容 - if timestamp_pattern.match(line): - timestamp_end = line.rfind(']') - lyrics = line[timestamp_end + 1:].strip() - timestamps = line[:timestamp_end + 1] - - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - # if lyrics: # 确保歌词部分不是空的 - # lyric_part.append(lyrics) - # timestamp_part.append(timestamps) - # print(processed_lyrics) - return timestamp_part, lyric_part - -def get_timestamps(timestamp_part): - - # 转换为秒 - - timestamps = [] - - for line in timestamp_part: - match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) - if match: - minutes = int(match.group(1)) - seconds = float(match.group(2)) - millis = float(match.group(3)) if match.group(3) else 0 - total_seconds = minutes * 60 + seconds + millis - timestamps.append(total_seconds) - - - return timestamps - -def process_lyrics_lrc(lyrics): - timestamp_part, lyric_part = process_lyrics(lyrics) - # print(timestamp_part) - # print(lyric_part) - timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - - return output_list - - - -def process_lyrics_yrc(lyrics): - - timestamps, lyric_part = extract_lrc(lyrics) - - # timestamp_part, lyric_part = process_lyrics(lyrics) - # import pdb; pdb.set_trace() - # print(timestamp_part) - # print(lyric_part) - # timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - # import pdb; pdb.set_trace() - return output_list - -def extract_lrc(lyrics): - timestamp_part, lyric_part = [], [] - - for i, text in enumerate(lyrics): - # 提取中括号内的内容 - bracket_content = re.search(r'\[(.*?)\]', text).group(1) - bracket_content = bracket_content.split(',') - # 提取小括号内的内容 - parentheses_content = re.findall(r'\((.*?)\)', text) - # 提取其他内容 - other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() - - # 数据怎么处理? - # import pdb; pdb.set_trace() - if i<10 and check_lryics(other_content): - continue - - # import pdb; pdb.set_trace() - timestamp_part.append(float(bracket_content[0])/1000) - lyric_part.append(other_content) - # import pdb; pdb.set_trace() - return timestamp_part, lyric_part - - - -class WYYSongDataset(Dataset): - def __init__(self, - metadata_path:str, - sr:int = 0, - use_lang = ['en', 'zh-cn'], - num_examples = -1, - ): - - self.sr = sr - self.use_lang = use_lang - self._load_metadata(metadata_path) - - # buffer - self.lyric_buffer = {} - - if(num_examples<=0): - self.dataset_len = len(self.data) - self.random_slc = False - else: - self.dataset_len = num_examples - self.random_slc = True - - # 读取jsonl文件 - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None: - if 'lyrics' in item and 'lang_info' in item: - if len(item['lyrics']) > 0: - for lang in self.use_lang: - if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9: - # if '伴奏' not in item['path'] and "cloud" in item['path']: - if '伴奏' not in item['path']: - self.data.append(item) - - - def __len__(self): - return self.dataset_len - - - def __getitem__(self, idx): - try_cnt = 0 - while True: - if(self.random_slc): - idx = np.random.randint(0, len(self.data)) - yrc_lyrics = [] - lrc_lyrics = [] - try: - info = self.data[idx] - - # audio path - path:str = info["path"] - - # 读取歌词段落 - if 'lyrics' not in info: - if idx not in self.lyric_buffer: - # 字级别align的歌词 - if info['yrc-lyric'] is not None: - with open(info['yrc-lyric']) as f_in: - yrc_lyric = json.load(f_in) - yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1]) - - # 句子级align的歌词 - if info['lrc-lyric'] is not None: - with open(info['lrc-lyric']) as f_in: - lrc_lyric = json.load(f_in) - lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1]) - - # 优先使用字级别align的歌词 - if len(yrc_lyrics) > 0: - lyrics = yrc_lyrics - else: - lyrics = lrc_lyrics - self.lyric_buffer[idx] = lyrics - - # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲 - else: - lyrics = self.lyric_buffer[idx] - else: - lyrics = info['lyrics'] - - # 随机选取一个lyric段落 - ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item() - # ly_id = 0 - - lyric = lyrics[ly_id] - - - - st, et, lyric = self.parse_lyric(lyric) - - assert et - st < 40 - - # 文本过滤 - - lyric = re.sub(r'【.*?】', '', lyric) - if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8: - assert 200 > len(lyric.replace(" ", "")) > 30 - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - - if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8: - assert 200 > len(lyric.split()) > 20 - - if ':' in lyrics: - if len(lyrics.split(":")[0].split()) <=3: - lyrics = "".join(lyrics.split(":")[1:]) - - if ':' in lyrics: - if len(lyrics.split(":")[0].split()) <=3: - lyrics = "".join(lyrics.split(":")[1:]) - - - - # 读取音频文件 - cur_sample_rate = torchaudio.info(path).sample_rate - offset = int(cur_sample_rate*st) - num_frames = int(cur_sample_rate * (et -st)) - chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) - - # 随机选取一个channel - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - - if(cur_sample_rate!=self.sr): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) - - return chunk, lyric, [st, et], path - except: - print("Error loadding ", info["path"]) - try_cnt += 1 - idx = np.random.randint(0, len(self.data)) - if(try_cnt>10): - raise FileNotFoundError() - - def parse_lyric(self, lyric): - pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' - match = re.search(pattern, lyric) - - start_time = float(match.group(1)) - end_time = float(match.group(2)) - content = match.group(3) - return start_time, end_time, content - -def collect_song(data_list): - audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) - lyrics = [data[1] for data in data_list] - st_et = [data[2] for data in data_list] - paths = [data[3] for data in data_list] - return audios, lyrics, st_et diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py deleted file mode 100644 index 991c59786412dc7f2bd22c57c7e4a7e3d30e5776..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py +++ /dev/null @@ -1,313 +0,0 @@ -import re -import sys -import json - -from torch.utils.data import Dataset -import torchaudio -from torchaudio.functional import resample -import torch -import numpy as np - -from torch.nn.utils.rnn import pad_sequence - - - -def check_lryics(lyric): - _FILTER_STRING = [ - '作词', '作曲', '编曲', '【', '策划', - '录音', '混音', '母带', ':', '制作', - '版权', '校对', '演奏', '制作', '伴奏' - ] - for item in _FILTER_STRING: - if item in lyric: - return True - - return False - - - -def process_lyrics(lines): - lyric_part = [] - timestamp_part = [] - - timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') - - for i, line in enumerate(lines): - - # 删除前几行的特定信息 - if i<10 and check_lryics(line): - continue - - # 检查是否包含有效的时间戳和歌词内容 - if timestamp_pattern.match(line): - timestamp_end = line.rfind(']') - lyrics = line[timestamp_end + 1:].strip() - timestamps = line[:timestamp_end + 1] - - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - # if lyrics: # 确保歌词部分不是空的 - # lyric_part.append(lyrics) - # timestamp_part.append(timestamps) - # print(processed_lyrics) - return timestamp_part, lyric_part - -def get_timestamps(timestamp_part): - - # 转换为秒 - - timestamps = [] - - for line in timestamp_part: - match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) - if match: - minutes = int(match.group(1)) - seconds = float(match.group(2)) - millis = float(match.group(3)) if match.group(3) else 0 - total_seconds = minutes * 60 + seconds + millis - timestamps.append(total_seconds) - - - return timestamps - -def process_lyrics_lrc(lyrics): - timestamp_part, lyric_part = process_lyrics(lyrics) - # print(timestamp_part) - # print(lyric_part) - timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - - return output_list - - - -def process_lyrics_yrc(lyrics): - - timestamps, lyric_part = extract_lrc(lyrics) - - # timestamp_part, lyric_part = process_lyrics(lyrics) - # import pdb; pdb.set_trace() - # print(timestamp_part) - # print(lyric_part) - # timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - # import pdb; pdb.set_trace() - return output_list - -def extract_lrc(lyrics): - timestamp_part, lyric_part = [], [] - - for i, text in enumerate(lyrics): - # 提取中括号内的内容 - bracket_content = re.search(r'\[(.*?)\]', text).group(1) - bracket_content = bracket_content.split(',') - # 提取小括号内的内容 - parentheses_content = re.findall(r'\((.*?)\)', text) - # 提取其他内容 - other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() - - # 数据怎么处理? - # import pdb; pdb.set_trace() - if i<10 and check_lryics(other_content): - continue - - # import pdb; pdb.set_trace() - timestamp_part.append(float(bracket_content[0])/1000) - lyric_part.append(other_content) - # import pdb; pdb.set_trace() - return timestamp_part, lyric_part - - - -class WYYSongDataset(Dataset): - def __init__(self, - metadata_path:str, - sr:int = 0, - use_lang = ['en', 'zh-cn'], - num_examples = -1, - ): - - self.sr = sr - self.use_lang = use_lang - self._load_metadata(metadata_path) - - # buffer - self.lyric_buffer = {} - - if(num_examples<=0): - self.dataset_len = len(self.data) - self.random_slc = False - else: - self.dataset_len = num_examples - self.random_slc = True - - # 读取jsonl文件 - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None: - if 'lyrics' in item and 'lang_info' in item: - if len(item['lyrics']) > 0: - for lang in self.use_lang: - if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9: - # if '伴奏' not in item['path'] and "cloud" in item['path']: - if '伴奏' not in item['path']: - self.data.append(item) - - - def __len__(self): - return self.dataset_len - - - def __getitem__(self, idx): - try_cnt = 0 - while True: - if(self.random_slc): - idx = np.random.randint(0, len(self.data)) - yrc_lyrics = [] - lrc_lyrics = [] - try: - info = self.data[idx] - - # audio path - path:str = info["path"] - - # 读取歌词段落 - if 'lyrics' not in info: - if idx not in self.lyric_buffer: - # 字级别align的歌词 - if info['yrc-lyric'] is not None: - with open(info['yrc-lyric']) as f_in: - yrc_lyric = json.load(f_in) - yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1]) - - # 句子级align的歌词 - if info['lrc-lyric'] is not None: - with open(info['lrc-lyric']) as f_in: - lrc_lyric = json.load(f_in) - lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1]) - - # 优先使用字级别align的歌词 - if len(yrc_lyrics) > 0: - lyrics = yrc_lyrics - else: - lyrics = lrc_lyrics - self.lyric_buffer[idx] = lyrics - - # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲 - else: - lyrics = self.lyric_buffer[idx] - else: - lyrics = info['lyrics'] - - # 随机选取一个lyric段落 - ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item() - # ly_id = 0 - - lyric = lyrics[ly_id] - - - - st, et, lyric = self.parse_lyric(lyric) - - assert et - st < 20 - - # 文本过滤 - - lyric = re.sub(r'【.*?】', '', lyric) - if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8: - assert 100 > len(lyric.replace(" ", "")) > 5 - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - - if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8: - assert 100 > len(lyric.split()) > 5 - - if ':' in lyrics: - if len(lyrics.split(":")[0].split()) <=3: - lyrics = "".join(lyrics.split(":")[1:]) - - if ':' in lyrics: - if len(lyrics.split(":")[0].split()) <=3: - lyrics = "".join(lyrics.split(":")[1:]) - - - - # 读取音频文件 - cur_sample_rate = torchaudio.info(path).sample_rate - offset = int(cur_sample_rate*st) - num_frames = int(cur_sample_rate * (et -st)) - chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) - - # 随机选取一个channel - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - - if(cur_sample_rate!=self.sr): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) - - return chunk, lyric, [st, et], path - except: - print("Error loadding ", info["path"]) - try_cnt += 1 - idx = np.random.randint(0, len(self.data)) - if(try_cnt>10): - raise FileNotFoundError() - - def parse_lyric(self, lyric): - pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' - match = re.search(pattern, lyric) - - start_time = float(match.group(1)) - end_time = float(match.group(2)) - content = match.group(3) - return start_time, end_time, content - -def collect_song(data_list): - audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) - lyrics = [data[1] for data in data_list] - st_et = [data[2] for data in data_list] - paths = [data[3] for data in data_list] - return audios, lyrics, st_et diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py deleted file mode 100644 index ab395273c6270912f5d84df71c70386f5eeab71b..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py +++ /dev/null @@ -1,313 +0,0 @@ -import re -import sys -import json - -from torch.utils.data import Dataset -import torchaudio -from torchaudio.functional import resample -import torch -import numpy as np - -from torch.nn.utils.rnn import pad_sequence - - - -def check_lryics(lyric): - _FILTER_STRING = [ - '作词', '作曲', '编曲', '【', '策划', - '录音', '混音', '母带', ':', '制作', - '版权', '校对', '演奏', '制作', '伴奏' - ] - for item in _FILTER_STRING: - if item in lyric: - return True - - return False - - - -def process_lyrics(lines): - lyric_part = [] - timestamp_part = [] - - timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') - - for i, line in enumerate(lines): - - # 删除前几行的特定信息 - if i<10 and check_lryics(line): - continue - - # 检查是否包含有效的时间戳和歌词内容 - if timestamp_pattern.match(line): - timestamp_end = line.rfind(']') - lyrics = line[timestamp_end + 1:].strip() - timestamps = line[:timestamp_end + 1] - - if ':' in lyrics: - if len(lyrics.split(":")[0]) <=5: - lyrics = "".join(lyrics.split(":")[1:]) - # if lyrics: # 确保歌词部分不是空的 - # lyric_part.append(lyrics) - # timestamp_part.append(timestamps) - # print(processed_lyrics) - return timestamp_part, lyric_part - -def get_timestamps(timestamp_part): - - # 转换为秒 - - timestamps = [] - - for line in timestamp_part: - match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) - if match: - minutes = int(match.group(1)) - seconds = float(match.group(2)) - millis = float(match.group(3)) if match.group(3) else 0 - total_seconds = minutes * 60 + seconds + millis - timestamps.append(total_seconds) - - - return timestamps - -def process_lyrics_lrc(lyrics): - timestamp_part, lyric_part = process_lyrics(lyrics) - # print(timestamp_part) - # print(lyric_part) - timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - - return output_list - - - -def process_lyrics_yrc(lyrics): - - timestamps, lyric_part = extract_lrc(lyrics) - - # timestamp_part, lyric_part = process_lyrics(lyrics) - # import pdb; pdb.set_trace() - # print(timestamp_part) - # print(lyric_part) - # timestamps = get_timestamps(timestamp_part) - # print(timestamps) - if len(timestamps) == 0: - # print(f'{lyric_path}') - return [] - - slice_start = timestamps[0] - slice_start_idx = 0 - - output_list = [] - for i in range(1, len(timestamps)): - # 如果累积时间超过30秒,则进行切分 - if timestamps[i] - slice_start > 30: - output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) - - slice_start = timestamps[i] - slice_start_idx = i - # import pdb; pdb.set_trace() - return output_list - -def extract_lrc(lyrics): - timestamp_part, lyric_part = [], [] - - for i, text in enumerate(lyrics): - # 提取中括号内的内容 - bracket_content = re.search(r'\[(.*?)\]', text).group(1) - bracket_content = bracket_content.split(',') - # 提取小括号内的内容 - parentheses_content = re.findall(r'\((.*?)\)', text) - # 提取其他内容 - other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() - - # 数据怎么处理? - if i<10 and check_lryics(other_content): - continue - timestamp_part.append(float(bracket_content[0])/1000) - lyric_part.append(other_content) - return timestamp_part, lyric_part - - - -class WYYSongDataset(Dataset): - def __init__(self, - metadata_path:str, - sr:int = 0, - use_lang = ['en', 'zh-cn'], - num_examples = -1, - max_dur = 20, - pad_to_max= True, - ): - - self.sr = sr - self.use_lang = use_lang - self._load_metadata(metadata_path) - self.max_dur = max_dur - self.pad_to_max = pad_to_max - - # buffer - self.lyric_buffer = {} - - if(num_examples<=0): - self.dataset_len = len(self.data) - self.random_slc = False - else: - self.dataset_len = num_examples - self.random_slc = True - - # 读取jsonl文件 - def _load_metadata(self, metadata_path): - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - if '伴奏' not in item['path']: - # if "lang_type" in item and item['lang_type'] == 'en': - if "lang_type" in item: - self.data.append(item) - - - def __len__(self): - return self.dataset_len - - - def __getitem__(self, idx): - try_cnt = 0 - while True: - if(self.random_slc): - idx = np.random.randint(0, len(self.data)) - yrc_lyrics = [] - lrc_lyrics = [] - try: - info = self.data[idx] - - # audio path - path = info["path"] - lang_type = info["lang_type"] - if info["lang_type"] == 'en': - lyrics = info['lyrics'] - else: - lyrics = info['lyrics_phone'] - - # 随机选取一个lyric段落 - ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item() - lyric = lyrics[ly_id].strip() - - st, et, lyric = self.parse_lyric(lyric) - lyric = lyric.replace("\xa0", " ") - - lyric = " ".join(lyric.split()) - - assert et - st < self.max_dur - - - if info["lang_type"] == 'en': - # print(len(lyric.split())/(et-st)) - assert 6 > len(lyric.split())/(et-st) > 1 - else: - # print(len(lyric.split())/(et-st)) - lyric = lyric.replace("-", "") - assert 6 > len(lyric.split())/(et-st) > 1 - - - # 读取音频文件 - cur_sample_rate = torchaudio.info(path).sample_rate - offset = int(cur_sample_rate*st) - num_frames = int(cur_sample_rate * (et -st)) - chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) - # chunk = torch.zeros(1, 48000*15) - - # 随机选取一个channel - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - - if(cur_sample_rate!=self.sr): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) - - if self.pad_to_max: - chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0) - - return chunk, lyric, et-st, path, lang_type - except: - # print("Error loadding ", info["path"]) - try_cnt += 1 - idx = np.random.randint(0, len(self.data)) - if(try_cnt>20): - raise FileNotFoundError() - - def parse_lyric(self, lyric): - pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' - match = re.search(pattern, lyric) - - start_time = float(match.group(1)) - end_time = float(match.group(2)) - content = match.group(3) - return start_time, end_time, content - - def pad_2d_tensor(self, x, max_len, pad_id): - # 获取输入 tensor 的形状 - batch_size, seq_len = x.size() - max_len = max(max_len, seq_len) - # 计算需要填充的长度 - pad_len = max_len - seq_len - - # 如果需要填充 - if pad_len > 0: - # 创建填充 tensor - pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) - - # 沿第二个维度(列)连接输入 tensor 和填充 tensor - padded_tensor = torch.cat([x, pad_tensor], dim=1) - else: - # 如果不需要填充,直接返回输入 tensor - padded_tensor = x - - return padded_tensor - -def collect_data(data_list): - audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) - lyrics = [data[1] for data in data_list] - st_et = [data[2] for data in data_list] - paths = [data[3] for data in data_list] - lang_types = [data[4] for data in data_list] - return audios, lyrics, st_et, lang_types - # return audios, lyrics, st_et - - -def build_dataset(): - train_dataset = WYYSongDataset( - metadata_path = "train.jsonl", - sr = 48000, - use_lang = ['zh-cn', 'en'], - num_examples = 10*10000 - ) - - valid_dataset = WYYSongDataset( - metadata_path = "valid.jsonl", - sr = 48000, - use_lang = ['zh-cn', 'en'], - num_examples = 500 - ) - - return train_dataset, valid_dataset diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py deleted file mode 100644 index 693efb42b76cf4c15ed2d045997e92695af30b3a..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py +++ /dev/null @@ -1,461 +0,0 @@ -from torch.utils.data import Dataset -from beartype.typing import Sequence, Callable, Optional, Dict, List -from beartype.door import is_bearable -import random -import os -from torchaudio.functional import resample -import torch -import typing as tp -from pathlib import Path -import torchaudio as ta -import torch.nn.functional as F -import soundfile -import numpy as np -import json -import yaml -import random -import librosa -from loguru import logger -import re - - -def _av_read(filepath, seek_time=0, duration=None): - if duration is not None: - sr = librosa.get_samplerate(filepath) - offset = seek_time - num_samples = int(duration * sr) - wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration) - else: - wav, sr = librosa.load(filepath, sr=None, offset=seek_time) - - return wav, sr - -def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., - duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]: - """Read audio by picking the most appropriate backend tool based on the audio format. - - Args: - filepath (str or Path): Path to audio file to read. - seek_time (float): Time at which to start reading in the file. - duration (float): Duration to read from the file. If set to -1, the whole file is read. - pad (bool): Pad output audio if not reaching expected duration. - Returns: - tuple of torch.Tensor, int: Tuple containing audio data and sample rate. - """ - fp = Path(filepath) - if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg - # There is some bug with ffmpeg and reading flac - info = soundfile.info(filepath) - frames = -1 if duration <= 0 else int(duration * info.samplerate) - frame_offset = int(seek_time * info.samplerate) - wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) - assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}" - wav = torch.from_numpy(wav).t().contiguous() - if len(wav.shape) == 1: - wav = torch.unsqueeze(wav, 0) - elif ( - fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() - and duration <= 0 and seek_time == 0 - ): - # Torchaudio is faster if we load an entire file at once. - wav, sr = librosa.load(fp, sr=None, mono=True) - else: - wav, sr = _av_read(filepath, seek_time, duration) - if pad and duration > 0: - expected_frames = int(duration * sr) - wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1])) - if not isinstance(wav, torch.Tensor): - wav = torch.tensor(wav) - return wav, sr - -def random_seek_read(filepath, duration): - if duration > 0: - total_duration = librosa.get_duration(path=filepath) - acceptable_start = max(0, total_duration - duration) - wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True) - else: - wav, sr = audio_read(filepath, 0, -1, pad=False) - return wav, sr - -def safe_random_seek_read(filepath, duration, sample_rate): - try: - wav, sr = random_seek_read(filepath, duration) - if sr != sample_rate: - wav = resample(wav, sr, sample_rate) - sr = sample_rate - except Exception as e: - logger.error(f"Error reading {filepath}: {e}") - sr = sample_rate - wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32) - return wav, sr - -def read_jsonlike(path: os.PathLike): - #json or jsonl - if str(path).endswith(".json"): - with open(path, 'r', encoding='utf8') as f: - data = json.load(f) - return data - elif str(path).endswith(".jsonl"): - with open(path, 'r', encoding='utf8') as f: - data = [json.loads(line) for line in f.readlines()] - return data - else: - raise ValueError("Unknown file format") - -dist_prob_map = { - 1: (1.0,), - 2: (0.5, 0.5), - 3: (0.3, 0.4, 0.3), - 4: (0.2, 0.3, 0.3, 0.2), - 5: (0.2, 0.2, 0.3, 0.2, 0.1), - 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), - 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), - 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), - 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), - 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) -} - -dist_prob_map_low = { - 1: (1.0,), - 2: (0.8, 0.2), - 3: (0.8, 0.1, 0.1), - 4: (0.7, 0.1, 0.1, 0.1), - 5: (0.7, 0.1, 0.1, 0.05, 0.05), - 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), -} - - -_bpm_range_rights = ( - (40, '20-40'), - (60, '40-60'), - (66, '60-66'), - (76, '66-76'), - (108, '76-108'), - (120, '108-120'), - (168, '120-168'), - (176, '168-176'), - (200, '176-200') -) -_bpm_desc_map = { - '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), - '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), - '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), - '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), - '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), - '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), - '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), - '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), - '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), - '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') -} -_bpm_desc_map_zh = { - '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), - '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), - '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), - '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), - '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), - '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), - '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), - '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), - '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), - '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') -} -def get_bpm_range(bpm): - bpm = int(bpm) - for right, tag in _bpm_range_rights: - if bpm <= right: - return tag - return '>200' - -def gen_bpm_descript(bpm, lang='en'): - bpm_range = get_bpm_range(bpm) - if lang == 'en': - return random.choice(_bpm_desc_map[bpm_range]) - elif lang == 'zh': - return random.choice(_bpm_desc_map_zh[bpm_range]) - else: - raise ValueError(f"Unknown language {lang}") - -def read_translate(translate: Optional[Dict[str, os.PathLike]]): - if translate is None: - return None - return {k: read_jsonlike(path) for k, path in translate.items()} - - -def tags_to_desc(tag_list, sep=',') -> str: - if not isinstance(tag_list, Sequence): - return str(tag_list) - if isinstance(tag_list, str): - return tag_list - if len(tag_list) <= 0: - return '' - elif len(tag_list) <= 5: - probs = dist_prob_map[len(tag_list)] - tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - else: - probs = dist_prob_map[5] - tags_num = random.choices(range(1, 6), probs)[0] - random.shuffle(tag_list) - tag_list = tag_list[:tags_num] - return sep.join(tag_list) - - -class PromptTemplate: - def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): - self.template_text = template_text - self.tag_map = tag_map - self.lang = lang - - @property - def tags(self): - return tuple(self.tag_map.keys()) - - def apply(self, **kwargs): - for tag in list(kwargs.keys()): - if kwargs[tag] == '': - kwargs.pop(tag) - for tag in self.tags: - if tag in kwargs: - kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') - else: - kwargs[tag] = '' - prompt = self.template_text.format(**kwargs) - - return self.beautify(prompt) - - def beautify(self, text): - if self.lang == 'en': - return self._beautify_en(text) - elif self.lang == 'zh': - return self._beautify_zh(text) - else: - raise ValueError(f'Unknown language {self.lang}') - - @staticmethod - def _beautify_en(text): - # no continuous commas without content between them - text = re.sub(r'[,\s]*,[,\s]*', r', ', text) - # no continuous whitespace - text = re.sub(r'\s+', ' ', text) - # the comma is NOT followed by whitespace, and should be followed by ONE whitespace - text = re.sub(r'\s+,', r',', text) - text = re.sub(r',\s+', r', ', text) - # no whitespace before the full stop - text = re.sub(r'\s+\.', r'.', text) - # strip whitespace, comma, and replace ',.' - text = text.strip(' ,') - text = text.replace(',.', '.') - return text - - @staticmethod - def _beautify_zh(text): - # no continuous commas without content between them - text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) - text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) - # assume there should be NO whitespace in Chinese - text = re.sub(r'\s+', r'', text) - # strip whitespace, comma, and replace ',。' - text = text.strip(', 、') - text = text.replace(',。', '。') - return text - - def __repr__(self): - return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' - - __str__ = __repr__ - -def parse_prompt_template(prompt_template_text, lang='en'): - span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) - tag_pattern = re.compile(r'{.+?}', re.DOTALL) - - template_text = prompt_template_text.strip() - span_texts = span_pattern.findall(prompt_template_text) - tag_map = {} - for span_text in span_texts: - tag = tag_pattern.findall(span_text)[0].strip('{}') - tag_map[tag] = span_text - template_text = template_text.replace(span_text, '{'+tag+'}') - - return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) - -def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: - with open(path, 'r') as f: - lines = f.readlines() - cnt = 0 - pts = [] - for line in lines: - pt = parse_prompt_template(line, lang=lang) - cnt += 1 - if len(pt.tags) < num: - logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') - pts.append(pt) - - return pts - - -class AudioStockDataset(Dataset): - def __init__(self, - num_examples:int, - metadata_path:str, - duration:float=60, - sr:int = 0, - return_path = False, - return_audio = True, - prompt_template_path: os.PathLike = None, - tag_types = [], - lang = 'en', - translate:Optional[Dict[str, os.PathLike]] = None - ): - self.duration = duration - self.MAX_DURATION = 360 - self._load_metadata(metadata_path) - if num_examples > 0: - self.random_choose = True - self.dataset_len = num_examples - else: - self.random_choose = False - self.dataset_len = len(self.data) - self.sr = sr - self.return_path = return_path - self.return_audio = return_audio - - self.use_dynamic_prompt = prompt_template_path is not None - if self.use_dynamic_prompt: - self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) - self.tag_types = tag_types - - self.lang = lang - self.translate = read_translate(translate) - - def _load_metadata(self, metadata_path): - total_len = 0; valid_len = 0 - with open(metadata_path) as fp: - lines = fp.readlines() - self.data = [] - for line in lines: - item = json.loads(line) - total_len += 1 - if(item['duration']>self.duration and item['duration']10): - raise ValueError() - - def getitem_main(self, idx): - path:str = self.data[idx]["path"] - json_path = path[:path.rfind('.')] + ".json" - if self.is_info_recorded: - item = self.data[idx] - else: - with open(json_path) as fp: - item:dict = json.load(fp) - description = self.generate_description(item) - if self.return_audio: - audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr) - else: - audio = None - if self.return_path: - return audio, description, path - return audio, description - - - - def generate_description(self, item): - if self.use_dynamic_prompt: - # dynamically generate prompt from given prompt template - prompt_template = random.choice(self.prompt_templates) - description = self.generate_description_dynamic(item, prompt_template) - else: - # use ordinary static prompt instead - description = self.generate_description_ordinary(item) - return description - - def generate_description_dynamic(self, data, prompt_template: PromptTemplate): - exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] - - if len(exists_tag) > 0: - probs = dist_prob_map[len(exists_tag)] - tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] - random.shuffle(exists_tag) - tags = exists_tag[:tags_num] - tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} - tags_args = self.handle_BPM_tag(tags_args) - prompt = prompt_template.apply(**tags_args) - else: - # no strong tags, use all weak tags instead - prompt = prompt_template.apply() - - return prompt - - def tags_to_desc(self, tag_list, tag_type) -> str: - if self.lang == 'en': - return tags_to_desc(tag_list) - elif self.lang == 'zh': - if tag_type == 'BPM': - return tags_to_desc(tag_list, sep='、') - translator = self.translate[tag_type] - translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] - return tags_to_desc(translated_tag_list, sep='、') - - def handle_BPM_tag(self, tags_args): - if "BPM" in tags_args and 'BPMDescript' in self.tag_types: - bpm = tags_args["BPM"] - del tags_args["BPM"] - tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) - for tag_type in tag_types_used: - tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) - return tags_args - - def generate_description_ordinary(self, data, thresh = 0.3): - if self.lang != 'en': - raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') - description = f'a piece of music by {data["Artist"]}' - - # Add genre if available - if data["Genre"] and random.random() > thresh: - genres = ', '.join(data["Genre"]) - description += f', belonging to the {genres} genres' - - # Add moods if available - if data["Tags"] and random.random() > thresh: - tags = ', '.join(data["Tags"]) - description += f'. This track contains the tags:{tags}' - - # Add moods if available - if data["Mood"] and random.random() > thresh: - moods = ', '.join(data["Mood"]) - description += f'. This track conveys a {moods} mood.' - - # Add instruments if available - if data["Instrument"] and random.random() > thresh: - instruments = ', '.join(data["Instrument"]) - description += f'. and primarily features the following instruments: {instruments}' - - # Add a period to end the description - description += '.' - - return description - diff --git a/codeclm/tokenizer/Flow1dVAE/model_1rvq.py b/codeclm/tokenizer/Flow1dVAE/model_1rvq.py index 7a3f708b223c6e47117d528a058de173ef26145c..3850618665ad9d35edbfe96795bccbc35ca6a48a 100644 --- a/codeclm/tokenizer/Flow1dVAE/model_1rvq.py +++ b/codeclm/tokenizer/Flow1dVAE/model_1rvq.py @@ -270,8 +270,6 @@ class PromptCondAudioDiffusion(nn.Module): hubert_layer=None, ssl_layer=None, uncondition=True, - out_paint=False, - ssl_path='ckpt/encode-s12k.pt' ): super().__init__() diff --git a/codeclm/tokenizer/Flow1dVAE/model_2rvq.py b/codeclm/tokenizer/Flow1dVAE/model_2rvq.py deleted file mode 100644 index d9f3644d88a28d798527a1b6de19ee81a2d24ddb..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/model_2rvq.py +++ /dev/null @@ -1,774 +0,0 @@ -import yaml -import random -import inspect -import numpy as np -from tqdm import tqdm -import typing as tp -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio - -from einops import repeat -from tools.torch_tools import wav_to_fbank - -import diffusers -from diffusers.utils.torch_utils import randn_tensor -from diffusers import DDPMScheduler -from models.transformer_2d_flow import Transformer2DModel -from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel -# from tools.get_mulan import get_mulan -from third_party.wespeaker.extract_embd import XVECModel -# from libs.rvq2 import RVQEmbedding -from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize - -from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model -from models_gpt.models.gpt2_config import GPT2Config - -from torch.cuda.amp import autocast - - -from our_MERT_BESTRQ.test import load_model - -class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - - # The final projection layer is only used for backward compatibility. - # Following https://github.com/auspicious3000/contentvec/issues/6 - # Remove this layer is necessary to achieve the desired outcome. - print("hidden_size:",config.hidden_size) - print("classifier_proj_size:",config.classifier_proj_size) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - """Project back from diffusion space to the actual sample space.""" - return z - -class Feature1DProcessor(SampleProcessor): - def __init__(self, dim: int = 100, power_std = 1., \ - num_samples: int = 100_000, cal_num_frames: int = 600): - super().__init__() - - self.num_samples = num_samples - self.dim = dim - self.power_std = power_std - self.cal_num_frames = cal_num_frames - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(dim)) - self.register_buffer('sum_x2', torch.zeros(dim)) - self.register_buffer('sum_target_x2', torch.zeros(dim)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - if(self.counts < 10): - mean = torch.zeros_like(mean) - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - if(self.counts < 10): - std = torch.ones_like(std) - return std - - @property - def target_std(self): - return 1 - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - if self.counts.item() < self.num_samples: - self.counts += len(x) - self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) - self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) - return x - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - rescale = (self.std / self.target_std) ** self.power_std - # print(rescale, self.mean) - x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) - return x - -def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): - if(prior_text_encoder_hidden_states.shape[1] 1.0): - - model_input = torch.cat([ \ - torch.cat([latent_mask_input, latent_mask_input], 0), \ - torch.cat([incontext_x, incontext_x], 0), \ - torch.cat([torch.zeros_like(mu), mu], 0), \ - torch.cat([x, x], 0), \ - ], 2) - timestep=t.unsqueeze(-1).repeat(2) - - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) - dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) - else: - model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) - timestep=t.unsqueeze(-1) - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - - dphi_dt = dphi_dt[: ,:, -x.shape[2]:] - # print("dphi_dt.shape:",dphi_dt.shape) - # print("x.shape:",x.shape) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def projection_loss(self,hidden_proj, bestrq_emb): - bsz = hidden_proj.shape[0] - - hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) - bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) - - proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) - proj_loss = 1+proj_loss.mean() - - return proj_loss - - def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_channels, mel_timesteps, n_feats) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_channels, mel_timesteps, n_feats) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_channels, mel_timesteps, n_feats) - """ - b = mu[0].shape[0] - len_x = x1.shape[2] - # random timestep - if(validation_mode): - t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 - else: - t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - # print("y.shape:",y.shape) - #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state - model_input = torch.cat([*mu,y], 2) - t=t.squeeze(-1).squeeze(-1) - # print("model_input.shape:",model_input.shape) - # print("attention_mask.shape:",attention_mask.shape) - out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) - hidden_layer = out.hidden_states[self.ssl_layer] - hidden_proj = self.mlp(hidden_layer) - # print("hidden_proj.shape:",hidden_proj.shape) - # print("mert_emb.shape:",mert_emb.shape) - # exit() - - - out = out.last_hidden_state - - out=out[:,:,-len_x:] - # out=self.proj_out(out) - - weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 - # print("out.shape",out.shape) - # print("u.shape",u.shape) - loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() - # print("hidden_proj.shape:",hidden_proj.shape) - # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) - loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) - loss = loss_re + loss_cos * 0.5 - # print("loss_cos:",loss_cos,loss_cos.device) - print("loss:",loss,loss.device) - # exit() - return loss, loss_re, loss_cos - -class PromptCondAudioDiffusion(nn.Module): - def __init__( - self, - num_channels, - unet_model_name=None, - unet_model_config_path=None, - snr_gamma=None, - hubert_layer=None, - ssl_layer=None, - uncondition=True, - out_paint=False, - ): - super().__init__() - - assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" - - self.unet_model_name = unet_model_name - self.unet_model_config_path = unet_model_config_path - self.snr_gamma = snr_gamma - self.uncondition = uncondition - self.num_channels = num_channels - self.hubert_layer = hubert_layer - self.ssl_layer = ssl_layer - - # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview - self.normfeat = Feature1DProcessor(dim=64) - - self.sample_rate = 48000 - self.num_samples_perseg = self.sample_rate * 20 // 1000 - self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) - self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) - # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - self.bestrq = load_model( - model_dir='path/to/our-MERT/mert_fairseq', - checkpoint_dir='checkpoint-120000.pt', - ) - self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) - self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) - for v in self.bestrq.parameters():v.requires_grad = False - self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - # for v in self.rvq_bestrq_emb.parameters(): - # print(v) - freeze_parameters='quantizers.0' - for name, param in self.rvq_bestrq_emb.named_parameters(): - if freeze_parameters in name: - param.requires_grad = False - print("Freezing RVQ parameters:", name) - self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") - for v in self.hubert.parameters():v.requires_grad = False - self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) - # self.xvecmodel = XVECModel() - config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) - unet = GPT2Model(config) - mlp = nn.Sequential( - nn.Linear(1200, 1024), - nn.SiLU(), - nn.Linear(1024, 1024), - nn.SiLU(), - nn.Linear(1024, 768) - ) - self.set_from = "random" - self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) - self.mask_emb = torch.nn.Embedding(3, 48) - print("Transformer initialized from pretrain.") - torch.cuda.empty_cache() - # self.unet.set_attn_processor(AttnProcessor2_0()) - # self.unet.set_use_memory_efficient_attention_xformers(True) - - # self.start_embedding = nn.Parameter(torch.randn(1,1024)) - # self.end_embedding = nn.Parameter(torch.randn(1,1024)) - - def compute_snr(self, timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = self.noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - def preprocess_audio(self, input_audios, threshold=0.9): - assert len(input_audios.shape) == 2, input_audios.shape - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios/norm_value.unsqueeze(-1) - - def extract_wav2vec_embeds(self, input_audios,output_len): - wav2vec_stride = 2 - - wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 - # print(wav2vec_embeds) - # print("audio.shape:",input_audios.shape) - wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] - # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) - wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) - return wav2vec_embeds_last - - def extract_mert_embeds(self, input_audios): - prompt_stride = 3 - inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 - mert_emb= prompt_embeds[-1] - mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) - - return mert_emb - - def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): - self.bestrq.eval() - # print("audio shape:",input_audio_0.shape) - input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 - # print("input_wav_mean.shape:",input_wav_mean.shape) - # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) - input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) - layer_results = input_wav_mean['layer_results'] - # print("layer_results.shape:",layer_results[layer].shape) - bestrq_emb = layer_results[layer] - bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() - #[b,t,1024] t=t/960 - #35.84s->batch,896,1024 - return bestrq_emb - - - def extract_spk_embeds(self, input_audios): - spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) - spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) - return spk_embeds - - def extract_lyric_feats(self, lyric): - with torch.no_grad(): - try: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) - except: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) - text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) - text_mask = text_mask.to(self.device) - text_encoder_hidden_states, text_mask, text_prompt_embeds = \ - pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() - return text_encoder_hidden_states, text_mask - - def extract_energy_bar(self, input_audios): - if(input_audios.shape[-1] % self.num_samples_perseg > 0): - energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) - else: - energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) - energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T - energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() - energy_embedding = self.energy_embedding(energy_bar) - energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t - return energy_embedding - - def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ - additional_feats = ['spk', 'lyric'], \ - train_rvq=True, train_ssl=False,layer=5): - if not hasattr(self,"device"): - self.device = input_audios.device - if not hasattr(self,"dtype"): - self.dtype = input_audios.dtype - device = self.device - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 - # energy_embedding = self.extract_energy_bar(input_audios) - # print("energy_embedding.shape:",energy_embedding.shape) - # with autocast(enabled=False): - if(train_ssl): - self.wav2vec.train() - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) - self.clap_embd_extractor.train() - prompt_embeds = self.extract_mert_embeds(input_audios) - if('spk' in additional_feats): - self.xvecmodel.train() - spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) - else: - with torch.no_grad(): - with autocast(enabled=False): - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - # mert_emb = self.extract_mert_embeds(input_audios_mert) - - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) - - bestrq_emb = bestrq_emb.detach() - if('lyric' in additional_feats): - text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) - else: - text_encoder_hidden_states, text_mask = None, None - - - if(train_rvq): - random_num=random.random() - if(random_num<0.6): - rvq_layer = 1 - elif(random_num<0.8): - rvq_layer = 2 - else: - rvq_layer = 4 - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t - else: - bestrq_emb = bestrq_emb.float() - self.rvq_bestrq_emb.eval() - # with autocast(enabled=False): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() - codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() - quantized_bestrq_emb = quantized_bestrq_emb.detach() - - commitment_loss = commitment_loss_bestrq_emb - codebook_loss = codebook_loss_bestrq_emb - - - alpha=1 - quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) - - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - scenario = np.random.choice(['start_seg', 'other_seg']) - if(scenario == 'other_seg'): - for binx in range(input_audios.shape[0]): - # latent_masks[binx,0:64] = 1 - latent_masks[binx,0:random.randint(64,128)] = 1 - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - - - - - if self.uncondition: - mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] - if len(mask_indices) > 0: - quantized_bestrq_emb[mask_indices] = 0 - # print("latents.shape:",latents.shape) - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.project_sample(latents) - latents = latents.permute(0,2,1).contiguous() - incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - # print("incontext_latents.shape:",incontext_latents.shape) - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - latent_mask_input = self.mask_emb(latent_masks) - #64+48+64+1024 - loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) - return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() - - def init_device_dtype(self, device, dtype): - self.device = device - self.dtype = dtype - - @torch.no_grad() - def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[[0],:] - input_audio_1 = input_audios[[1],:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, - guidance_scale=2, num_steps=20, - disable_progress=True, scenario='start_seg'): - classifier_free_guidance = guidance_scale > 1.0 - device = self.device - dtype = self.dtype - # codes_bestrq_middle, codes_bestrq_last = codes - codes_bestrq_emb = codes[0] - - - batch_size = codes_bestrq_emb.shape[0] - - - quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - - if('spk' in additional_feats): - spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() - - num_frames = quantized_bestrq_emb.shape[1] - - num_channels_latents = self.num_channels - shape = (batch_size, num_frames, 64) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - - - - latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) - latent_masks[:,0:latent_length] = 2 - if(scenario=='other_seg'): - latent_masks[:,0:incontext_length] = 1 - - - - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - true_latents = true_latents.permute(0,2,1).contiguous() - true_latents = self.normfeat.project_sample(true_latents) - true_latents = true_latents.permute(0,2,1).contiguous() - incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] - - - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - - if('spk' in additional_feats): - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) - additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) - else: - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) - additional_model_input = torch.cat([quantized_bestrq_emb],1) - - temperature = 1.0 - t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) - latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) - - latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.return_sample(latents) - # latents = latents.permute(0,2,1).contiguous() - return latents - - @torch.no_grad() - def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) - - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents - - @torch.no_grad() - def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) - import time - start = time.time() - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents,time.time()-start - - def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): - divisor = 4 - shape = (batch_size, num_channels_latents, num_frames, 32) - if(num_frames%divisor>0): - num_frames = round(num_frames/float(divisor))*divisor - shape = (batch_size, num_channels_latents, num_frames, 32) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - return latents - - diff --git a/codeclm/tokenizer/Flow1dVAE/model_4rvq.py b/codeclm/tokenizer/Flow1dVAE/model_4rvq.py deleted file mode 100644 index 09f61d5f589a51853110504c9ebb396093836ef9..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/model_4rvq.py +++ /dev/null @@ -1,774 +0,0 @@ -import yaml -import random -import inspect -import numpy as np -from tqdm import tqdm -import typing as tp -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio - -from einops import repeat -from tools.torch_tools import wav_to_fbank - -import diffusers -from diffusers.utils.torch_utils import randn_tensor -from diffusers import DDPMScheduler -from models.transformer_2d_flow import Transformer2DModel -from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel -# from tools.get_mulan import get_mulan -from third_party.wespeaker.extract_embd import XVECModel -# from libs.rvq2 import RVQEmbedding -from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize - -from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model -from models_gpt.models.gpt2_config import GPT2Config - -from torch.cuda.amp import autocast - - -from our_MERT_BESTRQ.test import load_model - -class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - - # The final projection layer is only used for backward compatibility. - # Following https://github.com/auspicious3000/contentvec/issues/6 - # Remove this layer is necessary to achieve the desired outcome. - print("hidden_size:",config.hidden_size) - print("classifier_proj_size:",config.classifier_proj_size) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - """Project back from diffusion space to the actual sample space.""" - return z - -class Feature1DProcessor(SampleProcessor): - def __init__(self, dim: int = 100, power_std = 1., \ - num_samples: int = 100_000, cal_num_frames: int = 600): - super().__init__() - - self.num_samples = num_samples - self.dim = dim - self.power_std = power_std - self.cal_num_frames = cal_num_frames - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(dim)) - self.register_buffer('sum_x2', torch.zeros(dim)) - self.register_buffer('sum_target_x2', torch.zeros(dim)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - if(self.counts < 10): - mean = torch.zeros_like(mean) - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - if(self.counts < 10): - std = torch.ones_like(std) - return std - - @property - def target_std(self): - return 1 - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - if self.counts.item() < self.num_samples: - self.counts += len(x) - self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) - self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) - return x - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - rescale = (self.std / self.target_std) ** self.power_std - # print(rescale, self.mean) - x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) - return x - -def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): - if(prior_text_encoder_hidden_states.shape[1] 1.0): - - model_input = torch.cat([ \ - torch.cat([latent_mask_input, latent_mask_input], 0), \ - torch.cat([incontext_x, incontext_x], 0), \ - torch.cat([torch.zeros_like(mu), mu], 0), \ - torch.cat([x, x], 0), \ - ], 2) - timestep=t.unsqueeze(-1).repeat(2) - - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) - dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) - else: - model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) - timestep=t.unsqueeze(-1) - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - - dphi_dt = dphi_dt[: ,:, -x.shape[2]:] - print("dphi_dt.shape:",dphi_dt.shape) - print("x.shape:",x.shape) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def projection_loss(self,hidden_proj, bestrq_emb): - bsz = hidden_proj.shape[0] - - hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) - bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) - - proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) - proj_loss = 1+proj_loss.mean() - - return proj_loss - - def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_channels, mel_timesteps, n_feats) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_channels, mel_timesteps, n_feats) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_channels, mel_timesteps, n_feats) - """ - b = mu[0].shape[0] - len_x = x1.shape[2] - # random timestep - if(validation_mode): - t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 - else: - t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - # print("y.shape:",y.shape) - #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state - model_input = torch.cat([*mu,y], 2) - t=t.squeeze(-1).squeeze(-1) - # print("model_input.shape:",model_input.shape) - # print("attention_mask.shape:",attention_mask.shape) - out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) - hidden_layer = out.hidden_states[self.ssl_layer] - hidden_proj = self.mlp(hidden_layer) - # print("hidden_proj.shape:",hidden_proj.shape) - # print("mert_emb.shape:",mert_emb.shape) - # exit() - - - out = out.last_hidden_state - - out=out[:,:,-len_x:] - # out=self.proj_out(out) - - weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 - # print("out.shape",out.shape) - # print("u.shape",u.shape) - loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() - # print("hidden_proj.shape:",hidden_proj.shape) - # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) - loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) - loss = loss_re + loss_cos * 0.5 - # print("loss_cos:",loss_cos,loss_cos.device) - print("loss:",loss,loss.device) - # exit() - return loss, loss_re, loss_cos - -class PromptCondAudioDiffusion(nn.Module): - def __init__( - self, - num_channels, - unet_model_name=None, - unet_model_config_path=None, - snr_gamma=None, - hubert_layer=None, - ssl_layer=None, - uncondition=True, - out_paint=False, - ): - super().__init__() - - assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" - - self.unet_model_name = unet_model_name - self.unet_model_config_path = unet_model_config_path - self.snr_gamma = snr_gamma - self.uncondition = uncondition - self.num_channels = num_channels - self.hubert_layer = hubert_layer - self.ssl_layer = ssl_layer - - # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview - self.normfeat = Feature1DProcessor(dim=64) - - self.sample_rate = 48000 - self.num_samples_perseg = self.sample_rate * 20 // 1000 - self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) - self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) - # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - self.bestrq = load_model( - model_dir='path/to/our-MERT/mert_fairseq', - checkpoint_dir='checkpoint-120000.pt', - ) - self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) - self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) - for v in self.bestrq.parameters():v.requires_grad = False - self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - # for v in self.rvq_bestrq_emb.parameters(): - # print(v) - freeze_parameters='quantizers.0' - for name, param in self.rvq_bestrq_emb.named_parameters(): - if freeze_parameters in name: - param.requires_grad = False - print("Freezing RVQ parameters:", name) - self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") - for v in self.hubert.parameters():v.requires_grad = False - self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) - # self.xvecmodel = XVECModel() - config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) - unet = GPT2Model(config) - mlp = nn.Sequential( - nn.Linear(1200, 1024), - nn.SiLU(), - nn.Linear(1024, 1024), - nn.SiLU(), - nn.Linear(1024, 768) - ) - self.set_from = "random" - self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) - self.mask_emb = torch.nn.Embedding(3, 48) - print("Transformer initialized from pretrain.") - torch.cuda.empty_cache() - # self.unet.set_attn_processor(AttnProcessor2_0()) - # self.unet.set_use_memory_efficient_attention_xformers(True) - - # self.start_embedding = nn.Parameter(torch.randn(1,1024)) - # self.end_embedding = nn.Parameter(torch.randn(1,1024)) - - def compute_snr(self, timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = self.noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - def preprocess_audio(self, input_audios, threshold=0.9): - assert len(input_audios.shape) == 2, input_audios.shape - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios/norm_value.unsqueeze(-1) - - def extract_wav2vec_embeds(self, input_audios,output_len): - wav2vec_stride = 2 - - wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 - # print(wav2vec_embeds) - # print("audio.shape:",input_audios.shape) - wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] - # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) - wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) - return wav2vec_embeds_last - - def extract_mert_embeds(self, input_audios): - prompt_stride = 3 - inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 - mert_emb= prompt_embeds[-1] - mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) - - return mert_emb - - def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): - self.bestrq.eval() - # print("audio shape:",input_audio_0.shape) - input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 - # print("input_wav_mean.shape:",input_wav_mean.shape) - # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) - input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) - layer_results = input_wav_mean['layer_results'] - # print("layer_results.shape:",layer_results[layer].shape) - bestrq_emb = layer_results[layer] - bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() - #[b,t,1024] t=t/960 - #35.84s->batch,896,1024 - return bestrq_emb - - - def extract_spk_embeds(self, input_audios): - spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) - spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) - return spk_embeds - - def extract_lyric_feats(self, lyric): - with torch.no_grad(): - try: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) - except: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) - text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) - text_mask = text_mask.to(self.device) - text_encoder_hidden_states, text_mask, text_prompt_embeds = \ - pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() - return text_encoder_hidden_states, text_mask - - def extract_energy_bar(self, input_audios): - if(input_audios.shape[-1] % self.num_samples_perseg > 0): - energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) - else: - energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) - energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T - energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() - energy_embedding = self.energy_embedding(energy_bar) - energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t - return energy_embedding - - def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ - additional_feats = ['spk', 'lyric'], \ - train_rvq=True, train_ssl=False,layer=5): - if not hasattr(self,"device"): - self.device = input_audios.device - if not hasattr(self,"dtype"): - self.dtype = input_audios.dtype - device = self.device - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 - # energy_embedding = self.extract_energy_bar(input_audios) - # print("energy_embedding.shape:",energy_embedding.shape) - # with autocast(enabled=False): - if(train_ssl): - self.wav2vec.train() - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) - self.clap_embd_extractor.train() - prompt_embeds = self.extract_mert_embeds(input_audios) - if('spk' in additional_feats): - self.xvecmodel.train() - spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) - else: - with torch.no_grad(): - with autocast(enabled=False): - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - # mert_emb = self.extract_mert_embeds(input_audios_mert) - - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) - - bestrq_emb = bestrq_emb.detach() - if('lyric' in additional_feats): - text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) - else: - text_encoder_hidden_states, text_mask = None, None - - - if(train_rvq): - random_num=random.random() - if(random_num<0.6): - rvq_layer = 1 - elif(random_num<0.8): - rvq_layer = 2 - else: - rvq_layer = 4 - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t - else: - bestrq_emb = bestrq_emb.float() - self.rvq_bestrq_emb.eval() - # with autocast(enabled=False): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() - codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() - quantized_bestrq_emb = quantized_bestrq_emb.detach() - - commitment_loss = commitment_loss_bestrq_emb - codebook_loss = codebook_loss_bestrq_emb - - - alpha=1 - quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) - - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - scenario = np.random.choice(['start_seg', 'other_seg']) - if(scenario == 'other_seg'): - for binx in range(input_audios.shape[0]): - # latent_masks[binx,0:64] = 1 - latent_masks[binx,0:random.randint(64,128)] = 1 - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - - - - - if self.uncondition: - mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] - if len(mask_indices) > 0: - quantized_bestrq_emb[mask_indices] = 0 - # print("latents.shape:",latents.shape) - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.project_sample(latents) - latents = latents.permute(0,2,1).contiguous() - incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - # print("incontext_latents.shape:",incontext_latents.shape) - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - latent_mask_input = self.mask_emb(latent_masks) - #64+48+64+1024 - loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) - return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() - - def init_device_dtype(self, device, dtype): - self.device = device - self.dtype = dtype - - @torch.no_grad() - def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[[0],:] - input_audio_1 = input_audios[[1],:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, - guidance_scale=2, num_steps=20, - disable_progress=True, scenario='start_seg'): - classifier_free_guidance = guidance_scale > 1.0 - device = self.device - dtype = self.dtype - # codes_bestrq_middle, codes_bestrq_last = codes - codes_bestrq_emb = codes[0] - - - batch_size = codes_bestrq_emb.shape[0] - - - quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - - if('spk' in additional_feats): - spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() - - num_frames = quantized_bestrq_emb.shape[1] - - num_channels_latents = self.num_channels - shape = (batch_size, num_frames, 64) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - - - - latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) - latent_masks[:,0:latent_length] = 2 - if(scenario=='other_seg'): - latent_masks[:,0:incontext_length] = 1 - - - - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - true_latents = true_latents.permute(0,2,1).contiguous() - true_latents = self.normfeat.project_sample(true_latents) - true_latents = true_latents.permute(0,2,1).contiguous() - incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] - - - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - - if('spk' in additional_feats): - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) - additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) - else: - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) - additional_model_input = torch.cat([quantized_bestrq_emb],1) - - temperature = 1.0 - t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) - latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) - - latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.return_sample(latents) - # latents = latents.permute(0,2,1).contiguous() - return latents - - @torch.no_grad() - def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) - - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents - - @torch.no_grad() - def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) - import time - start = time.time() - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents,time.time()-start - - def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): - divisor = 4 - shape = (batch_size, num_channels_latents, num_frames, 32) - if(num_frames%divisor>0): - num_frames = round(num_frames/float(divisor))*divisor - shape = (batch_size, num_channels_latents, num_frames, 32) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - return latents - - diff --git a/codeclm/tokenizer/Flow1dVAE/model_septoken.py b/codeclm/tokenizer/Flow1dVAE/model_septoken.py index 2d4588e0826b73e709da6479c672fd140b28512c..d5790fcae2155696ef1b516264303afc8f0cd626 100644 --- a/codeclm/tokenizer/Flow1dVAE/model_septoken.py +++ b/codeclm/tokenizer/Flow1dVAE/model_septoken.py @@ -252,8 +252,6 @@ class PromptCondAudioDiffusion(nn.Module): unet_model_config_path=None, snr_gamma=None, uncondition=True, - out_paint=False, - ssl_path='ckpt/encode-s12k.pt' ): super().__init__() diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml deleted file mode 100644 index cc60e5b3492fad2c7bb3d793ad3705a1f2086e36..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml +++ /dev/null @@ -1,122 +0,0 @@ -# @package _group_ - -common: - fp16: true - log_format: json - log_interval: 200 - tensorboard_logdir: tb - min_loss_scale: 1e-6 - fp16_no_flatten_grads: true - user_dir: ${env:PWD} - seed: 1 - -checkpoint: - save_interval: 1 - save_interval_updates: 10000 - keep_interval_updates: 1 - no_epoch_checkpoints: true - -task: - _name: mae_image_pretraining - data: unbalanced_train - rebuild_batches: true - key: source - precompute_mask_config: {} - downsr_16hz: true - audio_mae: true - h5_format: false - target_length: 1024 - flexible_mask: false - -dataset: - num_workers: 10 - batch_size: 12 - skip_invalid_size_inputs_valid_test: true - required_batch_size_multiple: 1 - disable_validation: true - -distributed_training: - distributed_world_size: 4 - ddp_backend: c10d - -criterion: - _name: model - log_keys: - - ema_decay - - target_var - - pred_var - - model_norm - - ema_norm - - masked_pct - -optimization: - max_update: 400000 - lr: [ 0.0005 ] - debug_param_names: true - clip_norm: 4 - -optimizer: - _name: composite - dynamic_groups: true - groups: - default: - lr_float: 0.0005 - optimizer: - _name: adam - adam_betas: [0.9,0.95] - weight_decay: 0.05 - lr_scheduler: - _name: cosine - warmup_updates: 53333 - -lr_scheduler: pass_through - -model: - _name: data2vec_multi - - ema_decay: 0.9998 - ema_end_decay: 0.99999 - ema_anneal_end_step: 100000 - instance_norm_target_layer: true - layer_norm_target_layer: false - layer_norm_targets: true - end_of_block_targets: false - - depth: 12 - average_top_k_layers: 12 - clone_batch: 16 - - norm_eps: 1e-6 - - min_target_var: 0 - min_pred_var: 0 - - encoder_dropout: 0 - post_mlp_drop: 0 - attention_dropout: 0 - activation_dropout: 0 - - supported_modality: IMAGE - cls_loss: 1 - - ema_encoder_only: false - - modalities: - image: - in_chans: 1 - inverse_mask: true - mask_prob: 0.8 - mask_prob_adjust: 0.07 - mask_length: 5 - mask_noise_std: 0.01 - prenet_depth: 0 - ema_local_encoder: true - num_extra_tokens: 1 - init_extra_token_zero: false - use_alibi_encoder: false - decoder: - decoder_dim: 768 - decoder_groups: 16 - decoder_kernel: 3 - decoder_layers: 6 - input_dropout: 0 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml deleted file mode 100644 index 92318155dcb8ab93a395dbb49d9b99144b534ded..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml +++ /dev/null @@ -1,125 +0,0 @@ -# @package _group_ - -common: - fp16: true - log_format: json - log_interval: 200 - tensorboard_logdir: tb - min_loss_scale: 1e-6 - fp16_no_flatten_grads: true - user_dir: ${env:PWD} - seed: 1 - -checkpoint: - save_interval: 1 - save_interval_updates: 10000 - keep_interval_updates: 1000 - no_epoch_checkpoints: true - -task: - _name: mae_image_pretraining - data: music4all_sh/ - rebuild_batches: true - key: source - precompute_mask_config: {} - downsr_16hz: false - audio_mae: true - h5_format: false - target_length: 752 - flexible_mask: false - sample_rate: 24000 - fixed_duration: 30 - -dataset: - num_workers: 10 - batch_size: 12 - skip_invalid_size_inputs_valid_test: true - required_batch_size_multiple: 1 - disable_validation: true - -distributed_training: - distributed_world_size: 4 - ddp_backend: c10d - -criterion: - _name: model - log_keys: - - ema_decay - - target_var - - pred_var - - model_norm - - ema_norm - - masked_pct - -optimization: - max_update: 400000 - lr: [ 0.0001 ] - # debug_param_names: true - clip_norm: 4 - -optimizer: - _name: composite - # dynamic_groups: true - groups: - default: - lr_float: 0.0005 - optimizer: - _name: adam - adam_betas: [0.9,0.95] - weight_decay: 0.05 - lr_scheduler: - _name: cosine - warmup_updates: 10000 # 53333 - -lr_scheduler: pass_through - -model: - _name: data2vec_multi - - ema_decay: 0.9998 - ema_end_decay: 0.99999 - ema_anneal_end_step: 100000 - instance_norm_target_layer: true - layer_norm_target_layer: false - layer_norm_targets: true - end_of_block_targets: false - - depth: 12 - average_top_k_layers: 12 - clone_batch: 16 - - norm_eps: 1e-6 - - min_target_var: 0 - min_pred_var: 0 - - encoder_dropout: 0 - post_mlp_drop: 0 - attention_dropout: 0 - activation_dropout: 0 - - supported_modality: IMAGE - cls_loss: 1 - - ema_encoder_only: false - - modalities: - image: - in_chans: 1 - inverse_mask: true - mask_prob: 0.8 - mask_prob_adjust: 0.07 - mask_length: 5 - mask_noise_std: 0.01 - prenet_depth: 0 - ema_local_encoder: true - num_extra_tokens: 1 - init_extra_token_zero: false - use_alibi_encoder: false - decoder: - decoder_dim: 768 - decoder_groups: 16 - decoder_kernel: 3 - decoder_layers: 6 - input_dropout: 0 - target_length: 752 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml deleted file mode 100644 index 900a60b1ce41b671db2ac77a6d0cd4290dc8ff2f..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml +++ /dev/null @@ -1,137 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 100 - seed: 1337 - - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 5000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - # reset-dataloader: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sharding_data: -1 #数据分块 - load_random_data_shard: false - sample_rate: 24000 - # crop to 5s - # max_sample_size: 120000 - # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. - max_sample_size: 122880 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - # normalize: true # must be consistent with extractor_mode: layer_norm - normalize: false # must be consistent with extractor_mode: default (groupnorm) - - -dataset: - num_workers: 6 - max_tokens: 900000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 1000000 - lr: [0.0015] - clip_norm: 1.0 - update_freq: [8] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 36000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - - final_dim: 128 - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 - # default refers to group norm - extractor_mode: default - # extractor_mode: layer_norm - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - encoder_layerdrop: 0.0 - dropout_input: 0.0 - dropout_features: 0.0 - dropout: 0.0 - attention_dropout: 0.0 - - layer_norm_first: true - feature_grad_mult: 1.0 - - untie_final_proj: true - activation_dropout: 0.0 - - deepnorm: false - attention_relax: 32.0 - - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml deleted file mode 100644 index f6ddce0e95f235300e69375bed680234a015112e..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml +++ /dev/null @@ -1,139 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 100 - seed: 1337 - # model_parallel_size: 8 - # amp: true - - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 5000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - # reset-dataloader: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sharding_data: -1 #数据分块 - load_random_data_shard: false - sample_rate: 24000 - # crop to 5s - # max_sample_size: 120000 - # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. - max_sample_size: 122880 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - # normalize: true # must be consistent with extractor_mode: layer_norm - normalize: false # must be consistent with extractor_mode: default (groupnorm) - - -dataset: - num_workers: 6 - max_tokens: 900000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 1000000 - lr: [0.0015] - clip_norm: 1.0 - update_freq: [8] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 36000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - - final_dim: 128 - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 - # default refers to group norm - extractor_mode: default - # extractor_mode: layer_norm - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - encoder_layerdrop: 0.0 - dropout_input: 0.0 - dropout_features: 0.0 - dropout: 0.0 - attention_dropout: 0.0 - - layer_norm_first: true - feature_grad_mult: 1.0 - - untie_final_proj: true - activation_dropout: 0.0 - - deepnorm: false - attention_relax: 32.0 - - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: run - sweep: - dir: sweep - subdir: subdir diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml deleted file mode 100644 index e65613fd792771729436d20f48578fc88a5e7ad1..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml +++ /dev/null @@ -1,138 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 100 - seed: 1337 - # amp: true - - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 5000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - # reset-dataloader: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sharding_data: -1 #数据分块 - load_random_data_shard: false - sample_rate: 24000 - # crop to 5s - # max_sample_size: 120000 - # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. - max_sample_size: 122880 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - # normalize: true # must be consistent with extractor_mode: layer_norm - normalize: false # must be consistent with extractor_mode: default (groupnorm) - - -dataset: - num_workers: 6 - max_tokens: 900000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 1000000 - lr: [0.0015] - clip_norm: 1.0 - update_freq: [8] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 36000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - - final_dim: 128 - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 - # default refers to group norm - extractor_mode: default - # extractor_mode: layer_norm - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - encoder_layerdrop: 0.0 - dropout_input: 0.0 - dropout_features: 0.0 - dropout: 0.0 - attention_dropout: 0.0 - - layer_norm_first: true - feature_grad_mult: 1.0 - - untie_final_proj: true - activation_dropout: 0.0 - - deepnorm: false - attention_relax: 32.0 - - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: run - sweep: - dir: sweep - subdir: subdir diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml deleted file mode 100644 index de0ca018b11daebbee3a6b82c5356bf69d0bd839..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml +++ /dev/null @@ -1,139 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 100 - seed: 1337 - model_parallel_size: 8 - # amp: true - - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 5000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - # reset-dataloader: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sharding_data: -1 #数据分块 - load_random_data_shard: false - sample_rate: 24000 - # crop to 5s - # max_sample_size: 120000 - # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. - max_sample_size: 122880 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - # normalize: true # must be consistent with extractor_mode: layer_norm - normalize: false # must be consistent with extractor_mode: default (groupnorm) - - -dataset: - num_workers: 6 - max_tokens: null - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 1000000 - lr: [0.0015] - clip_norm: 1.0 - update_freq: [8] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 36000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - - final_dim: 128 - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 - # default refers to group norm - extractor_mode: default - # extractor_mode: layer_norm - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - encoder_layerdrop: 0.0 - dropout_input: 0.0 - dropout_features: 0.0 - dropout: 0.0 - attention_dropout: 0.0 - - layer_norm_first: true - feature_grad_mult: 1.0 - - untie_final_proj: true - activation_dropout: 0.0 - - deepnorm: false - attention_relax: 32.0 - - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: run - sweep: - dir: sweep - subdir: subdir diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml deleted file mode 100644 index f1beb97fd08917202a5574535ad8badabbeae72c..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml +++ /dev/null @@ -1,135 +0,0 @@ -# @package _group_ -common: - fp16: true - log_format: json - log_interval: 100 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 5000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sharding_data: 6 - load_random_data_shard: false - sample_rate: 24000 - # crop to 5s - # max_sample_size: 120000 - # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. - max_sample_size: 122880 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - # normalize: true # must be consistent with extractor_mode: layer_norm - normalize: false # must be consistent with extractor_mode: default (groupnorm) - - -dataset: - num_workers: 6 - max_tokens: 900000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0015] - clip_norm: 1.0 - update_freq: [8] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 36000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - - final_dim: 128 - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 - # default refers to group norm - extractor_mode: default - # extractor_mode: layer_norm - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - encoder_layerdrop: 0.0 - dropout_input: 0.0 - dropout_features: 0.0 - dropout: 0.0 - attention_dropout: 0.0 - - layer_norm_first: true - feature_grad_mult: 1.0 - - untie_final_proj: true - activation_dropout: 0.0 - - deepnorm: false - attention_relax: 32.0 - - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml deleted file mode 100644 index 82adc82cd824b65e0423221de7565307c893cd75..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml +++ /dev/null @@ -1,137 +0,0 @@ -# @package _group_ -common: - fp16: true - log_format: json - log_interval: 100 - seed: 1337 - - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 5000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - # reset-dataloader: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sharding_data: -1 #数据分块 - load_random_data_shard: false - sample_rate: 24000 - # crop to 5s - # max_sample_size: 120000 - # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. - max_sample_size: 122880 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - # normalize: true # must be consistent with extractor_mode: layer_norm - normalize: false # must be consistent with extractor_mode: default (groupnorm) - - -dataset: - num_workers: 6 - max_tokens: 900000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0015] - clip_norm: 1.0 - update_freq: [8] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - # freeze_parameters:true - logit_temp: 0.1 - - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 36000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - - final_dim: 128 - encoder_layers: 24 - encoder_embed_dim: 1024 - encoder_ffn_embed_dim: 4096 - encoder_attention_heads: 16 - # default refers to group norm - extractor_mode: default - # extractor_mode: layer_norm - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - encoder_layerdrop: 0.0 - dropout_input: 0.0 - dropout_features: 0.0 - dropout: 0.0 - attention_dropout: 0.0 - - layer_norm_first: true - feature_grad_mult: 1.0 - - untie_final_proj: true - activation_dropout: 0.0 - - deepnorm: false - attention_relax: 32.0 - - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml deleted file mode 100644 index 1f5cea6050278a6af45f9e4d8f2b4c20476ab5a6..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml +++ /dev/null @@ -1,116 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 25000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: w2v_conv - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml deleted file mode 100644 index bdb39618e657a1f2f2663943e94b4d5bc176cc75..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml +++ /dev/null @@ -1,125 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 25000 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 8 # 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: true - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 1 - audio_rq_loss_num_embeds: 8192 - audio_rq_loss_seed: 42 - audio_rq_loss_use_norm: true - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml deleted file mode 100644 index 3ea092f03084c911848080d58b470fe5dcb00f8d..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml +++ /dev/null @@ -1,128 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: true - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 1 - audio_rq_loss_num_embeds: 8192 - audio_rq_loss_seed: 42 - audio_rq_loss_use_norm: true - audio_rq_loss_use_chroma: true - audio_rq_loss_seed_chroma: 123 - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 32 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml deleted file mode 100644 index c7471c8e8482ad82bcadcdc7909513a6b17efd88..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml +++ /dev/null @@ -1,126 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: true - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 1 - audio_rq_loss_num_embeds: 8192 - audio_rq_loss_seed: 42 - audio_rq_loss_use_norm: true - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml deleted file mode 100644 index ce6f750f1cc7e4bccb947713fa6eef280673e53a..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml +++ /dev/null @@ -1,128 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: true - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 1 - audio_rq_loss_num_embeds: 8192 - audio_rq_loss_seed: 42 - audio_rq_loss_use_norm: true - audio_rq_loss_use_chroma: false - audio_rq_loss_seed_chroma: 123 - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml deleted file mode 100644 index b296cdc55d0c69e4e6630e29a12ba7acb0bb6727..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml +++ /dev/null @@ -1,128 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0 # 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 80 # 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: true - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 1 - audio_rq_loss_num_embeds: 8192 - audio_rq_loss_seed: 42 - audio_rq_loss_use_norm: true - audio_rq_loss_use_chroma: false - audio_rq_loss_seed_chroma: 123 - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: false - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml deleted file mode 100644 index 4c898b49f42845afeb8df7823e76d69f66a53ce5..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml +++ /dev/null @@ -1,121 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: w2v_conv - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # ---- codec target - audio_codec_type: rvq - audio_codec_ckpt_path: RVQ_3000.pth - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml deleted file mode 100644 index 52d274734f088016f071c5c13b9d5c2add7af536..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml +++ /dev/null @@ -1,121 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: w2v_conv - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # ---- codec target - audio_codec_type: dac - audio_codec_dac_model_path: weights_24khz_8kbps_0.0.4.pth #nj - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml deleted file mode 100644 index 35d50add44d5354045a8272d18fba4e151724bee..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml +++ /dev/null @@ -1,125 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: true - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 64 # 32 - audio_rq_loss_num_embeds: 1024 - audio_rq_loss_seed: 42 - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 16 # 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml deleted file mode 100644 index 83e0d8b62c7afa92a7f9931c8dd686cff2fff737..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # crop to 5s - max_sample_size: 120000 - min_sample_size: 72000 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: hubert - pred_masked_weight: 1.0 - pred_nomask_weight: 0.0 - loss_weights: [10, 1] - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [4] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: mert - label_rate: ??? - skip_masked: false - skip_nomask: true - mask_prob: 0.8 - mask_length: 5 - - logit_temp: 0.1 - - # ----- mixture ------ - mixture_prob: 0.5 - inbatch_noise_augment_len_range: "[12000, 24000]" - inbatch_noise_augment_number_range: "[1, 3]" - inbatch_noise_augment_volume: 1.0 - # ------------------------ - extractor_mode: default - audio_extract_type: melspec # use melspec (instead of `w2v_conv`) - melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave - conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' - - # best-rq loss - audio_rq_loss_m: false - audio_rq_loss_embed_dim: 16 - audio_rq_loss_num_codebooks: 1 - audio_rq_loss_num_embeds: 8192 - - # ---- cqt reconstruction, need to add loss weight --- - audio_cqt_loss_m: true - audio_cqt_bins: 336 - # ----------- - final_dim: 64 - encoder_layerdrop: 0.05 - dropout_input: 0.1 - dropout_features: 0.1 - dropout: 0.1 - attention_dropout: 0.1 - feature_grad_mult: 0.1 - untie_final_proj: true - activation_dropout: 0.0 - - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml deleted file mode 100644 index 9a5fa87c965ce2aab465f0df5cd563a99a4ae20c..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml +++ /dev/null @@ -1,108 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # # crop to 5s - # max_sample_size: 120000 - # min_sample_size: 72000 - - # crop to 30s - max_sample_size: 720000 - min_sample_size: 432000 - clip_secs: 30 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: model - # log_keys: - # - accuracies - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [1] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: musicfm - label_rate: 25 - num_codebooks: 1 - codebook_dim: 16 - codebook_size: 8192 # 4096 - features: ["melspec_2048"] - hop_length: 240 - n_mels: 128 - conv_dim: 512 - encoder_dim: 1024 - encoder_depth: 12 - mask_hop: 0.4 - mask_prob: 0.6 - is_flash: false - - stat_path: msd_stats.json - model_path: null - w2v2_config_path: our-MERT/data/models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa - use_rvq_target: true - rvq_ckpt_path: RVQ_4000.pth - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml deleted file mode 100644 index fd38c04f61129d7866b231df2bb6ec2cbf606d78..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml +++ /dev/null @@ -1,105 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 12500 - keep_interval_updates: -1 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # # crop to 5s - # max_sample_size: 120000 - # min_sample_size: 72000 - - # crop to 30s - max_sample_size: 720000 - min_sample_size: 432000 - clip_secs: 30 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - -criterion: - _name: model - # log_keys: - # - accuracies - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [1] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: musicfm - label_rate: 25 - num_codebooks: 1 - codebook_dim: 16 - codebook_size: 4096 - features: ["melspec_2048"] - hop_length: 240 - n_mels: 128 - conv_dim: 512 - encoder_dim: 1024 - encoder_depth: 12 - mask_hop: 0.4 - mask_prob: 0.6 - is_flash: false - stat_path: msd_stats.json - model_path: pretrained_msd.pt - w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml deleted file mode 100644 index 0cc0d2a03c6f937ddbdf39fc46eb532eef98fd2f..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml +++ /dev/null @@ -1,106 +0,0 @@ -# @package _group_ -common: - fp16: false - log_format: json - log_interval: 200 - seed: 1337 - # tensorboard_logdir: tblog_proj_name - # wandb_project: wandb_proj_name - -checkpoint: - save_interval_updates: 2500 - keep_interval_updates: 10000 - no_epoch_checkpoints: true - - -distributed_training: - ddp_backend: no_c10d - distributed_backend: 'nccl' - distributed_world_size: 64 - nprocs_per_node: 8 - find_unused_parameters: true - -task: - _name: mert_pretraining - data: ??? - label_dir: ??? - labels: ??? - label_rate: ${model.label_rate} - sample_rate: 24000 - # # crop to 5s - # max_sample_size: 120000 - # min_sample_size: 72000 - - # crop to 30s - max_sample_size: 720000 - min_sample_size: 12000 - # clip_secs: 30 - - pad_audio: false - random_crop: true - normalize: false # must be consistent with extractor - - -dataset: - num_workers: 6 - max_tokens: 2000000 - skip_invalid_size_inputs_valid_test: true - validate_interval: 1 - validate_interval_updates: 10000 - disable_validation: true - -criterion: - _name: model - # log_keys: - # - accuracies - -optimization: - max_update: 400000 - lr: [0.0005] - clip_norm: 10.0 - update_freq: [1] - -optimizer: - _name: adam - adam_betas: (0.9,0.98) - adam_eps: 1e-06 - weight_decay: 0.01 - -lr_scheduler: - _name: polynomial_decay - warmup_updates: 32000 - -model: - _name: musicfm - label_rate: 25 - num_codebooks: 1 - codebook_dim: 16 - codebook_size: 4096 - features: ["melspec_2048"] - hop_length: 240 - n_mels: 128 - conv_dim: 512 - encoder_dim: 1024 - encoder_depth: 12 - mask_hop: 0.4 - mask_prob: 0.6 - is_flash: false - stat_path: msd_stats.json - model_path: null - w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa - -hydra: - job: - config: - override_dirname: - kv_sep: '-' - item_sep: '__' - exclude_keys: - - run - - task.data - - task.label_dir - run: - dir: ??? - sweep: - dir: ??? - subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml deleted file mode 100644 index 46c979cd2835fe026b0a532a54533904d1001e54..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# @package _global_ - -hydra: - launcher: - cpus_per_task: 8 - gpus_per_node: 8 - tasks_per_node: ${hydra.launcher.gpus_per_node} - nodes: 4 - comment: null - mem_gb: 384 - timeout_min: 4320 - max_num_timeout: 100 - constraint: volta32gb - name: ${hydra.job.config_name}/${hydra.job.override_dirname} - submitit_folder: ${hydra.sweep.dir}/submitit/%j - -distributed_training: - distributed_world_size: 32 - distributed_port: 29671 - nprocs_per_node: 8 diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py deleted file mode 100644 index 17079090f01ca01bf4bab30fa298fd9762752fc8..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .mert_dataset import MERTDataset -from .eat_data import * \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py deleted file mode 100644 index 47e8b8507d823305152a2f296432498afc8946c1..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py +++ /dev/null @@ -1,115 +0,0 @@ -import logging -import torch -import torch.nn.functional as F -from fairseq.data.audio.raw_audio_dataset import RawAudioDataset -from typing import Tuple -try: - import kaldiio -except: - kaldiio = None -import warnings - -logger = logging.getLogger(__name__) - - -class ArkDataset(RawAudioDataset): - def __init__( - self, - wav_scp, - dur_scp, - sr = 24000, - max_dur = 20, - num_buckets=0, - normalize=False, - ): - super().__init__( - sample_rate=sr, - max_sample_size=max_dur*sr, - min_sample_size=1200, - shuffle=True, - pad=True, - normalize=normalize, - compute_mask=False, - ) - self.sr = sr - self.max_dur = max_dur - self.normalize = normalize - - logger.info("Loading Kaldi scp files from {}".format(wav_scp)) - - self.wav_data = kaldiio.load_scp(wav_scp) - self.keys = list(self.wav_data.keys()) - dur_data = {} - keys_set = set(self.keys) - - with open(dur_scp, 'r') as f: - for line in f: - line = line.strip().split() - if line[0] in keys_set: - dur_data[line[0]] = float(line[-1]) - self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys] - - logger.info("Loading Kaldi scp files done") - - self.dataset_len = len(self.keys) - self.set_bucket_info(num_buckets) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, idx): - # print("getitem idx: ", idx) - try_cnt = 0 - while True: - idx = idx + try_cnt - try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - key = self.keys[idx] - # print(self.wav_data[key].keys()) - wav = self.wav_data[key]['wav'] - - wav = torch.from_numpy(wav).float() - wav = self.postprocess(wav) - # print("success load", idx, " shape =", wav.shape) - return {"id": idx, "source": wav} - except Exception as e: - # from traceback import print_exc - # print_exc() - # print("Error loadding ", idx) - # return {"id": idx, "source": None} - try_cnt += 1 - if try_cnt > 50: - return {"id": idx, "source": None} - continue - - def size(self, idx): - return self.sizes[idx] - - def postprocess(self, wav): - if wav.dim() == 2: - wav = wav.mean(-1) - assert wav.dim() == 1, wav.dim() - - if self.normalize: - with torch.no_grad(): - wav = F.layer_norm(wav, wav.shape) - return wav - - def collater(self, samples): - # print("collate from:", [s['source'].shape for s in samples if s['source'] is not None]) - return super().collater(samples) - -if __name__ == '__main__': - import torch - raw_tensor_str = torch.Tensor.__repr__ - torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self) - - ds = ArkDataset( - wav_scp='data/ark_demo/wav_ark.scp', - dur_scp='data/ark_demo/dur_ark.scp', - sr=24000, - ) - - for i in range(len(ds)): - print(ds[i]) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/__init__.py deleted file mode 100644 index 92eae486fee1ed52c3550a8c47b17f760fe0933e..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -try: - from .mae_image_dataset import MaeImageDataset - from .raw_audio_dataset import FileAudioDataset -except: - import sys, os - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '.')) - from mae_image_dataset import MaeImageDataset - from raw_audio_dataset import FileAudioDataset - -__all__ = [ - "MaeImageDataset", - "FileAudioDataset", -] \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/add_class_target_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/add_class_target_dataset.py deleted file mode 100644 index 1ea93918d00fe2c5d4582bbc3d5b9e6035ccdc94..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/add_class_target_dataset.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import torch - -from fairseq.data import BaseWrapperDataset - -# add labels for audio clips in fine-tuning -class AddClassTargetDataset(BaseWrapperDataset): - def __init__( - self, - dataset, - labels, - multi_class, - num_classes=None, - label_indices=None, - add_to_input=True, - ): - super().__init__(dataset) - - self.label_indices = label_indices - self.labels = labels - self.multi_class = multi_class - self.add_to_input = add_to_input - if num_classes is None and multi_class: - assert self.label_indices is not None - num_classes = len(self.label_indices) - - self.num_classes = num_classes - - def __getitem__(self, index): - item = self.dataset[index] - item_labels = self.labels[index] - if self.multi_class: - item["label"] = torch.zeros(self.num_classes) - for il in item_labels: - if self.label_indices is not None: - il = self.label_indices[il] - item["label"][int(il)] = 1.0 - else: - item["label"] = torch.tensor( - self.labels[index] - if self.label_indices is None - else self.label_indices[self.labels[index]] - ) - - return item - - def collater(self, samples): - collated = self.dataset.collater(samples) - if len(collated) == 0: - return collated - - indices = set(collated["id"].tolist()) - target = [s["label"] for s in samples if s["id"] in indices] - collated["label"] = torch.stack(target, dim=0) - - if self.add_to_input: - collated["net_input"]["label"] = collated["label"] - - return collated diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/mae_image_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/mae_image_dataset.py deleted file mode 100644 index cce5086ca31263e8eb5d66c084b0edc4a57f63fe..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/mae_image_dataset.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - - -from functools import partial -import logging -import random -import time -import numpy as np -import os -import torch - -from fairseq.data import FairseqDataset -try: - from ..utils.data_utils import compute_block_mask_1d, compute_block_mask_2d -except: - import sys, os - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) - from utils.data_utils import compute_block_mask_1d, compute_block_mask_2d -try: - from .raw_audio_dataset import FileAudioDataset -except: - import sys, os - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '.')) - from raw_audio_dataset import FileAudioDataset - -from shutil import copyfile - -logger = logging.getLogger(__name__) - - -def load(path, loader, cache): - if hasattr(caching_loader, "cache_root"): - cache = caching_loader.cache_root - - cached_path = cache + path - - num_tries = 3 - for curr_try in range(num_tries): - try: - if curr_try == 2: - return loader(path) - if not os.path.exists(cached_path) or curr_try > 0: - os.makedirs(os.path.dirname(cached_path), exist_ok=True) - copyfile(path, cached_path) - os.chmod(cached_path, 0o777) - return loader(cached_path) - except Exception as e: - logger.warning(str(e)) - if "Errno 13" in str(e): - caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}" - logger.warning(f"setting cache root to {caching_loader.cache_root}") - cached_path = caching_loader.cache_root + path - if curr_try == (num_tries - 1): - raise - time.sleep(2) - - -def caching_loader(cache_root: str, loader): - if cache_root is None: - return loader - - if cache_root == "slurm_tmpdir": - cache_root = os.environ["SLURM_TMPDIR"] - assert len(cache_root) > 0 - - if not cache_root.endswith("/"): - cache_root += "/" - - return partial(load, loader=loader, cache=cache_root) - - -class MaeImageDataset(FairseqDataset): - def __init__( - self, - root: str, - split: str, - input_size, - shuffle=True, - key="imgs", - compute_mask=False, - patch_size: int = 16, - mask_prob: float = 0.75, - mask_prob_adjust: float = 0, - mask_length: int = 1, - inverse_mask: bool = False, - expand_adjacent: bool = False, - mask_dropout: float = 0, - non_overlapping: bool = False, - require_same_masks: bool = True, - clone_batch: int = 1, - audio_mae:bool = False, - h5_format:bool = False, - downsr_16hz:bool = False, - target_length:int = 1024, - esc50_eval:bool = False, - spcv2_eval:bool = False, - roll_aug: bool = False, - noise: bool = False, - dataset_type: str = "imagefolder", - num_samples: int = 200000, - replacement: bool = False, - AS2M_finetune: bool = False, - spcv1_finetune: bool =False, - weights_file: str="", - flexible_mask: bool = False, - sample_rate=24000, - fixed_duration=10, - ): - FairseqDataset.__init__(self) - - self.shuffle = shuffle - self.key = key - self.audio_mae = audio_mae - if self.audio_mae: - self.h5_format = h5_format - self.downsr_16hz = downsr_16hz - self.target_length = target_length - self.esc50_eval = esc50_eval - self.spcv2_eval = spcv2_eval - self.noise = noise - self.num_samples = num_samples - self.replacement = replacement - self.split = split - self.AS2M_finetune = AS2M_finetune - self.spcv1_finetune= spcv1_finetune - self.weights_file = weights_file - self.flexible_mask = flexible_mask - - self.transform_source = None - self.transform_target = None - self.img_shape = None - self.roll_aug = roll_aug - - # load wav files - mask_args = {} - if self.audio_mae: - min_sample_size = 10000 - - input_size = (self.target_length,128) - manifest_path = os.path.join(root, "{}.jsonl".format(split)) - self.dataset = FileAudioDataset( - manifest_path=manifest_path, - sample_rate=sample_rate, - fixed_duration=fixed_duration, - max_sample_size=sample_rate*fixed_duration, - min_sample_size=min_sample_size, - pad=False, - normalize=True, - num_buckets=0, - compute_mask=False, - h5_format=self.h5_format, - downsr_16hz=self.downsr_16hz, - wav2fbank=True, - target_length=self.target_length, - esc50_eval=self.esc50_eval, - spcv2_eval=self.spcv2_eval, - roll_mag_aug=self.roll_aug, - train_mode=split, - noise=self.noise, - **mask_args, - ) - self.skipped_indices = self.dataset.skipped_indices - - else: - raise Exception(f"invalid dataset type {dataset_type}") - - - logger.info(f"loaded {len(self.dataset)} examples") - - self.is_compute_mask = compute_mask - - if type(input_size) == tuple: - self.patches = (input_size[0] // patch_size ) * ( input_size[1] // patch_size ) - self.img_shape = (input_size[0] // patch_size,input_size[1] // patch_size ) - - else: - self.patches = (input_size // patch_size) ** 2 - self.mask_prob = mask_prob - self.mask_prob_adjust = mask_prob_adjust - self.mask_length = mask_length - self.inverse_mask = inverse_mask - self.expand_adjacent = expand_adjacent - self.mask_dropout = mask_dropout - self.non_overlapping = non_overlapping - self.require_same_masks = require_same_masks - self.clone_batch = clone_batch - - def __getitem__(self, index): - if self.audio_mae: - img = self.dataset[index]['source'] - else: - img, _ = self.dataset[index] - - source = None - target = None - - v = {"id": index, self.key: source if source is not None else img} - if target is not None: - v["target"] = target - - # inverse block mask on audio patches - if self.is_compute_mask: - if self.mask_length == 1: - mask = compute_block_mask_1d( - shape=(self.clone_batch, self.patches), - mask_prob=self.mask_prob, - mask_length=self.mask_length, - mask_prob_adjust=self.mask_prob_adjust, - inverse_mask=self.inverse_mask, - require_same_masks=True, - ) - else: # mask_length==5 - mask = compute_block_mask_2d( - shape=(self.clone_batch, self.patches), - mask_prob=self.mask_prob, - mask_length=self.mask_length, - mask_prob_adjust=self.mask_prob_adjust, - inverse_mask=self.inverse_mask, - require_same_masks=True, - expand_adjcent=self.expand_adjacent, - mask_dropout=self.mask_dropout, - non_overlapping=self.non_overlapping, - img_shape=self.img_shape, - flexible_mask=self.flexible_mask - ) - - if mask.shape[1] < self.patches: - padding = torch.zeros((mask.shape[0], self.patches - mask.shape[1])) - mask = torch.cat((mask, padding), dim=1) - - v["precomputed_mask"] = mask - - return v - - def __len__(self): - return len(self.dataset) - - def collater(self, samples): - if len(samples) == 0: - return {} - - collated_img = torch.stack([s[self.key] for s in samples], dim=0) - - res = { - "id": torch.LongTensor([s["id"] for s in samples]), - "net_input": { - self.key: collated_img, - }, - } - - if "target" in samples[0]: - collated_target = torch.stack([s["target"] for s in samples], dim=0) - res["net_input"]["target"] = collated_target - - if "precomputed_mask" in samples[0]: - collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0) - res["net_input"]["precomputed_mask"] = collated_mask - - return res - - def num_tokens(self, index): - return 1 - - def size(self, index): - return 1 - - @property - def sizes(self): - return np.full((len(self),), 1) - - # shuffle data (for pre-training and fine-tuning) - def ordered_indices(self): - """Return an ordered list of indices. Batches will be constructed based - on this order.""" - if self.shuffle and (self.AS2M_finetune or self.spcv1_finetune) and self.split == "train" : - weights = np.loadtxt(self.weights_file) - normalized_weights = weights / np.sum(weights) - weights_tensor = torch.from_numpy(normalized_weights) - - subsample_balanced_indicies = torch.multinomial(weights_tensor, self.num_samples, self.replacement) - order = subsample_balanced_indicies.numpy() - - # order = [np.random.choice(order[0], size=len(self), replace=True, p=weights)] - return order - - elif self.shuffle and self.split == "train": - order = [np.random.permutation(len(self))] - return order[0] - - - else: - order = [np.arange(len(self))] - return order[0] diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/raw_audio_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/raw_audio_dataset.py deleted file mode 100644 index d8475ec02fcc3c0de9cebf16c8232c685037d6c1..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/raw_audio_dataset.py +++ /dev/null @@ -1,545 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - - -import logging -import os -import sys -import time -import io -try: - import h5py -except: - h5py = None - -import numpy as np -import torch -import torch.nn.functional as F -import torchaudio - -from fairseq.data import FairseqDataset -try: - from ..utils.data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes -except: - import sys, os - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) - from utils.data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes -from fairseq.data.audio.audio_utils import ( - parse_path, - read_from_stored_zip, - is_sf_audio_data, -) -from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel - -import math -from typing import Tuple -import json - -logger = logging.getLogger(__name__) - -def load_audio_by_json(json_path, max_keep, min_keep): - # read json file - n_long, n_short = 0, 0 - datas = [] - inds = [] - sizes = [] - with open(json_path) as fp: - for ind,line in enumerate(fp): - data = json.loads(line) - sz = int(data['duration'] * data['sample_rate']) - if min_keep is not None and sz < min_keep: - n_short += 1 - elif max_keep is not None and sz > max_keep: - n_long += 1 - else: - datas.append(data) - inds.append(ind) - sizes.append(sz) - tot = ind + 1 - logger.info( - ( - f"json_path={json_path}, " - f"max_keep={max_keep}, min_keep={min_keep}, " - f"loaded {len(datas)}, skipped {n_short} short and {n_long} long, " - f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" - ) - ) - return datas, inds, tot, sizes - -class RawAudioDataset(FairseqDataset): - def __init__( - self, - sample_rate, - max_sample_size=None, - min_sample_size=0, - shuffle=True, - pad=False, - normalize=False, - compute_mask=False, - feature_encoder_spec: str = "None", - mask_prob: float = 0.75, - mask_prob_adjust: float = 0, - mask_length: int = 1, - inverse_mask: bool = False, - require_same_masks: bool = True, - clone_batch: int = 1, - expand_adjacent: bool = False, - mask_dropout: float = 0, - non_overlapping: bool = False, - corpus_key=None, - ): - super().__init__() - - self.sample_rate = sample_rate - self.sizes = [] - self.max_sample_size = ( - max_sample_size if max_sample_size is not None else sys.maxsize - ) - self.min_sample_size = min_sample_size - self.pad = pad - self.shuffle = shuffle - self.normalize = normalize - - self.is_compute_mask = compute_mask - self.feature_encoder_spec = eval(feature_encoder_spec) - self._features_size_map = {} - self.mask_prob = mask_prob - self.mask_prob_adjust = mask_prob_adjust - self.mask_length = mask_length - self.inverse_mask = inverse_mask - self.require_same_masks = require_same_masks - self.clone_batch = clone_batch - self.expand_adjacent = expand_adjacent - self.mask_dropout = mask_dropout - self.non_overlapping = non_overlapping - self.corpus_key = corpus_key - - def __getitem__(self, index): - raise NotImplementedError() - - def __len__(self): - return len(self.sizes) - - def _roll_mag_aug(self, waveform): - waveform=waveform.numpy() - idx=np.random.randint(len(waveform)) - rolled_waveform=np.roll(waveform,idx) - mag = np.random.beta(10, 10) + 0.5 - return torch.Tensor(rolled_waveform*mag) - - - def postprocess(self, feats, curr_sample_rate, roll_aug = False): - if feats.dim() == 2: - feats = feats.mean(-1) - - if curr_sample_rate != self.sample_rate: - raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") - - assert feats.dim() == 1, feats.dim() - # if self.normalize: - # with torch.no_grad(): - # feats = F.layer_norm(feats, feats.shape) - feats = feats - feats.mean() - - if roll_aug: - feats = self._roll_mag_aug(feats) - - return feats - - def crop_to_max_size(self, t, target_size, dim=0): - size = t.size(dim) - diff = size - target_size - if diff <= 0: - return t - - start = np.random.randint(0, diff + 1) - end = size - diff + start - - slices = [] - for d in range(dim): - slices.append(slice(None)) - slices.append(slice(start, end)) - - return t[slices] - - @staticmethod - def _bucket_tensor(tensor, num_pad, value): - return F.pad(tensor, (0, num_pad), value=value) - - def collater(self, samples): - samples = [s for s in samples if s["source"] is not None] - if len(samples) == 0: - return {} - - sources = [s["source"] for s in samples] - sizes = [len(s) for s in sources] - - if self.pad: - target_size = min(max(sizes), self.max_sample_size) - else: - target_size = min(min(sizes), self.max_sample_size) - - collated_sources = sources[0].new_zeros(len(sources), target_size) - padding_mask = ( - torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None - ) - for i, (source, size) in enumerate(zip(sources, sizes)): - diff = size - target_size - if diff == 0: - collated_sources[i] = source - elif diff < 0: - assert self.pad - collated_sources[i] = torch.cat( - [source, source.new_full((-diff,), 0.0)] - ) - padding_mask[i, diff:] = True - else: - collated_sources[i] = self.crop_to_max_size(source, target_size) - - input = {"source": collated_sources} - if self.corpus_key is not None: - input["corpus_key"] = [self.corpus_key] * len(sources) - out = {"id": torch.LongTensor([s["id"] for s in samples])} - if self.pad: - input["padding_mask"] = padding_mask - - if hasattr(self, "num_buckets") and self.num_buckets > 0: - assert self.pad, "Cannot bucket without padding first." - bucket = max(self._bucketed_sizes[s["id"]] for s in samples) - num_pad = bucket - collated_sources.size(-1) - if num_pad: - input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) - input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) - - if "precomputed_mask" in samples[0]: - target_size = self._get_mask_indices_dims(target_size) - collated_mask = torch.cat( - [ - self.crop_to_max_size(s["precomputed_mask"], target_size, dim=1) - for s in samples - ], - dim=0, - ) - input["precomputed_mask"] = collated_mask - - out["net_input"] = input - return out - - def _get_mask_indices_dims(self, size, padding=0, dilation=1): - if size not in self.feature_encoder_spec: - L_in = size - for (_, kernel_size, stride) in self.feature_encoder_spec: - L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 - L_out = 1 + L_out // stride - L_in = L_out - self._features_size_map[size] = L_out - return self._features_size_map[size] - - def num_tokens(self, index): - return self.size(index) - - def size(self, index): - """Return an example's size as a float or tuple. This value is used when - filtering a dataset with ``--max-positions``.""" - if self.pad: - return self.sizes[index] - return min(self.sizes[index], self.max_sample_size) - - def ordered_indices(self): - """Return an ordered list of indices. Batches will be constructed based - on this order.""" - - if self.shuffle: - order = [np.random.permutation(len(self))] - order.append( - np.minimum( - np.array(self.sizes), - self.max_sample_size, - ) - ) - return np.lexsort(order)[::-1] - else: - return np.arange(len(self)) - - def set_bucket_info(self, num_buckets): - self.num_buckets = num_buckets - if self.num_buckets > 0: - self._collated_sizes = np.minimum( - np.array(self.sizes), - self.max_sample_size, - ) - self.buckets = get_buckets( - self._collated_sizes, - self.num_buckets, - ) - self._bucketed_sizes = get_bucketed_sizes( - self._collated_sizes, self.buckets - ) - logger.info( - f"{len(self.buckets)} bucket(s) for the audio dataset: " - f"{self.buckets}" - ) - - def filter_indices_by_size(self, indices, max_sizes): - return indices, [] - -class Read_and_PadCrop_Normalized_T(torch.nn.Module): - def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): - super().__init__() - - self.n_samples = n_samples - self.sample_rate = sample_rate - self.randomize = randomize - - - def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: - if(duration<(float(self.n_samples)/self.sample_rate+1)): - # print(duration,(float(self.n_samples)/self.sample_rate+1)) - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) - offset = 0 - # print('c1:',chunk.shape) - else: - offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - t_start = offset / float(cur_sample_rate) / duration - t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration - chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - # print('offset:',offset) - # print('c0:',chunk.shape) - # Pad with silence if necessary. - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - if(cur_sample_rate!=self.sample_rate): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) - # print('b:',self.sample_rate,chunk.shape) - if chunk.shape[-1] < self.n_samples: - chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) - else: - chunk = chunk[:,0:self.n_samples] - seconds_start = math.floor(offset / cur_sample_rate) - seconds_total = math.floor(duration) - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total - ) - -class FileAudioDataset(RawAudioDataset): - def __init__( - self, - manifest_path, - sample_rate, - fixed_duration=None, - max_sample_size=None, - min_sample_size=0, - shuffle=True, - pad=False, - normalize=False, - num_buckets=0, - compute_mask=False, - text_compression_level=TextCompressionLevel.none, - h5_format=False, - downsr_16hz=False, - wav2fbank=False, - target_length=1024, - esc50_eval=False, - spcv2_eval=False, - roll_mag_aug=False, - noise=False, - train_mode='train', - **mask_compute_kwargs, - ): - super().__init__( - sample_rate=sample_rate, - max_sample_size=max_sample_size, - min_sample_size=min_sample_size, - shuffle=shuffle, - pad=pad, - normalize=normalize, - compute_mask=compute_mask, - **mask_compute_kwargs, - ) - - self.text_compressor = TextCompressor(level=text_compression_level) - self.h5_format = h5_format - self.downsr_16hz = downsr_16hz - self.wav2fbank = wav2fbank - self.target_length = target_length - self.esc50_eval = esc50_eval - self.spcv2_eval = spcv2_eval - self.roll_mag_aug = roll_mag_aug - self.noise = noise - self.train_mode = train_mode - self.reader = Read_and_PadCrop_Normalized_T(n_samples = int(fixed_duration*sample_rate), sample_rate = sample_rate) - - skipped = 0 - self.fnames = [] - sizes = [] - self.skipped_indices = set() - - # exclude data not in sample rate range 10.h5/****.wav 320000 - self.durations = [] - self.raw_srs = [] - datas, inds, tot, sizes = load_audio_by_json(manifest_path, max_keep=None, min_keep=None) - for data in datas: - self.fnames.append(self.text_compressor.compress(data['path'])) - self.durations.append(data['duration']) - self.raw_srs.append(data['sample_rate']) - - logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") - - if self.esc50_eval: - task_dataset = "ESC-50" - elif self.spcv2_eval: - task_dataset = "SPC-2" - else: - task_dataset = "AS" - - logger.info( - f"sample rate: {sample_rate}\t" - f"target length: {self.target_length}\t" - f"current task: {task_dataset}\t" - ) - - # self.sizes = np.array(sizes, dtype=np.int64) - self.sizes = np.array(sizes, dtype=np.int64) - self.durations = np.array(self.durations) - self.raw_srs = np.array(self.raw_srs) - - try: - import pyarrow - - self.fnames = pyarrow.array(self.fnames) - except: - logger.debug( - "Could not create a pyarrow array. Please install pyarrow for better performance" - ) - pass - - self.set_bucket_info(num_buckets) - # print("skipped_index: {}".format(self.skipped_indices)) - # print(len(self.skipped_indices)) - - # two file format. h5_format = true -> .h5(.hdf5) ; h5_format = false -> .wav - def __getitem__(self, index): - import soundfile as sf - - fn = self.fnames[index] - fn = fn if isinstance(self.fnames, list) else fn.as_py() - fn = self.text_compressor.decompress(fn) - path_or_fp = fn # os.path.join(self.root_dir, fn) - _path, slice_ptr = parse_path(path_or_fp) - if len(slice_ptr) == 2: - byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) # root/10.h5/***.wav - assert is_sf_audio_data(byte_data) - path_or_fp = io.BytesIO(byte_data) - - retry = 3 - wav = None - for i in range(retry): - try: - # if self.h5_format and self.train_mode == 'train': - # parts = path_or_fp.split("/") - # path_or_fp = "/".join(parts[:-1]) - # path_or_fp = h5py.File(path_or_fp,'r') - # wav = path_or_fp[parts[-1]][:] - # curr_sample_rate = 32000 - # break - # else: - # wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") - # break - wav, *ignored = self.reader(fn, self.durations[index], self.raw_srs[index]) - curr_sample_rate = self.reader.sample_rate - except Exception as e: - logger.warning( - f"Failed to read {path_or_fp}: {e}. Sleeping for {1 * i}" - ) - time.sleep(1 * i) - - if wav is None: - raise Exception(f"Failed to load {path_or_fp}") - - if self.h5_format: - feats = torch.tensor(wav).float() - else: - if not isinstance(wav, torch.Tensor): - feats = torch.from_numpy(wav).float() - else: - feats = wav - if len(feats.shape) == 2: - feats = feats.squeeze(dim=0) - - if self.downsr_16hz: - feats = torchaudio.functional.resample(feats, orig_freq=curr_sample_rate, new_freq=16000) - curr_sample_rate = 16000 - self.sample_rate = curr_sample_rate - - # whether to use roll augmentation on waveform - use_roll = self.roll_mag_aug and self.train_mode == 'train' - - feats = self.postprocess(feats, curr_sample_rate, use_roll) - - # convert waveform to spectrogram - if self.wav2fbank: #这里,将wav转换为mel谱 - feats = feats.unsqueeze(dim=0) - feats = torchaudio.compliance.kaldi.fbank(feats, htk_compat=True, sample_frequency=curr_sample_rate, use_energy=False, - window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=40).unsqueeze(dim=0) - - # padding - n_frames = feats.shape[1] - diff = self.target_length - n_frames #时间维度上补齐至1024 - if diff > 0: - m = torch.nn.ZeroPad2d((0, 0, 0, diff)) - feats = m(feats) - - elif diff < 0: - feats = feats[0:self.target_length, :] - - # global normalization for AS - self.norm_mean = -4.268 - self.norm_std = 4.569 - - # global normalization for ESC-50 - if self.esc50_eval: - self.norm_mean = -6.627 - self.norm_std = 5.359 - - # global normalization for spcv2 - if self.spcv2_eval: - self.norm_mean = -6.846 - self.norm_std = 5.565 - - feats = (feats - self.norm_mean) / (self.norm_std * 2) - - if self.noise and self.train_mode == 'train': - feats = feats + torch.rand(feats.shape[1], feats.shape[2]) * np.random.rand() / 10 # 这个加noise的方式视情况可能要换成inbatch noise - feats = torch.roll(feats, np.random.randint(-10, 10), 1) - - v = {"id": index, "source": feats} - - if self.is_compute_mask: - T = self._get_mask_indices_dims(feats.size(-1)) - mask = compute_block_mask_1d( - shape=(self.clone_batch, T), - mask_prob=self.mask_prob, - mask_length=self.mask_length, - mask_prob_adjust=self.mask_prob_adjust, - inverse_mask=self.inverse_mask, - require_same_masks=True, - expand_adjcent=self.expand_adjacent, - mask_dropout=self.mask_dropout, - non_overlapping=self.non_overlapping, - ) - - v["precomputed_mask"] = mask - - return v diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py deleted file mode 100644 index b05f757859fc83a8cd0555ba53a161a4bac99c08..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py +++ /dev/null @@ -1,639 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import itertools -import logging -import os -import sys -from typing import Any, List, Optional, Union - -import numpy as np -from typing import Tuple -import torch -import torch.nn.functional as F -from fairseq.data import data_utils -from fairseq.data.fairseq_dataset import FairseqDataset -from fairseq.data.audio.audio_utils import ( - parse_path, - read_from_stored_zip, -) - -import math -import io -import torchaudio -# this is in the user_dir -from nnAudio import features as nnAudioFeatures - -# from tqdm import tqdm -import tqdm -import json -import random -import traceback -# from scripts.prepare_codecs_from_manifest import * - -logger = logging.getLogger(__name__) - -class model_cqt_pred(torch.nn.Module): - def __init__(self, n_bins=84, sr=16000, freq=50): - super().__init__() - self.epsilon=1e-10 - # Getting Mel Spectrogram on the fly - self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7, - fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7, - filter_scale=1, norm=1, window='hann', center=True, - pad_mode='constant', trainable=False, - output_format='Magnitude', verbose=True) - - # self.fc = nn.Linear(input_dim, n_bins) - - # self.criterion = nn.MSELoss() - self.forward_dict = { - # 'masked_transformer_output': self.plain_forward - 'compute_cqt': self.compute_cqt - } - def compute_cqt(self, x): - ''' - convert waveform to CQT -> [batch, bins, len] -> transpose - ''' - # align with the padding of HuBERT model, - # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different - # x = x[..., :-560] - return torch.transpose(self.spec_layer(x), -1, -2) - - def forward(self, x, forward_type='masked_transformer_output'): - ''' - take input from transformer hidden states: [batch, len_seq, channel] - output: [batch, len_seq, n_bins] - ''' - - return self.forward_dict[forward_type](x) -# def audio2label(wav,sr): -# wav = convert_audio(wav, sr, model.sample_rate, model.channels) -# wav = wav.unsqueeze(0) -# wav = wav.to(device) -# with torch.no_grad(): -# encoded_frames = model.encode(wav) -# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] -# codes = codes.to('cpu')[0] - -# # for i in range(args.n_codebook): -# # f_codecs[i].write(' '.join([str(x) for x in codes[i].numpy()]) + '\n') -def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5): - # read json file - print(json_path) - datas = [] - inds = [] - sizes = [] - with open(json_path) as fp: - for ind,line in enumerate(fp): - data = json.loads(line) - datas.append(data) - inds.append(ind) - # sz = int(data['duration'] * data['sample_rate']) - sz = int(tgt_sample_rate * clip_secs) - sizes.append(sz) - tot = ind + 1 - return datas,inds,tot,sizes -def load_audio(manifest_path, max_keep, min_keep): #读取tsv文件(原本) - print(manifest_path) - - n_long, n_short = 0, 0 - names, inds, sizes = [], [], [] - with open(manifest_path) as f: - root = f.readline().strip() - for ind, line in enumerate(f): - items = line.strip().split("\t") - assert len(items) == 2, line - sz = int(items[1]) - if min_keep is not None and sz < min_keep: - n_short += 1 - elif max_keep is not None and sz > max_keep: - n_long += 1 - else: - names.append(items[0]) - inds.append(ind) - sizes.append(sz) - tot = ind + 1 - logger.info( - ( - f"max_keep={max_keep}, min_keep={min_keep}, " - f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " - f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" - ) - ) - return root, names, inds, tot, sizes - - -def load_label(label_path, inds, tot): - with open(label_path) as f: - labels = [] - for line in tqdm.tqdm(f): - labels.append(line.rstrip()) - # labels = [line.rstrip() ] - assert ( - len(labels) == tot - ), f"number of labels does not match ({len(labels)} != {tot})" - labels = [labels[i] for i in inds] - return labels - -def load_numpy_label(label_path, inds, tot): - labels = np.load(label_path, mmap_mode='r') - assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})" - return labels - - -# def load_label_offset(label_path, inds, tot): -# with open(label_path) as f: -# code_lengths = [len(line.encode("utf-8")) for line in f] -# assert ( -# len(code_lengths) == tot -# ), f"number of labels does not match ({len(code_lengths)} != {tot})" -# offsets = list(itertools.accumulate([0] + code_lengths)) -# offsets = [(offsets[i], offsets[i + 1]) for i in inds] -# return offsets - - -def verify_label_lengths( - audio_sizes, - audio_rate, - label_path, - label_rate, - inds, - tot, - tol=0.1, # tolerance in seconds -): - if label_rate < 0: - logger.info(f"{label_path} is sequence label. skipped") - return - - with open(label_path) as f: - lengths = [] - for line in tqdm.tqdm(f): - lengths.append(len(line.rstrip().split())) - assert len(lengths) == tot - lengths = [lengths[i] for i in inds] - num_invalid = 0 - for i, ind in enumerate(inds): - dur_from_audio = audio_sizes[i] / audio_rate - dur_from_label = lengths[i] / label_rate - if abs(dur_from_audio - dur_from_label) > tol: - logger.warning( - ( - f"audio and label duration differ too much " - f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " - f"in line {ind+1} of {label_path}. Check if `label_rate` " - f"is correctly set (currently {label_rate}). " - f"num. of samples = {audio_sizes[i]}; " - f"label length = {lengths[i]}" - ) - ) - num_invalid += 1 - if num_invalid > 0: - logger.warning( - f"total {num_invalid} (audio, label) pairs with mismatched lengths" - ) - -class Read_and_PadCrop_Normalized_T(torch.nn.Module): - def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): - - super().__init__() - - self.n_samples = n_samples - self.sample_rate = sample_rate - self.randomize = randomize - - - def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: - if(duration<(float(self.n_samples)/self.sample_rate+1)): - # print(duration,(float(self.n_samples)/self.sample_rate+1)) - chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) - t_start = 0. - t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) - offset = 0 - # print('c1:',chunk.shape) - else: - offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - t_start = offset / float(cur_sample_rate) / duration - t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration - chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) - # print('offset:',offset) - # print('c0:',chunk.shape) - # Pad with silence if necessary. - if(chunk.shape[0]>1): - chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() - else: - chunk = chunk[[0],:].float() - if(cur_sample_rate!=self.sample_rate): - # print('a:',cur_sample_rate,chunk.shape) - chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) - # print('b:',self.sample_rate,chunk.shape) - if chunk.shape[-1] < self.n_samples: - chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) - else: - chunk = chunk[:,0:self.n_samples] - seconds_start = math.floor(offset / cur_sample_rate) - seconds_total = math.floor(duration) - - return ( - chunk, - t_start, - t_end, - seconds_start, - seconds_total - ) - - -class MERTDataset(FairseqDataset): - def __init__( - self, - manifest_path: str, - sample_rate: float, - label_paths: List[str], - label_rates: Union[List[float], float], # -1 for sequence labels - pad_list: List[str], - eos_list: List[str], - label_processors: Optional[List[Any]] = None, - max_keep_sample_size: Optional[int] = None, - min_keep_sample_size: Optional[int] = None, - max_sample_size: Optional[int] = None, - shuffle: bool = True, - pad_audio: bool = False, - normalize: bool = False, - store_labels: bool = True, - npmemmap: bool = False, - random_crop: bool = False, - single_target: bool = False, - augmentation_effects: List[str] = [], - augmentation_probs: List[float] = [], - inbatch_noise_augment_len_range: List[int] = [8000, 24000], - inbatch_noise_augment_number_range: List[int] = [1, 3], - inbatch_noise_augment_volume: float = 1.0, - cqt_prediction_bin: int = -1, - dataset_len:int = 128*3000, - clip_secs = 5, - ): - self.sample_rate = sample_rate - self.shuffle = shuffle - self.random_crop = random_crop - self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs) - - self.num_labels = len(label_paths) - self.pad_list = pad_list - self.eos_list = eos_list - self.label_processors = label_processors - self.single_target = single_target - self.label_rates = ( - [label_rates for _ in range(len(label_paths))] - if isinstance(label_rates, float) - else label_rates - ) - self.store_labels = store_labels - self.npmemmap = npmemmap - - # self.dataset_len = dataset_len - self.dataset_len = len(self.datas) - logger.info('preparing labels') - logger.info('========dataset len: {}=========='.format(self.dataset_len)) - if store_labels: - if self.npmemmap: - self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths] - else: - self.label_list = [load_label(p, inds, tot) for p in label_paths] - else: - self.label_paths = label_paths - - assert label_processors is None or len(label_processors) == self.num_labels - - self.max_sample_size = ( - max_sample_size if max_sample_size is not None else sys.maxsize - ) - self.pad_audio = pad_audio - self.normalize = normalize - logger.info( - f"pad_audio={pad_audio}, random_crop={random_crop}, " - f"normalize={normalize}, max_sample_size={self.max_sample_size}" - ) - - self.augmentation_effects = augmentation_effects - self.augmentation_probs = augmentation_probs - - self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range - self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range - self.inbatch_noise_augment_volume = inbatch_noise_augment_volume - - - self.cqt_prediction_bin = cqt_prediction_bin - if self.cqt_prediction_bin > 0: - self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin) - logger.info('preparing cqt loss objective in dataloader with cpu') - - self.epoch = -1 - - self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate,sample_rate = self.sample_rate) - - - - @property - def can_reuse_epoch_itr_across_epochs(self): - """ - Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for - this dataset across epochs. - - This needs to return ``False`` if the sample sizes can change across - epochs, in which case we may need to regenerate batches at each epoch. - If your dataset relies in ``set_epoch`` then you should consider setting - this to ``False``. - """ - return False - def set_epoch(self, epoch): - """Will receive the updated epoch number at the beginning of the epoch.""" - self.epoch = epoch - - def inbatch_noise_augment(self, - target_audio: torch.Tensor, target_audio_idx: int , - batch_audios: torch.Tensor, # [bsz, audio_lengths] - noise_len_min: int, noise_len_max: int, - n_noise_min: int, n_noise_max: int, - noise_vol: float = 1.0): - ''' - augmenation that leverages in-batch noise audios. - noise_len_min and noise_len_max are the range of the lengths of noises (counted as samples) - n_noise_min and n_noise_max are the range of number of noises, - ''' - # assert noise_len_max <= target_audio.shape[0] and noise_len_min >= 1 # should assert this outside? - - augmented_audio = torch.clone(target_audio) - - # exclude the target audio and use the rest as noise candidates - noise_pool = torch.cat( batch_audios[:target_audio_idx] + batch_audios[target_audio_idx+1:], dim=0).view(-1) - - n_noise = np.random.randint(n_noise_min, n_noise_max) - # n_noise - random_start_idxs = np.random.randint(0, noise_pool.shape[0] - noise_len_max, size=(n_noise,)) - random_durations = np.random.randint(noise_len_min, noise_len_max, size=(n_noise,)) - - for noise_idx in range(n_noise): - augmentation_position = np.random.randint(0, target_audio.shape[0] - random_durations[noise_idx], size=None) - # assign noise to the original audio - augmented_audio[augmentation_position:augmentation_position+random_durations[noise_idx]] += \ - noise_vol * noise_pool[random_start_idxs[noise_idx]: random_start_idxs[noise_idx]+random_durations[noise_idx]] - - return augmented_audio - def get_audio_by_slice(self,index): - wav_path = self.datas[index]['path'] - audio_info = torchaudio.info(wav_path) - origin_sample_rate = audio_info.sample_rate - origin_duration = audio_info.num_frames / origin_sample_rate - - wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate) - wav = wav.float() - - wav = wav.permute(1,0) - wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化 - return wav - - def get_audio(self, index): - import soundfile as sf - wav_path = self.audio_names[index] - _path, slice_ptr = parse_path(wav_path) - if len(slice_ptr) == 0: - wav, cur_sample_rate = sf.read(_path) - else: - assert _path.endswith(".zip") - data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) - f = io.BytesIO(data) - wav, cur_sample_rate = sf.read(f) - wav = torch.from_numpy(wav).float() - - wav = self.postprocess(wav, cur_sample_rate) #降至单个声道,确认采样率,归一化 - # print(wav.shape) - return wav - - def get_label(self, index, label_idx): - if self.store_labels and (not self.npmemmap): - label = self.label_list[label_idx][index] - elif self.store_labels and self.npmemmap: - label = self.label_list[label_idx][index] - else: - with open(self.label_paths[label_idx]) as f: - offset_s, offset_e = self.label_offsets_list[label_idx][index] - f.seek(offset_s) - label = f.read(offset_e - offset_s) - - if self.label_processors is not None: - label = self.label_processors[label_idx](label) - return 0 - - def get_labels(self, index): - return [self.get_label(index, i) for i in range(self.num_labels)] - - #在这里修改,将raw_data直接处理完放在里面;如果已经处理过则直接读取 - def __getitem__(self, i): - # WORLD_SIZE = int(torch.distributed.get_world_size()) - # WORLD_RANK = int(torch.distributed.get_rank()) - # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i) - # index = random.randint(0,len(self.sizes) - 1) - index = i - item = None - while item is None: - try: - wav = self.get_audio_by_slice(index) - # labels = self.get_labels(index) #这个得改 - # labels = None - # item = {"id": index, "source": wav, "label_list": labels} - item = {"id": index, "source": wav} - except Exception as e: - # print(e) - traceback.print_exc() - print(f'skip damaged data {index}') - index = np.random.randint(0,len(self.sizes)-1) - return item - - def __len__(self): - return self.dataset_len - - def crop_to_max_size(self, wav, target_size): - size = len(wav) - diff = size - target_size - if diff <= 0: - return wav, 0 - - start, end = 0, target_size - if self.random_crop: - start = np.random.randint(0, diff + 1) - end = size - diff + start - return wav[start:end], start - - def collater(self, samples): - #这个方法类似collate_fn - samples = [s for s in samples if s["source"] is not None] - if len(samples) == 0: - return {} - - audios = [s["source"] for s in samples] - audio_sizes = [len(s) for s in audios] - if self.pad_audio: - audio_size = min(max(audio_sizes), self.max_sample_size) - else: - audio_size = min(min(audio_sizes), self.max_sample_size) - collated_audios, padding_mask, audio_starts, collated_cqt_labels = self.collater_audio( - audios, audio_size - ) - - # targets_by_label = [ - # [s["label_list"][i] for s in samples] for i in range(self.num_labels) - # ] - # targets_list, lengths_list, ntokens_list = self.collater_label( - # targets_by_label, audio_size, audio_starts - # ) - - net_input = {"source": collated_audios, "padding_mask": padding_mask, "cqt_labels": collated_cqt_labels} - - batch = { - "id": torch.LongTensor([s["id"] for s in samples]), - "net_input": net_input, - } - - if self.single_target: - batch["target_lengths"] = None - batch["ntokens"] = None - batch["target"] = None - else: - batch["target_lengths_list"] = None - batch["ntokens_list"] = None - batch["target_list"] = None - return batch - - def collater_audio(self, audios, audio_size): - collated_audios = audios[0].new_zeros(len(audios), audio_size) - padding_mask = ( - torch.BoolTensor(collated_audios.shape).fill_(False) - # if self.pad_audio else None - ) - audio_starts = [0 for _ in audios] - - for i, audio in enumerate(audios): - diff = len(audio) - audio_size - if diff == 0: - collated_audios[i] = audio - elif diff < 0: - assert self.pad_audio - collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) - padding_mask[i, diff:] = True - else: - collated_audios[i], audio_starts[i] = self.crop_to_max_size( - audio, audio_size - ) - - cqt_labels = None - if self.cqt_prediction_bin > 0: - cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt') - - for i, _ in enumerate(audios): - if len(self.augmentation_effects) > 0: - with torch.no_grad(): - for effect, prob in zip(self.augmentation_effects, self.augmentation_probs): - if torch.rand(1).item() > prob: - if effect == 'composed_augmentation_v1': - # collated_audios[i] = self.composed_augment_v1(collated_audios[i]) - pass - elif effect == 'inbatch_noise_augment': - assert len(audios) > 1 - collated_audios[i] = self.inbatch_noise_augment( - target_audio = collated_audios[i], target_audio_idx = i, batch_audios = audios, - noise_len_min = self.inbatch_noise_augment_len_range[0], noise_len_max = self.inbatch_noise_augment_len_range[1], - n_noise_min = self.inbatch_noise_augment_number_range[0], n_noise_max = self.inbatch_noise_augment_number_range[1], - noise_vol = self.inbatch_noise_augment_volume) - else: - raise NotImplementedError() - - - return collated_audios, padding_mask, audio_starts, cqt_labels - - def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad): - assert label_rate > 0 - s2f = label_rate / self.sample_rate - frm_starts = [int(round(s * s2f)) for s in audio_starts] - frm_size = int(round(audio_size * s2f)) - if not self.pad_audio: - rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] - frm_size = min(frm_size, *rem_size) - targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] - logger.debug(f"audio_starts={audio_starts}") - logger.debug(f"frame_starts={frm_starts}") - logger.debug(f"frame_size={frm_size}") - - lengths = torch.LongTensor([len(t) for t in targets]) - ntokens = lengths.sum().item() - targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) - return targets, lengths, ntokens - - def collater_seq_label(self, targets, pad): - lengths = torch.LongTensor([len(t) for t in targets]) - ntokens = lengths.sum().item() - targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) - return targets, lengths, ntokens - - def collater_label(self, targets_by_label, audio_size, audio_starts): - targets_list, lengths_list, ntokens_list = [], [], [] - itr = zip(targets_by_label, self.label_rates, self.pad_list) - for targets, label_rate, pad in itr: - if label_rate == -1.0: - targets, lengths, ntokens = self.collater_seq_label(targets, pad) - else: - targets, lengths, ntokens = self.collater_frm_label( - targets, audio_size, audio_starts, label_rate, pad - ) - targets_list.append(targets) - lengths_list.append(lengths) - ntokens_list.append(ntokens) - return targets_list, lengths_list, ntokens_list - - def num_tokens(self, index): - return self.size(index) - - def size(self, index): - if self.pad_audio: - return self.sizes[index] - return min(self.sizes[index], self.max_sample_size) - - # def ordered_indices(self): - # if self.shuffle: - # order = [np.random.permutation(len(self.sizes))] - # else: - # order = [np.arange(len(self.sizes))] - - # order.append(self.sizes) - # return np.lexsort(order)[::-1] - - def ordered_indices(self): - if self.shuffle: - try: - print("========Local rank :",torch.distributed.get_rank(),"========") - WORLD_SIZE = int(torch.distributed.get_world_size()) - WORLD_RANK = int(torch.distributed.get_rank()) - np.random.seed(self.epoch * WORLD_SIZE + WORLD_RANK) - order = np.random.permutation(len(self.sizes)) - print("==================multinode multigpu shuffle==================") - except: - print("==================singlenode shuffle==================") - order = np.random.permutation(len(self.sizes)) - else: - order = np.arange(len(self.sizes)) - - return order - - def postprocess(self, wav, cur_sample_rate): - if wav.dim() == 2: - wav = wav.mean(-1) - assert wav.dim() == 1, wav.dim() - - if cur_sample_rate != self.sample_rate: - raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") - - if self.normalize: - with torch.no_grad(): - wav = F.layer_norm(wav, wav.shape) - return wav diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/_test.ipynb b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/_test.ipynb deleted file mode 100644 index 283e1e94ebba98991f9e3ed3f4f05d9956eaf585..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/_test.ipynb +++ /dev/null @@ -1,143 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(1034, (94, 11))" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from data_utils import compute_block_mask_2d\n", - "\n", - "\n", - "input_size = (752, 88) # 94 * 11 # (1024, 128) # 64 * 8\n", - "patch_size = 8\n", - "patches = (input_size[0] // patch_size ) * ( input_size[1] // patch_size )\n", - "img_shape = (input_size[0] // patch_size,input_size[1] // patch_size )\n", - "mask = compute_block_mask_2d( \n", - " shape=(16, patches),\n", - " mask_prob=0.8,\n", - " mask_length=2,\n", - " mask_prob_adjust=0.07,\n", - " inverse_mask=True,\n", - " require_same_masks=True,\n", - " expand_adjcent=False,\n", - " mask_dropout=0.0,\n", - " non_overlapping=False,\n", - " img_shape=img_shape,\n", - " flexible_mask=False\n", - ")\n", - "patches, img_shape" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([16, 1034])\n", - "tensor([828., 828., 828., 828., 828., 828., 828., 828., 828., 828., 828., 828.,\n", - " 828., 828., 828., 828.])\n" - ] - } - ], - "source": [ - "print(mask.shape)\n", - "print(mask.sum(dim=1))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# library 导入库\n", - "import seaborn as sns\n", - "import pandas as pd\n", - "import numpy as np\n", - "# jupyter notebook显示多行输出\n", - "from IPython.core.interactiveshell import InteractiveShell \n", - "InteractiveShell.ast_node_interactivity = 'all'\n" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "
" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAABXgAAADLCAYAAADUbftLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA46klEQVR4nO3da3xU1b3/8e8kJJOAJ6CmSQjIzRsiCJVAGhBpaypaDxbbI6jUIFosChZJRQgoUTkaPFYup4IoFFFbBe1RpIVCMYVWJTXlJqVVLoKlBRLgpVzkkoSZ9X/gPykDk5nZM2smmeTz9rUfZM/M/q4985s1a5abNS5jjBEAAAAAAAAAIO4kNHQDAAAAAAAAAADhYYIXAAAAAAAAAOIUE7wAAAAAAAAAEKeY4AUAAAAAAACAOMUELwAAAAAAAADEKSZ4AQAAAAAAACBOMcELAAAAAAAAAHGKCV4AAAAAAAAAiFNM8AIAAAAAAABAnGKCFwAAAAAAAADiFBO8AAAAAAAAABChP/3pTxo8eLCys7Plcrm0dOnSoI9Zu3atrr76arndbl1yySVatGiR41wmeAEAAAAAAAAgQsePH1fPnj01Z86ckO6/e/du3XTTTfrWt76lzZs368EHH9SPfvQjrVq1ylGuyxhjwmkwAAAAAAAAAOBcLpdLb7/9toYMGVLvfSZOnKjly5dr69atdftuu+02HT58WCtXrgw5iyt4AQAAAAAAAMCPqqoqHT161GerqqqycuyysjLl5+f77Bs0aJDKysocHaeFldZY0CK5nePHnNz3XhRagniVmj3A8WOoIQDxKJz+Llz0k01XrOqIGgLOxfsPkaKGANQnKb1LQzch7tQc2hXw9pLnXtHjjz/us6+4uFiPPfZYxNkVFRXKzMz02ZeZmamjR4/q5MmTSk1NDek4jWaCFwAAAAAAAABiylMT8OaioiIVFhb67HO73dFskWNM8AIAAAAAAABonrzegDe73e6oTehmZWWpsrLSZ19lZaXS0tJCvnpXYoIXAAAAAAAAQDNlPKcbLDsvL08rVqzw2bd69Wrl5eU5Og4/sgYAAAAAAACgefLUBN4c+PLLL7V582Zt3rxZkrR7925t3rxZe/bskfTVcg8FBQV19x89erR27dqlhx9+WJ988onmzp2rN954Q+PHj3eU6/gK3kOHDmnhwoUqKytTRUWFpK8uJ+7Xr5/uuusufe1rX3N6SAAAAAAAAACIPRN4iQYn1q9fr29961t1f9eu3TtixAgtWrRI+/fvr5vslaTOnTtr+fLlGj9+vGbPnq327dtrwYIFGjRokKNcRxO8f/nLXzRo0CC1bNlS+fn5uuyyyyR9tTbE//7v/2r69OlatWqVcnJyAh6nqqpKVVVVPvuMMXK5XI4aDwAAAAAAAADhsrlEwze/+U0ZY+q9fdGiRX4fs2nTpohyHU3wPvDAA7r11ls1b968cyZjjTEaPXq0HnjgAZWVlQU8TklJiR5//HGffa6E8+RKTHPSHAAAAAAAAAAIXwOuwWuLozV4P/roI40fP97vlbYul0vjx4+vW2MikKKiIh05csRncyX8h5OmAAAAAAAAAEBkvJ7AWxxwdAVvVlaWysvL1bVrV7+3l5eXKzMzM+hx3G633G63zz6WZwAAAAAAAAAQU03gCl5HE7wPPfSQ7r33Xm3YsEHXXXdd3WRuZWWlSktLNX/+fP3sZz+LSkMBAAAAAAAAwCqLP7LWUBxN8I4ZM0bp6emaOXOm5s6dK4/nq8uUExMT1bt3by1atEhDhw6NSkMBAAAAAAAAwCbjqWnoJkTM0QSvJA0bNkzDhg1TTU2NDh06JElKT09XUlKS9cYBAAAAAAAAQNQ0tyUazpSUlKS2bdvabAsAAAAAAAAAxE4TWKLBZYwxDd0ISao5tKuhm9AopGYPcPyYk/vei0JL4k84z104Gvvz3RRrKFavbTga+3MXDmoofI39eQAAm+hbEammOOZoamI5Due1bZqoodhLSu/S0E2IO6fK3wx4e0rfW2PUkvCFfQUvAAAAAAAAAMS15rxEAwAAAAAAAADENW/8L9HABC8AAAAAAACAZsl4ahq6CRFjghcAAAAAAABA88QSDQAAAAAAAAAQp1iiAQAAAAAAAADiFFfwAgAAAAAAAECcYoIXAAAAAAAAAOIUSzQAAAAAAAAAQJziCl4AAAAAAAAAiFNM8AIAAAAAAABAnDIs0QAAAAAAAAAA8el0/F/Bm9DQDQAAAAAAAACABuHxBN7CMGfOHHXq1EkpKSnKzc1VeXl5wPvPmjVLl19+uVJTU3XRRRdp/PjxOnXqVMh5XMHbyJzc915DNyFuNcXnLjV7gOPHNMXnoSmeU6xQQ19piucExIum2A81xXNqzMJ5vsMR7msUq/aFo7HXXTjta8zPtxS757yxvy9ihf74K9RD+KghSJK8dpdoWLJkiQoLCzVv3jzl5uZq1qxZGjRokLZt26aMjIxz7v/aa69p0qRJWrhwofr166ft27frrrvuksvl0owZM0LK5ApeAAAAAAAAAM2T53TgzaEZM2Zo1KhRGjlypLp166Z58+apZcuWWrhwod/7r1u3Tv3799cdd9yhTp066frrr9ftt98e9KrfMzHBCwAAAAAAAKB5CrJEQ1VVlY4ePeqzVVVV+T1UdXW1NmzYoPz8/Lp9CQkJys/PV1lZmd/H9OvXTxs2bKib0N21a5dWrFih7373uyGfAhO8AAAAAAAAAJonrzfgVlJSotatW/tsJSUlfg916NAheTweZWZm+uzPzMxURUWF38fccccdeuKJJ3TNNdcoKSlJF198sb75zW9q8uTJIZ8CE7wAAAAAAAAAmiXj8QTcioqKdOTIEZ+tqKjIWv7atWv11FNPae7cudq4caPeeustLV++XNOmTQv5GPzIGgAAAAAAAIDmKcg6u263W263O6RDpaenKzExUZWVlT77KysrlZWV5fcxjz76qO6880796Ec/kiT16NFDx48f17333qspU6YoISH49bmOr+A9efKk3n//ff39738/57ZTp07plVdeCXoMJ2tXAAAAAAAAAEBUeE3gzYHk5GT17t1bpaWl/z6816vS0lLl5eX5fcyJEyfOmcRNTEyUJBkTWr6jCd7t27friiuu0LXXXqsePXpo4MCB2r9/f93tR44c0ciRI4Mex9/aFU/PnuekKQAAAAAAAAAQmdOnA28OFRYWav78+Xr55Zf18ccf67777tPx48fr5kwLCgp8lngYPHiwnn/+eS1evFi7d+/W6tWr9eijj2rw4MF1E73BOFqiYeLEierevbvWr1+vw4cP68EHH1T//v21du1adejQIeTjFBUVqbCw0GdfwrG9TpoCAAAAAAAAAJHxeKwebtiwYTp48KCmTp2qiooK9erVSytXrqz74bU9e/b4XLH7yCOPyOVy6ZFHHtHevXv1ta99TYMHD9aTTz4ZcqajCd5169bp3XffVXp6utLT0/Wb3/xG999/vwYMGKA1a9aoVatWIR3H39oVNdWHnDQFAAAAAAAAACLjcBmGUIwdO1Zjx471e9vatWt9/m7RooWKi4tVXFwcdp6jJRpOnjypFi3+PSfscrn0/PPPa/DgwRo4cKC2b98edkMAAAAAAAAAIKY8nsBbHHB0BW/Xrl21fv16XXHFFT77n3vuOUnSzTffbK9lAAAAAAAAABBF5nR8TOIG4ugK3ltuuUWvv/6639uee+453X777SH/uhsAAAAAAAAANCjjDbzFAUcTvEVFRVqxYkW9t8+dO1deb3ycOAAAAAAAAIBm7rQn8BYHXKaRXHJbc2iX48ekZg+IQkvOdXLfe44fE6u2hSucc0LT1ZjfS4gMr23T1Ng/Y8JBDcUe/UPTFM7rymuEM8XyM4baa/wa+5iDGmr8qKHYS0rv0tBNiDvHHx0a8PZW096IUUvC52gNXgAAAAAAAABoMryN4trXiDDBCwAAAAAAAKBZago/ssYELwAAAAAAAIDmycMELwAAAAAAAADEJ5ZoAAAAAAAAAID4ZE57G7oJEWOCFwAAAAAAAEDzxBq8AAAAAAAAABCnWKIBAAAAAAAAAOKT8bBEAwAAAAAAAADEJdbgBQAAAAAAAIB4xRINAAAAAAAAABCfzGkmeAEAAAAAAAAgPjHBCwAAAAAAAADxyTSBJRpcxphGcRY1h3Y1dBOaldTsATHJObnvvZjkIPbCqSHqITK8bxEpagjxglpFpBinAIB99K3hC3dsE87zl5TeJays5uzzWwYGvP2Ct/8Yo5aEjyt4AQAAAAAAADRL5nRDtyByCQ3dAAAAAAAAAABoEN4gWxjmzJmjTp06KSUlRbm5uSovLw94/8OHD2vMmDFq27at3G63LrvsMq1YsSLkPK7gBQAAAAAAANAs2b6Cd8mSJSosLNS8efOUm5urWbNmadCgQdq2bZsyMjLOuX91dbW+853vKCMjQ7/+9a/Vrl07/eMf/1CbNm1CzrQywWuMkcvlsnEoAAAAAAAAAIgJr+UJ3hkzZmjUqFEaOXKkJGnevHlavny5Fi5cqEmTJp1z/4ULF+rzzz/XunXrlJSUJEnq1KmTo0wrSzS43W59/PHHNg4FAAAAAAAAADFhvIG3qqoqHT161Gerqqrye6zq6mpt2LBB+fn5dfsSEhKUn5+vsrIyv49ZtmyZ8vLyNGbMGGVmZqp79+566qmn5PF4Qj4HR1fwFhYW+t3v8Xg0ffp0XXjhhZK+mqkOpKqq6pwnIqGqSm6320lzAAAAAAAAACBsxhN4VYKSkhI9/vjjPvuKi4v12GOPnXPfQ4cOyePxKDMz02d/ZmamPvnkE7/H37Vrl/7whz9o+PDhWrFihXbu3Kn7779fNTU1Ki4uDukcHE3wzpo1Sz179jxnDQhjjD7++GO1atUqpKUa/D0xj0z4iaY+PM5JcwAAAAAAAAAgbN7Tgecyi4qKzrno1eZFql6vVxkZGXrxxReVmJio3r17a+/evXrmmWeiM8H71FNP6cUXX9Szzz6rb3/723X7k5KStGjRInXr1i2k4/h7YhKO7XXSFAAAAAAAAACIiDGBJ3jdbnfIE7rp6elKTExUZWWlz/7KykplZWX5fUzbtm2VlJSkxMTEun1XXHGFKioqVF1dreTk5KC5jtbgnTRpkpYsWaL77rtPDz30kGpqapw8vI7b7VZaWprPxvIMAAAAAAAAAGLJe9oVcHMiOTlZvXv3Vmlp6b+P7/WqtLRUeXl5fh/Tv39/7dy5U16vt27f9u3b1bZt25Amd6UwfmStT58+2rBhgw4ePKicnBxt3bo1pGUZAAAAAAAAAKAx8XpcATenCgsLNX/+fL388sv6+OOPdd999+n48eMaOXKkJKmgoEBFRUV197/vvvv0+eefa9y4cdq+fbuWL1+up556SmPGjAk509ESDbXOO+88vfzyy1q8eLHy8/Md/aobAAAAAAAAADQGxmv3wtVhw4bp4MGDmjp1qioqKtSrVy+tXLmy7ofX9uzZo4SEf19ze9FFF2nVqlUaP368rrrqKrVr107jxo3TxIkTQ84Ma4K31m233aZrrrlGGzZsUMeOHSM5FAAAAAAAAADEVDhX6QYzduxYjR071u9ta9euPWdfXl6e/vznP4edF9EEryS1b99e7du3j/QwAAAAAAAAABBT0ZjgjTWXMcY0dCMkqebQroZuQr1Sswc4fszJfe9FoSX+hdO+cIRzTrFqmxTb57ypacw1FEs8D4gUNYSGwGctIhVuDVEPAAA0PknpXRq6CXFn+xU3BLz9so9Xxqgl4Yv4Cl4AAAAAAAAAiEdeT0LwOzVyTPACAAAAAAAAaJaawhINTPACAAAAAAAAaJa8hgleAAAAAAAAAIhLXi8TvAAAAAAAAAAQlzxe1uAFAAAAAAAAgLhkTEO3IHJM8AIAAAAAAABolriCFwAAAAAAAADilIcfWQMAAAAAAACA+ORlghcAAAAAAAAA4hNX8AIAAAAAAABAnGKCFwAAAAAAAADilFdM8AIAAAAAAABAXPIwwds8nNz3XkM3IaDG3L7G3Db8G6/TV3geEClqCJFKzR4Qs6ymVq88d1+J5fPQ1FBDAIBQ8HnR9DDBCwAAAAAAAABxytvQDbCACV4AAAAAAAAAzZLHxRW8AAAAAAAAABCXmsKPrCU0dAMAAAAAAAAAoCF4gmzhmDNnjjp16qSUlBTl5uaqvLw8pMctXrxYLpdLQ4YMcZTHBC8AAAAAAACAZsnjcgXcnFqyZIkKCwtVXFysjRs3qmfPnho0aJAOHDgQ8HGfffaZHnroIQ0Y4PyH/BxN8G7cuFG7d++u+/vVV19V//79ddFFF+maa67R4sWLQzpOVVWVjh496rNVVVU5azkAAAAAAAAARMAbZHNqxowZGjVqlEaOHKlu3bpp3rx5atmypRYuXFjvYzwej4YPH67HH39cXbp0cZzpaIJ35MiR+vTTTyVJCxYs0I9//GPl5ORoypQp6tOnj0aNGhWwsbVKSkrUunVrn+3p2fMcNx4AAAAAAAAAwnXa5Qq4OVFdXa0NGzYoPz+/bl9CQoLy8/NVVlZW7+OeeOIJZWRk6J577gnrHBz9yNqOHTt06aWXSpLmzp2r2bNna9SoUXW39+nTR08++aTuvvvugMcpKipSYWGhz76EY3udNAUAAAAAAAAAIuIJModbVVV1zsoDbrdbbrf7nPseOnRIHo9HmZmZPvszMzP1ySef+D3++++/r1/84hfavHmzo3afydEVvC1bttShQ4ckSXv37lXfvn19bs/NzfVZwqE+brdbaWlpPpu/JwUAAAAAAAAAoiXYEg3+ViIoKSmxkn3s2DHdeeedmj9/vtLT08M+jqMreG+88UY9//zzWrBggQYOHKhf//rX6tmzZ93tb7zxhi655JKwGwMAAAAAAAAAsRLsCl5/KxHUd6Fqenq6EhMTVVlZ6bO/srJSWVlZ59z/008/1WeffabBgwfX7fN6v1r5t0WLFtq2bZsuvvjioOfgaIL36aefVv/+/TVw4EDl5OTo2Wef1dq1a3XFFVdo27Zt+vOf/6y3337bySEBAAAAAAAAoEGcDnJ7fcsx+JOcnKzevXurtLRUQ4YMkfTVhG1paanGjh17zv27du2qv/71rz77HnnkER07dkyzZ8/WRRddFFKuowne7Oxsbdq0SdOnT9dvfvMbGWNUXl6uf/7zn+rfv78++OAD5eTkODkkAAAAAAAAADQI4+x31IIqLCzUiBEjlJOTo759+2rWrFk6fvy4Ro4cKUkqKChQu3btVFJSopSUFHXv3t3n8W3atJGkc/YH4miCtzZk+vTpmj59utOHAgAAAAAAAECjEewKXqeGDRumgwcPaurUqaqoqFCvXr20cuXKuh9e27NnjxISHP0sWlAuY4yxesQw1RzaFZOc1OwBjh9zct97UWhJ/AnnuWvswnltqaH4EKt6jeVrS+0hUk2tH6e+m66m2Ic3RY35dYplf0cdha+pfS6FixqKrcbcd4WrKZ4TwpeU3qWhmxB3Znf4YcDbx+35ZYxaEj7HV/ACAAAAAAAAQFPgbegGWMAELwAAAAAAAIBmydPQDbCACV4AAAAAAAAAzdJpyz+y1hCY4AUAAAAAAADQLDWKHyeLEBO8AAAAAAAAAJql001gipcJXgAAAAAAAADNEmvwAgAAAAAAAECc8rIGLwAAAAAAAADEJw9LNAAAAAAAAABAfGINXgAAAAAAAACIU/E/vcsELwAAAAAAAIBmiit4AQAAAAAAACBOeRq6ARYwwQsAAAAAAACgWTJN4ApelzGmUZxFi+R2Dd2ERuHkvvcauglxKzV7QEM3oVGghmIrlnXX1F5bnjs0lFjVHnXXdFFDAJoLxmvxoTF/LjXFGmrs55SU3iUKLWna7u80NODtcz97I0YtCR9X8AIAAAAAAABoljxN4ApeJngBAAAAAAAANEvehm6ABUzwAgAAAAAAAGiWuIIXAAAAAAAAAOIUE7wAAAAAAAAAEKe8Jv4neBOcPuC5555TQUGBFi9eLEl69dVX1a1bN3Xt2lWTJ0/W6dOngx6jqqpKR48e9dlME3gyAQAAAAAAAMQPj0zALR44muD97//+b02ePFknTpzQ+PHj9fTTT2v8+PEaPny4RowYoQULFmjatGlBj1NSUqLWrVv7bMZ7LOyTAAAAAAAAAACnojHBO2fOHHXq1EkpKSnKzc1VeXl5vfedP3++BgwYoPPPP1/nn3++8vPzA97fH0cTvIsWLdKiRYv061//WitXrtSUKVM0e/ZsTZkyRUVFRXrhhRf02muvBT1OUVGRjhw54rO5Ev7DUcMBAAAAAAAAIBJemYCbU0uWLFFhYaGKi4u1ceNG9ezZU4MGDdKBAwf83n/t2rW6/fbbtWbNGpWVlemiiy7S9ddfr71794ac6WiCd9++fcrJyZEk9ezZUwkJCerVq1fd7VdffbX27dsX9Dhut1tpaWk+m8vlctIUAAAAAAAAAIiI7St4Z8yYoVGjRmnkyJHq1q2b5s2bp5YtW2rhwoV+7/+rX/1K999/v3r16qWuXbtqwYIF8nq9Ki0tDTnT0QRvVlaW/v73v0uSduzYIY/HU/e3JP3tb39TRkaGk0MCAAAAAAAAQIPwGG/Azd9viVVVVfk9VnV1tTZs2KD8/Py6fQkJCcrPz1dZWVlI7Tlx4oRqamp0wQUXhHwOjiZ4hw8froKCAo0aNUqDBg3Sww8/rIceekjz5s3TCy+8oNGjR+uWW25xckgAAAAAAAAAaBDeIJu/3xIrKSnxe6xDhw7J4/EoMzPTZ39mZqYqKipCas/EiROVnZ3tM0kcTIuQ7ynp8ccfV2pqqsrKyjRq1ChNmjRJPXv21MMPP6wTJ05o8ODBIf3IGgAAAAAAAAA0NI+8AW8vKipSYWGhzz632x2VtkyfPl2LFy/W2rVrlZKSEvLjHE3wJiQkaPLkyT77brvtNt12221ODgMAAAAAAAAADc5jAq+z63a7Q57QTU9PV2JioiorK332V1ZWKisrK+Bjf/azn2n69Ol69913ddVVV4WUV8vRBG80ndz3nuPHpGYPiEkOvhLO8x2ucF6nWL22sXwemppwnzte26apsffHTbEeGvtzHis8D+FrzO+LWL6uTe1zKZbPHTUUW419/I7Gj9c1PjTm16kxty1cTfGcmjsTxg+p1Sc5OVm9e/dWaWmphgwZIkl1P5g2duzYeh/3P//zP3ryySe1atUq5eTkOM5tNBO8AAAAAAAAABBLHhN4iQanCgsLNWLECOXk5Khv376aNWuWjh8/rpEjR0qSCgoK1K5du7p1fJ9++mlNnTpVr732mjp16lS3Vu95552n8847L6RMJngBAAAAAAAANEvB1uB1atiwYTp48KCmTp2qiooK9erVSytXrqz74bU9e/YoISGh7v7PP/+8qqur9V//9V8+xykuLtZjjz0WUiYTvAAAAAAAAACaJW+QNXjDMXbs2HqXZFi7dq3P35999lnEeUzwAgAAAAAAAGiWPBbX4G0oTPACAAAAAAAAaJZsr8HbEJjgBQAAAAAAANAsebmCFwAAAAAAAADik5creAEAAAAAAAAgPrFEAwAAAAAAAADEKZZoAAAAAAAAAIA4xRW8AAAAAAAAABCnmOAFAAAAAAAAgDhlWKIBAAAAAAAAAOJTU7iC12WMaRTT1DWHdjV0E6xKzR4Qs6yT+96LWRZiK1Z1RA3FB+oBtRr7Z0w47aPuEE9i+R50ivdSfKCGmi7Ga4gUNYRIJaV3aegmxJ0u6V8PePuuQ5ti1JLwcQUvAAAAAAAAgGbJNIEreMOa4K2urtbSpUtVVlamiooKSVJWVpb69eun733ve0pOTrbaSAAAAAAAAACwrSks0ZDg9AE7d+7UFVdcoREjRmjTpk3yer3yer3atGmTCgoKdOWVV2rnzp3RaCsAAAAAAAAAWOMx3oBbPHB8Be99992nHj16aNOmTUpLS/O57ejRoyooKNCYMWO0atUqa40EAAAAAAAAANu8jePnySLieIL3gw8+UHl5+TmTu5KUlpamadOmKTc310rjAAAAAAAAACBavHFylW4gjid427Rpo88++0zdu3f3e/tnn32mNm3aBDxGVVWVqqqqfPYlVFXJ7XY7bQ4AAAAAAAAAhCVelmEIxPEavD/60Y9UUFCgmTNnasuWLaqsrFRlZaW2bNmimTNn6q677tK9994b8BglJSVq3bq1z/b07HlhnwQAAAAAAAAAOOU1JuAWDxxfwfvEE0+oVatWeuaZZ/TTn/5ULpdLkmSMUVZWliZOnKiHH3444DGKiopUWFjosy/h2F6nTQEAAAAAAACAsDWFK3gdT/BK0sSJEzVx4kTt3r1bFRUVkqSsrCx17tw5pMe73e5zlmOoqT4UTlMAAAAAAAAAICweb/xP8DpeouFMnTt3Vl5envLy8uomd//5z3/q7rvvttI4AAAAAAAAAIgWE+S/cMyZM0edOnVSSkqKcnNzVV5eHvD+b775prp27aqUlBT16NFDK1ascJQX0QSvP59//rlefvll24cFAAAAAAAAAKs8Xm/AzaklS5aosLBQxcXF2rhxo3r27KlBgwbpwIEDfu+/bt063X777brnnnu0adMmDRkyREOGDNHWrVtDznS8RMOyZcsC3r5r1y6nhwQAAAAAAACAmPNaXoN3xowZGjVqlEaOHClJmjdvnpYvX66FCxdq0qRJ59x/9uzZuuGGGzRhwgRJ0rRp07R69Wo999xzmjdvXkiZjid4hwwZIpfLJRPgV+Rqf3gNAAAAAAAAABqrQHOcklRVVaWqqiqfff5+X0ySqqurtWHDBhUVFdXtS0hIUH5+vsrKyvwev6ysTIWFhT77Bg0apKVLl4Z4BpKMQ9nZ2Wbp0qX13r5p0yaTkJDg9LB+nTp1yhQXF5tTp05ZOV5jyGpqObHM4pwaf04sszinxp8Ty6ymlhPLLM6p8efEMotzavw5scxqajmxzOKcGn9OLLOaWk4sszinxp8TyyzOCQ2luLjYSPLZiouL/d537969RpJZt26dz/4JEyaYvn37+n1MUlKSee2113z2zZkzx2RkZITcRscTvIMHDzaPPvpovbdv3rzZuFwup4f168iRI0aSOXLkiJXjNYasppYTyyzOqfHnxDKLc2r8ObHMamo5sczinBp/TiyzOKfGnxPLrKaWE8sszqnx58Qyq6nlxDKLc2r8ObHM4pzQUE6dOmWOHDnis9U3Kd9QE7yOl2iYMGGCjh8/Xu/tl1xyidasWeP0sAAAAAAAAADQqNS3HIM/6enpSkxMVGVlpc/+yspKZWVl+X1MVlaWo/v7kxDyPf+/AQMG6IYbbqj39latWmngwIFODwsAAAAAAAAAcSs5OVm9e/dWaWlp3T6v16vS0lLl5eX5fUxeXp7P/SVp9erV9d7fH8dX8AIAAAAAAAAAzlVYWKgRI0YoJydHffv21axZs3T8+HGNHDlSklRQUKB27dqppKREkjRu3DgNHDhQzz77rG666SYtXrxY69ev14svvhhyZqOe4HW73SouLg75Muh4yGpqObHM4pwaf04sszinxp8Ty6ymlhPLLM6p8efEMotzavw5scxqajmxzOKcGn9OLLOaWk4sszinxp8TyyzOCfFi2LBhOnjwoKZOnaqKigr16tVLK1euVGZmpiRpz549Skj496IK/fr102uvvaZHHnlEkydP1qWXXqqlS5eqe/fuIWe6jDHG+pkAAAAAAAAAAKLO8Rq8AAAAAAAAAIDGgQleAAAAAAAAAIhTTPACAAAAAAAAQJxighcAAAAAAAAA4lSjnuCdM2eOOnXqpJSUFOXm5qq8vNzq8UtKStSnTx/9x3/8hzIyMjRkyBBt27bNaoY/06dPl8vl0oMPPhiV4+/du1c//OEPdeGFFyo1NVU9evTQ+vXrrWZ4PB49+uij6ty5s1JTU3XxxRdr2rRpsvGbfX/60580ePBgZWdny+VyaenSpT63G2M0depUtW3bVqmpqcrPz9eOHTus5tTU1GjixInq0aOHWrVqpezsbBUUFGjfvn1ROaczjR49Wi6XS7NmzYpKzscff6ybb75ZrVu3VqtWrdSnTx/t2bPHas6XX36psWPHqn379kpNTVW3bt00b948x+cTynv01KlTGjNmjC688EKdd955+sEPfqDKykrrWZ9//rkeeOABXX755UpNTVWHDh30k5/8REeOHLF+TrWMMbrxxhuD1kwkOWVlZfr2t7+tVq1aKS0tTddee61OnjxpPauiokJ33nmnsrKy1KpVK1199dX6v//7P0c5zz//vK666iqlpaUpLS1NeXl5+t3vfld3u61aCJZlqxZCOadakdRCqDk2aiGULBu14I+/z1abNVFfjs16CJZ1pkhrIpQcWzURKMdWPTz22GNyuVw+W9euXetut1kLgbJs1kOwc6oVaS2EkmOrFoJl2ewfgo2HbY0nA+XYHk86GeNHMp4MJcfGeDJYjq3xZKdOnc6pO5fLpTFjxkiy1z8EyrH9WRHsnGpF2j+EkmOrfwiWZat/COV7rI3+IViOzf7B6XfzcPuHUHMi7R9CybHVP0jSsWPH9OCDD6pjx45KTU1Vv3799Je//KXudlufF4FybH9eoJkyjdTixYtNcnKyWbhwofnb3/5mRo0aZdq0aWMqKyutZQwaNMi89NJLZuvWrWbz5s3mu9/9runQoYP58ssvrWWcrby83HTq1MlcddVVZty4cdaP//nnn5uOHTuau+66y3z44Ydm165dZtWqVWbnzp1Wc5588klz4YUXmt/+9rdm9+7d5s033zTnnXeemT17dsTHXrFihZkyZYp56623jCTz9ttv+9w+ffp007p1a7N06VLz0UcfmZtvvtl07tzZnDx50lrO4cOHTX5+vlmyZIn55JNPTFlZmenbt6/p3bt3VM6p1ltvvWV69uxpsrOzzcyZM63n7Ny501xwwQVmwoQJZuPGjWbnzp3mnXfecfy+CpYzatQoc/HFF5s1a9aY3bt3mxdeeMEkJiaad955x1FOKO/R0aNHm4suusiUlpaa9evXm2984xumX79+jnJCyfrrX/9qvv/975tly5aZnTt3mtLSUnPppZeaH/zgB9bPqdaMGTPMjTfeGLBmIslZt26dSUtLMyUlJWbr1q3mk08+MUuWLDGnTp2ynvWd73zH9OnTx3z44Yfm008/NdOmTTMJCQlm48aNIecsW7bMLF++3Gzfvt1s27bNTJ482SQlJZmtW7caY+zVQrAsW7UQyjnViqQWQsmxVQuhZNmohbPV99lqsybqy7FZD6GcU61IayJYjs2aCJRjqx6Ki4vNlVdeafbv31+3HTx4sO52m7UQKMtmPQQ7p1qR1kKwHJu1ECzLVj2EMh62MZ4MlmNzPOlkjB/JeDKUHBvjyVBybI0nDxw44FNzq1evNpLMmjVrjDH2+odAObY/K4KdU61I+4dgOTb7h2BZtvqHUL7H2ugfguXY7B+cfDePpH8IJcdG/xBKjq3+wRhjhg4darp162b++Mc/mh07dpji4mKTlpZm/vWvfxlj7M0/BMqxPf+A5qnRTvD27dvXjBkzpu5vj8djsrOzTUlJSdQyDxw4YCSZP/7xj1E5/rFjx8yll15qVq9ebQYOHBiVCd6JEyeaa665xvpxz3bTTTeZu+++22ff97//fTN8+HCrOWcPRrxer8nKyjLPPPNM3b7Dhw8bt9ttXn/9dWs5/pSXlxtJ5h//+EfYOYGy/vWvf5l27dqZrVu3mo4dO4Y1wRssZ9iwYeaHP/xhRMcNJefKK680TzzxhM++q6++2kyZMiWirLPfo4cPHzZJSUnmzTffrLvPxx9/bCSZsrIyq1n+vPHGGyY5OdnU1NRYz9m0aZNp166d2b9/f8QTOPXl5ObmmkceeSSi44aa1apVK/PKK6/43O+CCy4w8+fPjyjr/PPPNwsWLIhqLZyd5Y+NWqgvx3Yt+MuJVi34y7JdC/V9ttquCSef4ZHWQ7AsWzURKMdmTQTKsVUPxcXFpmfPnn5vs10LgbL8CbceQsmxUQvBcmzWQrAsW/UQbDxsazwZzrg73PFkqFmRjidDybExngwlJ1rjyXHjxpmLL77YeL3eqI4fzszxx+bYwV9WNMYPZ+dEc/xwdpat/iHY91hb/UM435fD7R9CzYq0fwglx0b/EEqOrf7hxIkTJjEx0fz2t7/1eyxb9RAsxx9b8w9oPhrlEg3V1dXasGGD8vPz6/YlJCQoPz9fZWVlUcut/ScyF1xwQVSOP2bMGN10000+52XbsmXLlJOTo1tvvVUZGRn6+te/rvnz51vP6devn0pLS7V9+3ZJ0kcffaT3339fN954o/WsM+3evVsVFRU+z2Hr1q2Vm5sb1dqQvqoPl8ulNm3aWD+21+vVnXfeqQkTJujKK6+0fvzajOXLl+uyyy7ToEGDlJGRodzc3Ij/aa8//fr107Jly7R3714ZY7RmzRpt375d119/fUTHPfs9umHDBtXU1PjUQ9euXdWhQ4eI6yGU/uDIkSNKS0tTixYtrOacOHFCd9xxh+bMmaOsrKywjx0o58CBA/rwww+VkZGhfv36KTMzUwMHDtT7779vPUv6qiaWLFmizz//XF6vV4sXL9apU6f0zW9+M6wMj8ejxYsX6/jx48rLy4tqLZyd5Y+NWvCXE41aODsnmrXg75xs10J9n622a8LJZ3ik9RAoy2ZN1JdjuyYCnY/NetixY4eys7PVpUsXDR8+vO6fg0ajf6gvy59I6iFQjs1aqC8nGv1DoHOyVQ/BxsO2xpPhjLvDHU+GkmVjPBksx9Z4MpTzicZ4srq6Wr/85S919913y+VyRW38cHaOPzbGDvVlRWP8cHZONMcP/s7JVv8Q7Husrf4hnO/L4fYPoWTZ6B+C5djqH0I5H1v9w+nTp+XxeJSSkuKzPzU1Ve+//761egiW40805x/QRDXwBLNfe/fuNZLMunXrfPZPmDDB9O3bNyqZHo/H3HTTTaZ///5ROf7rr79uunfvXncZf7Su4HW73cbtdpuioiKzceNG88ILL5iUlBSzaNEiqzkej8dMnDjRuFwu06JFC+NyucxTTz1lNcOYc68O/eCDD4wks2/fPp/73XrrrWbo0KHWcs528uRJc/XVV5s77rgj7IxAWU899ZT5zne+U/d/qKNxBW/t/71v2bKlmTFjhtm0aZMpKSkxLpfLrF271lqOMcacOnXKFBQUGEmmRYsWJjk52bz88sthZxjj/z36q1/9yiQnJ59z3z59+piHH37YatbZDh48aDp06GAmT55sPefee+8199xzT93fweoznJyysjIjyVxwwQVm4cKFZuPGjebBBx80ycnJZvv27VazjDHmiy++MNdff31dTaSlpZlVq1Y5Pv6WLVtMq1atTGJiomndurVZvny5MSY6tVBf1tkirYVAOTZrob6caNRCoHOyVQvGBP5stVkTTj7DI62HYFm2aiJQjs2aCHY+tuphxYoV5o033jAfffSRWblypcnLyzMdOnQwR48etd4/BMo6WyT1ECzHVi0EyrHdPwQ7J1v1EGw8bGs86XTcHcl4MpQsG+PJYDm2xpOhnE80xpNLliwxiYmJZu/evcaY6I0lz845m41xZKAs22NJfznRGkv6yzLGXv8Q7Husrf7B6fflSPqHULJs9A/Bcmz1D6Gcj83+IS8vzwwcONDs3bvXnD592rz66qsmISHBXHbZZVbnHwLlnM3m/AOaDyZ4/7/Ro0ebjh07mn/+85/Wj71nzx6TkZFhPvroo7p90ZrgTUpKMnl5eT77HnjgAfONb3zDas7rr79u2rdvb15//XWzZcsW88orr5gLLrjA+kRyY5jgra6uNoMHDzZf//rXzZEjR8LOqC9r/fr1JjMz02cAE40J3tr31e233+5zv8GDB5vbbrvNWo4xxjzzzDPmsssuM8uWLTMfffSR+fnPf27OO+88s3r16rBz/L1HozUoD9YfHDlyxPTt29fccMMNprq62mrOO++8Yy655BJz7Nixun2RDsr95dS+l4qKinzu26NHDzNp0iSrWcYYM3bsWNO3b1/z7rvvms2bN5vHHnvMtG7d2mzZssXR8auqqsyOHTvM+vXrzaRJk0x6err529/+FpVaqC/rTDZqob4c27VQX040aiHQc2erFoJ9ttqqCSef4ZHWQ7AsWzURLMdWTYTy3Nmqh7N98cUXJi0tzSxYsCBqnxX+ss5k67PCX040Piv85UTrs8JfljH26iHYeNjWeNLJuDvS8WSwLFvjyWA5tsaToTx30RhPXn/99eY///M/6/6OVv9wds6ZbPcNZ2dFq384Oyea/YO/589W/xDse6yt/sHJ9+VI+4dgWbb6h2A5tvqHUJ47m/3Dzp07zbXXXmskmcTERNOnTx8zfPhw07VrV6vzD4FyzmR7/gHNR6Oc4K2qqjKJiYnnfAgVFBSYm2++2XremDFjTPv27c2uXbusH9sYY95+++26N3HtJsm4XC6TmJhoTp8+bS2rQ4cOPv+31hhj5s6da7Kzs61lGGNM+/btzXPPPeezb9q0aebyyy+3mnP2YOTTTz81ksymTZt87nfttdean/zkJ9ZyalVXV5shQ4aYq666yhw6dCjs4wfKmjlzZl0tnFkfCQkJpmPHjtZyqqqqTIsWLcy0adN87vfwww9H9KNDZ+ecOHHCJCUlnbO+0D333GMGDRoUVkZ979HS0lIjyXzxxRc++zt06GBmzJhhNavW0aNHTV5enrnuuuscL6wfSs64cePqrYeBAwday9m1a5eRZF599VWf/UOHDg37/xTXl7Vz504j6ZwfDrvuuuvMj3/847CyzjzGvffeG5VaqC+rlq1aqC/Hdi3UlxONWqgvy2YtBPtsfffdd63URKif4TbqIVjW2LFjrdREsJza1ynSmgg1Jxp9gzHG5OTkmEmTJsWkf6jNqhWt/qE2J9r9Q21OLPqH2iyb9RBsPGxrPBnquNvGeDJYlq3xZLAcW+PJYDnRGE9+9tlnJiEhwSxdurRuXzT6B385tWz3Df6yotE/+MuJVv/gL8tm/xDse6yt/iHU78s2+odgWbb6h2A5tvqHYDnR6B+MMebLL7+sm8gdOnSo+e53vxuV+Qd/ObWiMf+A5qNRrsGbnJys3r17q7S0tG6f1+tVaWlpvesfhsMYo7Fjx+rtt9/WH/7wB3Xu3Nnasc903XXX6a9//as2b95ct+Xk5Gj48OHavHmzEhMTrWX1799f27Zt89m3fft2dezY0VqG9NW6TgkJvuWTmJgor9drNedsnTt3VlZWlk9tHD16VB9++KHV2pCkmpoaDR06VDt27NC7776rCy+80Orxa915553asmWLT31kZ2drwoQJWrVqlbWc5ORk9enTJ+r1UVNTo5qaGiv1Eew92rt3byUlJfnUw7Zt27Rnzx7H9RBKf3D06FFdf/31Sk5O1rJly85ZQ8lGzqRJk86pB0maOXOmXnrpJWs5nTp1UnZ2tpV6CJZ14sQJSYpKn+H1elVVVWW1FoJlSXZqIViOrVoIlmOzFoJl2ayFYJ+tOTk5VmoilM9wW/UQLGvKlClWaiJYTpcuXazURLCcaPYNX375pT799FO1bds26v3DmVlS9PqHM3Oi2T+cmRPt/uHMLJv1EGw8bGs8Gcq429Z4MliWrfFksBxb48lgOTbHk7VeeuklZWRk6KabbqrbF43+wV+OFJ2+wV9WNPoHfznR6h/8ZdnsH4J9j7XVP4TyfdlW/xAsy1b/ECzHVv8QLCca/YMktWrVSm3bttUXX3yhVatW6Xvf+15U5h/85dSeVyzmH9CENeTsciCLFy82brfbLFq0yPz973839957r2nTpo2pqKiwlnHfffeZ1q1bm7Vr15r9+/fXbSdOnLCWUZ9oLdFQXl5uWrRoYZ588kmzY8cO86tf/cq0bNnS/PKXv7SaM2LECNOuXTvz29/+1uzevdu89dZbJj093co/dTx27JjZtGmT2bRpk5FUt35P7a9HTp8+3bRp08a88847ZsuWLeZ73/ue6dy5s+P/Cx4op7q62tx8882mffv2ZvPmzT71UVVVZf2czhbuEg3Bct566y2TlJRkXnzxRbNjxw7z85//3CQmJpr33nvPas7AgQPNlVdeadasWWN27dplXnrpJZOSkmLmzp3rKCeU9+jo0aNNhw4dzB/+8Aezfv16k5eXd84/+bORdeTIEZObm2t69Ohhdu7c6XMfJ1fhh9PvKIx/VhdKzsyZM01aWpp58803zY4dO8wjjzxiUlJSzM6dO61mVVdXm0suucQMGDDAfPjhh2bnzp3mZz/7mXG5XPWua+vPpEmTzB//+Eeze/dus2XLFjNp0iTjcrnM73//e2OMvVoIlmWrFkI5p7OFUwuh5NiqhWBZtmqhPmd/ttqsifpybNZDsCx/wq2JYDk2a6K+HJv18NOf/tSsXbvW7N6923zwwQcmPz/fpKenmwMHDhhj7NZCoCyb9RDsnM4Wbi0Ey7FZC4GybNZDKONhG+PJYDk2x5PhjPHDGU+GkmNjPBlKjq3xpDFfrenZoUMHM3HixHNus9k/1JcTjc+KQOd0tkg+KwLl2P6sqC/LZv8QyvdYG/1DsByb/UM4383D6R9CybHRP4SSY7N/WLlypfnd735ndu3aZX7/+9+bnj17mtzc3LolVGzNPwTKsT3/gOap0U7wGmPMz3/+c9OhQweTnJxs+vbta/785z9bPb4kv9tLL71kNcefaE3wGmPMb37zG9O9e3fjdrtN165dzYsvvmg94+jRo2bcuHGmQ4cOJiUlxXTp0sVMmTLFSuezZs0av6/LiBEjjDHGeL1e8+ijj5rMzEzjdrvNddddZ7Zt22Y1Z/fu3fXWx5o1a6yf09nCneANJecXv/iFueSSS0xKSorp2bOn338+FmnO/v37zV133WWys7NNSkqKufzyy82zzz5bt6h/qEJ5j548edLcf//95vzzzzctW7Y0t9xyi9m/f7/jcwqWVd85SzK7d++2ek7+HuN0UB5qTklJiWnfvr1p2bKlycvLczzZH2rW9u3bzfe//32TkZFhWrZsaa666irzyiuvOMq5++67TceOHU1ycrL52te+Zq677jqfiVBbtRAsy1YthHJOZwv3C1ooOTZqIZQsG7VQn7M/W23WRH05NushWJY/0ZrgNcZeTQTKsVUPw4YNM23btjXJycmmXbt2ZtiwYT4TDDZrIVCWzXoIdk5nC7cWQsmxVQvBsmz2D8HGw7bGk4FybI8nnY7xwx1PhpJjYzwZLMfWeNIYY1atWmUk+X2NbfYP9eVE47Mi0DmdLZLPimA5Nj8rAmXZ6h9C+R5ro38IlmOzfwjnu3k4/UOoOZH2D6Hk2OwflixZYrp06WKSk5NNVlaWGTNmjDl8+HDd7bY+LwLl2P68QPPkMsYYAQAAAAAAAADiTqNcgxcAAAAAAAAAEBwTvAAAAAAAAAAQp5jgBQAAAAAAAIA4xQQvAAAAAAAAAMQpJngBAAAAAAAAIE4xwQsAAAAAAAAAcYoJXgAAAAAAAACIU0zwAgAAAAAAAECcYoIXAAAAAAAAAOIUE7wAAAAAAAAAEKeY4AUAAAAAAACAOMUELwAAAAAAAADEqf8HB1LbLLNtYSkAAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "plt.figure(figsize=(20, 2))\n", - "# sns.heatmap(mask.numpy())\n", - "sns.heatmap(mask[2].reshape(11, -1).numpy())" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "fairseq", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/data_utils.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/data_utils.py deleted file mode 100644 index f0234692a83f93272e868452e8ba13743264ce6d..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/data_utils.py +++ /dev/null @@ -1,535 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import logging -import math -import numpy as np -import torch - -from typing import Optional, Tuple - - - -logger = logging.getLogger(__name__) - - - -def compute_mask_indices( - shape: Tuple[int, int], - padding_mask: Optional[torch.Tensor], - mask_prob: float, - mask_length: int, - mask_type: str = "static", - mask_other: float = 0.0, - min_masks: int = 0, - no_overlap: bool = False, - min_space: int = 0, - require_same_masks: bool = True, - mask_dropout: float = 0.0, - add_masks: bool = False, - seed: Optional[int] = None, - epoch: Optional[int] = None, - indices: Optional[torch.Tensor] = None, - idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset - num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset -) -> np.ndarray: - """ - Computes random mask spans for a given shape - - Args: - shape: the the shape for which to compute masks. - should be of size 2 where first element is batch size and 2nd is timesteps - padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements - mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by - number of timesteps divided by length of mask span to mask approximately this percentage of all elements. - however due to overlaps, the actual number will be smaller (unless no_overlap is True) - mask_type: how to compute mask lengths - static = fixed size - uniform = sample from uniform distribution [mask_other, mask_length*2] - normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element - poisson = sample from possion distribution with lambda = mask length - min_masks: minimum number of masked spans - no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping - min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans - require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample - mask_dropout: randomly dropout this percentage of masks in each example - """ - - bsz, all_sz = shape - mask = np.full((bsz, all_sz), False) - - if num_mask_ver == 1: - all_num_mask = int( - # add a random number for probabilistic rounding - mask_prob * all_sz / float(mask_length) - + np.random.rand() - ) - all_num_mask = max(min_masks, all_num_mask) - - mask_idcs = [] - for i in range(bsz): - if seed is not None and epoch is not None and indices is not None: - seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) - else: - seed_i = None - - rng = np.random.default_rng(seed_i) - - if padding_mask is not None: - sz = all_sz - padding_mask[i].long().sum().item() - assert sz >= 0, sz - else: - sz = all_sz - - if num_mask_ver == 1: - if padding_mask is not None: - num_mask = int( - # add a random number for probabilistic rounding - mask_prob * sz / float(mask_length) - + np.random.rand() - ) - num_mask = max(min_masks, num_mask) - else: - num_mask = all_num_mask - elif num_mask_ver == 2: - num_mask = int( - # add a random number for probabilistic rounding - mask_prob * sz / float(mask_length) - + rng.random() - ) - num_mask = max(min_masks, num_mask) - else: - raise ValueError() - - if mask_type == "static": - lengths = np.full(num_mask, mask_length) - elif mask_type == "uniform": - lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) - elif mask_type == "normal": - lengths = rng.normal(mask_length, mask_other, size=num_mask) - lengths = [max(1, int(round(x))) for x in lengths] - elif mask_type == "poisson": - lengths = rng.poisson(mask_length, size=num_mask) - lengths = [int(round(x)) for x in lengths] - else: - raise Exception("unknown mask selection " + mask_type) - - if sum(lengths) == 0: - if mask_type == "static": - raise ValueError(f"this should never happens") - else: - lengths = [min(mask_length, sz - 1)] - - if no_overlap: - mask_idc = [] - - def arrange(s, e, length, keep_length): - span_start = rng.randint(s, e - length) - mask_idc.extend(span_start + i for i in range(length)) - - new_parts = [] - if span_start - s - min_space >= keep_length: - new_parts.append((s, span_start - min_space + 1)) - if e - span_start - length - min_space > keep_length: - new_parts.append((span_start + length + min_space, e)) - return new_parts - - parts = [(0, sz)] - min_length = min(lengths) - for length in sorted(lengths, reverse=True): - lens = np.fromiter( - (e - s if e - s >= length + min_space else 0 for s, e in parts), - np.int, - ) - l_sum = np.sum(lens) - if l_sum == 0: - break - probs = lens / np.sum(lens) - c = rng.choice(len(parts), p=probs) - s, e = parts.pop(c) - parts.extend(arrange(s, e, length, min_length)) - mask_idc = np.asarray(mask_idc) - else: - if idc_select_ver == 1: - min_len = min(lengths) - if sz - min_len <= num_mask: - min_len = sz - num_mask - 1 - mask_idc = rng.choice(sz - min_len, num_mask, replace=False) - elif idc_select_ver == 2: - mask_idc = rng.choice(sz, num_mask, replace=False) - else: - raise ValueError() - - mask_idc = np.asarray( - [ - mask_idc[j] + offset - for j in range(len(mask_idc)) - for offset in range(lengths[j]) - ] - ) - - mask_idc = np.unique(mask_idc[mask_idc < sz]) - if len(mask_idc) >= sz: - raise ValueError( - ( - f"the entire sequence is masked. " - f"sz={sz}; mask_idc[mask_idc]; " - f"index={indices[i] if indices is not None else None}" - ) - ) - mask_idcs.append(mask_idc) - - target_len = None - if require_same_masks: - if add_masks: - target_len = max([len(m) for m in mask_idcs]) - else: - target_len = min([len(m) for m in mask_idcs]) - - for i, mask_idc in enumerate(mask_idcs): - if target_len is not None and len(mask_idc) > target_len: - mask_idc = rng.choice(mask_idc, target_len, replace=False) - - mask[i, mask_idc] = True - - if target_len is not None and len(mask_idc) < target_len: - unmasked = np.flatnonzero(~mask[i]) - to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) - mask[i, to_mask] = True - - if mask_dropout > 0: - masked = np.flatnonzero(mask[i]) - num_holes = np.rint(len(masked) * mask_dropout).astype(int) - to_drop = rng.choice(masked, num_holes, replace=False) - mask[i, to_drop] = False - - return mask - - -def compute_block_mask_2d( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - mask_prob_adjust: float = 0, - inverse_mask: bool = False, - require_same_masks: bool = True, - expand_adjcent: bool = False, - mask_dropout: float = 0, - non_overlapping: bool = False, - img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways - flexible_mask: bool = False, -) -> torch.Tensor: - - assert mask_length > 1 - - B, L = shape - - d = (int(L**0.5),int(L**0.5)) - - if img_shape: - d = (img_shape[0],img_shape[1]) - - if flexible_mask: - index = np.random.randint(0,3) - block_size_options = np.array([(6, 4), (5, 5), (8, 3)]) - block_size = block_size_options[index] - - if inverse_mask: - mask_prob = 1 - mask_prob - - if flexible_mask: - mask = torch.zeros((B, d[0], d[1])) - mask_inds = torch.randint( - 0, - L, - size=( - B, - int( - L - * ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1])) - * (1 + mask_dropout) - ), - ), - ) - mask.view(B, -1).scatter_(1, mask_inds, 1) - centers = mask.nonzero(as_tuple=True) - - inds = ([], [], []) - - offset = mask_length // 2 - for i in range(block_size[0]): - for j in range(block_size[1]): - k1 = i - offset - k2 = j - offset - inds[0].append(centers[0]) - inds[1].append(centers[1] + k1) - inds[2].append(centers[2] + k2) - - i0 = torch.cat(inds[0]) - i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1) - i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1) - - mask[(i0, i1, i2)] = 1 - - elif non_overlapping: - sz = math.ceil(d[0] / mask_length) - inp_len = sz * sz - - inp = torch.zeros((B, 1, sz, sz)) - w = torch.ones((1, 1, mask_length, mask_length)) - - mask_inds = torch.multinomial( - 1 - inp.view(B, -1), - int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), - replacement=False, - ) - inp.view(B, -1).scatter_(1, mask_inds, 1) - - mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze( - 1 - ) - if mask.size(-1) > d[0]: - mask = mask[..., :d, :d] - else: - mask = torch.zeros((B, d[0], d[1])) - mask_inds = torch.randint( - 0, - L, - size=( - B, - int( - L - * ((mask_prob + mask_prob_adjust) / mask_length**2) - * (1 + mask_dropout) - ), - ), - ) - mask.view(B, -1).scatter_(1, mask_inds, 1) - centers = mask.nonzero(as_tuple=True) - - inds = ([], [], []) - - offset = mask_length // 2 - for i in range(mask_length): - for j in range(mask_length): - k1 = i - offset - k2 = j - offset - inds[0].append(centers[0]) - inds[1].append(centers[1] + k1) - inds[2].append(centers[2] + k2) - - i0 = torch.cat(inds[0]) - i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1) - i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1) - - mask[(i0, i1, i2)] = 1 - - def get_nbs(b, m, w): - all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same") - all_nbs = all_nbs.clamp_max_(1).view(b, -1) - return all_nbs - - if require_same_masks and expand_adjcent: - w = torch.zeros((1, 1, 3, 3)) - w[..., 0, 1] = 1 - w[..., 2, 1] = 1 - w[..., 1, 0] = 1 - w[..., 1, 2] = 1 - - all_nbs = get_nbs(B, mask, w) - - mask = mask.reshape(B, -1) - - if require_same_masks: - n_masks = mask.sum(dim=-1) - final_target_len = int(L * (mask_prob)) - target_len = int(final_target_len * (1 + mask_dropout)) - - for i in range(len(mask)): - n = n_masks[i] - m = mask[i] - r = 0 - while expand_adjcent and n < target_len: - if r == 0: - nbs = all_nbs[i] - else: - nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten() - - cands = (1 - m + nbs) > 1 - cand_sz = int(cands.sum().item()) - - assert cand_sz > 0, f"{nbs} {cand_sz}" - - to_mask = torch.multinomial( - cands.float(), min(cand_sz, int(target_len - n)), replacement=False - ) - m[to_mask] = 1 - assert to_mask.numel() > 0 - n += to_mask.numel() - r += 1 - - if n > final_target_len: - to_unmask = torch.multinomial( - m, int(n - final_target_len), replacement=False - ) - m[to_unmask] = 0 - elif n < final_target_len: - to_mask = torch.multinomial( - (1 - m), int(final_target_len - n), replacement=False - ) - m[to_mask] = 1 - - if inverse_mask: - mask = 1 - mask - - return mask - - -def compute_block_mask_1d( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - mask_prob_adjust: float = 0, - inverse_mask: bool = False, - require_same_masks: bool = True, - expand_adjcent: bool = False, - mask_dropout: float = 0, - non_overlapping: bool = False, -) -> torch.Tensor: - - B, L = shape - - if inverse_mask: - mask_prob = 1 - mask_prob - - if non_overlapping: - sz = math.ceil(L / mask_length) - - inp = torch.zeros((B, 1, sz)) - w = torch.ones((1, 1, mask_length)) - - mask_inds = torch.multinomial( - 1 - inp.view(B, -1), - int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), - replacement=False, - ) - inp.view(B, -1).scatter_(1, mask_inds, 1) - - mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze( - 1 - ) - if mask.size(-1) > L: - mask = mask[..., :L] - - else: - mask = torch.zeros((B, L)) - mask_inds = torch.randint( - 0, - L, - size=( - B, - int( - L - * ((mask_prob + mask_prob_adjust) / mask_length) - * (1 + mask_dropout) - ), - ), - ) - - mask.view(B, -1).scatter_(1, mask_inds, 1) - centers = mask.nonzero(as_tuple=True) - - inds = ([], []) - - offset = mask_length // 2 - for i in range(mask_length): - k1 = i - offset - inds[0].append(centers[0]) - inds[1].append(centers[1] + k1) - - i0 = torch.cat(inds[0]) - i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1) - - mask[(i0, i1)] = 1 - - def get_nbs(b, m, w): - all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same") - all_nbs = all_nbs.clamp_max_(1).view(b, -1) - return all_nbs - - if require_same_masks and expand_adjcent: - w = torch.ones((1, 1, 3)) - w[..., 1] = 0 - all_nbs = get_nbs(B, mask, w) - - mask = mask.view(B, -1) - - if require_same_masks: - n_masks = mask.sum(dim=-1) - final_target_len = int(L * (mask_prob)) - target_len = int(final_target_len * (1 + mask_dropout)) - - for i in range(len(mask)): - n = n_masks[i] - m = mask[i] - r = 0 - while expand_adjcent and n < target_len: - if r == 0: - nbs = all_nbs[i] - else: - nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0) - - cands = (1 - m + nbs) > 1 - cand_sz = int(cands.sum().item()) - - assert cand_sz > 0, f"{nbs} {cand_sz}" - - to_mask = torch.multinomial( - cands.float(), min(cand_sz, int(target_len - n)), replacement=False - ) - m[to_mask] = 1 - assert to_mask.numel() > 0 - n += to_mask.numel() - r += 1 - - if n > final_target_len: - to_unmask = torch.multinomial( - m, int(n - final_target_len), replacement=False - ) - m[to_unmask] = 0 - elif n < final_target_len: - to_mask = torch.multinomial( - (1 - m), int(final_target_len - n), replacement=False - ) - m[to_mask] = 1 - - if inverse_mask: - mask = 1 - mask - - return mask - - -def get_buckets(sizes, num_buckets): - buckets = np.unique( - np.percentile( - sizes, - np.linspace(0, 100, num_buckets + 1), - interpolation="lower", - )[1:] - ) - return buckets - - -def get_bucketed_sizes(orig_sizes, buckets): - sizes = np.copy(orig_sizes) - assert np.min(sizes) >= 0 - start_val = -1 - for end_val in buckets: - mask = (sizes > start_val) & (sizes <= end_val) - sizes[mask] = end_val - start_val = end_val - return sizes - - diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/mixup.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/mixup.py deleted file mode 100644 index 9cd0d2c333a8973e1c994bc686d935117854ef40..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/mixup.py +++ /dev/null @@ -1,220 +0,0 @@ -""" Mixup and Cutmix - -Papers: -mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) - -CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) - -Code Reference: -CutMix: https://github.com/clovaai/CutMix-PyTorch - -Hacked together by / Copyright 2019, Ross Wightman -""" -import numpy as np -import torch - - -def one_hot(x, num_classes, on_value=1., off_value=0.): - x = x.long().view(-1, 1) - return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value) - -# adapted from using one_hot to directly using target values -def mixup_target(target, num_classes, lam=1., smoothing=0.0): - # off_value = smoothing / num_classes - # on_value = 1. - smoothing + off_value - # y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) - # y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) - y1 = target - y2 = target.flip(0) - return y1 * lam + y2 * (1. - lam) - - -def rand_bbox(img_shape, lam, margin=0., count=None): - """ Standard CutMix bounding-box - Generates a random square bbox based on lambda value. This impl includes - support for enforcing a border margin as percent of bbox dimensions. - - Args: - img_shape (tuple): Image shape as tuple - lam (float): Cutmix lambda value - margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) - count (int): Number of bbox to generate - """ - ratio = np.sqrt(1 - lam) - img_h, img_w = img_shape[-2:] - cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) - margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) - cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) - cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) - yl = np.clip(cy - cut_h // 2, 0, img_h) - yh = np.clip(cy + cut_h // 2, 0, img_h) - xl = np.clip(cx - cut_w // 2, 0, img_w) - xh = np.clip(cx + cut_w // 2, 0, img_w) - return yl, yh, xl, xh - - -def rand_bbox_minmax(img_shape, minmax, count=None): - """ Min-Max CutMix bounding-box - Inspired by Darknet cutmix impl, generates a random rectangular bbox - based on min/max percent values applied to each dimension of the input image. - - Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. - - Args: - img_shape (tuple): Image shape as tuple - minmax (tuple or list): Min and max bbox ratios (as percent of image size) - count (int): Number of bbox to generate - """ - assert len(minmax) == 2 - img_h, img_w = img_shape[-2:] - cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) - cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) - yl = np.random.randint(0, img_h - cut_h, size=count) - xl = np.random.randint(0, img_w - cut_w, size=count) - yu = yl + cut_h - xu = xl + cut_w - return yl, yu, xl, xu - - -def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): - """ Generate bbox and apply lambda correction. - """ - if ratio_minmax is not None: - yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) - else: - yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) - if correct_lam or ratio_minmax is not None: - bbox_area = (yu - yl) * (xu - xl) - lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) - return (yl, yu, xl, xu), lam - - -class Mixup: - """ Mixup/Cutmix that applies different params to each element or whole batch - - Args: - mixup_alpha (float): mixup alpha value, mixup is active if > 0. - cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. - cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. - prob (float): probability of applying mixup or cutmix per batch or element - switch_prob (float): probability of switching to cutmix instead of mixup when both are active - mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) - correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders - label_smoothing (float): apply label smoothing to the mixed target tensor - num_classes (int): number of classes for target - """ - def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, - mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): - self.mixup_alpha = mixup_alpha - self.cutmix_alpha = cutmix_alpha - self.cutmix_minmax = cutmix_minmax - if self.cutmix_minmax is not None: - assert len(self.cutmix_minmax) == 2 - # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe - self.cutmix_alpha = 1.0 - self.mix_prob = prob - self.switch_prob = switch_prob - self.label_smoothing = label_smoothing - self.num_classes = num_classes - self.mode = mode - self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix - self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) - - def _params_per_elem(self, batch_size): - lam = np.ones(batch_size, dtype=np.float32) - use_cutmix = np.zeros(batch_size, dtype=bool) - if self.mixup_enabled: - if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: - use_cutmix = np.random.rand(batch_size) < self.switch_prob - lam_mix = np.where( - use_cutmix, - np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), - np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) - elif self.mixup_alpha > 0.: - lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) - elif self.cutmix_alpha > 0.: - use_cutmix = np.ones(batch_size, dtype=bool) - lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) - else: - assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) - return lam, use_cutmix - - def _params_per_batch(self): - lam = 1. - use_cutmix = False - if self.mixup_enabled and np.random.rand() < self.mix_prob: - if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: - use_cutmix = np.random.rand() < self.switch_prob - lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ - np.random.beta(self.mixup_alpha, self.mixup_alpha) - elif self.mixup_alpha > 0.: - lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) - elif self.cutmix_alpha > 0.: - use_cutmix = True - lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) - else: - assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam = float(lam_mix) - return lam, use_cutmix - - def _mix_elem(self, x): - batch_size = len(x) - lam_batch, use_cutmix = self._params_per_elem(batch_size) - x_orig = x.clone() # need to keep an unmodified original for mixing source - for i in range(batch_size): - j = batch_size - i - 1 - lam = lam_batch[i] - if lam != 1.: - if use_cutmix[i]: - (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] - lam_batch[i] = lam - else: - x[i] = x[i] * lam + x_orig[j] * (1 - lam) - return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) - - def _mix_pair(self, x): - batch_size = len(x) - lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) - x_orig = x.clone() # need to keep an unmodified original for mixing source - for i in range(batch_size // 2): - j = batch_size - i - 1 - lam = lam_batch[i] - if lam != 1.: - if use_cutmix[i]: - (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] - x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] - lam_batch[i] = lam - else: - x[i] = x[i] * lam + x_orig[j] * (1 - lam) - x[j] = x[j] * lam + x_orig[i] * (1 - lam) - lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) - return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) - - def _mix_batch(self, x): - lam, use_cutmix = self._params_per_batch() - if lam == 1.: - return 1. - if use_cutmix: - (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] - else: - x_flipped = x.flip(0).mul_(1. - lam) - x.mul_(lam).add_(x_flipped) - return lam - - def __call__(self, x, target): - assert len(x) % 2 == 0, 'Batch size should be even when using this' - if self.mode == 'elem': - lam = self._mix_elem(x) - elif self.mode == 'pair': - lam = self._mix_pair(x) - else: - lam = self._mix_batch(x) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing) - return x, target \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py deleted file mode 100644 index 11d1d46adcdc67b635c7e88d9c65bcdd8e87b27b..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -import logging -import os -import sys -from typing import Dict, List, Optional, Tuple - -import numpy as np -import torch - -from dataclasses import dataclass, field -from fairseq.data import Dictionary, HubertDataset -from fairseq.dataclass.configs import FairseqDataclass -from fairseq.tasks import register_task -from fairseq.tasks.fairseq_task import FairseqTask -from omegaconf import MISSING - -from ..data.mert_dataset import MERTDataset -from ..data.ark_dataset import ArkDataset - -logger = logging.getLogger(__name__) - - -class LabelEncoder(object): - def __init__(self, dictionary: Dictionary) -> None: - self.dictionary = dictionary - - def __call__(self, label: str) -> List[str]: - return self.dictionary.encode_line( - label, - append_eos=False, - add_if_not_exist=False, - ) -class PaddedNumpyLabelEncoder(object): - def __init__(self): - # self.dictionary = dictionary - pass - - def __call__(self, label): - t = torch.IntTensor(np.asarray(label)) - t = t[t>=0] # remove padded -1 values at the end - return t - -@dataclass -class MERTPretrainingConfig(FairseqDataclass): - data: str = field(default=MISSING, metadata={"help": "path to data directory"}) - sharding_data: int = field( - default=-1, - metadata={ - "help": "set this para >1 to use sharding dataset to prevent OOM" - "prepare data tsv and label files by adding postfix for sharding 64 like:" - "train_28_64.tsv and train_28_64.encodec_6" - }, - ) - load_random_data_shard: bool = field( - default=True, - metadata={ - "help": "whether to laod shards randomly or in order when use sharding_data" - }, - ) - fine_tuning: bool = field( - default=False, metadata={"help": "set to true if fine-tuning Hubert"} - ) - labels: List[str] = field( - default_factory=lambda: ["ltr"], - metadata={ - "help": ( - "extension of the label files to load, frame-level labels for" - " pre-training, and sequence-level label for fine-tuning" - ) - }, - ) - label_dir: Optional[str] = field( - default=None, - metadata={ - "help": "if set, looks for labels in this directory instead", - }, - ) - label_rate: float = field( - default=-1.0, - metadata={"help": "label frame rate. -1.0 for sequence label"}, - ) - sample_rate: int = field( - default=16_000, - metadata={ - "help": "target sample rate. audio files will be up/down " - "sampled to this rate" - }, - ) - normalize: bool = field( - default=False, - metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, - ) - enable_padding: bool = field( - default=False, - metadata={"help": "pad shorter samples instead of cropping"}, - ) - max_keep_size: Optional[int] = field( - default=None, - metadata={"help": "exclude sample longer than this"}, - ) - max_sample_size: Optional[int] = field( - default=None, - metadata={"help": "max sample size to crop to for batching"}, - ) - min_sample_size: Optional[int] = field( - default=None, - metadata={"help": "min sample size to crop to for batching"}, - ) - single_target: Optional[bool] = field( - default=False, - metadata={ - "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" - }, - ) - random_crop: Optional[bool] = field( - default=True, - metadata={"help": "always crop from the beginning if false"}, - ) - pad_audio: Optional[bool] = field( - default=False, - metadata={"help": "pad audio to the longest one in the batch if true"}, - ) - - store_labels: Optional[bool] = field( - default=False, - metadata={"help": "whether to load all of the label into memory"}, - ) - - numpy_memmap_label: Optional[bool] = field( - default=False, - metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"}, - ) - - augmentation_effects: Optional[str] = field( - default="[]", - metadata={ - "help": ( - "a list of effects that might apply to the audios" - "example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" " - "supported: random_mute," - "todo: " - ) - }, - ) - augmentation_probs: Optional[str] = field( - default="[]", - metadata={ - "help": ( - "the corresponding probabilities for the data augmentation effects" - "example: \"[0.1, 0.5, 0.8]\" " - "the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio" - ) - }, - ) - - # inbatch_noise_augment_len_range: Optional[List[int]] = field( - # default_factory=lambda: [8000, 24000], - # default = [8000, 24000], - inbatch_noise_augment_len_range: Optional[str] = field( - default = "[8000, 24000]", - metadata={ - "help": ( - "the range of length of the mix-up noise augmentation, unit in smaples" - ) - }, - ) - # inbatch_noise_augment_number_range: Optional[List[int]] = field( - # default_factory=lambda: [1, 3], - # default = [1, 3], - inbatch_noise_augment_number_range: Optional[str] = field( - default = "[1, 3]", - metadata={ - "help": ( - "the range of numbers of the mix-up noise augmentation" - ) - }, - ) - inbatch_noise_augment_volume: float = field( - default = 1.0, - metadata={ - "help": ( - "the coefficient used to modify the volume of the noise audios wavs" - ) - }, - ) - dynamic_crops: Optional[str] = field( - default="[]", - metadata={ - "help": ( - "used to set the maximum audio length setting, for training" - "example: \"[1, 2, 3, 4, 5, 10]\" " - ) - }, - ) - dynamic_crops_epoches: Optional[str] = field( - default="[]", - metadata={ - "help": ( - "used to set training epoches of changing the maximum audio length" - "example: \"[1, 10, 20, 40, 80, 160,]\" " - "then len need to be equal to len(dynamic_crops)" - ) - }, - ) - - cqt_loss_bin_dataloader: Optional[int] = field( - default=-1, - metadata={ - "help": ( - "use this parameter to prepare cqt prediction objective in dataloader" - ) - }, - ) - - clip_secs: int = field( - default=5, - metadata={ - "help": "clip secs for each audio" - } - ) - - -@register_task("mert_pretraining", dataclass=MERTPretrainingConfig) -class MERTPretrainingTask(FairseqTask): - - cfg: MERTPretrainingConfig - - def __init__( - self, - cfg: MERTPretrainingConfig, - ) -> None: - super().__init__(cfg) - - logger.info(f"current directory is {os.getcwd()}") - logger.info(f"MERTPretrainingTask Config {cfg}") - - self.cfg = cfg - self.fine_tuning = cfg.fine_tuning - - if cfg.fine_tuning: - self.state.add_factory("target_dictionary", self.load_dictionaries) - else: - self.state.add_factory("dictionaries", self.load_dictionaries) - - self.blank_symbol = "" - self.augmentation_effects = eval(self.cfg.augmentation_effects) - self.augmentation_probs = eval(self.cfg.augmentation_probs) - if len(self.augmentation_effects) > 0: - assert len(self.augmentation_effects) == len(self.augmentation_probs) - logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}") - - self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range) - self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range) - - self.max_sample_size = self.cfg.max_sample_size - - self.dynamic_crops = eval(self.cfg.dynamic_crops) - self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches) - assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches) - if len(self.dynamic_crops) > 0: - assert self.dynamic_crops_epoches[0] == 1 - - self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader - - self.numpy_memmap_label = self.cfg.numpy_memmap_label - self.store_labels = self.cfg.store_labels - if self.numpy_memmap_label: - assert self.store_labels - - @property - def source_dictionary(self) -> Optional[Dictionary]: - return None - - @property - def target_dictionary(self) -> Optional[Dictionary]: - return self.state.target_dictionary - - @property - def dictionaries(self) -> List[Dictionary]: - return self.state.dictionaries - - @classmethod - def setup_task( - cls, cfg: MERTPretrainingConfig, **kwargs - ) -> "MERTPretrainingTask": - return cls(cfg) - - def load_dictionaries(self): - label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir - print(label_dir) - dictionaries = [ - Dictionary.load(f"{label_dir}/dict.{label}.txt") - for label in self.cfg.labels - ] - return dictionaries[0] if self.cfg.fine_tuning else dictionaries - - def get_label_dir(self) -> str: - if self.cfg.label_dir is None or self.cfg.label_dir=='': - return self.cfg.data - return self.cfg.label_dir - - def is_force_load_dataset(self, epoch, training_restore=False): - # find the threshold that holds epoch \in [threshold, next_threshold) - return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1) - # for idx in range(len(self.dynamic_crops_epoches)): - # if (idx == len(self.dynamic_crops_epoches)-1) or \ - # (epoch >= self.dynamic_crops_epoches[idx] and epoch < self.dynamic_crops_epoches[idx+1]): - # return True - # return False - - def set_dynamic_crop_max_sample(self, epoch): - """ force to set the max_sample_size config for the dynamic cropping function""" - if epoch in self.dynamic_crops_epoches: - for idx in range(len(self.dynamic_crops_epoches)): - if (idx == len(self.dynamic_crops_epoches)-1) or \ - (epoch >= self.dynamic_crops_epoches[idx] and epoch < self.dynamic_crops_epoches[idx+1]): - # set new cropping parameters and end loop - self.max_sample_size = self.dynamic_crops[idx]*self.cfg.sample_rate - self.cfg.max_sample_size = self.dynamic_crops[idx]*self.cfg.sample_rate - logger.info(f"epoch {epoch} forcely set new maximum audio length as {self.dynamic_crops[idx]}s == {self.max_sample_size} samples") - break - # logger.info(f'reloading dataset for changing the sequence length') - # self.load_dataset('train') - #TODO : 修改数据地址 - def load_dataset(self, split: str, **kwargs) -> None: - if len(list(filter(lambda s: s.endswith('.scp'), os.listdir(self.cfg.data)))) > 0: - return self.load_dataset_ark(split, **kwargs) - else: - return self.load_dataset_mert(split, **kwargs) - - def load_dataset_ark(self, split, **kwargs): - if 'train' not in split: - logger.info(f'split {split} is only used for training') - # raise ValueError(f"No support for split: {split}") - else: - self.datasets[split] = ArkDataset( - wav_scp=os.path.join(self.cfg.data, f"wav_ark.scp"), - dur_scp=os.path.join(self.cfg.data, f"dur_ark.scp"), - sr=self.cfg.sample_rate, - ) - - def load_dataset_mert(self, split: str, **kwargs) -> None: - if 'train' in split: - epoch = kwargs['epoch'] - # the epoch to change crops - if self.is_force_load_dataset(epoch): - self.set_dynamic_crop_max_sample(epoch) - - # load all training data - if self.cfg.sharding_data <= 1: - # manifest = f"{self.cfg.data}/{split}.tsv" - manifest = f"{self.cfg.data}/{split}.json" - - paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] - # load part of the training data - else: - if self.cfg.load_random_data_shard: - data_shard_idx = np.random.randint(self.cfg.sharding_data) - else: - data_shard_idx = (epoch-1) % self.cfg.sharding_data # epoch start from 1 - assert data_shard_idx < self.cfg.sharding_data - logger.info(f'loading shard {data_shard_idx} of {self.cfg.sharding_data} training data for ecpoh {epoch}') - - # manifest = f"{self.cfg.data}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.tsv" - manifest = f"{self.cfg.data}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.json" - - paths = [f"{self.get_label_dir()}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.{l}" for l in self.cfg.labels] - else: - # manifest = f"{self.cfg.data}/{split}.tsv" - manifest = f"{self.cfg.data}/{split}.json" - - paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] - - dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries - pad_list = [dict.pad() for dict in dicts] - eos_list = [dict.eos() for dict in dicts] - - if self.numpy_memmap_label: - procs = [PaddedNumpyLabelEncoder() for dict in dicts] - else: - procs = [LabelEncoder(dict) for dict in dicts] - - self.datasets[split] = MERTDataset( - manifest, - sample_rate=self.cfg.sample_rate, - label_paths=paths, # this containes the ensemble label sequence names - label_rates=self.cfg.label_rate, - pad_list=pad_list, - eos_list=eos_list, - label_processors=procs, - max_keep_sample_size=self.cfg.max_keep_size, - min_keep_sample_size=self.cfg.min_sample_size, - max_sample_size=self.max_sample_size, - pad_audio=self.cfg.pad_audio, - normalize=self.cfg.normalize, - store_labels=self.store_labels, - npmemmap=self.numpy_memmap_label, - random_crop=self.cfg.random_crop, - single_target=self.cfg.single_target, - augmentation_effects=self.augmentation_effects, - augmentation_probs=self.augmentation_probs, - inbatch_noise_augment_len_range=self.inbatch_noise_augment_len_range, - inbatch_noise_augment_number_range=self.inbatch_noise_augment_number_range, - inbatch_noise_augment_volume=self.cfg.inbatch_noise_augment_volume, - cqt_prediction_bin=self.cqt_loss_bin_dataloader, - clip_secs=self.cfg.clip_secs, - ) - - def max_positions(self) -> Tuple[int, int]: - return (sys.maxsize, sys.maxsize) - - def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: - return indices diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/pretraining_AS2M.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/pretraining_AS2M.py deleted file mode 100644 index 4073a364ca1e361f088a53d7301089ce70d1dfea..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/pretraining_AS2M.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -import logging -import sys - -from typing import Optional, List -from dataclasses import dataclass, field -from omegaconf import MISSING, II - -from fairseq.dataclass import FairseqDataclass -from fairseq.tasks import FairseqTask, register_task - -try: - from ..data.eat_data import MaeImageDataset -except: - import sys, os - sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) - from data.eat_data.mae_image_dataset import MaeImageDataset - -logger = logging.getLogger(__name__) - - -@dataclass -class ImageMaskingConfig: - patch_size: int = II("model.modalities.image.patch_size") - mask_prob: float = II("model.modalities.image.mask_prob") - mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust") - mask_length: int = II("model.modalities.image.mask_length") - inverse_mask: bool = II("model.modalities.image.inverse_mask") - mask_dropout: float = II("model.modalities.image.mask_dropout") - clone_batch: int = II("model.clone_batch") - expand_adjacent: bool = False - non_overlapping: bool = False - - -@dataclass -class MaeImagePretrainingConfig(FairseqDataclass): - data: str = field(default=MISSING, metadata={"help": "path to data directory"}) - multi_data: Optional[List[str]] = None - input_size: int = 224 - local_cache_path: Optional[str] = None - key: str = "imgs" - beit_transforms: bool = False - target_transform: bool = False - no_transform: bool = False - - rebuild_batches: bool = True - precompute_mask_config: Optional[ImageMaskingConfig] = None - subsample: float = 1 - seed: int = II("common.seed") - dataset_type: str = "imagefolder" - - audio_mae: bool = field(default=False,metadata={"help": "if set, we use image_mae way to deal with audio files."}) - h5_format: bool = field(default=False,metadata={"help": "if set, dataset will read data file in h5df format."}) - downsr_16hz: bool = field(default=False,metadata={"help": "if set, wav file's sample rate will be reduced to 16kHz."}) - target_length: int = field(default=1024,metadata={"help": "This setting will pad the audio spectrogram with zeros."}) - flexible_mask: bool = field(default=False, metadata={"help": "if true, we will using flexible inverse block mask method."}) - - esc50_eval: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on esc50 dataset."}) - spcv2_eval: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on speech command v2 dataset."}) - AS2M_finetune: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on Audioset 2M with weighted sample."}) - spcv1_finetune: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on speech commands v1 with weighted sample."}) - roll_aug: bool = field(default=False, metadata={"help": "if true, we will use roll aug in fine-tuning."}) - noise: bool = field(default=False, metadata={"help": "if true, we will add gaussian noise as augmentation during fine-tuning."}) - weights_file : str = field(default="", metadata={"help": "the path of weighted sample file"}) - num_samples: int = field(default=200000, metadata={"help": "this setting will determine the number of samples in each epoch, usually used in unbalanced training."}) - is_finetuning: bool = field(default=False, metadata={"help": "this property has been deprecated"}) - - sample_rate: int = field(default=24000) - fixed_duration: float = field(default=30.0) - - - - -@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig) -class MaeImagePretrainingTask(FairseqTask): - """ """ - - cfg: MaeImagePretrainingConfig - - @classmethod - def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs): - """Setup the task (e.g., load dictionaries). - - Args: - cfg (AudioPretrainingConfig): configuration of this task - """ - - return cls(cfg) - - def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): - data_path = self.cfg.data - cfg = task_cfg or self.cfg - - - compute_mask = cfg.precompute_mask_config is not None - mask_args = {} - if compute_mask: - mask_args = cfg.precompute_mask_config - - self.datasets[split] = MaeImageDataset( - root=data_path if cfg.multi_data is None else cfg.multi_data, - split=split, - input_size=cfg.input_size, - key=cfg.key, - compute_mask=compute_mask, - dataset_type=cfg.dataset_type, - audio_mae=cfg.audio_mae, - downsr_16hz=cfg.downsr_16hz, - h5_format=cfg.h5_format, - esc50_eval=cfg.esc50_eval, - spcv2_eval=cfg.spcv2_eval, - roll_aug=cfg.roll_aug and split == 'train', - target_length=cfg.target_length, - noise=cfg.noise, - AS2M_finetune=cfg.AS2M_finetune, - spcv1_finetune=cfg.spcv1_finetune, - num_samples=cfg.num_samples, - weights_file=cfg.weights_file, - flexible_mask=cfg.flexible_mask, - sample_rate=cfg.sample_rate, - fixed_duration=cfg.fixed_duration, - **mask_args, - ) - - @property - def source_dictionary(self): - return None - - @property - def target_dictionary(self): - return None - - def max_positions(self): - """Maximum input length supported by the encoder.""" - return sys.maxsize, sys.maxsize diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/modify_env.md b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/modify_env.md deleted file mode 100644 index 417c6e0cab4f072e297d371950f2534105ddedec..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/modify_env.md +++ /dev/null @@ -1,4 +0,0 @@ -cp -r fairseq/fairseq/model_parallel/megatron /opt/conda/envs/map/lib/python3.8/site-packages/fairseq/model_parallel/ -vi /opt/conda/envs/map/lib/python3.8/site-packages/apex/amp/_initialize.py # string_classes = str -vi /opt/conda/envs/map/lib/python3.8/site-packages/fairseq/modules/layer_norm.py -vi /opt/conda/envs/map/lib/python3.8/site-packages/fairseq/distributed/utils.py # import datetime; timeout=datetime.timedelta(seconds=51200); logger.info("add nccl time to 51200") diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_eat.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_eat.sh deleted file mode 100644 index d7cf006b58dec68fa9c93b83684e03393f231d87..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_eat.sh +++ /dev/null @@ -1,72 +0,0 @@ -WORKER_RANK=${1:-$INDEX} -PLATFORM=${2:-'shef'} -YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} -TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} -MASTER_PROC_ADD=${5:-$CHIEF_IP} -DIST_PORT=${6:-'25520'} -# echo $PATH -# export PATH=$PATH:./ -echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" - -MAP_PROJ_DIR=$(pwd) -echo $MAP_PROJ_DIR - -NNODS=1 -BATCH_SIZE=12 -NUM_WOKERS=6 - -run_command_prefix=' ' -# Loading folders -# 1. tsv files for audio paths -# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv -DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest -# 2. working folder for saving checkpoints and loading config files -CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain -# 3. clustering labels for training data -LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset - -FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; -SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ - -case $YAML_NAME_WITHOUT_EXT in - EAT_pretraining_music_multinodes) - NNODS=4 - NPROCES_PER_NODE=8 - LABEL_RATE=25 - BATCH_SIZE=12 - ;; - *) - echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" - exit 1 - ;; - esac - -echo running $YAML_NAME_WITHOUT_EXT .. - -mkdir -p ${SAVE_DIR} -echo "checkpoint save at: ${SAVE_DIR}" -cd ${SAVE_DIR} - -DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` -ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` -echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" - -DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` - -OMP_NUM_THREADS=6 ${run_command_prefix} \ -python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \ ---config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \ -common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ -common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ -checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT} \ -distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \ -distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \ -distributed_training.distributed_num_procs=${DISTRIBUTED_WORLD_SIZE} \ -distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \ -distributed_training.distributed_init_method="tcp://${CHIEF_IP}:${DIST_PORT}" \ -task.data=${DATA_DIR} \ -dataset.num_workers=${NUM_WOKERS} \ -dataset.batch_size=${BATCH_SIZE} \ -dataset.disable_validation=true \ - -# pip install h5py timm -i https://mirrors.tencent.com/pypi/simple/ \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_mulNodes_wotorchdist_womodelparsize.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_mulNodes_wotorchdist_womodelparsize.sh deleted file mode 100644 index 238d6a28678f4b272c097ce7b91f8161e22543ff..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_mulNodes_wotorchdist_womodelparsize.sh +++ /dev/null @@ -1,177 +0,0 @@ -# bash run_training_mulNodes_wotorchdist.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes -# bash run_training_mulNodes_wotorchdist.sh 1 dummy MERT_RVQ-VAE_CQT_330M_multinodes -# bash run_training_mulNodes_wotorchdist.sh 2 dummy MERT_RVQ-VAE_CQT_330M_multinodes -# bash run_training_mulNodes_wotorchdist.sh 3 dummy MERT_RVQ-VAE_CQT_330M_multinodes - -# the rank of distributed node worker -# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. -WORKER_RANK=${1:-$INDEX} -PLATFORM=${2:-'shef'} -YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} -TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} -MASTER_PROC_ADD=${5:-$CHIEF_IP} -DIST_PORT=${6:-'25520'} -DATASET_NAME=${7:-'dataindex'} -# echo $PATH -# export PATH=$PATH:./ -echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" - -MAP_PROJ_DIR=$(pwd) -echo $MAP_PROJ_DIR - -NNODS=1 -# MAX_TOKENS=1000000 # set for 80GB A100 batchsize -NUM_WOKERS=6 - -run_command_prefix=' ' -# Loading folders -# 1. tsv files for audio paths -# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv -DATA_DIR=${MAP_PROJ_DIR}/data/${DATASET_NAME} #audio_manifest -# 2. working folder for saving checkpoints and loading config files -CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain -# 3. clustering labels for training data -LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset - -FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; -SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ - -# set 75 for the RVQ-VAE model -LABEL_RATE=75 - -case $YAML_NAME_WITHOUT_EXT in - MERT_RVQ-VAE_CQT_95M) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=1 - LABEL_RATE=75 - MAX_TOKENS=1800000 - ;; - MERT_RVQ-VAE_CQT_95M_mel_multinodes) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=1200000 - ;; - MERT_RVQ-VAE_CQT_95M_bestrq_multinodes) - TASK_LABELS_POSTFIX='["rq_0"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=1200000 - ;; - MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes) - TASK_LABELS_POSTFIX='["rq_0"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=1600000 - ;; - MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes) - TASK_LABELS_POSTFIX='["rq_0"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=1600000 - ;; - MERT_RVQ-VAE_CQT_95M_dac_multinodes) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=1600000 - ;; - MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes) - TASK_LABELS_POSTFIX='["grq_0"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=1600000 - ;; - MusicFM_95M_multinodes) - TASK_LABELS_POSTFIX='["grq_0"]' - NNODS=4 - LABEL_RATE=25 - NPROCES_PER_NODE=8 - MAX_TOKENS=4800000 - ;; - MusicFM_95M_bestrvq_multinodes) - TASK_LABELS_POSTFIX='["grq_0"]' - NNODS=4 - LABEL_RATE=25 - NPROCES_PER_NODE=8 - MAX_TOKENS=4800000 - ;; - MusicFM_95M_speech_multinodes) - TASK_LABELS_POSTFIX='[]' - NNODS=4 - LABEL_RATE=25 - NPROCES_PER_NODE=8 - MAX_TOKENS=1200000 - ;; - MERT_RVQ-VAE_CQT_330M) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=1 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=720000 - ;; - MERT_RVQ-VAE_CQT_330M_multinodes) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=600000 - ;; - MERT_RVQ-VAE_CQT_330M_multinodes_debug2node) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=2 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=600000 - ;; - MERT_RVQ-VAE_CQT_330M_multinodes_debug1node) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=1 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=600000 - ;; - *) - echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT} = ${YAML_NAME_WITHOUT_EXT}" - exit 1 - ;; - esac - - echo running $YAML_NAME_WITHOUT_EXT .. - - mkdir -p ${SAVE_DIR} - echo "checkpoint save at: ${SAVE_DIR}" - cd ${SAVE_DIR} - - echo "NPROCES_PER_NODE is ${NPROCES_PER_NODE}" - - DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` - ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` - echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" - - DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` - - OMP_NUM_THREADS=6 ${run_command_prefix} \ - python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \ - --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \ - common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ - common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ - checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT} \ - distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \ - distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \ - distributed_training.distributed_num_procs=${DISTRIBUTED_WORLD_SIZE} \ - distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \ - distributed_training.distributed_init_method="tcp://${CHIEF_IP}:${DIST_PORT}" \ - task.data=${DATA_DIR} \ - task.label_dir=${LABEL_DIR} \ - task.labels=${TASK_LABELS_POSTFIX} \ - dataset.num_workers=${NUM_WOKERS} \ - dataset.max_tokens=${MAX_TOKENS} \ - dataset.disable_validation=true \ - model.label_rate=${LABEL_RATE} \ diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_orig.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_orig.sh deleted file mode 100644 index 889ae0efcd041defce188823f2d822ce84417f55..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_orig.sh +++ /dev/null @@ -1,79 +0,0 @@ -# the rank of distributed node worker -# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. -WORKER_RANK=${1:-'0'} -PLATFORM=${2:-'shef'} -YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} -TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} -MASTER_PROC_ADD=${5:-'127.0.0.1'} -DIST_PORT=${6:-'39683'} - -echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" - -MAP_PROJ_DIR=$HOME/MERT - -DISTRIBUTED_WORLD_SIZE=2 -NPROCES_PER_NODE=2 -MAX_TOKENS=1000000 # set for 80GB A100 -NUM_WOKERS=6 - -run_command_prefix=' ' -# Loading folders -# 1. tsv files for audio paths -DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv -# 2. working folder for saving checkpoints and loading config files -CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain -# 3. clustering labels for training data -LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/labels - - -FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; -SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ - -# set 75 for the RVQ-VAE model -LABEL_RATE=75 - -case $YAML_NAME_WITHOUT_EXT in - MERT_RVQ-VAE_CQT_95M) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - DISTRIBUTED_WORLD_SIZE=8 - NPROCES_PER_NODE=1 - LABEL_RATE=75 - MAX_TOKENS=1800000 - ;; - MERT_RVQ-VAE_CQT_330M) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - DISTRIBUTED_WORLD_SIZE=64 - NPROCES_PER_NODE=8 - LABEL_RATE=75 - MAX_TOKENS=920000 - ;; - *) - echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" - exit 1 - ;; - esac - - echo running $YAML_NAME_WITHOUT_EXT .. - - mkdir -p ${SAVE_DIR} - echo "checkpoint save at: ${SAVE_DIR}" - cd ${SAVE_DIR} - - ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` - echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" - - OMP_NUM_THREADS=6 ${run_command_prefix} python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \ - --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \ - common.user_dir=${MAP_PROJ_DIR}/mert_faiseq \ - common.wandb_project=pretrain_${TRAINING_SETTING} \ - checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}/${YAML_NAME_WITHOUT_EXT} \ - distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \ - distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \ - distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \ - distributed_training.distributed_init_method="tcp://${MASTER_PROC_ADD}:${DIST_PORT}" \ - task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} \ - task.labels=${TASK_LABELS_POSTFIX} \ - dataset.num_workers=${NUM_WOKERS} \ - dataset.max_tokens=${MAX_TOKENS} \ - dataset.disable_validation=true \ - model.label_rate=${LABEL_RATE} \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_sglNodes.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_sglNodes.sh deleted file mode 100644 index 6a9f6d873f888e942474abb141c9ad56e01e5b7f..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_sglNodes.sh +++ /dev/null @@ -1,115 +0,0 @@ -# bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node - -# the rank of distributed node worker -# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. -WORKER_RANK=${1:-'0'} -PLATFORM=${2:-'shef'} -YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} -TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} -MASTER_PROC_ADD=${5:-'127.0.0.1'} -DIST_PORT=${6:-'39685'} -# echo $PATH -# export PATH=$PATH:./ -echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" - -MAP_PROJ_DIR=$(pwd) -echo $MAP_PROJ_DIR - -NNODS=1 -MAX_TOKENS=1000000 # set for 80GB A100 batchsize -NUM_WOKERS=0 - -run_command_prefix=' ' -# Loading folders -# 1. tsv files for audio paths -# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv -DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest -# 2. working folder for saving checkpoints and loading config files -CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain -# 3. clustering labels for training data -LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset - -FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; -SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ - -# set 75 for the RVQ-VAE model -LABEL_RATE=75 - -case $YAML_NAME_WITHOUT_EXT in - MERT_RVQ-VAE_CQT_95M) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=1 - LABEL_RATE=75 - MAX_TOKENS=1800000 - ;; - MERT_RVQ-VAE_CQT_95M_bestrq) - TASK_LABELS_POSTFIX='["rq_0"]' - NNODS=1 - LABEL_RATE=75 - MAX_TOKENS=1200000 - ;; - MERT_RVQ-VAE_CQT_330M) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=1 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=720000 - ;; - MERT_RVQ-VAE_CQT_330M_multinodes) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=4 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=600000 - ;; - MERT_RVQ-VAE_CQT_330M_multinodes_debug2node) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=2 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=600000 - ;; - MERT_RVQ-VAE_CQT_330M_multinodes_debug1node) - TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' - NNODS=1 - LABEL_RATE=75 - NPROCES_PER_NODE=8 - MAX_TOKENS=600000 - ;; - *) - echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" - exit 1 - ;; - esac - - echo running $YAML_NAME_WITHOUT_EXT .. - - mkdir -p ${SAVE_DIR} - echo "checkpoint save at: ${SAVE_DIR}" - cd ${SAVE_DIR} - - DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` - ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` - echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" - - DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` - CKPT_SAVE_DIR="${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT}" - - OMP_NUM_THREADS=6 ${run_command_prefix} \ - python -u -m torch.distributed.launch --use_env \ - --nproc_per_node=8 --nnodes=${NNODS} --node_rank=${INDEX} \ - --master_addr=${CHIEF_IP} --master_port=25521 \ - ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py -m \ - --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT}\ - common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ - common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ - task.data=${DATA_DIR}\ - task.label_dir=${LABEL_DIR} \ - task.labels=${TASK_LABELS_POSTFIX} \ - dataset.num_workers=${NUM_WOKERS} \ - dataset.max_tokens=${MAX_TOKENS} \ - dataset.disable_validation=true \ - model.label_rate=${LABEL_RATE}\ - checkpoint.save_dir=${CKPT_SAVE_DIR} \ - checkpoint.restore_file="checkpoint_last.pt" - \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/test.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/test.py deleted file mode 100644 index 993b0597e8e2e8bcffafe90ad88f742b82550739..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/test.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from dataclasses import dataclass -from logging import getLogger -import torch.nn.functional as F -import fairseq.utils -from fairseq.checkpoint_utils import load_model_ensemble_and_task - -logger = getLogger(__name__) - -@dataclass -class UserDirModule: - user_dir: str - -def load_model(model_dir, checkpoint_dir): - '''Load Fairseq SSL model''' - - #导入模型所在的代码模块 - model_path = UserDirModule(model_dir) - fairseq.utils.import_user_module(model_path) - - #载入模型的checkpoint - model, cfg, task = load_model_ensemble_and_task([checkpoint_dir], strict=False) - model = model[0] - - return model diff --git a/codeclm/tokenizer/Flow1dVAE/tools/__init__.py b/codeclm/tokenizer/Flow1dVAE/tools/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/codeclm/tokenizer/Flow1dVAE/tools/check_stereo.py b/codeclm/tokenizer/Flow1dVAE/tools/check_stereo.py deleted file mode 100644 index be7c2ff166a0f2c655c0b4cbbd110e8c469f2a37..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/check_stereo.py +++ /dev/null @@ -1,59 +0,0 @@ -''' -TAMPLEATE = { - "path": "" - "duration": "" - "sample_rate": "" - "amplitude": null, - "weight": null, - "info_path": null -} -''' -import torchaudio -import json -from tqdm import tqdm - -import torchaudio -import numpy as np -import torch, torch.nn as nn, random -from torchaudio import transforms -import os -import argparse -from tqdm import tqdm -import torchaudio -from torchaudio.transforms import Resample -from multiprocessing import Pool - -def preprocess(args, wav_json, thread_id): - # f = open("pretrain_tme_20230927.scp").readlines() - f = open("out.{}".format(thread_id), 'w') - for line in tqdm(wav_json): - try: - # import pdb; pdb.set_trace() - line = line.strip() - wav_info = json.loads(line) - meta = torchaudio.info(wav_info["path"]) - - wav_info["num_channels"] = meta.num_channels - json_string = json.dumps(wav_info) - # print(json_string) - f.write("{}\n".format(json_string)) - except: - print(line) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser(description='Deep Speaker Embedding Inference') - parser.add_argument('--wav_json', type=str) - parser.add_argument('--num_thread', default=10, type=int, help='random seed') - args = parser.parse_args() - - wav_json_total = open(args.wav_json).readlines() - args.num_thread = min(len(wav_json_total), args.num_thread) - wav_json_list = np.array_split(wav_json_total, args.num_thread) - - p = Pool(args.num_thread) - for thread_id, wav_json in enumerate(wav_json_list): - r = p.apply_async(preprocess, (args, wav_json, thread_id)) - p.close() - p.join() - r.get() diff --git a/codeclm/tokenizer/Flow1dVAE/tools/compare_2models.py b/codeclm/tokenizer/Flow1dVAE/tools/compare_2models.py deleted file mode 100644 index 7dd0385c8cb8027a99311ff9a0f816bc672ff6a0..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/compare_2models.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch -import sys - -if __name__=="__main__": - m1, m2 = sys.argv[1:3] - m1 = torch.load(m1, map_location = 'cpu') - m2 = torch.load(m2, map_location = 'cpu') - m1_keys = set(m1.keys()) - m2_keys = set(m2.keys()) - - m1_uniq_keys = m1_keys - m2_keys - m2_uniq_keys = m2_keys - m1_keys - m12_shared_keys = m1_keys & m2_keys - - print("m1_uniq_keys: ", m1_uniq_keys) - print("m2_uniq_keys: ", m2_uniq_keys) - print("m12_shared_keys but different: ") - for k in m12_shared_keys: - if(m1[k].numel() != m2[k].numel()): - print(k,m1[k].shape,m2[k].shape) diff --git a/codeclm/tokenizer/Flow1dVAE/tools/creat_jsonl.py b/codeclm/tokenizer/Flow1dVAE/tools/creat_jsonl.py deleted file mode 100644 index 7f8bb4ae22c431955a5be6e98edef64a1844ba4b..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/creat_jsonl.py +++ /dev/null @@ -1,71 +0,0 @@ -''' -TAMPLEATE = { - "path": "" - "duration": "" - "sample_rate": "" - "amplitude": null, - "weight": null, - "info_path": null -} -''' -import torchaudio -import json -from tqdm import tqdm - -import torchaudio -import numpy as np -import torch, torch.nn as nn, random -from torchaudio import transforms -import os -import argparse -from tqdm import tqdm -import torchaudio -from torchaudio.transforms import Resample -from multiprocessing import Pool - -def preprocess(args, wav_scp, thread_id): - # f = open("pretrain_tme_20230927.scp").readlines() - f = open("out.{}".format(thread_id), 'w') - for line in tqdm(wav_scp): - try: - # import pdb; pdb.set_trace() - line = line.strip() - meta = torchaudio.info(line) - duration = meta.num_frames / float(meta.sample_rate) - sr = meta.sample_rate - - # json_path = line.replace(".flac", ".json") - # with open(json_path, encoding='utf-8') as fh: - # data = json.load(fh) - # duration = data['duration'] - wav_info = { - "path": line, - "duration": duration, - "sample_rate": sr, - "amplitude": None, - "weight": None, - "info_path": None - } - json_string = json.dumps(wav_info) - # print(json_string) - f.write("{}\n".format(json_string)) - except: - print(line) - -if __name__ == "__main__": - - parser = argparse.ArgumentParser(description='Deep Speaker Embedding Inference') - parser.add_argument('--wav_scp', type=str) - parser.add_argument('--num_thread', default=10, type=int, help='random seed') - args = parser.parse_args() - - wav_scp_total = open(args.wav_scp).readlines() - args.num_thread = min(len(wav_scp_total), args.num_thread) - wav_scp_list = np.array_split(wav_scp_total, args.num_thread) - - p = Pool(args.num_thread) - for thread_id, wav_scp in enumerate(wav_scp_list): - r = p.apply_async(preprocess, (args, wav_scp, thread_id)) - p.close() - p.join() - r.get() diff --git a/codeclm/tokenizer/Flow1dVAE/tools/extract_rvq.py b/codeclm/tokenizer/Flow1dVAE/tools/extract_rvq.py deleted file mode 100644 index 6a58f76ae5b7c8a3a51751a499d115b1caa40214..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/extract_rvq.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -import sys - -if __name__=="__main__": - p = sys.argv[1] - bd = '/'.join(p.split('/')[:-1]) - bn = p.split('/')[-1] - - d = {} - m = torch.load(p, map_location='cpu') - for k in m.keys(): - if('rvq' in k): - d[k] = m[k] - - torch.save(d, '{}/rvq.bin'.format(bd)) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae.py deleted file mode 100644 index ef7af1b8470bc18050db3e0091cea93214bafbbb..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -from tqdm import tqdm -import torchaudio -from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config -import numpy as np -import os -import json - -def get_model(model_config, path): - with open(model_config) as f: - model_config = json.load(f) - state_dict = torch.load(path) - model = create_autoencoder_from_config(model_config) - model.load_state_dict(state_dict['state_dict']) - return model \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_1920.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_1920.py deleted file mode 100644 index ef7af1b8470bc18050db3e0091cea93214bafbbb..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_1920.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -from tqdm import tqdm -import torchaudio -from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config -import numpy as np -import os -import json - -def get_model(model_config, path): - with open(model_config) as f: - model_config = json.load(f) - state_dict = torch.load(path) - model = create_autoencoder_from_config(model_config) - model.load_state_dict(state_dict['state_dict']) - return model \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py old mode 100644 new mode 100755 index a80fcec2c9e579f415d274dc46a0447e4c6477ee..1dcbec313f68c7ec97140b239a30f2587e4a5512 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py @@ -1,9 +1,5 @@ import torch -from tqdm import tqdm -import torchaudio from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config -import numpy as np -import os import json def get_model(model_config, path): diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large_melvae.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large_melvae.py deleted file mode 100644 index 86e701e994151e18a5c824e9c9feb0d575837f20..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large_melvae.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch -from tqdm import tqdm -import torchaudio -from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config -import numpy as np -import os -import json - -def get_model(model_config, path): - with open(model_config) as f: - model_config = json.load(f) - state_dict = torch.load(path, map_location='cpu') - model = create_autoencoder_from_config(model_config) - model.load_state_dict(state_dict['state_dict'], strict=False) - return model \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae.py b/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae.py deleted file mode 100644 index a850487b5219187a5335c760e1ed0220570798a6..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae.py +++ /dev/null @@ -1,395 +0,0 @@ -"""! -@author Yi Luo (oulyluo) -@copyright Tencent AI Lab -""" - -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from torch.utils.checkpoint import checkpoint_sequential -from thop import profile, clever_format - -class RMVN(nn.Module): - """ - Rescaled MVN. - """ - def __init__(self, dimension, groups=1): - super(RMVN, self).__init__() - - self.mean = nn.Parameter(torch.zeros(dimension)) - self.std = nn.Parameter(torch.ones(dimension)) - self.groups = groups - self.eps = torch.finfo(torch.float32).eps - - def forward(self, input): - # input size: (B, N, T) - B, N, T = input.shape - assert N % self.groups == 0 - - input = input.view(B, self.groups, -1, T) - input_norm = (input - input.mean(2).unsqueeze(2)) / (input.var(2).unsqueeze(2) + self.eps).sqrt() - input_norm = input_norm.view(B, N, T) * self.std.view(1, -1, 1) + self.mean.view(1, -1, 1) - - return input_norm - -class ConvActNorm1d(nn.Module): - def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): - super(ConvActNorm1d, self).__init__() - - self.in_channel = in_channel - self.kernel = kernel - self.causal = causal - if not causal: - self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2), - RMVN(in_channel), - nn.Conv1d(in_channel, hidden_channel*2, 1), - nn.GLU(dim=1), - nn.Conv1d(hidden_channel, in_channel, 1) - ) - else: - self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1), - RMVN(in_channel), - nn.Conv1d(in_channel, hidden_channel*2, 1), - nn.GLU(dim=1), - nn.Conv1d(hidden_channel, in_channel, 1) - ) - - def forward(self, input): - - output = self.conv(input) - if self.causal: - output = output[...,:-self.kernel+1].contiguous() - return input + output - -class ICB(nn.Module): - def __init__(self, in_channel, kernel=7, causal=False): - super(ICB, self).__init__() - - self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), - ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), - ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) - ) - - def forward(self, input): - - return self.blocks(input) - -class ResRNN(nn.Module): - def __init__(self, input_size, hidden_size, bidirectional=False): - super(ResRNN, self).__init__() - - self.input_size = input_size - self.hidden_size = hidden_size - self.eps = torch.finfo(torch.float32).eps - - self.norm = RMVN(input_size) - self.rnn = nn.LSTM(input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional) - - self.proj = nn.Linear(hidden_size*(int(bidirectional)+1), input_size) - - def forward(self, input, use_head=1): - # input shape: batch, dim, seq - - B, N, T = input.shape - - rnn_output, _ = self.rnn(self.norm(input).transpose(1,2).contiguous()) - - output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])) - output = output.view(B, T, -1).transpose(1,2).contiguous() - - return input + output - -class BSNet(nn.Module): - def __init__(self, feature_dim, kernel=7, causal=False): - super(BSNet, self).__init__() - - self.feature_dim = feature_dim - - self.seq_net = ICB(self.feature_dim, kernel=kernel, causal=causal) - self.band_net = ResRNN(self.feature_dim, self.feature_dim*2, bidirectional=True) - - def forward(self, input): - # input shape: B, nband, N, T - - B, nband, N, T = input.shape - - band_output = self.seq_net(input.view(B*nband, N, T)).view(B, nband, -1, T) - - # band comm - band_output = band_output.permute(0,3,2,1).contiguous().view(B*T, -1, nband) - output = self.band_net(band_output).view(B, T, -1, nband).permute(0,3,2,1).contiguous() - - return output.view(B, nband, N, T) - -# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py -class VQEmbeddingEMA(nn.Module): - def __init__(self, num_code, code_dim, decay=0.99, layer=0): - super(VQEmbeddingEMA, self).__init__() - - self.num_code = num_code - self.code_dim = code_dim - self.decay = decay - self.layer = layer - self.stale_tolerance = 100 - self.eps = torch.finfo(torch.float32).eps - - embedding = torch.empty(num_code, code_dim).normal_() / ((layer+1) * code_dim) - self.register_buffer("embedding", embedding) - self.register_buffer("ema_weight", self.embedding.clone()) - self.register_buffer("ema_count", torch.zeros(self.num_code)) - self.register_buffer("stale_counter", torch.zeros(self.num_code)) - - def forward(self, input): - - B, N, T = input.shape - assert N == self.code_dim - - input_detach = input.detach().mT.contiguous().view(B*T, N) # B*T, dim - - # distance - eu_dis = input_detach.pow(2).sum(-1).unsqueeze(-1) + self.embedding.pow(2).sum(-1).unsqueeze(0) # B*T, num_code - eu_dis = eu_dis - 2 * input_detach.mm(self.embedding.T) # B*T, num_code - - # best codes - indices = torch.argmin(eu_dis, dim=-1) # B*T - quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) # B*T, dim - quantized = quantized.view(B, T, N).mT.contiguous() # B, N, T - - # calculate perplexity - encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code - avg_probs = encodings.mean(0) # num_code - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() - indices = indices.view(B, T) - - if self.training: - # EMA update for codebook - - self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) # num_code - - update_direction = encodings.T.mm(input_detach) # num_code, dim - self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction # num_code, dim - - # Laplace smoothing on the counters - # make sure the denominator will never be zero - n = torch.sum(self.ema_count, dim=-1, keepdim=True) # 1 - self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n # num_code - - self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) - - # calculate code usage - stale_codes = (encodings.sum(0) == 0).float() # num_code - self.stale_counter = self.stale_counter * stale_codes + stale_codes - - # random replace codes that haven't been used for a while - replace_code = (self.stale_counter == self.stale_tolerance).float() # num_code - if replace_code.sum(-1).max() > 0: - random_input_idx = torch.randperm(input_detach.shape[0]) - random_input = input_detach[random_input_idx].view(input_detach.shape) - if random_input.shape[0] < self.num_code: - random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) - random_input = random_input[:self.num_code].contiguous() # num_code, dim - - self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) - self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) - self.ema_count = self.ema_count * (1 - replace_code) - self.stale_counter = self.stale_counter * (1 - replace_code) - - return quantized, indices, perplexity - -class RVQEmbedding(nn.Module): - def __init__(self, code_dim, decay=0.99, bit=[10]): - super(RVQEmbedding, self).__init__() - - self.code_dim = code_dim - self.decay = decay - self.eps = torch.finfo(torch.float32).eps - - self.VQEmbedding = nn.ModuleList([]) - for i in range(len(bit)): - self.VQEmbedding.append(VQEmbeddingEMA(2**bit[i], code_dim, decay, layer=i)) - - def forward(self, input): - quantized = [] - indices = [] - ppl = [] - - residual_input = input - for i in range(len(self.VQEmbedding)): - this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) - indices.append(this_indices) - ppl.append(this_perplexity) - residual_input = residual_input - this_quantized - if i == 0: - quantized.append(this_quantized) - else: - quantized.append(quantized[-1] + this_quantized) - - quantized = torch.stack(quantized, -1) - indices = torch.stack(indices, -1) - ppl = torch.stack(ppl, -1) - latent_loss = 0 - for i in range(quantized.shape[-1]): - latent_loss = latent_loss + F.mse_loss(input, quantized.detach()[...,i]) - - return quantized, indices, ppl, latent_loss - -class Codec(nn.Module): - def __init__(self, nch=1, sr=44100, win=100, feature_dim=128, vae_dim=2, enc_layer=12, dec_layer=12, bit=[8]*5, causal=True): - super(Codec, self).__init__() - - self.nch = nch - self.sr = sr - self.win = int(sr / 1000 * win) - self.stride = self.win // 2 - self.enc_dim = self.win // 2 + 1 - self.feature_dim = feature_dim - self.vae_dim = vae_dim - self.bit = bit - self.eps = torch.finfo(torch.float32).eps - - # 0-1k (50 hop), 1k-4k (100 hop), 4k-8k (200 hop), 8k-12k (400 hop), 12k-22k (500 hop) - # 100 bands - bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) - bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) - bandwidth_200 = int(np.floor(200 / (sr / 2.) * self.enc_dim)) - bandwidth_400 = int(np.floor(400 / (sr / 2.) * self.enc_dim)) - bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) - self.band_width = [bandwidth_50]*20 - self.band_width += [bandwidth_100]*30 - self.band_width += [bandwidth_200]*20 - self.band_width += [bandwidth_400]*10 - self.band_width += [bandwidth_500]*19 - self.band_width.append(self.enc_dim - np.sum(self.band_width)) - self.nband = len(self.band_width) - print(self.band_width, self.nband) - - self.VAE_BN = nn.ModuleList([]) - for i in range(self.nband): - self.VAE_BN.append(nn.Sequential(RMVN((self.band_width[i]*2+1)*self.nch), - nn.Conv1d(((self.band_width[i]*2+1)*self.nch), self.feature_dim, 1)) - ) - - self.VAE_encoder = [] - for _ in range(enc_layer): - self.VAE_encoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) - self.VAE_encoder = nn.Sequential(*self.VAE_encoder) - - self.vae_FC = nn.Sequential(RMVN(self.nband*self.feature_dim, groups=self.nband), - nn.Conv1d(self.nband*self.feature_dim, self.nband*self.vae_dim*2, 1, groups=self.nband) - ) - self.codebook = RVQEmbedding(self.nband*self.vae_dim*2, bit=bit) - self.vae_reshape = nn.Conv1d(self.nband*self.vae_dim, self.nband*self.feature_dim, 1, groups=self.nband) - - self.VAE_decoder = [] - for _ in range(dec_layer): - self.VAE_decoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) - self.VAE_decoder = nn.Sequential(*self.VAE_decoder) - - self.VAE_output = nn.ModuleList([]) - for i in range(self.nband): - self.VAE_output.append(nn.Sequential(RMVN(self.feature_dim), - nn.Conv1d(self.feature_dim, self.band_width[i]*4*self.nch, 1), - nn.GLU(dim=1)) - ) - - def spec_band_split(self, input): - - B, nch, nsample = input.shape - - spec = torch.stft(input.view(B*nch, nsample).float(), n_fft=self.win, hop_length=self.stride, - window=torch.hann_window(self.win).to(input.device), return_complex=True) - - subband_spec = [] - subband_spec_norm = [] - subband_power = [] - band_idx = 0 - for i in range(self.nband): - this_spec = spec[:,band_idx:band_idx+self.band_width[i]] - subband_spec.append(this_spec) # B, BW, T - subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T - subband_spec_norm.append([this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1]]) # B, BW, T - band_idx += self.band_width[i] - subband_power = torch.cat(subband_power, 1) # B, nband, T - - return subband_spec, subband_spec_norm, subband_power - - def feature_extractor(self, input): - - _, subband_spec_norm, subband_power = self.spec_band_split(input) - - # normalization and bottleneck - subband_feature = [] - for i in range(self.nband): - concat_spec = torch.cat([subband_spec_norm[i][0], subband_spec_norm[i][1], torch.log(subband_power[:,i].unsqueeze(1))], 1) - concat_spec = concat_spec.view(-1, (self.band_width[i]*2+1)*self.nch, concat_spec.shape[-1]) - subband_feature.append(self.VAE_BN[i](concat_spec.type(input.type()))) - subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T - - return subband_feature - - def vae_sample(self, input): - - B, nch, _ = input.shape - - subband_feature = self.feature_extractor(input) - - # encode - enc_output = checkpoint_sequential(self.VAE_encoder, len(self.VAE_encoder), subband_feature) - enc_output = self.vae_FC(enc_output.view(B, self.nband*self.feature_dim, -1)).view(B, self.nband, 2, self.vae_dim, -1) - mu = enc_output[:,:,0].contiguous() - logvar = enc_output[:,:,1].contiguous() - - # vae - reparam_feature = mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar) - vae_loss = (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(2)).mean() - - # quantization - mu_var = torch.stack([mu, logvar], 1).view(B, self.nband*self.vae_dim*2, -1) - quantized_emb, indices, ppl, latent_loss = self.codebook(mu_var.detach()) - - return reparam_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss - - def vae_decode(self, vae_feature, nsample=None): - B = vae_feature.shape[0] - dec_input = self.vae_reshape(vae_feature.contiguous().view(B, self.nband*self.vae_dim, -1)) - output = checkpoint_sequential(self.VAE_decoder, len(self.VAE_decoder), dec_input.view(B, self.nband, self.feature_dim, -1)) - - est_spec = [] - for i in range(self.nband): - this_RI = self.VAE_output[i](output[:,i]).view(B*self.nch, 2, self.band_width[i], -1) - est_spec.append(torch.complex(this_RI[:,0].float(), this_RI[:,1].float())) - est_spec = torch.cat(est_spec, 1) - if nsample is not None: - output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, - window=torch.hann_window(self.win).to(vae_feature.device), length=nsample).view(B, self.nch, -1) - else: - output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, - window=torch.hann_window(self.win).to(vae_feature.device)).view(B, self.nch, -1) - - return output.type(vae_feature.type()) - - def forward(self, input): - - B, nch, nsample = input.shape - assert nch == self.nch - - vae_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss = self.vae_sample(input) - output = self.vae_decode(vae_feature, nsample=nsample).view(input.shape) - - - return output # , vae_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss - -def get_bsrnnvae(ckpt): - nch = 1 - model = Codec(nch = nch, \ - win = 100, \ - feature_dim = 128, \ - vae_dim = 8, \ - bit = [14]*5, \ - causal = True) - weight = torch.load(ckpt, map_location='cpu') - model.load_state_dict(weight) - return model.eval() diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae_old.py b/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae_old.py deleted file mode 100644 index 010606804f85edcc3fc43850c0ee3e943080218d..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae_old.py +++ /dev/null @@ -1,427 +0,0 @@ -"""! -@author Yi Luo (oulyluo) -@copyright Tencent AI Lab -""" - -from __future__ import print_function - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from torch.utils.checkpoint import checkpoint_sequential -from thop import profile, clever_format - -class RMVN(nn.Module): - """ - Rescaled MVN. - """ - def __init__(self, dimension, groups=1): - super(RMVN, self).__init__() - - self.mean = nn.Parameter(torch.zeros(dimension)) - self.std = nn.Parameter(torch.ones(dimension)) - self.groups = groups - self.eps = torch.finfo(torch.float32).eps - - def forward(self, input): - # input size: (B, N, T) - B, N, T = input.shape - assert N % self.groups == 0 - - input = input.view(B, self.groups, -1, T) - input_norm = (input - input.mean(2).unsqueeze(2)) / (input.var(2).unsqueeze(2) + self.eps).sqrt() - input_norm = input_norm.view(B, N, T) * self.std.view(1, -1, 1) + self.mean.view(1, -1, 1) - - return input_norm - -class ConvActNorm1d(nn.Module): - def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): - super(ConvActNorm1d, self).__init__() - - self.in_channel = in_channel - self.kernel = kernel - self.causal = causal - if not causal: - self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2), - RMVN(in_channel), - nn.Conv1d(in_channel, hidden_channel*2, 1), - nn.GLU(dim=1), - nn.Conv1d(hidden_channel, in_channel, 1) - ) - else: - self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1), - RMVN(in_channel), - nn.Conv1d(in_channel, hidden_channel*2, 1), - nn.GLU(dim=1), - nn.Conv1d(hidden_channel, in_channel, 1) - ) - - def forward(self, input): - - output = self.conv(input) - if self.causal: - output = output[...,:-self.kernel+1].contiguous() - return input + output - -class ICB(nn.Module): - def __init__(self, in_channel, kernel=7, causal=False): - super(ICB, self).__init__() - - self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), - ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), - ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) - ) - - def forward(self, input): - - return self.blocks(input) - -class ResRNN(nn.Module): - def __init__(self, input_size, hidden_size, bidirectional=False): - super(ResRNN, self).__init__() - - self.input_size = input_size - self.hidden_size = hidden_size - self.eps = torch.finfo(torch.float32).eps - - self.norm = RMVN(input_size) - self.rnn = nn.LSTM(input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional) - - self.proj = nn.Linear(hidden_size*(int(bidirectional)+1), input_size) - - def forward(self, input, use_head=1): - # input shape: batch, dim, seq - - B, N, T = input.shape - - rnn_output, _ = self.rnn(self.norm(input).transpose(1,2).contiguous()) - - output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])) - output = output.view(B, T, -1).transpose(1,2).contiguous() - - return input + output - -class BSNet(nn.Module): - def __init__(self, feature_dim, kernel=7, causal=False): - super(BSNet, self).__init__() - - self.feature_dim = feature_dim - - self.seq_net = ICB(self.feature_dim, kernel=kernel, causal=causal) - self.band_net = ResRNN(self.feature_dim, self.feature_dim*2, bidirectional=True) - - def forward(self, input): - # input shape: B, nband, N, T - - B, nband, N, T = input.shape - - band_output = self.seq_net(input.view(B*nband, N, T)).view(B, nband, -1, T) - - # band comm - band_output = band_output.permute(0,3,2,1).contiguous().view(B*T, -1, nband) - output = self.band_net(band_output).view(B, T, -1, nband).permute(0,3,2,1).contiguous() - - return output.view(B, nband, N, T) - -# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py -class VQEmbeddingEMA(nn.Module): - def __init__(self, num_code, code_dim, decay=0.99, layer=0): - super(VQEmbeddingEMA, self).__init__() - - self.num_code = num_code - self.code_dim = code_dim - self.decay = decay - self.layer = layer - self.stale_tolerance = 100 - self.eps = torch.finfo(torch.float32).eps - - embedding = torch.empty(num_code, code_dim).normal_() / ((layer+1) * code_dim) - self.register_buffer("embedding", embedding) - self.register_buffer("ema_weight", self.embedding.clone()) - self.register_buffer("ema_count", torch.zeros(self.num_code)) - self.register_buffer("stale_counter", torch.zeros(self.num_code)) - - def forward(self, input): - - B, N, T = input.shape - assert N == self.code_dim - - input_detach = input.detach().mT.contiguous().view(B*T, N) # B*T, dim - - # distance - eu_dis = input_detach.pow(2).sum(-1).unsqueeze(-1) + self.embedding.pow(2).sum(-1).unsqueeze(0) # B*T, num_code - eu_dis = eu_dis - 2 * input_detach.mm(self.embedding.T) # B*T, num_code - - # best codes - indices = torch.argmin(eu_dis, dim=-1) # B*T - quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) # B*T, dim - quantized = quantized.view(B, T, N).mT.contiguous() # B, N, T - - # calculate perplexity - encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code - avg_probs = encodings.mean(0) # num_code - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() - indices = indices.view(B, T) - - if self.training: - # EMA update for codebook - - self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) # num_code - - update_direction = encodings.T.mm(input_detach) # num_code, dim - self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction # num_code, dim - - # Laplace smoothing on the counters - # make sure the denominator will never be zero - n = torch.sum(self.ema_count, dim=-1, keepdim=True) # 1 - self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n # num_code - - self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) - - # calculate code usage - stale_codes = (encodings.sum(0) == 0).float() # num_code - self.stale_counter = self.stale_counter * stale_codes + stale_codes - - # random replace codes that haven't been used for a while - replace_code = (self.stale_counter == self.stale_tolerance).float() # num_code - if replace_code.sum(-1).max() > 0: - random_input_idx = torch.randperm(input_detach.shape[0]) - random_input = input_detach[random_input_idx].view(input_detach.shape) - if random_input.shape[0] < self.num_code: - random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) - random_input = random_input[:self.num_code].contiguous() # num_code, dim - - self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) - self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) - self.ema_count = self.ema_count * (1 - replace_code) - self.stale_counter = self.stale_counter * (1 - replace_code) - - return quantized, indices, perplexity - -class RVQEmbedding(nn.Module): - def __init__(self, code_dim, decay=0.99, bit=[10]): - super(RVQEmbedding, self).__init__() - - self.code_dim = code_dim - self.decay = decay - self.eps = torch.finfo(torch.float32).eps - - self.VQEmbedding = nn.ModuleList([]) - for i in range(len(bit)): - self.VQEmbedding.append(VQEmbeddingEMA(2**bit[i], code_dim, decay, layer=i)) - - def forward(self, input): - quantized = [] - indices = [] - ppl = [] - - residual_input = input - for i in range(len(self.VQEmbedding)): - this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) - indices.append(this_indices) - ppl.append(this_perplexity) - residual_input = residual_input - this_quantized - if i == 0: - quantized.append(this_quantized) - else: - quantized.append(quantized[-1] + this_quantized) - - quantized = torch.stack(quantized, -1) - indices = torch.stack(indices, -1) - ppl = torch.stack(ppl, -1) - latent_loss = 0 - for i in range(quantized.shape[-1]): - latent_loss = latent_loss + F.mse_loss(input, quantized.detach()[...,i]) - - return quantized, indices, ppl, latent_loss - -class Codec(nn.Module): - def __init__(self, nch=1, sr=44100, win=80, feature_dim=128, vae_dim=2, enc_layer=12, dec_layer=12, bit=[8]*5, causal=False): - super(Codec, self).__init__() - - self.nch = nch - self.sr = sr - self.win = int(sr / 1000 * win) - self.stride = self.win // 2 - self.enc_dim = self.win // 2 + 1 - self.feature_dim = feature_dim - self.vae_dim = vae_dim - self.bit = bit - self.eps = torch.finfo(torch.float32).eps - - # 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-12k (1k hop), 12k-20k (2k hop), 20k-inf - # 55 bands - bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) - bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) - bandwidth_250 = int(np.floor(250 / (sr / 2.) * self.enc_dim)) - bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) - bandwidth_1k = int(np.floor(1000 / (sr / 2.) * self.enc_dim)) - bandwidth_2k = int(np.floor(2000 / (sr / 2.) * self.enc_dim)) - self.band_width = [bandwidth_50]*20 - self.band_width += [bandwidth_100]*10 - self.band_width += [bandwidth_250]*8 - self.band_width += [bandwidth_500]*8 - self.band_width += [bandwidth_1k]*4 - self.band_width += [bandwidth_2k]*4 - self.band_width.append(self.enc_dim - np.sum(self.band_width)) - self.nband = len(self.band_width) - print(self.band_width, self.nband) - - self.VAE_BN = nn.ModuleList([]) - for i in range(self.nband): - self.VAE_BN.append(nn.Sequential(RMVN((self.band_width[i]*2+1)*self.nch), - nn.Conv1d(((self.band_width[i]*2+1)*self.nch), self.feature_dim, 1)) - ) - - self.VAE_encoder = [] - for _ in range(enc_layer): - self.VAE_encoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) - self.VAE_encoder = nn.Sequential(*self.VAE_encoder) - - self.vae_FC = nn.Sequential(RMVN(self.nband*self.feature_dim, groups=self.nband), - nn.Conv1d(self.nband*self.feature_dim, self.nband*self.vae_dim*2, 1, groups=self.nband) - ) - self.codebook = RVQEmbedding(self.nband*self.vae_dim*2, bit=bit) - self.vae_reshape = nn.Conv1d(self.nband*self.vae_dim, self.nband*self.feature_dim, 1, groups=self.nband) - - self.VAE_decoder = [] - for _ in range(dec_layer): - self.VAE_decoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) - self.VAE_decoder = nn.Sequential(*self.VAE_decoder) - - self.VAE_output = nn.ModuleList([]) - for i in range(self.nband): - self.VAE_output.append(nn.Sequential(RMVN(self.feature_dim), - nn.Conv1d(self.feature_dim, self.band_width[i]*4*self.nch, 1), - nn.GLU(dim=1)) - ) - - def spec_band_split(self, input): - - B, nch, nsample = input.shape - - spec = torch.stft(input.view(B*nch, nsample).float(), n_fft=self.win, hop_length=self.stride, - window=torch.hann_window(self.win).to(input.device), return_complex=True) - - subband_spec = [] - subband_spec_norm = [] - subband_power = [] - band_idx = 0 - for i in range(self.nband): - this_spec = spec[:,band_idx:band_idx+self.band_width[i]] - subband_spec.append(this_spec) # B, BW, T - subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T - subband_spec_norm.append([this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1]]) # B, BW, T - band_idx += self.band_width[i] - subband_power = torch.cat(subband_power, 1) # B, nband, T - - return subband_spec, subband_spec_norm, subband_power - - def feature_extractor(self, input): - - _, subband_spec_norm, subband_power = self.spec_band_split(input) - - # normalization and bottleneck - subband_feature = [] - for i in range(self.nband): - concat_spec = torch.cat([subband_spec_norm[i][0], subband_spec_norm[i][1], torch.log(subband_power[:,i].unsqueeze(1))], 1) - concat_spec = concat_spec.view(-1, (self.band_width[i]*2+1)*self.nch, concat_spec.shape[-1]) - subband_feature.append(self.VAE_BN[i](concat_spec.type(input.type()))) - subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T - - return subband_feature - - def vae_sample(self, input): - - B, nch, _ = input.shape - - subband_feature = self.feature_extractor(input) - - # encode - enc_output = checkpoint_sequential(self.VAE_encoder, len(self.VAE_encoder), subband_feature) - enc_output = self.vae_FC(enc_output.view(B, self.nband*self.feature_dim, -1)).view(B, self.nband, 2, self.vae_dim, -1) - mu = enc_output[:,:,0].contiguous() - logvar = enc_output[:,:,1].contiguous() - - # vae - reparam_feature = mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar) - - return reparam_feature.view(B, nch, self.nband, self.vae_dim, -1) - - def vae_decode(self, vae_feature): - B = vae_feature.shape[0] - dec_input = self.vae_reshape(vae_feature.contiguous().view(B, self.nband*self.vae_dim, -1)) - output = checkpoint_sequential(self.VAE_decoder, len(self.VAE_decoder), dec_input.view(B, self.nband, self.feature_dim, -1)) - - est_spec = [] - for i in range(self.nband): - this_RI = self.VAE_output[i](output[:,i]).view(B*self.nch, 2, self.band_width[i], -1) - est_spec.append(torch.complex(this_RI[:,0].float(), this_RI[:,1].float())) - est_spec = torch.cat(est_spec, 1) - - output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, - window=torch.hann_window(self.win).to(vae_feature.device)).view(B, self.nch, -1) - - return output.type(vae_feature.type()) - - def forward(self, input): - - B, nch, nsample = input.shape - assert nch == self.nch - - vae_feature = self.vae_sample(input) - output = self.vae_decode(vae_feature).view(B, nch, -1) - if(output.shape[-1] > nsample): - output = output[:,:,0:nsample] - elif(output.shape[-1] < nsample): - output = torch.cat([output, torch.zeros(B, nch, nsample - output.shape[-1], device= output.device, dtype=output.dtype)],-1) - - return output - - def encode(self, input, do_sample=True): - assert do_sample, do_sample - B, nch, nsample = input.shape - assert nch == self.nch - - vae_feature = self.vae_sample(input) - return vae_feature - -def get_bsrnnvae(ckpt): - nch = 1 - model = Codec(nch = nch, \ - win = 100, \ - feature_dim = 128, \ - vae_dim = 2, \ - bit = [14]*5, \ - causal = True) - weight = torch.load(ckpt, map_location='cpu') - model.load_state_dict(weight) - return model.eval() - -if __name__ == '__main__': - model = Codec(causal=True) - x = torch.empty(1, 1, 44100).uniform_(-1, 1) - - s = 0 - for param in model.parameters(): - s += np.product(param.size()) - print('# of parameters: '+str(s/1e6)+" M") - - output = model(x) - print(output.shape) - - macs, params = profile(model, inputs=(x,)) - macs, params = clever_format([macs, params], "%.3f") - print(macs, params) - - import torchaudio - model = get_bsrnnvae() - inp, fs = torchaudio.load('769000.mp3') - inp = inp[[0],:] - if(fs!=44100): - inp = torchaudio.functional.resample(inp, fs, 44100) - fs = 44100 - inp = inp[:,0:30*44100] - out = model(inp[None,:,:]).detach() - torchaudio.save('out.flac', out[0], fs) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_melvaehifigan48k.py b/codeclm/tokenizer/Flow1dVAE/tools/get_melvaehifigan48k.py deleted file mode 100644 index 9e94c9bb83bb0a85c3ec9c3d895b1c67ec02544c..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_melvaehifigan48k.py +++ /dev/null @@ -1,1578 +0,0 @@ - -import soundfile as sf -import os -from librosa.filters import mel as librosa_mel_fn -import sys -import tools.torch_tools as torch_tools -import torch.nn as nn -import torch -import numpy as np -from einops import rearrange -from scipy.signal import get_window -from librosa.util import pad_center, tiny -import librosa.util as librosa_util - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - -LRELU_SLOPE = 0.1 - -class ResBlock(torch.nn.Module): - def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): - super(ResBlock, self).__init__() - self.h = h - self.convs1 = nn.ModuleList( - [ - torch.nn.utils.weight_norm( - nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - torch.nn.utils.weight_norm( - nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - torch.nn.utils.weight_norm( - nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - self.convs1.apply(init_weights) - - self.convs2 = nn.ModuleList( - [ - torch.nn.utils.weight_norm( - nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - torch.nn.utils.weight_norm( - nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - torch.nn.utils.weight_norm( - nn.Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - self.convs2.apply(init_weights) - - def forward(self, x): - for c1, c2 in zip(self.convs1, self.convs2): - xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - torch.nn.utils.remove_weight_norm(l) - for l in self.convs2: - torch.nn.utils.remove_weight_norm(l) - - -class Generator_old(torch.nn.Module): - def __init__(self, h): - super(Generator_old, self).__init__() - self.h = h - self.num_kernels = len(h.resblock_kernel_sizes) - self.num_upsamples = len(h.upsample_rates) - self.conv_pre = torch.nn.utils.weight_norm( - nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) - ) - resblock = ResBlock - - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): - self.ups.append( - torch.nn.utils.weight_norm( - nn.ConvTranspose1d( - h.upsample_initial_channel // (2**i), - h.upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = h.upsample_initial_channel // (2 ** (i + 1)) - for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) - ): - self.resblocks.append(resblock(h, ch, k, d)) - - self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - - def forward(self, x): - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = torch.nn.functional.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - - return x - - def remove_weight_norm(self): - # print("Removing weight norm...") - for l in self.ups: - torch.nn.utils.remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - torch.nn.utils.remove_weight_norm(self.conv_pre) - torch.nn.utils.remove_weight_norm(self.conv_post) - - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # Do time downsampling here - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class DownsampleTimeStride4(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # Do time downsampling here - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) - return x - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class UpsampleTimeStride4(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=5, stride=1, padding=2 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h * w).contiguous() - q = q.permute(0, 2, 1).contiguous() # b,hw,c - k = k.reshape(b, c, h * w).contiguous() # b,c,hw - w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h * w).contiguous() - w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm( - v, w_ - ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b, c, h, w).contiguous() - - h_ = self.proj_out(h_) - - return x + h_ - - -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" - # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - raise ValueError(attn_type) - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - downsample_time_stride4_levels=[], - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.downsample_time_stride4_levels = downsample_time_stride4_levels - - if len(self.downsample_time_stride4_levels) > 0: - assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( - "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" - % str(self.num_resolutions) - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - if i_level in self.downsample_time_stride4_levels: - down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) - else: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - downsample_time_stride4_levels=[], - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - self.downsample_time_stride4_levels = downsample_time_stride4_levels - - if len(self.downsample_time_stride4_levels) > 0: - assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( - "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" - % str(self.num_resolutions) - ) - - # compute in_ch_mult, block_in and curr_res at lowest res - (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - # print( - # "Working with z of shape {} = {} dimensions.".format( - # self.z_shape, np.prod(self.z_shape) - # ) - # ) - - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - if i_level - 1 in self.downsample_time_stride4_levels: - up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) - else: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, z): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.mean( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.mean( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - -def get_vocoder_config_48k(): - return { - "resblock": "1", - "num_gpus": 8, - "batch_size": 128, - "learning_rate": 0.0001, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - - "upsample_rates": [6,5,4,2,2], - "upsample_kernel_sizes": [12,10,8,4,4], - "upsample_initial_channel": 1536, - "resblock_kernel_sizes": [3,7,11,15], - "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], - - "segment_size": 15360, - "num_mels": 256, - "n_fft": 2048, - "hop_size": 480, - "win_size": 2048, - - "sampling_rate": 48000, - - "fmin": 20, - "fmax": 24000, - "fmax_for_loss": None, - - "num_workers": 8, - - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:18273", - "world_size": 1 - } - } - -def get_vocoder(config, device, mel_bins): - name = "HiFi-GAN" - speaker = "" - if name == "MelGAN": - if speaker == "LJSpeech": - vocoder = torch.hub.load( - "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" - ) - elif speaker == "universal": - vocoder = torch.hub.load( - "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" - ) - vocoder.mel2wav.eval() - vocoder.mel2wav.to(device) - elif name == "HiFi-GAN": - if(mel_bins == 256): - config = get_vocoder_config_48k() - config = AttrDict(config) - vocoder = Generator_old(config) - # print("Load hifigan/g_01080000") - # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) - # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) - # ckpt = torch_version_orig_mod_remove(ckpt) - # vocoder.load_state_dict(ckpt["generator"]) - vocoder.eval() - vocoder.remove_weight_norm() - vocoder = vocoder.to(device) - # vocoder = vocoder.half() - else: - raise ValueError(mel_bins) - return vocoder - -def vocoder_infer(mels, vocoder, lengths=None): - with torch.no_grad(): - wavs = vocoder(mels).squeeze(1) - - #wavs = (wavs.cpu().numpy() * 32768).astype("int16") - wavs = (wavs.cpu().numpy()) - - if lengths is not None: - wavs = wavs[:, :lengths] - - # wavs = [wav for wav in wavs] - - # for i in range(len(mels)): - # if lengths is not None: - # wavs[i] = wavs[i][: lengths[i]] - - return wavs - -@torch.no_grad() -def vocoder_chunk_infer(mels, vocoder, lengths=None): - chunk_size = 256*4 - shift_size = 256*1 - ov_size = chunk_size-shift_size - # import pdb;pdb.set_trace() - - for cinx in range(0, mels.shape[2], shift_size): - if(cinx==0): - wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float() - num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size - wavs = wavs[:,0:num_samples] - ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) - ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0) - ov_win = torch.cat([ov_win,1-ov_win],-1) - if(cinx+chunk_size>=mels.shape[2]): - break - else: - cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float() - wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] - # wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0 - wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) - if(cinx+chunk_size>=mels.shape[2]): - break - # print(wavs.shape) - - wavs = (wavs.cpu().numpy()) - - if lengths is not None: - wavs = wavs[:, :lengths] - # print(wavs.shape) - return wavs - -def synth_one_sample(mel_input, mel_prediction, labels, vocoder): - if vocoder is not None: - - wav_reconstruction = vocoder_infer( - mel_input.permute(0, 2, 1), - vocoder, - ) - wav_prediction = vocoder_infer( - mel_prediction.permute(0, 2, 1), - vocoder, - ) - else: - wav_reconstruction = wav_prediction = None - - return wav_reconstruction, wav_prediction - - -class AutoencoderKL(nn.Module): - def __init__( - self, - ddconfig=None, - lossconfig=None, - batchsize=None, - embed_dim=None, - time_shuffle=1, - subband=1, - sampling_rate=16000, - ckpt_path=None, - reload_from_ckpt=None, - ignore_keys=[], - image_key="fbank", - colorize_nlabels=None, - monitor=None, - base_learning_rate=1e-5, - scale_factor=1 - ): - super().__init__() - self.automatic_optimization = False - assert ( - "mel_bins" in ddconfig.keys() - ), "mel_bins is not specified in the Autoencoder config" - num_mel = ddconfig["mel_bins"] - self.image_key = image_key - self.sampling_rate = sampling_rate - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - - self.loss = None - self.subband = int(subband) - - if self.subband > 1: - print("Use subband decomposition %s" % self.subband) - - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - - if self.image_key == "fbank": - self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels) == int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.learning_rate = float(base_learning_rate) - # print("Initial learning rate %s" % self.learning_rate) - - self.time_shuffle = time_shuffle - self.reload_from_ckpt = reload_from_ckpt - self.reloaded = False - self.mean, self.std = None, None - - self.feature_cache = None - self.flag_first_run = True - self.train_step = 0 - - self.logger_save_dir = None - self.logger_exp_name = None - self.scale_factor = scale_factor - - print("Num parameters:") - print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) - print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) - print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) - - def get_log_dir(self): - if self.logger_save_dir is None and self.logger_exp_name is None: - return os.path.join(self.logger.save_dir, self.logger._project) - else: - return os.path.join(self.logger_save_dir, self.logger_exp_name) - - def set_log_dir(self, save_dir, exp_name): - self.logger_save_dir = save_dir - self.logger_exp_name = exp_name - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - def encode(self, x): - # x = self.time_shuffle_operation(x) - # x = self.freq_split_subband(x) - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - # bs, ch, shuffled_timesteps, fbins = dec.size() - # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) - # dec = self.freq_merge_subband(dec) - return dec - - def decode_to_waveform(self, dec): - - if self.image_key == "fbank": - dec = dec.squeeze(1).permute(0, 2, 1) - wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) - elif self.image_key == "stft": - dec = dec.squeeze(1).permute(0, 2, 1) - wav_reconstruction = self.wave_decoder(dec) - return wav_reconstruction - - def mel_spectrogram_to_waveform( - self, mel, savepath=".", bs=None, name="outwav", save=True - ): - # Mel: [bs, 1, t-steps, fbins] - if len(mel.size()) == 4: - mel = mel.squeeze(1) - mel = mel.permute(0, 2, 1) - waveform = self.vocoder(mel) - waveform = waveform.cpu().detach().numpy() - #if save: - # self.save_waveform(waveform, savepath, name) - return waveform - - @torch.no_grad() - def encode_first_stage(self, x): - return self.encode(x) - - @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, "b h w c -> b c h w").contiguous() - - z = 1.0 / self.scale_factor * z - return self.decode(z) - - def decode_first_stage_withgrad(self, z): - z = 1.0 / self.scale_factor * z - return self.decode(z) - - def get_first_stage_encoding(self, encoder_posterior, use_mode=False): - if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: - z = encoder_posterior.mode() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError( - f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" - ) - return self.scale_factor * z - - def visualize_latent(self, input): - import matplotlib.pyplot as plt - - # for i in range(10): - # zero_input = torch.zeros_like(input) - 11.59 - # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 - - # posterior = self.encode(zero_input) - # latent = posterior.sample() - # avg_latent = torch.mean(latent, dim=1)[0] - # plt.imshow(avg_latent.cpu().detach().numpy().T) - # plt.savefig("%s.png" % i) - # plt.close() - - np.save("input.npy", input.cpu().detach().numpy()) - # zero_input = torch.zeros_like(input) - 11.59 - time_input = input.clone() - time_input[:, :, :, :32] *= 0 - time_input[:, :, :, :32] -= 11.59 - - np.save("time_input.npy", time_input.cpu().detach().numpy()) - - posterior = self.encode(time_input) - latent = posterior.sample() - np.save("time_latent.npy", latent.cpu().detach().numpy()) - avg_latent = torch.mean(latent, dim=1) - for i in range(avg_latent.size(0)): - plt.imshow(avg_latent[i].cpu().detach().numpy().T) - plt.savefig("freq_%s.png" % i) - plt.close() - - freq_input = input.clone() - freq_input[:, :, :512, :] *= 0 - freq_input[:, :, :512, :] -= 11.59 - - np.save("freq_input.npy", freq_input.cpu().detach().numpy()) - - posterior = self.encode(freq_input) - latent = posterior.sample() - np.save("freq_latent.npy", latent.cpu().detach().numpy()) - avg_latent = torch.mean(latent, dim=1) - for i in range(avg_latent.size(0)): - plt.imshow(avg_latent[i].cpu().detach().numpy().T) - plt.savefig("time_%s.png" % i) - plt.close() - - def get_input(self, batch): - fname, text, label_indices, waveform, stft, fbank = ( - batch["fname"], - batch["text"], - batch["label_vector"], - batch["waveform"], - batch["stft"], - batch["log_mel_spec"], - ) - # if(self.time_shuffle != 1): - # if(fbank.size(1) % self.time_shuffle != 0): - # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) - # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) - - ret = {} - - ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( - fbank.unsqueeze(1), - stft.unsqueeze(1), - fname, - waveform.unsqueeze(1), - ) - - return ret - - def save_wave(self, batch_wav, fname, save_dir): - os.makedirs(save_dir, exist_ok=True) - - for wav, name in zip(batch_wav, fname): - name = os.path.basename(name) - - sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): - log = dict() - x = batch.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - log["samples"] = self.decode(posterior.sample()) - log["reconstructions"] = xrec - - log["inputs"] = x - wavs = self._log_img(log, train=train, index=0, waveform=waveform) - return wavs - - def _log_img(self, log, train=True, index=0, waveform=None): - images_input = self.tensor2numpy(log["inputs"][index, 0]).T - images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T - images_samples = self.tensor2numpy(log["samples"][index, 0]).T - - if train: - name = "train" - else: - name = "val" - - if self.logger is not None: - self.logger.log_image( - "img_%s" % name, - [images_input, images_reconstruct, images_samples], - caption=["input", "reconstruct", "samples"], - ) - - inputs, reconstructions, samples = ( - log["inputs"], - log["reconstructions"], - log["samples"], - ) - - if self.image_key == "fbank": - wav_original, wav_prediction = synth_one_sample( - inputs[index], - reconstructions[index], - labels="validation", - vocoder=self.vocoder, - ) - wav_original, wav_samples = synth_one_sample( - inputs[index], samples[index], labels="validation", vocoder=self.vocoder - ) - wav_original, wav_samples, wav_prediction = ( - wav_original[0], - wav_samples[0], - wav_prediction[0], - ) - elif self.image_key == "stft": - wav_prediction = ( - self.decode_to_waveform(reconstructions)[index, 0] - .cpu() - .detach() - .numpy() - ) - wav_samples = ( - self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() - ) - wav_original = waveform[index, 0].cpu().detach().numpy() - - if self.logger is not None: - self.logger.experiment.log( - { - "original_%s" - % name: wandb.Audio( - wav_original, caption="original", sample_rate=self.sampling_rate - ), - "reconstruct_%s" - % name: wandb.Audio( - wav_prediction, - caption="reconstruct", - sample_rate=self.sampling_rate, - ), - "samples_%s" - % name: wandb.Audio( - wav_samples, caption="samples", sample_rate=self.sampling_rate - ), - } - ) - - return wav_original, wav_prediction, wav_samples - - def tensor2numpy(self, tensor): - return tensor.cpu().detach().numpy() - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = torch.nn.functional.conv2d(x, weight=self.colorize) - x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x - - -def window_sumsquare( - window, - n_frames, - hop_length, - win_length, - n_fft, - dtype=np.float32, - norm=None, -): - """ - # from librosa 0.6 - Compute the sum-square envelope of a window function at a given hop length. - - This is used to estimate modulation effects induced by windowing - observations in short-time fourier transforms. - - Parameters - ---------- - window : string, tuple, number, callable, or list-like - Window specification, as in `get_window` - - n_frames : int > 0 - The number of analysis frames - - hop_length : int > 0 - The number of samples to advance between frames - - win_length : [optional] - The length of the window function. By default, this matches `n_fft`. - - n_fft : int > 0 - The length of each analysis frame. - - dtype : np.dtype - The data type of the output - - Returns - ------- - wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` - The sum-squared envelope of the window function - """ - if win_length is None: - win_length = n_fft - - n = n_fft + hop_length * (n_frames - 1) - x = np.zeros(n, dtype=dtype) - - # Compute the squared window at the desired length - win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 - win_sq = librosa_util.pad_center(win_sq, n_fft) - - # Fill the envelope - for i in range(n_frames): - sample = i * hop_length - x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] - return x - -def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return normalize_fun(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - -class STFT(torch.nn.Module): - """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - - def __init__(self, filter_length, hop_length, win_length, window="hann"): - super(STFT, self).__init__() - self.filter_length = filter_length - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.forward_transform = None - scale = self.filter_length / self.hop_length - fourier_basis = np.fft.fft(np.eye(self.filter_length)) - - cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack( - [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] - ) - - forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) - inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :] - ) - - if window is not None: - assert filter_length >= win_length - # get window and zero center pad it to filter_length - fft_window = get_window(window, win_length, fftbins=True) - fft_window = pad_center(fft_window, size=filter_length) - fft_window = torch.from_numpy(fft_window).float() - - # window the bases - forward_basis *= fft_window - inverse_basis *= fft_window - - self.register_buffer("forward_basis", forward_basis.float()) - self.register_buffer("inverse_basis", inverse_basis.float()) - - def transform(self, input_data): - - device = self.forward_basis.device - input_data = input_data.to(device) - - num_batches = input_data.size(0) - num_samples = input_data.size(1) - - self.num_samples = num_samples - - # similar to librosa, reflect-pad the input - input_data = input_data.view(num_batches, 1, num_samples) - input_data = torch.nn.functional.pad( - input_data.unsqueeze(1), - (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), - mode="reflect", - ) - input_data = input_data.squeeze(1) - - forward_transform = torch.nn.functional.conv1d( - input_data, - torch.autograd.Variable(self.forward_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - )#.cpu() - - cutoff = int((self.filter_length / 2) + 1) - real_part = forward_transform[:, :cutoff, :] - imag_part = forward_transform[:, cutoff:, :] - - magnitude = torch.sqrt(real_part**2 + imag_part**2) - phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) - - return magnitude, phase - - def inverse(self, magnitude, phase): - - device = self.forward_basis.device - magnitude, phase = magnitude.to(device), phase.to(device) - - recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 - ) - - inverse_transform = torch.nn.functional.conv_transpose1d( - recombine_magnitude_phase, - torch.autograd.Variable(self.inverse_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ) - - if self.window is not None: - window_sum = window_sumsquare( - self.window, - magnitude.size(-1), - hop_length=self.hop_length, - win_length=self.win_length, - n_fft=self.filter_length, - dtype=np.float32, - ) - # remove modulation effects - approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0] - ) - window_sum = torch.autograd.Variable( - torch.from_numpy(window_sum), requires_grad=False - ) - window_sum = window_sum - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ - approx_nonzero_indices - ] - - # scale by hop ratio - inverse_transform *= float(self.filter_length) / self.hop_length - - inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] - inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] - - return inverse_transform - - def forward(self, input_data): - self.magnitude, self.phase = self.transform(input_data) - reconstruction = self.inverse(self.magnitude, self.phase) - return reconstruction - - -class TacotronSTFT(torch.nn.Module): - def __init__( - self, - filter_length, - hop_length, - win_length, - n_mel_channels, - sampling_rate, - mel_fmin, - mel_fmax, - ): - super(TacotronSTFT, self).__init__() - self.n_mel_channels = n_mel_channels - self.sampling_rate = sampling_rate - self.stft_fn = STFT(filter_length, hop_length, win_length) - mel_basis = librosa_mel_fn( - sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax - ) - mel_basis = torch.from_numpy(mel_basis).float() - self.register_buffer("mel_basis", mel_basis) - - def spectral_normalize(self, magnitudes, normalize_fun): - output = dynamic_range_compression(magnitudes, normalize_fun) - return output - - def spectral_de_normalize(self, magnitudes): - output = dynamic_range_decompression(magnitudes) - return output - - def mel_spectrogram(self, y, normalize_fun=torch.log): - """Computes mel-spectrograms from a batch of waves - PARAMS - ------ - y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] - - RETURNS - ------- - mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) - """ - assert torch.min(y.data) >= -1, torch.min(y.data) - assert torch.max(y.data) <= 1, torch.max(y.data) - - magnitudes, phases = self.stft_fn.transform(y) - magnitudes = magnitudes.data - mel_output = torch.matmul(self.mel_basis, magnitudes) - mel_output = self.spectral_normalize(mel_output, normalize_fun) - energy = torch.norm(magnitudes, dim=1) - - log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) - - return mel_output, log_magnitudes, energy - - -def build_pretrained_models(ckpt): - checkpoint = torch.load(ckpt, map_location="cpu") - scale_factor = checkpoint["state_dict"]["scale_factor"].item() - print("scale_factor: ", scale_factor) - - vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} - - config = { - "preprocessing": { - "audio": { - "sampling_rate": 48000, - "max_wav_value": 32768, - "duration": 10.24 - }, - "stft": { - "filter_length": 2048, - "hop_length": 480, - "win_length": 2048 - }, - "mel": { - "n_mel_channels": 256, - "mel_fmin": 20, - "mel_fmax": 24000 - } - }, - "model": { - "params": { - "first_stage_config": { - "params": { - "sampling_rate": 48000, - "batchsize": 4, - "monitor": "val/rec_loss", - "image_key": "fbank", - "subband": 1, - "embed_dim": 16, - "time_shuffle": 1, - "lossconfig": { - "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", - "params": { - "disc_start": 50001, - "kl_weight": 1000, - "disc_weight": 0.5, - "disc_in_channels": 1 - } - }, - "ddconfig": { - "double_z": True, - "mel_bins": 256, - "z_channels": 16, - "resolution": 256, - "downsample_time": False, - "in_channels": 1, - "out_ch": 1, - "ch": 128, - "ch_mult": [ - 1, - 2, - 4, - 8 - ], - "num_res_blocks": 2, - "attn_resolutions": [], - "dropout": 0 - } - } - }, - } - } - } - vae_config = config["model"]["params"]["first_stage_config"]["params"] - vae_config["scale_factor"] = scale_factor - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(vae_state_dict) - - fn_STFT = TacotronSTFT( - config["preprocessing"]["stft"]["filter_length"], - config["preprocessing"]["stft"]["hop_length"], - config["preprocessing"]["stft"]["win_length"], - config["preprocessing"]["mel"]["n_mel_channels"], - config["preprocessing"]["audio"]["sampling_rate"], - config["preprocessing"]["mel"]["mel_fmin"], - config["preprocessing"]["mel"]["mel_fmax"], - ) - - vae.eval() - fn_STFT.eval() - return vae, fn_STFT - - -if __name__=="__main__": - vae, stft = build_pretrained_models() - vae, stft = vae.cuda(), stft.cuda() - - json_file="outputs/wav.scp" - out_path="outputs/Music_inverse" - - wavform = torch.randn(2,int(48000*10.24)) - mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft) - mel = mel.unsqueeze(1).cuda() - print(mel.shape) - # true_latent = torch.cat([vae.get_first_stage_encoding(vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0) - # print(true_latent.shape) - true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel)) - print(true_latent.shape) - true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach() - - true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3]) - print("111", true_latent.size()) - - mel = vae.decode_first_stage(true_latent) - print("222", mel.size()) - audio = vae.decode_to_waveform(mel) - print("333", audio.shape) - - # out_file = out_path + "/" + os.path.basename(fname.strip()) - # sf.write(out_file, audio[0], samplerate=48000) diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_mert_embedding.py b/codeclm/tokenizer/Flow1dVAE/tools/get_mert_embedding.py deleted file mode 100644 index a7fdd45820a6a1126b5c2d1559abdad61470747f..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_mert_embedding.py +++ /dev/null @@ -1,120 +0,0 @@ -import hydra -import librosa -import torch -import yaml -from prodict import Prodict -import torchaudio - -from musiclm_pytorch import AudioSpectrogramTransformerPretrained, TextTransformerPretrained, MuLaN, MuLaNEmbedder -from omegaconf import DictConfig -import os - -def get_pretrained_config(root, name): - if root is None: - return name - path = os.path.join(root, name) - #获取snapshots目录下的目录 - config_dir = os.path.join(path, 'snapshots') - config_files = os.listdir(config_dir) - assert len(config_files) == 1 - config_path = os.path.join(config_dir, config_files[0]) - return config_path - -def create_MuLaN_from_config(config: DictConfig): - """ - Create a MuLaN model from a configuration file. - """ - pretraind_root = config.model.pretraind_root - - audio_model_name = get_pretrained_config(pretraind_root, config.model.audio_model.name) - audio_transformer = AudioSpectrogramTransformerPretrained( - model_name = audio_model_name, - model_dim = config.model.audio_model.model_dim, - use_layer_idx = config.model.audio_model.use_layer_idx, - **config.model.audio_transformer - ) - text_model_name = get_pretrained_config(pretraind_root, config.model.text_model.name) - text_transformer = TextTransformerPretrained( - model_name = text_model_name, - **config.model.text_transformer - ) - - mulan = MuLaN( - audio_transformer = audio_transformer, - text_transformer = text_transformer, - **config.model.mulan - ) - - return mulan - - -def create_CLAP_model( model_kwargs = {}, ckpt_path = None ): - from musiclm_pytorch import SoftmaxContrastiveLearning - import laion_clap - - from torch import nn - import torch - from torchaudio.functional import resample - - import numpy as np - - from functools import partial - - # quantization - def int16_to_float32(x): - return (x / 32767.0).float() - - def float32_to_int16(x): - x = torch.clip(x, min=-1., max=1.) - return (x * 32767.).int() - - model = laion_clap.CLAP_Module(enable_fusion=False, **model_kwargs) - if ckpt_path is not None: - model.load_ckpt(ckpt_path) - else: - model.load_ckpt() - - class CLAP_Model(nn.Module): - def __init__(self, model, sr = 24000, decoupled_contrastive_learning = True): - super().__init__() - self.model = model - self.model.eval() - self.orig_sr = sr - - klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) - self.contrast = klass() - - - def forward(self, wavs, raw_texts): - with torch.no_grad(): - wavs = int16_to_float32(float32_to_int16(resample(wavs, self.orig_sr, 48000))) - audio_latents = self.model.get_audio_embedding_from_data(x = wavs, use_tensor=True).float() - text_latents = model.get_text_embedding(raw_texts, use_tensor=True) - cl_loss = self.contrast(audio_latents, text_latents) - return cl_loss - - clap = CLAP_Model(model) - return clap - -def get_mulan(config): - with open(config, "r") as stream: - mulan_config = yaml.safe_load(stream) - mulan_config = Prodict.from_dict(mulan_config) - ckpt_path = mulan_config.checkpoint_path - mulan = create_MuLaN_from_config(mulan_config) - mulan_embedder = MuLaNEmbedder(mulan, checkpoint_path = ckpt_path) - mulan_embedder.eval() - - return mulan_embedder - -def extract_mert_embeds(mulan_embd_extractor, layer_num, filename): - input_audios, fs = torchaudio.load(filename) - mulan_sr = 24000 - if(fs!=mulan_sr): - input_audios = torchaudio.functional.resample(input_audios, fs, mulan_sr) - fs = mulan_sr - # print(input_audios.shape) - inputs = mulan_embd_extractor.mulan.audio.processor(input_audios, sampling_rate=mulan_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = mulan_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states[layer_num] # batch_size, Time steps, 1024 feature_dim - return prompt_embeds diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_mulan.py b/codeclm/tokenizer/Flow1dVAE/tools/get_mulan.py deleted file mode 100644 index d51f000dfcb94db35bc7061ee97a79d3bf0d3947..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_mulan.py +++ /dev/null @@ -1,108 +0,0 @@ -from musiclm_pytorch import MuLaNEmbedder -import hydra -import librosa -import torch -import yaml -from prodict import Prodict - -from musiclm_pytorch import AudioSpectrogramTransformerPretrained, TextTransformerPretrained, MuLaN -from omegaconf import DictConfig -import os - -def get_pretrained_config(root, name): - if root is None: - return name - path = os.path.join(root, name) - #获取snapshots目录下的目录 - config_dir = os.path.join(path, 'snapshots') - config_files = os.listdir(config_dir) - assert len(config_files) == 1 - config_path = os.path.join(config_dir, config_files[0]) - return config_path - -def create_MuLaN_from_config(config: DictConfig): - """ - Create a MuLaN model from a configuration file. - """ - pretraind_root = config.model.pretraind_root - - audio_model_name = get_pretrained_config(pretraind_root, config.model.audio_model.name) - audio_transformer = AudioSpectrogramTransformerPretrained( - model_name = audio_model_name, - model_dim = config.model.audio_model.model_dim, - use_layer_idx = config.model.audio_model.use_layer_idx, - **config.model.audio_transformer - ) - text_model_name = get_pretrained_config(pretraind_root, config.model.text_model.name) - text_transformer = TextTransformerPretrained( - model_name = text_model_name, - **config.model.text_transformer - ) - - mulan = MuLaN( - audio_transformer = audio_transformer, - text_transformer = text_transformer, - **config.model.mulan - ) - - return mulan - - -def create_CLAP_model( model_kwargs = {}, ckpt_path = None ): - from musiclm_pytorch import SoftmaxContrastiveLearning - import laion_clap - - from torch import nn - import torch - from torchaudio.functional import resample - - import numpy as np - - from functools import partial - - # quantization - def int16_to_float32(x): - return (x / 32767.0).float() - - def float32_to_int16(x): - x = torch.clip(x, min=-1., max=1.) - return (x * 32767.).int() - - model = laion_clap.CLAP_Module(enable_fusion=False, **model_kwargs) - if ckpt_path is not None: - model.load_ckpt(ckpt_path) - else: - model.load_ckpt() - - class CLAP_Model(nn.Module): - def __init__(self, model, sr = 24000, decoupled_contrastive_learning = True): - super().__init__() - self.model = model - self.model.eval() - self.orig_sr = sr - - klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) - self.contrast = klass() - - - def forward(self, wavs, raw_texts): - with torch.no_grad(): - wavs = int16_to_float32(float32_to_int16(resample(wavs, self.orig_sr, 48000))) - audio_latents = self.model.get_audio_embedding_from_data(x = wavs, use_tensor=True).float() - text_latents = model.get_text_embedding(raw_texts, use_tensor=True) - cl_loss = self.contrast(audio_latents, text_latents) - return cl_loss - - clap = CLAP_Model(model) - return clap - -def get_mulan(config): - with open(config, "r") as stream: - mulan_config = yaml.safe_load(stream) - mulan_config = Prodict.from_dict(mulan_config) - ckpt_path = mulan_config.checkpoint_path - mulan = create_MuLaN_from_config(mulan_config) - mulan_embedder = MuLaNEmbedder(mulan, checkpoint_path = ckpt_path) - mulan_embedder.eval() - - return mulan_embedder diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_whisper_encoder.py b/codeclm/tokenizer/Flow1dVAE/tools/get_whisper_encoder.py deleted file mode 100644 index 7a113a269abe5dd7bfbfbb2c9ec3409f22ea6b7a..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/get_whisper_encoder.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -from transformers import WhisperProcessor, WhisperForConditionalGeneration - -def get_whisper_encoder(): - processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") - model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3").model.encoder - return processor, model.eval() - -if __name__=="__main__": - import numpy as np - processor, model = get_whisper_encoder() - model = model.cuda() - - with torch.no_grad(): - input_features = processor(np.random.rand(16000*30,), sampling_rate=16000, return_tensors="pt").input_features.cuda() - print(input_features.shape) - out = model(input_features.repeat(10,1,1)) - import pdb;pdb.set_trace() - print(list(out.values())[0].shape) diff --git a/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py b/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py deleted file mode 100644 index ec8a54ff9c66bb99e646b60758c363fc10c0aebe..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py +++ /dev/null @@ -1,47 +0,0 @@ -import json -import torch -from tqdm import tqdm -import torchaudio -import librosa -import os -import math -import numpy as np -from tools.get_bsrnnvae import get_bsrnnvae -import tools.torch_tools as torch_tools - -class Tango: - def __init__(self, \ - device="cuda:0"): - - self.sample_rate = 44100 - self.device = device - - self.vae = get_bsrnnvae() - self.vae = self.vae.eval().to(device) - - def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False): - """ Genrate audio without condition. """ - num_frames = math.ceil(duration * 100. / 8) - with torch.no_grad(): - orig_samples, fs = torchaudio.load(fname) - if(fs!=44100): - orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100) - fs = 44100 - if(orig_samples.shape[-1]1):init_audio = init_audio[0] - init_audio = torch.from_numpy(init_audio)[None,None,:].to(self.device) - init_audio = init_audio[:,:,int(0*self.sample_rate):int(10.24*3*self.sample_rate)] - if(init_audio.shape[-1]1):init_audio = init_audio[0] - init_audio = torch.from_numpy(init_audio)[None,None,:].to(self.device) - init_audio = init_audio[:,:,0:int(10.24*2*self.sample_rate)] - if(init_audio.shape[-1] 0 - The number of analysis frames - - hop_length : int > 0 - The number of samples to advance between frames - - win_length : [optional] - The length of the window function. By default, this matches `n_fft`. - - n_fft : int > 0 - The length of each analysis frame. - - dtype : np.dtype - The data type of the output - - Returns - ------- - wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` - The sum-squared envelope of the window function - """ - if win_length is None: - win_length = n_fft - - n = n_fft + hop_length * (n_frames - 1) - x = np.zeros(n, dtype=dtype) - - # Compute the squared window at the desired length - win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 - win_sq = librosa_util.pad_center(win_sq, n_fft) - - # Fill the envelope - for i in range(n_frames): - sample = i * hop_length - x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] - return x - - -def griffin_lim(magnitudes, stft_fn, n_iters=30): - """ - PARAMS - ------ - magnitudes: spectrogram magnitudes - stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods - """ - - angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) - angles = angles.astype(np.float32) - angles = torch.autograd.Variable(torch.from_numpy(angles)) - signal = stft_fn.inverse(magnitudes, angles).squeeze(1) - - for i in range(n_iters): - _, angles = stft_fn.transform(signal) - signal = stft_fn.inverse(magnitudes, angles).squeeze(1) - return signal - - -def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return normalize_fun(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - -class STFT(torch.nn.Module): - """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" - - def __init__(self, filter_length, hop_length, win_length, window="hann"): - super(STFT, self).__init__() - self.filter_length = filter_length - self.hop_length = hop_length - self.win_length = win_length - self.window = window - self.forward_transform = None - scale = self.filter_length / self.hop_length - fourier_basis = np.fft.fft(np.eye(self.filter_length)) - - cutoff = int((self.filter_length / 2 + 1)) - fourier_basis = np.vstack( - [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] - ) - - forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) - inverse_basis = torch.FloatTensor( - np.linalg.pinv(scale * fourier_basis).T[:, None, :] - ) - - if window is not None: - assert filter_length >= win_length - # get window and zero center pad it to filter_length - fft_window = get_window(window, win_length, fftbins=True) - fft_window = pad_center(fft_window, size=filter_length) - fft_window = torch.from_numpy(fft_window).float() - - # window the bases - forward_basis *= fft_window - inverse_basis *= fft_window - - self.register_buffer("forward_basis", forward_basis.float()) - self.register_buffer("inverse_basis", inverse_basis.float()) - - def transform(self, input_data): - num_batches = input_data.size(0) - num_samples = input_data.size(1) - - self.num_samples = num_samples - - # similar to librosa, reflect-pad the input - input_data = input_data.view(num_batches, 1, num_samples) - input_data = F.pad( - input_data.unsqueeze(1), - (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), - mode="reflect", - ) - input_data = input_data.squeeze(1) - - forward_transform = F.conv1d( - input_data, - torch.autograd.Variable(self.forward_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ) - - cutoff = int((self.filter_length / 2) + 1) - real_part = forward_transform[:, :cutoff, :] - imag_part = forward_transform[:, cutoff:, :] - - magnitude = torch.sqrt(real_part**2 + imag_part**2) - phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) - - return magnitude, phase - - def inverse(self, magnitude, phase): - recombine_magnitude_phase = torch.cat( - [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 - ) - - inverse_transform = F.conv_transpose1d( - recombine_magnitude_phase, - torch.autograd.Variable(self.inverse_basis, requires_grad=False), - stride=self.hop_length, - padding=0, - ) - - if self.window is not None: - window_sum = window_sumsquare( - self.window, - magnitude.size(-1), - hop_length=self.hop_length, - win_length=self.win_length, - n_fft=self.filter_length, - dtype=np.float32, - ) - # remove modulation effects - approx_nonzero_indices = torch.from_numpy( - np.where(window_sum > tiny(window_sum))[0] - ) - window_sum = torch.autograd.Variable( - torch.from_numpy(window_sum), requires_grad=False - ) - window_sum = window_sum - inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ - approx_nonzero_indices - ] - - # scale by hop ratio - inverse_transform *= float(self.filter_length) / self.hop_length - - inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] - inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] - - return inverse_transform - - def forward(self, input_data): - self.magnitude, self.phase = self.transform(input_data) - reconstruction = self.inverse(self.magnitude, self.phase) - return reconstruction - - -class TacotronSTFT(torch.nn.Module): - def __init__( - self, - filter_length=1024, - hop_length=160, - win_length=1024, - n_mel_channels=64, - sampling_rate=16000, - mel_fmin=0, - mel_fmax=8000., - ): - super(TacotronSTFT, self).__init__() - self.n_mel_channels = n_mel_channels - self.sampling_rate = sampling_rate - self.stft_fn = STFT(filter_length, hop_length, win_length) - mel_basis = librosa_mel_fn( - sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax - ) - mel_basis = torch.from_numpy(mel_basis).float() - self.register_buffer("mel_basis", mel_basis) - - def spectral_normalize(self, magnitudes, normalize_fun): - output = dynamic_range_compression(magnitudes, normalize_fun) - return output - - def spectral_de_normalize(self, magnitudes): - output = dynamic_range_decompression(magnitudes) - return output - - def mel_spectrogram(self, y, normalize_fun=torch.log): - """Computes mel-spectrograms from a batch of waves - PARAMS - ------ - y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] - - RETURNS - ------- - mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) - """ - assert torch.min(y.data) >= -1, torch.min(y.data) - assert torch.max(y.data) <= 1, torch.max(y.data) - - magnitudes, phases = self.stft_fn.transform(y) - magnitudes = magnitudes.data - mel_output = torch.matmul(self.mel_basis, magnitudes) - mel_output = self.spectral_normalize(mel_output, normalize_fun) - energy = torch.norm(magnitudes, dim=1) - - log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) - - return mel_output, log_magnitudes, energy diff --git a/codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py b/codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py old mode 100644 new mode 100755 diff --git a/codeclm/tokenizer/Flow1dVAE/tools/transmodelnorm.py b/codeclm/tokenizer/Flow1dVAE/tools/transmodelnorm.py deleted file mode 100644 index 0c6f47d3042e5d37da4cd2e6a904ade2b61a06db..0000000000000000000000000000000000000000 --- a/codeclm/tokenizer/Flow1dVAE/tools/transmodelnorm.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch - -if __name__=="__main__": - src_ckpt = 'saved/train_mulan_v3_48k_everything3/latest/pytorch_model_2.bin' - tgt_ckpt = 'saved/train_mulan_v3_48k_everything3_sepnorm/src_pytorch_model_2.bin' - # src_ckpt = 'saved/train_enhcodec2D_again/latest/pytorch_model_3.bin' - # tgt_ckpt = 'saved/train_enhcodec2D_again_sepnorm/pytorch_model_3.bin' - - ckpt = torch.load(src_ckpt, map_location='cpu') - - ckpt['normfeat.sum_x'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_x'].dtype) * ckpt['normfeat.sum_x'] / ckpt['normfeat.counts'] - ckpt['normfeat.sum_x2'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_x2'].dtype) * ckpt['normfeat.sum_x2'] / ckpt['normfeat.counts'] - ckpt['normfeat.sum_target_x2'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_target_x2'].dtype) * ckpt['normfeat.sum_target_x2'] / ckpt['normfeat.counts'] - ckpt['normfeat.counts'] = torch.ones_like(ckpt['normfeat.counts']) - torch.save(ckpt, tgt_ckpt) - \ No newline at end of file diff --git a/generate.py b/generate.py index 81b6e5685eb702c27e2cccb2fac3306c3973e75e..d55ba978d45e1dcedd5af38863dbce8de045b944 100644 --- a/generate.py +++ b/generate.py @@ -14,10 +14,24 @@ import gc from codeclm.trainer.codec_song_pl import CodecLM_PL from codeclm.models import CodecLM from third_party.demucs.models.pretrained import get_model_from_yaml - +import re auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto'] +def check_language_by_text(text): + chinese_pattern = re.compile(r'[\u4e00-\u9fff]') + english_pattern = re.compile(r'[a-zA-Z]') + chinese_count = len(re.findall(chinese_pattern, text)) + english_count = len(re.findall(english_pattern, text)) + chinese_ratio = chinese_count / len(text) + english_ratio = english_count / len(text) + if chinese_ratio >= 0.2: + return "zh" + elif english_ratio >= 0.5: + return "en" + else: + return "en" + class Separator: def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): @@ -80,7 +94,8 @@ def parse_args(): help='Whether to use low memory mode (default: False)') return parser.parse_args() -def generate(args): +def generate(args, version = 'v1.0'): + torch.set_num_threads(1) ckpt_path = args.ckpt_path input_jsonl = args.input_jsonl save_dir = args.save_dir @@ -95,10 +110,9 @@ def generate(args): separator = Separator() - auto_prompt = torch.load('ckpt/prompt.pt') + auto_prompt = torch.load('tools/new_auto_prompt.pt') audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg) audio_tokenizer = audio_tokenizer.eval().cuda() - merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] with open(input_jsonl, "r") as fp: lines = fp.readlines() @@ -145,8 +159,9 @@ def generate(args): melody_is_wav = False elif "auto_prompt_audio_type" in item: assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found" - if item["auto_prompt_audio_type"] == "Auto": - prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] + if item['auto_prompt_audio_type'] == 'Auto': + lang = check_language_by_text(item['gt_lyric']) + prompt_token = auto_prompt['Auto'][lang][np.random.randint(0, len(auto_prompt['Auto'][lang]))] else: prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))] pmt_wav = prompt_token[:,[0],:] @@ -168,7 +183,7 @@ def generate(args): del audio_tokenizer del separator - + torch.cuda.empty_cache() if "audio_tokenizer_checkpoint_sep" in cfg.keys(): @@ -187,7 +202,7 @@ def generate(args): item['bgm_wav'] = bgm_wav torch.cuda.empty_cache() - audiolm = builders.get_lm_model(cfg) + audiolm = builders.get_lm_model(cfg, version=version) checkpoint = torch.load(ckpt_path, map_location='cpu') audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')} audiolm.load_state_dict(audiolm_state_dict, strict=False) @@ -216,7 +231,11 @@ def generate(args): for item in new_items: lyric = item["gt_lyric"] - descriptions = item["descriptions"] if "descriptions" in item else None + if version == 'v1.0': + descriptions = item["descriptions"] if "descriptions" in item else None + else: + descriptions = item["descriptions"] if "descriptions" in item else '.' + descriptions = '[Musicality-very-high]' + ', ' + descriptions pmt_wav = item['pmt_wav'] vocal_wav = item['vocal_wav'] bgm_wav = item['bgm_wav'] @@ -280,6 +299,7 @@ def generate(args): fw.writelines(json.dumps(item, ensure_ascii=False)+"\n") def generate_lowmem(args): + torch.set_num_threads(1) ckpt_path = args.ckpt_path input_jsonl = args.input_jsonl save_dir = args.save_dir @@ -304,8 +324,7 @@ def generate_lowmem(args): separator = Separator() audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg) audio_tokenizer = audio_tokenizer.eval().cuda() - auto_prompt = torch.load('ckpt/prompt.pt') - merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] + auto_prompt = torch.load('tools/new_prompt.pt') new_items = [] for line in lines: item = json.loads(line) @@ -345,10 +364,7 @@ def generate_lowmem(args): melody_is_wav = False elif "auto_prompt_audio_type" in item: assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found" - if item["auto_prompt_audio_type"] == "Auto": - prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] - else: - prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))] + prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))] pmt_wav = prompt_token[:,[0],:] vocal_wav = prompt_token[:,[1],:] bgm_wav = prompt_token[:,[2],:] @@ -471,7 +487,8 @@ def generate_lowmem(args): seperate_tokenizer.model.model.device = torch.device(device) seperate_tokenizer = seperate_tokenizer.eval() - offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False + # offload_wav_tokenizer_diffusion = True if 'offload' in cfg.keys() and 'wav_tokenizer_diffusion' in cfg.offload else False + offload_wav_tokenizer_diffusion = False if offload_wav_tokenizer_diffusion: sep_offload_param = OffloadParamParse.parse_config(seperate_tokenizer, cfg.offload.wav_tokenizer_diffusion) sep_offload_param.show() @@ -548,9 +565,9 @@ if __name__ == "__main__": res_mem = (total - reserved) / 1024 / 1024 / 1024 print(f"reserved memory: {res_mem}GB") - model_name = args.ckpt_path.split("/")[-1] - assert model_name in ['songgeneration_base'], f'{model_name} is not supported, currently only songgeneration_base is supported' - if model_name == 'songgeneration_base': + model_name = args.ckpt_path.split("/")[-1].lower().replace('-', '_') + assert model_name in ['songgeneration_base', 'songgeneration_base_new', 'songgeneration_base_full', 'songgeneration_large', 'songgeneration_new_small', 'songgeneration_new_large', 'songgeneration_new_medium'], f'{model_name} is not supported, currently only songgeneration_base, songgeneration_base_new, songgeneration_base_full, songgeneration_large are supported. Please download correct files and rename the folder to the corresponding version name.' + if model_name == 'songgeneration_base' or model_name == 'songgeneration_base_new' or model_name == 'songgeneration_base_full': if res_mem > 24 and not args.low_mem: print("use generate") generate(args) @@ -558,8 +575,20 @@ if __name__ == "__main__": from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse print("use generate_lowmem") generate_lowmem(args) + elif model_name == 'songgeneration_large': + if res_mem > 36 and not args.low_mem: + print("use generate") + generate(args) + else: + print("use generate_lowmem") + from codeclm.utils.offload_profiler import OffloadProfiler, OffloadParamParse + generate_lowmem(args) + elif model_name == 'songgeneration_new_small' or model_name == 'songgeneration_new_large' or model_name == 'songgeneration_new_medium': + print("use generate") + generate(args, version = 'v1.5') + else: print("CUDA is not available") exit() - \ No newline at end of file + diff --git a/levo_inference.py b/levo_inference.py index 28471d6c16bd6966a70679324dba7033c5a939b8..62a5120c61fa5f956f241919ded02960c7cc5a22 100644 --- a/levo_inference.py +++ b/levo_inference.py @@ -12,7 +12,7 @@ import json import numpy as np from omegaconf import OmegaConf -from codeclm.trainer.codec_song_pl import CodecLM_PL +from codeclm.models import builders from codeclm.models import CodecLM from separator import Separator @@ -36,20 +36,24 @@ class LeVoInference(torch.nn.Module): self.max_duration = self.cfg.max_dur # Define model or load pretrained model - model_light = CodecLM_PL(self.cfg, pt_path) + audiolm = builders.get_lm_model(self.cfg, version='v1.5') + checkpoint = torch.load(pt_path, map_location='cpu') + audiolm_state_dict = {k.replace('audiolm.', ''): v for k, v in checkpoint.items() if k.startswith('audiolm')} + audiolm.load_state_dict(audiolm_state_dict, strict=False) + audiolm = audiolm.eval() + audiolm = audiolm.cuda().to(torch.float16) - model_light = model_light.eval().cuda() - model_light.audiolm.cfg = self.cfg + audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) + audio_tokenizer = audio_tokenizer.eval() - self.model_lm = model_light.audiolm - self.model_audio_tokenizer = model_light.audio_tokenizer - self.model_seperate_tokenizer = model_light.seperate_tokenizer + seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) + seperate_tokenizer = seperate_tokenizer.eval() self.model = CodecLM(name = "tmp", - lm = self.model_lm, - audiotokenizer = self.model_audio_tokenizer, + lm = audiolm, + audiotokenizer = audio_tokenizer, max_duration = self.max_duration, - seperate_tokenizer = self.model_seperate_tokenizer, + seperate_tokenizer = seperate_tokenizer, ) self.separator = Separator()