Spaces:
Sleeping
Sleeping
| import os | |
| from time import perf_counter | |
| import datasets | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch.utils.data import ( | |
| Dataset as TorchDataset, | |
| DistributedSampler, | |
| WeightedRandomSampler, | |
| ) | |
| from data_util.audioset_classes import as_strong_train_classes | |
| from data_util.transforms import ( | |
| Mp3DecodeTransform, | |
| SequentialTransform, | |
| AddPseudoLabelsTransform, | |
| strong_label_transform, | |
| target_transform | |
| ) | |
| logger = datasets.logging.get_logger(__name__) | |
| def init_hf_config(max_shard_size="2GB", verbose=True, in_mem_max=None): | |
| datasets.config.MAX_SHARD_SIZE = max_shard_size | |
| if verbose: | |
| datasets.logging.set_verbosity_info() | |
| if in_mem_max is not None: | |
| datasets.config.IN_MEMORY_MAX_SIZE = in_mem_max | |
| def get_hf_local_path(path, local_datasets_path=None): | |
| if local_datasets_path is None: | |
| local_datasets_path = os.environ.get( | |
| "HF_DATASETS_LOCAL", | |
| os.path.join(os.environ.get("HF_DATASETS_CACHE"), "../local"), | |
| ) | |
| path = os.path.join(local_datasets_path, path) | |
| return path | |
| class catchtime: | |
| # context to measure loading time: https://stackoverflow.com/questions/33987060/python-context-manager-that-measures-time | |
| def __init__(self, debug_print="Time", logger=logger): | |
| self.debug_print = debug_print | |
| self.logger = logger | |
| def __enter__(self): | |
| self.start = perf_counter() | |
| return self | |
| def __exit__(self, type, value, traceback): | |
| self.time = perf_counter() - self.start | |
| readout = f"{self.debug_print}: {self.time:.3f} seconds" | |
| self.logger.info(readout) | |
| def merge_overlapping_events(sample): | |
| events = pd.DataFrame(sample['events'][0]) | |
| events = events.sort_values(by='onset') | |
| sample['events'] = [None] | |
| for l in events['event_label'].unique(): | |
| rows = [] | |
| for i, r in events.loc[events['event_label'] == l].iterrows(): | |
| if len(rows) == 0 or rows[-1]['offset'] < r['onset']: | |
| rows.append(r) | |
| else: | |
| onset = min(rows[-1]['onset'], r['onset']) | |
| offset = max(rows[-1]['offset'], r['offset']) | |
| rows[-1]['onset'] = onset | |
| rows[-1]['offset'] = offset | |
| if sample["events"][0] is None: | |
| sample['events'][0] = pd.DataFrame(rows) | |
| else: | |
| sample["events"][0] = pd.concat([sample['events'][0], pd.DataFrame(rows)]) | |
| return sample | |
| def get_training_dataset( | |
| label_encoder, | |
| audio_length=10.0, | |
| sample_rate=16000, | |
| wavmix_p=0.0, | |
| pseudo_labels_file=None, | |
| ): | |
| init_hf_config() | |
| decode_transform = Mp3DecodeTransform( | |
| sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename" | |
| ) | |
| ds_list = [] | |
| with catchtime("Loading audioset_strong"): | |
| as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong")) | |
| # label encode transformation | |
| if label_encoder is not None: | |
| # set list of label names to be encoded | |
| label_encoder.labels = as_strong_train_classes | |
| encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder) | |
| else: | |
| encode_label_fun = lambda x: x | |
| as_transforms = [ | |
| decode_transform, | |
| merge_overlapping_events, | |
| encode_label_fun, | |
| target_transform, | |
| ] | |
| if pseudo_labels_file: | |
| as_transforms.append(AddPseudoLabelsTransform(pseudo_labels_file=pseudo_labels_file).add_pseudo_label_transform) | |
| as_ds.set_transform(SequentialTransform(as_transforms)) | |
| ds_list.append(as_ds["balanced_train"]) | |
| ds_list.append(as_ds["unbalanced_train"]) | |
| dataset = torch.utils.data.ConcatDataset(ds_list) | |
| if wavmix_p > 0: | |
| print("Using Wavmix!") | |
| dataset = MixupDataset(dataset, rate=wavmix_p) | |
| return dataset | |
| def get_eval_dataset( | |
| label_encoder, | |
| audio_length=10.0, | |
| sample_rate=16000 | |
| ): | |
| init_hf_config() | |
| ds_list = [] | |
| decode_transform = Mp3DecodeTransform( | |
| sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename" | |
| ) | |
| with catchtime(f"Loading audioset:"): | |
| as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong")) | |
| # label encode transformation | |
| if label_encoder is not None: | |
| label_encoder.labels = as_strong_train_classes | |
| encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder) | |
| else: | |
| encode_label_fun = lambda x: x | |
| as_transforms = [ | |
| decode_transform, | |
| merge_overlapping_events, | |
| encode_label_fun, | |
| target_transform | |
| ] | |
| as_ds.set_transform(SequentialTransform(as_transforms)) | |
| as_ds_eval = ( | |
| as_ds["eval"] | |
| ) | |
| ds_list.append(as_ds_eval) | |
| dataset = torch.utils.data.ConcatDataset(ds_list) | |
| return dataset | |
| def get_full_dataset(label_encoder, audio_length=10.0, sample_rate=16000): | |
| init_hf_config() | |
| decode_transform = Mp3DecodeTransform( | |
| sample_rate=sample_rate, max_length=audio_length, debug_info_key="filename" | |
| ) | |
| with catchtime(f"Loading audioset:"): | |
| as_ds = datasets.load_from_disk(get_hf_local_path("audioset_strong")) | |
| # label encode transformation | |
| if label_encoder is not None: | |
| label_encoder.labels = as_strong_train_classes | |
| encode_label_fun = lambda x: strong_label_transform(x, strong_label_encoder=label_encoder) | |
| else: | |
| encode_label_fun = lambda x: x | |
| as_transforms = [ | |
| decode_transform, | |
| merge_overlapping_events, | |
| encode_label_fun, | |
| ] | |
| as_ds.set_transform(SequentialTransform(as_transforms)) | |
| ds_list = [] | |
| ds_list.append(as_ds["balanced_train"]) | |
| ds_list.append(as_ds["unbalanced_train"]) | |
| ds_list.append(as_ds["eval"]) | |
| dataset = torch.utils.data.ConcatDataset(ds_list) | |
| return dataset | |
| def get_uniform_sample_weights(dataset): | |
| """ | |
| :return: float tensor of shape len(full_training_set) representing the weights of each sample. | |
| """ | |
| return torch.ones(len(dataset)).float() | |
| def get_temporal_count_balanced_sample_weights(dataset, sample_weight_offset=30, | |
| save_folder="/share/rk8/shared/as_strong"): | |
| """ | |
| :return: float tensor of shape len(full_training_set) representing the weights of each sample. | |
| """ | |
| # the order of balanced_train_hdf5, unbalanced_train_hdf5 is important. | |
| # should match get_full_training_set | |
| os.makedirs(save_folder, exist_ok=True) | |
| save_file = os.path.join(save_folder, f"weights_temporal_count_offset_{sample_weight_offset}.pt") | |
| if os.path.exists(save_file): | |
| return torch.load(save_file) | |
| from tqdm import tqdm | |
| all_y = [] | |
| for sample in tqdm(dataset, desc="Calculating sample weights."): | |
| all_y.append(sample["event_count"]) | |
| all_y = torch.from_numpy(np.stack(all_y, axis=0)) | |
| per_class = all_y.long().sum(0).float().reshape(1, -1) # frequencies per class | |
| per_class = sample_weight_offset + per_class # offset low freq classes | |
| if sample_weight_offset > 0: | |
| print(f"Warning: sample_weight_offset={sample_weight_offset} minnow={per_class.min()}") | |
| per_class_weights = 1000. / per_class | |
| all_weight = all_y * per_class_weights | |
| all_weight = all_weight.sum(dim=1) | |
| torch.save(all_weight, save_file) | |
| return all_weight | |
| class MixupDataset(TorchDataset): | |
| """ Mixing Up wave forms | |
| """ | |
| def __init__(self, dataset, beta=2, rate=0.5): | |
| self.beta = beta | |
| self.rate = rate | |
| self.dataset = dataset | |
| print(f"Mixing up waveforms from dataset of len {len(dataset)}") | |
| def __getitem__(self, index): | |
| if torch.rand(1) < self.rate: | |
| batch1 = self.dataset[index] | |
| idx2 = torch.randint(len(self.dataset), (1,)).item() | |
| batch2 = self.dataset[idx2] | |
| x1, x2 = batch1['audio'], batch2['audio'] | |
| y1, y2 = batch1['strong'], batch2['strong'] | |
| if 'pseudo_strong' in batch1: | |
| p1, p2 = batch1['pseudo_strong'], batch2['pseudo_strong'] | |
| l = np.random.beta(self.beta, self.beta) | |
| l = max(l, 1. - l) | |
| x1 = x1 - x1.mean() | |
| x2 = x2 - x2.mean() | |
| x = (x1 * l + x2 * (1. - l)) | |
| x = x - x.mean() | |
| batch1['audio'] = x | |
| batch1['strong'] = (y1 * l + y2 * (1. - l)) | |
| if 'pseudo_strong' in batch1: | |
| batch1['pseudo_strong'] = (p1 * l + p2 * (1. - l)) | |
| return batch1 | |
| return self.dataset[index] | |
| def __len__(self): | |
| return len(self.dataset) | |
| class DistributedSamplerWrapper(DistributedSampler): | |
| def __init__( | |
| self, sampler, dataset, num_replicas=None, rank=None, shuffle: bool = True | |
| ): | |
| super(DistributedSamplerWrapper, self).__init__( | |
| dataset, num_replicas, rank, shuffle | |
| ) | |
| # source: @awaelchli https://github.com/PyTorchLightning/pytorch-lightning/issues/3238 | |
| self.sampler = sampler | |
| def __iter__(self): | |
| if self.sampler.generator is None: | |
| self.sampler.generator = torch.Generator() | |
| self.sampler.generator.manual_seed(self.seed + self.epoch) | |
| indices = list(self.sampler) | |
| if self.epoch < 2: | |
| logger.info( | |
| f"\n DistributedSamplerWrapper (rank {self.rank}) : {indices[:3]} \n\n" | |
| ) | |
| indices = indices[self.rank : self.total_size : self.num_replicas] | |
| return iter(indices) | |
| def get_weighted_sampler( | |
| samples_weights, | |
| epoch_len=100_000, | |
| sampler_replace=False, | |
| ): | |
| num_nodes = int(os.environ.get("WORLD_SIZE", 1)) | |
| ddp = int(os.environ.get("DDP", 1)) | |
| num_nodes = max(ddp, num_nodes) | |
| rank = int(os.environ.get("NODE_RANK", 0)) | |
| return DistributedSamplerWrapper( | |
| sampler=WeightedRandomSampler( | |
| samples_weights, num_samples=epoch_len, replacement=sampler_replace | |
| ), | |
| dataset=range(epoch_len), | |
| num_replicas=num_nodes, | |
| rank=rank, | |
| ) | |
| if __name__ == "__main__": | |
| from helpers.encode import ManyHotEncoder | |
| encoder = ManyHotEncoder([], 10., 160, net_pooling=4, fs=16_000) | |
| train_ds = get_training_dataset( | |
| encoder, audio_length=10.0, sample_rate=16_000 | |
| ) | |
| valid_ds = get_eval_dataset( | |
| encoder, audio_length=10.0, sample_rate=16_000 | |
| ) | |
| print("Len train dataset: ", len(train_ds)) | |
| print("Len valid dataset: ", len(valid_ds)) | |