| |
| __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' |
|
|
|
|
| import os |
| import random |
| import numpy as np |
| import torch |
| import soundfile as sf |
| import pickle |
| import time |
| import itertools |
| import multiprocessing |
| from tqdm.auto import tqdm |
| from glob import glob |
| import audiomentations as AU |
| import pedalboard as PB |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
|
|
| def load_chunk(path, length, chunk_size, offset=None): |
| if chunk_size <= length: |
| if offset is None: |
| offset = np.random.randint(length - chunk_size + 1) |
| x = sf.read(path, dtype='float32', start=offset, frames=chunk_size)[0] |
| else: |
| x = sf.read(path, dtype='float32')[0] |
| if len(x.shape) == 1: |
| |
| pad = np.zeros((chunk_size - length)) |
| else: |
| pad = np.zeros([chunk_size - length, x.shape[-1]]) |
| x = np.concatenate([x, pad], axis=0) |
| |
| if len(x.shape) == 1: |
| x = np.expand_dims(x, axis=1) |
| return x.T |
|
|
|
|
| def get_track_set_length(params): |
| path, instruments, file_types = params |
| |
| lengths_arr = [] |
| for instr in instruments: |
| length = -1 |
| for extension in file_types: |
| path_to_audio_file = path + '/{}.{}'.format(instr, extension) |
| if os.path.isfile(path_to_audio_file): |
| length = len(sf.read(path_to_audio_file)[0]) |
| break |
| if length == -1: |
| print('Cant find file "{}" in folder {}'.format(instr, path)) |
| continue |
| lengths_arr.append(length) |
| lengths_arr = np.array(lengths_arr) |
| if lengths_arr.min() != lengths_arr.max(): |
| print('Warning: lengths of stems are different for path: {}. ({} != {})'.format( |
| path, |
| lengths_arr.min(), |
| lengths_arr.max()) |
| ) |
| |
| return path, lengths_arr.min() |
|
|
|
|
| |
| def get_track_length(params): |
| path = params |
| length = len(sf.read(path)[0]) |
| return (path, length) |
|
|
|
|
| class MSSDataset(torch.utils.data.Dataset): |
| def __init__(self, config, data_path, metadata_path="metadata.pkl", dataset_type=1, batch_size=None, verbose=True): |
| self.verbose = verbose |
| self.config = config |
| self.dataset_type = dataset_type |
| self.data_path = data_path |
| self.instruments = instruments = config.training.instruments |
| if batch_size is None: |
| batch_size = config.training.batch_size |
| self.batch_size = batch_size |
| self.file_types = ['wav', 'flac'] |
| self.metadata_path = metadata_path |
|
|
| |
| self.aug = False |
| if 'augmentations' in config: |
| if config['augmentations'].enable is True: |
| if self.verbose: |
| print('Use augmentation for training') |
| self.aug = True |
| else: |
| if self.verbose: |
| print('There is no augmentations block in config. Augmentations disabled for training...') |
|
|
| metadata = self.get_metadata() |
|
|
| if self.dataset_type in [1, 4]: |
| if len(metadata) > 0: |
| if self.verbose: |
| print('Found tracks in dataset: {}'.format(len(metadata))) |
| else: |
| print('No tracks found for training. Check paths you provided!') |
| exit() |
| else: |
| for instr in self.instruments: |
| if self.verbose: |
| print('Found tracks for {} in dataset: {}'.format(instr, len(metadata[instr]))) |
| self.metadata = metadata |
| self.chunk_size = config.audio.chunk_size |
| self.min_mean_abs = config.audio.min_mean_abs |
|
|
| def __len__(self): |
| return self.config.training.num_steps * self.batch_size |
|
|
| def read_from_metadata_cache(self, track_paths, instr=None): |
| metadata = [] |
| if os.path.isfile(self.metadata_path): |
| if self.verbose: |
| print('Found metadata cache file: {}'.format(self.metadata_path)) |
| old_metadata = pickle.load(open(self.metadata_path, 'rb')) |
| else: |
| return track_paths, metadata |
|
|
| if instr: |
| old_metadata = old_metadata[instr] |
|
|
| |
| track_paths_set = set(track_paths) |
| for old_path, file_size in old_metadata: |
| if old_path in track_paths_set: |
| metadata.append([old_path, file_size]) |
| track_paths_set.remove(old_path) |
| track_paths = list(track_paths_set) |
| if len(metadata) > 0: |
| print('Old metadata was used for {} tracks.'.format(len(metadata))) |
| return track_paths, metadata |
|
|
|
|
| def get_metadata(self): |
| read_metadata_procs = multiprocessing.cpu_count() |
| if 'read_metadata_procs' in self.config['training']: |
| read_metadata_procs = int(self.config['training']['read_metadata_procs']) |
|
|
| if self.verbose: |
| print( |
| 'Dataset type:', self.dataset_type, |
| 'Processes to use:', read_metadata_procs, |
| '\nCollecting metadata for', str(self.data_path), |
| ) |
|
|
| if self.dataset_type in [1, 4]: |
| track_paths = [] |
| if type(self.data_path) == list: |
| for tp in self.data_path: |
| tracks_for_folder = sorted(glob(tp + '/*')) |
| if len(tracks_for_folder) == 0: |
| print('Warning: no tracks found in folder \'{}\'. Please check it!'.format(tp)) |
| track_paths += tracks_for_folder |
| else: |
| track_paths += sorted(glob(self.data_path + '/*')) |
|
|
| track_paths = [path for path in track_paths if os.path.basename(path)[0] != '.' and os.path.isdir(path)] |
| track_paths, metadata = self.read_from_metadata_cache(track_paths, None) |
|
|
| if read_metadata_procs <= 1: |
| for path in tqdm(track_paths): |
| track_path, track_length = get_track_set_length((path, self.instruments, self.file_types)) |
| metadata.append((track_path, track_length)) |
| else: |
| p = multiprocessing.Pool(processes=read_metadata_procs) |
| with tqdm(total=len(track_paths)) as pbar: |
| track_iter = p.imap( |
| get_track_set_length, |
| zip(track_paths, itertools.repeat(self.instruments), itertools.repeat(self.file_types)) |
| ) |
| for track_path, track_length in track_iter: |
| metadata.append((track_path, track_length)) |
| pbar.update() |
| p.close() |
|
|
| elif self.dataset_type == 2: |
| metadata = dict() |
| for instr in self.instruments: |
| metadata[instr] = [] |
| track_paths = [] |
| if type(self.data_path) == list: |
| for tp in self.data_path: |
| track_paths += sorted(glob(tp + '/{}/*.wav'.format(instr))) |
| track_paths += sorted(glob(tp + '/{}/*.flac'.format(instr))) |
| else: |
| track_paths += sorted(glob(self.data_path + '/{}/*.wav'.format(instr))) |
| track_paths += sorted(glob(self.data_path + '/{}/*.flac'.format(instr))) |
|
|
| track_paths, metadata[instr] = self.read_from_metadata_cache(track_paths, instr) |
|
|
| if read_metadata_procs <= 1: |
| for path in tqdm(track_paths): |
| length = len(sf.read(path)[0]) |
| metadata[instr].append((path, length)) |
| else: |
| p = multiprocessing.Pool(processes=read_metadata_procs) |
| for out in tqdm(p.imap(get_track_length, track_paths), total=len(track_paths)): |
| metadata[instr].append(out) |
|
|
| elif self.dataset_type == 3: |
| import pandas as pd |
| if type(self.data_path) != list: |
| data_path = [self.data_path] |
|
|
| metadata = dict() |
| for i in range(len(self.data_path)): |
| if self.verbose: |
| print('Reading tracks from: {}'.format(self.data_path[i])) |
| df = pd.read_csv(self.data_path[i]) |
|
|
| skipped = 0 |
| for instr in self.instruments: |
| part = df[df['instrum'] == instr].copy() |
| print('Tracks found for {}: {}'.format(instr, len(part))) |
| for instr in self.instruments: |
| part = df[df['instrum'] == instr].copy() |
| metadata[instr] = [] |
| track_paths = list(part['path'].values) |
| track_paths, metadata[instr] = self.read_from_metadata_cache(track_paths, instr) |
|
|
| for path in tqdm(track_paths): |
| if not os.path.isfile(path): |
| print('Cant find track: {}'.format(path)) |
| skipped += 1 |
| continue |
| |
| try: |
| length = len(sf.read(path)[0]) |
| except: |
| print('Problem with path: {}'.format(path)) |
| skipped += 1 |
| continue |
| metadata[instr].append((path, length)) |
| if skipped > 0: |
| print('Missing tracks: {} from {}'.format(skipped, len(df))) |
| else: |
| print('Unknown dataset type: {}. Must be 1, 2, 3 or 4'.format(self.dataset_type)) |
| exit() |
|
|
| |
| pickle.dump(metadata, open(self.metadata_path, 'wb')) |
| return metadata |
|
|
| def load_source(self, metadata, instr): |
| while True: |
| if self.dataset_type in [1, 4]: |
| track_path, track_length = random.choice(metadata) |
| for extension in self.file_types: |
| path_to_audio_file = track_path + '/{}.{}'.format(instr, extension) |
| if os.path.isfile(path_to_audio_file): |
| try: |
| source = load_chunk(path_to_audio_file, track_length, self.chunk_size) |
| except Exception as e: |
| |
| print('Error: {} Path: {}'.format(e, path_to_audio_file)) |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) |
| break |
| else: |
| track_path, track_length = random.choice(metadata[instr]) |
| try: |
| source = load_chunk(track_path, track_length, self.chunk_size) |
| except Exception as e: |
| |
| print('Error: {} Path: {}'.format(e, track_path)) |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) |
|
|
| if np.abs(source).mean() >= self.min_mean_abs: |
| break |
| if self.aug: |
| source = self.augm_data(source, instr) |
| return torch.tensor(source, dtype=torch.float32) |
|
|
| def load_random_mix(self): |
| res = [] |
| for instr in self.instruments: |
| s1 = self.load_source(self.metadata, instr) |
| |
| if self.aug: |
| if 'mixup' in self.config['augmentations']: |
| if self.config['augmentations'].mixup: |
| mixup = [s1] |
| for prob in self.config.augmentations.mixup_probs: |
| if random.uniform(0, 1) < prob: |
| s2 = self.load_source(self.metadata, instr) |
| mixup.append(s2) |
| mixup = torch.stack(mixup, dim=0) |
| loud_values = np.random.uniform( |
| low=self.config.augmentations.loudness_min, |
| high=self.config.augmentations.loudness_max, |
| size=(len(mixup),) |
| ) |
| loud_values = torch.tensor(loud_values, dtype=torch.float32) |
| mixup *= loud_values[:, None, None] |
| s1 = mixup.mean(dim=0, dtype=torch.float32) |
| res.append(s1) |
| res = torch.stack(res) |
| return res |
|
|
| def load_aligned_data(self): |
| track_path, track_length = random.choice(self.metadata) |
| attempts = 10 |
| while attempts: |
| if track_length >= self.chunk_size: |
| common_offset = np.random.randint(track_length - self.chunk_size + 1) |
| else: |
| common_offset = None |
| res = [] |
| silent_chunks = 0 |
| for i in self.instruments: |
| for extension in self.file_types: |
| path_to_audio_file = track_path + '/{}.{}'.format(i, extension) |
| if os.path.isfile(path_to_audio_file): |
| try: |
| source = load_chunk(path_to_audio_file, track_length, self.chunk_size, offset=common_offset) |
| except Exception as e: |
| |
| print('Error: {} Path: {}'.format(e, path_to_audio_file)) |
| source = np.zeros((2, self.chunk_size), dtype=np.float32) |
| break |
| res.append(source) |
| if np.abs(source).mean() < self.min_mean_abs: |
| silent_chunks += 1 |
| if silent_chunks == 0: |
| break |
|
|
| attempts -= 1 |
| if attempts <= 0: |
| print('Attempts max!', track_path) |
| if common_offset is None: |
| |
| break |
|
|
| res = np.stack(res, axis=0) |
| if self.aug: |
| for i, instr in enumerate(self.instruments): |
| res[i] = self.augm_data(res[i], instr) |
| return torch.tensor(res, dtype=torch.float32) |
|
|
| def augm_data(self, source, instr): |
| |
| source_shape = source.shape |
| applied_augs = [] |
| if 'all' in self.config['augmentations']: |
| augs = self.config['augmentations']['all'] |
| else: |
| augs = dict() |
|
|
| |
| if instr in self.config['augmentations']: |
| for el in self.config['augmentations'][instr]: |
| augs[el] = self.config['augmentations'][instr][el] |
|
|
| |
| if 'channel_shuffle' in augs: |
| if augs['channel_shuffle'] > 0: |
| if random.uniform(0, 1) < augs['channel_shuffle']: |
| source = source[::-1].copy() |
| applied_augs.append('channel_shuffle') |
| |
| if 'random_inverse' in augs: |
| if augs['random_inverse'] > 0: |
| if random.uniform(0, 1) < augs['random_inverse']: |
| source = source[:, ::-1].copy() |
| applied_augs.append('random_inverse') |
| |
| if 'random_polarity' in augs: |
| if augs['random_polarity'] > 0: |
| if random.uniform(0, 1) < augs['random_polarity']: |
| source = -source.copy() |
| applied_augs.append('random_polarity') |
| |
| if 'pitch_shift' in augs: |
| if augs['pitch_shift'] > 0: |
| if random.uniform(0, 1) < augs['pitch_shift']: |
| apply_aug = AU.PitchShift( |
| min_semitones=augs['pitch_shift_min_semitones'], |
| max_semitones=augs['pitch_shift_max_semitones'], |
| p=1.0 |
| ) |
| source = apply_aug(samples=source, sample_rate=44100) |
| applied_augs.append('pitch_shift') |
| |
| if 'seven_band_parametric_eq' in augs: |
| if augs['seven_band_parametric_eq'] > 0: |
| if random.uniform(0, 1) < augs['seven_band_parametric_eq']: |
| apply_aug = AU.SevenBandParametricEQ( |
| min_gain_db=augs['seven_band_parametric_eq_min_gain_db'], |
| max_gain_db=augs['seven_band_parametric_eq_max_gain_db'], |
| p=1.0 |
| ) |
| source = apply_aug(samples=source, sample_rate=44100) |
| applied_augs.append('seven_band_parametric_eq') |
| |
| if 'tanh_distortion' in augs: |
| if augs['tanh_distortion'] > 0: |
| if random.uniform(0, 1) < augs['tanh_distortion']: |
| apply_aug = AU.TanhDistortion( |
| min_distortion=augs['tanh_distortion_min'], |
| max_distortion=augs['tanh_distortion_max'], |
| p=1.0 |
| ) |
| source = apply_aug(samples=source, sample_rate=44100) |
| applied_augs.append('tanh_distortion') |
| |
| if 'mp3_compression' in augs: |
| if augs['mp3_compression'] > 0: |
| if random.uniform(0, 1) < augs['mp3_compression']: |
| apply_aug = AU.Mp3Compression( |
| min_bitrate=augs['mp3_compression_min_bitrate'], |
| max_bitrate=augs['mp3_compression_max_bitrate'], |
| backend=augs['mp3_compression_backend'], |
| p=1.0 |
| ) |
| source = apply_aug(samples=source, sample_rate=44100) |
| applied_augs.append('mp3_compression') |
| |
| if 'gaussian_noise' in augs: |
| if augs['gaussian_noise'] > 0: |
| if random.uniform(0, 1) < augs['gaussian_noise']: |
| apply_aug = AU.AddGaussianNoise( |
| min_amplitude=augs['gaussian_noise_min_amplitude'], |
| max_amplitude=augs['gaussian_noise_max_amplitude'], |
| p=1.0 |
| ) |
| source = apply_aug(samples=source, sample_rate=44100) |
| applied_augs.append('gaussian_noise') |
| |
| if 'time_stretch' in augs: |
| if augs['time_stretch'] > 0: |
| if random.uniform(0, 1) < augs['time_stretch']: |
| apply_aug = AU.TimeStretch( |
| min_rate=augs['time_stretch_min_rate'], |
| max_rate=augs['time_stretch_max_rate'], |
| leave_length_unchanged=True, |
| p=1.0 |
| ) |
| source = apply_aug(samples=source, sample_rate=44100) |
| applied_augs.append('time_stretch') |
|
|
| |
| if source_shape != source.shape: |
| source = source[..., :source_shape[-1]] |
|
|
| |
| if 'pedalboard_reverb' in augs: |
| if augs['pedalboard_reverb'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_reverb']: |
| room_size = random.uniform( |
| augs['pedalboard_reverb_room_size_min'], |
| augs['pedalboard_reverb_room_size_max'], |
| ) |
| damping = random.uniform( |
| augs['pedalboard_reverb_damping_min'], |
| augs['pedalboard_reverb_damping_max'], |
| ) |
| wet_level = random.uniform( |
| augs['pedalboard_reverb_wet_level_min'], |
| augs['pedalboard_reverb_wet_level_max'], |
| ) |
| dry_level = random.uniform( |
| augs['pedalboard_reverb_dry_level_min'], |
| augs['pedalboard_reverb_dry_level_max'], |
| ) |
| width = random.uniform( |
| augs['pedalboard_reverb_width_min'], |
| augs['pedalboard_reverb_width_max'], |
| ) |
| board = PB.Pedalboard([PB.Reverb( |
| room_size=room_size, |
| damping=damping, |
| wet_level=wet_level, |
| dry_level=dry_level, |
| width=width, |
| freeze_mode=0.0, |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_reverb') |
|
|
| |
| if 'pedalboard_chorus' in augs: |
| if augs['pedalboard_chorus'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_chorus']: |
| rate_hz = random.uniform( |
| augs['pedalboard_chorus_rate_hz_min'], |
| augs['pedalboard_chorus_rate_hz_max'], |
| ) |
| depth = random.uniform( |
| augs['pedalboard_chorus_depth_min'], |
| augs['pedalboard_chorus_depth_max'], |
| ) |
| centre_delay_ms = random.uniform( |
| augs['pedalboard_chorus_centre_delay_ms_min'], |
| augs['pedalboard_chorus_centre_delay_ms_max'], |
| ) |
| feedback = random.uniform( |
| augs['pedalboard_chorus_feedback_min'], |
| augs['pedalboard_chorus_feedback_max'], |
| ) |
| mix = random.uniform( |
| augs['pedalboard_chorus_mix_min'], |
| augs['pedalboard_chorus_mix_max'], |
| ) |
| board = PB.Pedalboard([PB.Chorus( |
| rate_hz=rate_hz, |
| depth=depth, |
| centre_delay_ms=centre_delay_ms, |
| feedback=feedback, |
| mix=mix, |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_chorus') |
|
|
| |
| if 'pedalboard_phazer' in augs: |
| if augs['pedalboard_phazer'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_phazer']: |
| rate_hz = random.uniform( |
| augs['pedalboard_phazer_rate_hz_min'], |
| augs['pedalboard_phazer_rate_hz_max'], |
| ) |
| depth = random.uniform( |
| augs['pedalboard_phazer_depth_min'], |
| augs['pedalboard_phazer_depth_max'], |
| ) |
| centre_frequency_hz = random.uniform( |
| augs['pedalboard_phazer_centre_frequency_hz_min'], |
| augs['pedalboard_phazer_centre_frequency_hz_max'], |
| ) |
| feedback = random.uniform( |
| augs['pedalboard_phazer_feedback_min'], |
| augs['pedalboard_phazer_feedback_max'], |
| ) |
| mix = random.uniform( |
| augs['pedalboard_phazer_mix_min'], |
| augs['pedalboard_phazer_mix_max'], |
| ) |
| board = PB.Pedalboard([PB.Phaser( |
| rate_hz=rate_hz, |
| depth=depth, |
| centre_frequency_hz=centre_frequency_hz, |
| feedback=feedback, |
| mix=mix, |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_phazer') |
|
|
| |
| if 'pedalboard_distortion' in augs: |
| if augs['pedalboard_distortion'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_distortion']: |
| drive_db = random.uniform( |
| augs['pedalboard_distortion_drive_db_min'], |
| augs['pedalboard_distortion_drive_db_max'], |
| ) |
| board = PB.Pedalboard([PB.Distortion( |
| drive_db=drive_db, |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_distortion') |
|
|
| |
| if 'pedalboard_pitch_shift' in augs: |
| if augs['pedalboard_pitch_shift'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_pitch_shift']: |
| semitones = random.uniform( |
| augs['pedalboard_pitch_shift_semitones_min'], |
| augs['pedalboard_pitch_shift_semitones_max'], |
| ) |
| board = PB.Pedalboard([PB.PitchShift( |
| semitones=semitones |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_pitch_shift') |
|
|
| |
| if 'pedalboard_resample' in augs: |
| if augs['pedalboard_resample'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_resample']: |
| target_sample_rate = random.uniform( |
| augs['pedalboard_resample_target_sample_rate_min'], |
| augs['pedalboard_resample_target_sample_rate_max'], |
| ) |
| board = PB.Pedalboard([PB.Resample( |
| target_sample_rate=target_sample_rate |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_resample') |
|
|
| |
| if 'pedalboard_bitcrash' in augs: |
| if augs['pedalboard_bitcrash'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_bitcrash']: |
| bit_depth = random.uniform( |
| augs['pedalboard_bitcrash_bit_depth_min'], |
| augs['pedalboard_bitcrash_bit_depth_max'], |
| ) |
| board = PB.Pedalboard([PB.Bitcrush( |
| bit_depth=bit_depth |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_bitcrash') |
|
|
| |
| if 'pedalboard_mp3_compressor' in augs: |
| if augs['pedalboard_mp3_compressor'] > 0: |
| if random.uniform(0, 1) < augs['pedalboard_mp3_compressor']: |
| vbr_quality = random.uniform( |
| augs['pedalboard_mp3_compressor_pedalboard_mp3_compressor_min'], |
| augs['pedalboard_mp3_compressor_pedalboard_mp3_compressor_max'], |
| ) |
| board = PB.Pedalboard([PB.MP3Compressor( |
| vbr_quality=vbr_quality |
| )]) |
| source = board(source, 44100) |
| applied_augs.append('pedalboard_mp3_compressor') |
|
|
| |
| return source |
|
|
| def __getitem__(self, index): |
| if self.dataset_type in [1, 2, 3]: |
| res = self.load_random_mix() |
| else: |
| res = self.load_aligned_data() |
|
|
| |
| if self.aug: |
| if 'loudness' in self.config['augmentations']: |
| if self.config['augmentations']['loudness']: |
| loud_values = np.random.uniform( |
| low=self.config['augmentations']['loudness_min'], |
| high=self.config['augmentations']['loudness_max'], |
| size=(len(res),) |
| ) |
| loud_values = torch.tensor(loud_values, dtype=torch.float32) |
| res *= loud_values[:, None, None] |
|
|
| mix = res.sum(0) |
|
|
| if self.aug: |
| if 'mp3_compression_on_mixture' in self.config['augmentations']: |
| apply_aug = AU.Mp3Compression( |
| min_bitrate=self.config['augmentations']['mp3_compression_on_mixture_bitrate_min'], |
| max_bitrate=self.config['augmentations']['mp3_compression_on_mixture_bitrate_max'], |
| backend=self.config['augmentations']['mp3_compression_on_mixture_backend'], |
| p=self.config['augmentations']['mp3_compression_on_mixture'] |
| ) |
| mix_conv = mix.cpu().numpy().astype(np.float32) |
| required_shape = mix_conv.shape |
| mix = apply_aug(samples=mix_conv, sample_rate=44100) |
| |
| if mix.shape != required_shape: |
| mix = mix[..., :required_shape[-1]] |
| mix = torch.tensor(mix, dtype=torch.float32) |
|
|
| |
| if self.config.training.target_instrument is not None: |
| index = self.config.training.instruments.index(self.config.training.target_instrument) |
| return res[index:index+1], mix |
|
|
| return res, mix |
|
|