Spaces:
Runtime error
Runtime error
| # Ke Chen | |
| # knutchen@ucsd.edu | |
| # Zero-shot Audio Source Separation via Query-based Learning from Weakly-labeled Data | |
| # The Main Script | |
| import os | |
| # this is to avoid the sdr calculation from occupying all cpus | |
| os.environ["OMP_NUM_THREADS"] = "4" | |
| os.environ["OPENBLAS_NUM_THREADS"] = "4" | |
| os.environ["MKL_NUM_THREADS"] = "6" | |
| os.environ["VECLIB_MAXIMUM_THREADS"] = "4" | |
| os.environ["NUMEXPR_NUM_THREADS"] = "6" | |
| import sys | |
| import librosa | |
| import numpy as np | |
| import argparse | |
| import logging | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torch.utils.data.distributed import DistributedSampler | |
| from utils import collect_fn, dump_config, create_folder, prepprocess_audio | |
| import musdb | |
| from models.asp_model import ZeroShotASP, SeparatorModel, AutoTaggingWarpper, WhitingWarpper | |
| from data_processor import LGSPDataset, MusdbDataset | |
| import config | |
| import htsat_config | |
| from models.htsat import HTSAT_Swin_Transformer | |
| from sed_model import SEDWrapper | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from htsat_utils import process_idc | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| class data_prep(pl.LightningDataModule): | |
| def __init__(self, train_dataset, eval_dataset, device_num, config): | |
| super().__init__() | |
| self.train_dataset = train_dataset | |
| self.eval_dataset = eval_dataset | |
| self.device_num = device_num | |
| self.config = config | |
| def train_dataloader(self): | |
| train_sampler = DistributedSampler(self.train_dataset, shuffle = False) if self.device_num > 1 else None | |
| train_loader = DataLoader( | |
| dataset = self.train_dataset, | |
| num_workers = config.num_workers, | |
| batch_size = config.batch_size // self.device_num, | |
| shuffle = False, | |
| sampler = train_sampler, | |
| collate_fn = collect_fn | |
| ) | |
| return train_loader | |
| def val_dataloader(self): | |
| eval_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None | |
| eval_loader = DataLoader( | |
| dataset = self.eval_dataset, | |
| num_workers = config.num_workers, | |
| batch_size = config.batch_size // self.device_num, | |
| shuffle = False, | |
| sampler = eval_sampler, | |
| collate_fn = collect_fn | |
| ) | |
| return eval_loader | |
| def test_dataloader(self): | |
| test_sampler = DistributedSampler(self.eval_dataset, shuffle = False) if self.device_num > 1 else None | |
| test_loader = DataLoader( | |
| dataset = self.eval_dataset, | |
| num_workers = config.num_workers, | |
| batch_size = config.batch_size // self.device_num, | |
| shuffle = False, | |
| sampler = test_sampler, | |
| collate_fn = collect_fn | |
| ) | |
| return test_loader | |
| def save_idc(): | |
| train_index_path = os.path.join(config.dataset_path, "hdf5s", "indexes", config.index_type + ".h5") | |
| eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5") | |
| process_idc(train_index_path, config.classes_num, config.index_type + "_idc.npy") | |
| process_idc(eval_index_path, config.classes_num, "eval_idc.npy") | |
| # Process the musdb tracks into the sample rate of 32000 Hz sample rate, the original is 44100 Hz | |
| def process_musdb(): | |
| # use musdb as testset | |
| test_data = musdb.DB( | |
| root = config.musdb_path, | |
| download = False, | |
| subsets = "test", | |
| is_wav = True | |
| ) | |
| print(len(test_data.tracks)) | |
| mus_tracks = [] | |
| # in musdb, all fs is the same (44100) | |
| orig_fs = test_data.tracks[0].rate | |
| print(orig_fs) | |
| for track in test_data.tracks: | |
| temp = {} | |
| mixture = prepprocess_audio( | |
| track.audio, | |
| orig_fs, config.sample_rate, | |
| config.test_type | |
| ) | |
| temp["mixture" ]= mixture | |
| for dickey in config.test_key: | |
| source = prepprocess_audio( | |
| track.targets[dickey].audio, | |
| orig_fs, config.sample_rate, | |
| config.test_type | |
| ) | |
| temp[dickey] = source | |
| print(track.audio.shape, len(temp.keys()), temp["mixture"].shape) | |
| mus_tracks.append(temp) | |
| print(len(mus_tracks)) | |
| # save the file to npy | |
| np.save("musdb-32000fs.npy", mus_tracks) | |
| # weight average will perform in the given folder | |
| # It will output one model checkpoint, which avergas the weight of all models in the folder | |
| def weight_average(): | |
| model_ckpt = [] | |
| model_files = os.listdir(config.wa_model_folder) | |
| wa_ckpt = { | |
| "state_dict": {} | |
| } | |
| for model_file in model_files: | |
| model_file = os.path.join(config.esm_model_folder, model_file) | |
| model_ckpt.append(torch.load(model_file, map_location="cpu")["state_dict"]) | |
| keys = model_ckpt[0].keys() | |
| for key in keys: | |
| model_ckpt_key = torch.cat([d[key].float().unsqueeze(0) for d in model_ckpt]) | |
| model_ckpt_key = torch.mean(model_ckpt_key, dim = 0) | |
| assert model_ckpt_key.shape == model_ckpt[0][key].shape, "the shape is unmatched " + model_ckpt_key.shape + " " + model_ckpt[0][key].shape | |
| wa_ckpt["state_dict"][key] = model_ckpt_key | |
| torch.save(wa_ckpt, config.wa_model_path) | |
| # use the model to quickly separate a track given a query | |
| # it requires four variables in config.py: | |
| # inference_file: the track you want to separate | |
| # inference_query: a **folder** containing all samples from the same source | |
| # test_key: ["name"] indicate the source name (just a name for final output, no other functions) | |
| # wave_output_path: the output folder | |
| # make sure the query folder contain the samples from the same source | |
| # each time, the model is able to separate one source from the track | |
| # if you want to separate multiple sources, you need to change the query folder or write a script to help you do that | |
| def inference(): | |
| # set exp settings | |
| device_name = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device("cuda") | |
| assert config.test_key is not None, "there should be a separate key" | |
| create_folder(config.wave_output_path) | |
| test_track, fs = librosa.load(config.inference_file, sr = None) | |
| test_track = test_track[:,None] | |
| print(test_track.shape) | |
| print(fs) | |
| # convert the track into 32000 Hz sample rate | |
| test_track = prepprocess_audio( | |
| test_track, | |
| fs, config.sample_rate, | |
| config.test_type | |
| ) | |
| test_tracks = [] | |
| temp = [test_track] | |
| for dickey in config.test_key: | |
| temp.append(test_track) | |
| temp = np.array(temp) | |
| test_tracks.append(temp) | |
| dataset = MusdbDataset(tracks = test_tracks) # the action is similar to musdbdataset, reuse it | |
| loader = DataLoader( | |
| dataset = dataset, | |
| num_workers = 1, | |
| batch_size = 1, | |
| shuffle = False | |
| ) | |
| # obtain the samples for query | |
| queries = [] | |
| for query_file in os.listdir(config.inference_query): | |
| f_path = os.path.join(config.inference_query, query_file) | |
| if query_file.endswith(".wav"): | |
| temp_q, fs = librosa.load(f_path, sr = None) | |
| temp_q = temp_q[:, None] | |
| temp_q = prepprocess_audio( | |
| temp_q, | |
| fs, config.sample_rate, | |
| config.test_type | |
| ) | |
| temp = [temp_q] | |
| for dickey in config.test_key: | |
| temp.append(temp_q) | |
| temp = np.array(temp) | |
| queries.append(temp) | |
| assert config.resume_checkpoint is not None, "there should be a saved model when inferring" | |
| sed_model = HTSAT_Swin_Transformer( | |
| spec_size=htsat_config.htsat_spec_size, | |
| patch_size=htsat_config.htsat_patch_size, | |
| in_chans=1, | |
| num_classes=htsat_config.classes_num, | |
| window_size=htsat_config.htsat_window_size, | |
| config = htsat_config, | |
| depths = htsat_config.htsat_depth, | |
| embed_dim = htsat_config.htsat_dim, | |
| patch_stride=htsat_config.htsat_stride, | |
| num_heads=htsat_config.htsat_num_head | |
| ) | |
| at_model = SEDWrapper( | |
| sed_model = sed_model, | |
| config = htsat_config, | |
| dataset = None | |
| ) | |
| ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") | |
| at_model.load_state_dict(ckpt["state_dict"]) | |
| trainer = pl.Trainer( | |
| gpus = 1 | |
| ) | |
| avg_at = None | |
| # obtain the latent embedding as query | |
| if config.infer_type == "mean": | |
| avg_dataset = MusdbDataset(tracks = queries) | |
| avg_loader = DataLoader( | |
| dataset = avg_dataset, | |
| num_workers = 1, | |
| batch_size = 1, | |
| shuffle = False | |
| ) | |
| at_wrapper = AutoTaggingWarpper( | |
| at_model = at_model, | |
| config = config, | |
| target_keys = config.test_key | |
| ) | |
| trainer.test(at_wrapper, test_dataloaders = avg_loader) | |
| avg_at = at_wrapper.avg_at | |
| # import seapration model | |
| model = ZeroShotASP( | |
| channels = 1, config = config, | |
| at_model = at_model, | |
| dataset = dataset | |
| ) | |
| # resume checkpoint | |
| ckpt = torch.load(config.resume_checkpoint, map_location="cpu") | |
| model.load_state_dict(ckpt["state_dict"], strict= False) | |
| exp_model = SeparatorModel( | |
| model = model, | |
| config = config, | |
| target_keys = config.test_key, | |
| avg_at = avg_at, | |
| using_wiener = False, | |
| calc_sdr = False, | |
| output_wav = True | |
| ) | |
| trainer.test(exp_model, test_dataloaders = loader) | |
| # test the separation model, mainly in musdb | |
| def test(): | |
| # set exp settings | |
| device_name = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device("cuda") | |
| assert config.test_key is not None, "there should be a separate key" | |
| create_folder(config.wave_output_path) | |
| # use musdb as testset | |
| test_data = np.load(config.testset_path, allow_pickle = True) | |
| print(len(test_data)) | |
| mus_tracks = [] | |
| # in musdb, all fs is the same (44100) | |
| # load the dataset | |
| for track in test_data: | |
| temp = [] | |
| mixture = track["mixture"] | |
| temp.append(mixture) | |
| for dickey in config.test_key: | |
| source = track[dickey] | |
| temp.append(source) | |
| temp = np.array(temp) | |
| print(temp.shape) | |
| mus_tracks.append(temp) | |
| print(len(mus_tracks)) | |
| dataset = MusdbDataset(tracks = mus_tracks) | |
| loader = DataLoader( | |
| dataset = dataset, | |
| num_workers = 1, | |
| batch_size = 1, | |
| shuffle = False | |
| ) | |
| assert config.resume_checkpoint is not None, "there should be a saved model when inferring" | |
| sed_model = HTSAT_Swin_Transformer( | |
| spec_size=htsat_config.htsat_spec_size, | |
| patch_size=htsat_config.htsat_patch_size, | |
| in_chans=1, | |
| num_classes=htsat_config.classes_num, | |
| window_size=htsat_config.htsat_window_size, | |
| config = htsat_config, | |
| depths = htsat_config.htsat_depth, | |
| embed_dim = htsat_config.htsat_dim, | |
| patch_stride=htsat_config.htsat_stride, | |
| num_heads=htsat_config.htsat_num_head | |
| ) | |
| at_model = SEDWrapper( | |
| sed_model = sed_model, | |
| config = htsat_config, | |
| dataset = None | |
| ) | |
| ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") | |
| at_model.load_state_dict(ckpt["state_dict"]) | |
| trainer = pl.Trainer( | |
| gpus = 1 | |
| ) | |
| avg_at = None | |
| # obtain the query of four stems from the training set | |
| if config.infer_type == "mean": | |
| avg_data = np.load(config.testavg_path, allow_pickle = True)[:90] | |
| print(len(avg_data)) | |
| avgmus_tracks = [] | |
| # in musdb, all fs is the same (44100) | |
| # load the dataset | |
| for track in avg_data: | |
| temp = [] | |
| mixture = track["mixture"] | |
| temp.append(mixture) | |
| for dickey in config.test_key: | |
| source = track[dickey] | |
| temp.append(source) | |
| temp = np.array(temp) | |
| print(temp.shape) | |
| avgmus_tracks.append(temp) | |
| print(len(avgmus_tracks)) | |
| avg_dataset = MusdbDataset(tracks = avgmus_tracks) | |
| avg_loader = DataLoader( | |
| dataset = avg_dataset, | |
| num_workers = 1, | |
| batch_size = 1, | |
| shuffle = False | |
| ) | |
| at_wrapper = AutoTaggingWarpper( | |
| at_model = at_model, | |
| config = config, | |
| target_keys = config.test_key | |
| ) | |
| trainer.test(at_wrapper, test_dataloaders = avg_loader) | |
| avg_at = at_wrapper.avg_at | |
| model = ZeroShotASP( | |
| channels = 1, config = config, | |
| at_model = at_model, | |
| dataset = dataset | |
| ) | |
| ckpt = torch.load(config.resume_checkpoint, map_location="cpu") | |
| model.load_state_dict(ckpt["state_dict"], strict= False) | |
| exp_model = SeparatorModel( | |
| model = model, | |
| config = config, | |
| target_keys = config.test_key, | |
| avg_at = avg_at, | |
| using_wiener = config.using_wiener | |
| ) | |
| trainer.test(exp_model, test_dataloaders = loader) | |
| def train(): | |
| # set exp settings | |
| # device_name = "cuda" if torch.cuda.is_available() else "cpu" | |
| # device = torch.device("cuda") | |
| device_num = torch.cuda.device_count() | |
| print("each batch size:", config.batch_size // device_num) | |
| train_index_path = os.path.join(config.dataset_path, "hdf5s","indexes", config.index_type + ".h5") | |
| train_idc = np.load(os.path.join(config.idc_path, config.index_type + "_idc.npy"), allow_pickle = True) | |
| eval_index_path = os.path.join(config.dataset_path,"hdf5s", "indexes", "eval.h5") | |
| eval_idc = np.load(os.path.join(config.idc_path, "eval_idc.npy"), allow_pickle = True) | |
| # set exp folder | |
| exp_dir = os.path.join(config.workspace, "results", config.exp_name) | |
| checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint") | |
| if not config.debug: | |
| create_folder(os.path.join(config.workspace, "results")) | |
| create_folder(exp_dir) | |
| create_folder(checkpoint_dir) | |
| dump_config(config, os.path.join(exp_dir, config.exp_name), False) | |
| # load data | |
| # import dataset LGSPDataset (latent general source separation) and sampler | |
| dataset = LGSPDataset( | |
| index_path = train_index_path, | |
| idc = train_idc, | |
| config = config, | |
| factor = 0.05, | |
| eval_mode = False | |
| ) | |
| eval_dataset = LGSPDataset( | |
| index_path = eval_index_path, | |
| idc = eval_idc, | |
| config = config, | |
| factor = 0.05, | |
| eval_mode = True | |
| ) | |
| audioset_data = data_prep(train_dataset=dataset,eval_dataset=eval_dataset,device_num=device_num, config=config) | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor = "mixture_sdr", | |
| filename='l-{epoch:d}-{mixture_sdr:.3f}-{clean_sdr:.3f}-{silence_sdr:.3f}', | |
| save_top_k = 10, | |
| mode = "max" | |
| ) | |
| # infer at model | |
| sed_model = HTSAT_Swin_Transformer( | |
| spec_size=htsat_config.htsat_spec_size, | |
| patch_size=htsat_config.htsat_patch_size, | |
| in_chans=1, | |
| num_classes=htsat_config.classes_num, | |
| window_size=htsat_config.htsat_window_size, | |
| config = htsat_config, | |
| depths = htsat_config.htsat_depth, | |
| embed_dim = htsat_config.htsat_dim, | |
| patch_stride=htsat_config.htsat_stride, | |
| num_heads=htsat_config.htsat_num_head | |
| ) | |
| at_model = SEDWrapper( | |
| sed_model = sed_model, | |
| config = htsat_config, | |
| dataset = None | |
| ) | |
| # load the checkpoint | |
| ckpt = torch.load(htsat_config.resume_checkpoint, map_location="cpu") | |
| at_model.load_state_dict(ckpt["state_dict"]) | |
| trainer = pl.Trainer( | |
| deterministic=True, | |
| default_root_dir = checkpoint_dir, | |
| gpus = device_num, | |
| val_check_interval = 0.2, | |
| # check_val_every_n_epoch = 1, | |
| max_epochs = config.max_epoch, | |
| auto_lr_find = True, | |
| sync_batchnorm = True, | |
| callbacks = [checkpoint_callback], | |
| accelerator = "ddp" if device_num > 1 else None, | |
| resume_from_checkpoint = None, #config.resume_checkpoint, | |
| replace_sampler_ddp = False, | |
| gradient_clip_val=1.0, | |
| num_sanity_val_steps = 0, | |
| ) | |
| model = ZeroShotASP( | |
| channels = 1, config = config, | |
| at_model = at_model, | |
| dataset = dataset | |
| ) | |
| if config.resume_checkpoint is not None: | |
| ckpt = torch.load(config.resume_checkpoint, map_location="cpu") | |
| model.load_state_dict(ckpt["state_dict"]) | |
| # trainer.test(model, datamodule = audioset_data) | |
| trainer.fit(model, audioset_data) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="latent genreal source separation parser") | |
| subparsers = parser.add_subparsers(dest = "mode") | |
| parser_train = subparsers.add_parser("train") | |
| parser_test = subparsers.add_parser("test") | |
| parser_musdb = subparsers.add_parser("musdb_process") | |
| parser_saveidc = subparsers.add_parser("save_idc") | |
| parser_wa = subparsers.add_parser("weight_average") | |
| parser_infer = subparsers.add_parser("inference") | |
| args = parser.parse_args() | |
| # default settings | |
| logging.basicConfig(level=logging.INFO) | |
| pl.utilities.seed.seed_everything(seed = config.random_seed) | |
| if args.mode == "train": | |
| train() | |
| elif args.mode == "test": | |
| test() | |
| elif args.mode == "musdb_process": | |
| process_musdb() | |
| elif args.mode == "weight_average": | |
| weight_average() | |
| elif args.mode == "save_idc": | |
| save_idc() | |
| elif args.mode == "inference": | |
| inference() | |
| else: | |
| raise Exception("Error Mode!") | |
| if __name__ == '__main__': | |
| main() | |