Spaces:
Sleeping
Sleeping
| import math | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import random | |
| import hashlib | |
| from torch.utils.data import Dataset, DataLoader, TensorDataset | |
| from torch.nn.utils.rnn import pad_sequence | |
| from typing import Union, List, Sized, Iterable | |
| from tqdm import tqdm | |
| from src.utils.filesystem import ensure_dir | |
| from src.utils.device import DataParallelWrapper | |
| from src.pipelines.pipeline import Pipeline | |
| from src.models.speaker.speaker import SpeakerVerificationModel | |
| from src.models.speech.speech import SpeechRecognitionModel | |
| from src.constants import * | |
| ################################################################################ | |
| # Data-loading utilities | |
| ################################################################################ | |
| class DatasetWrapper(Dataset): | |
| """ | |
| Most data utilities here involve re-assigning or computing targets to | |
| train or evaluate adversarial attacks. This class wraps an existing | |
| dataset to overwrite its stored inputs and targets as necessary. | |
| """ | |
| def __init__(self, dataset, inputs, targets): | |
| super().__init__() | |
| self.dataset = dataset | |
| self.inputs = inputs | |
| self.targets = targets | |
| ref_batch = next(iter(dataset)) | |
| if isinstance(ref_batch, tuple): | |
| self.format = 'tuple' | |
| else: | |
| self.format = 'dict' | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| if self.format == 'tuple': | |
| x, y, *features = self.dataset[idx] | |
| return self.inputs[idx], self.targets[idx], *features | |
| else: | |
| batch = self.dataset[idx] | |
| batch['x'] = self.inputs[idx] | |
| batch['y'] = self.targets[idx] | |
| return batch | |
| def pad_batch_power_2(batch): | |
| """ | |
| Given a batch of tensors, pad to nearest power of 2 to maximum length. Used | |
| as a `collate_fn` argument to Pytorch `DataLoader` objects. | |
| Parameters: | |
| ----------- | |
| batch | |
| Returns: | |
| -------- | |
| """ | |
| # get tensors | |
| (x, y) = zip(*batch) | |
| n_batch = len(x) | |
| if n_batch < 1: | |
| return torch.Tensor([]), None | |
| if n_batch == 1: | |
| return x[0:1], y | |
| if type(y[0]) != str: | |
| y = torch.stack(y, dim=0) | |
| # compute maximum length | |
| dtype, device = x[0].dtype, x[0].device | |
| lengths = [x_i.shape[-1] for x_i in x] | |
| max_len = max(lengths) | |
| next_pow_2 = 2**(max_len - 1).bit_length() | |
| # pad inputs | |
| shape = next(iter(x)).shape[1:-1] | |
| batch_padded = torch.zeros( | |
| (n_batch, *shape, next_pow_2), | |
| dtype=dtype, | |
| device=device | |
| ) | |
| for i in range(n_batch): | |
| batch_padded[i, ..., :lengths[i]] = x[i] | |
| return batch_padded, y | |
| def text_to_tensor( | |
| text: Union[str, List[str]], | |
| labels: list, | |
| return_lengths: bool = True, | |
| max_length: int = None, | |
| padding_value: int = -1): | |
| """ | |
| Convert one or more string transcripts to padded tensor form (character | |
| indices), and optionally return sequence lengths. | |
| Parameters: | |
| ----------- | |
| text (str): a string or list of string transcripts | |
| labels (list): list of characters, ordered by index | |
| return_lengths (bool): if True, return sequence lengths | |
| max_length (int): if given, trim/pad all sequences to length | |
| padding_value (int): value with which to perform length padding | |
| Returns: | |
| -------- | |
| sequences (Tensor): tensor containing padded index sequences | |
| lengths (Tensor): tensor containing sequence lengths | |
| """ | |
| if isinstance(text, str): | |
| text = [text] | |
| # convert from characters to token indices | |
| char_to_idx = {labels[i].upper(): i for i in range(len(labels))} | |
| lengths = [] | |
| tensors = [] | |
| for t in text: | |
| lengths.append(len(t)) | |
| token_indices = [char_to_idx[c] for c in t.upper()] | |
| tensors.append( | |
| torch.as_tensor(token_indices, dtype=torch.long) | |
| ) | |
| # pad and return | |
| tensors = pad_sequence( | |
| tensors, | |
| batch_first=True, | |
| padding_value=padding_value | |
| ) # (n_batch, max_seq_len) | |
| if max_length is not None: | |
| if tensors.shape[-1] > max_length: | |
| tensors = tensors[..., :max_length] | |
| elif tensors.shape[-1] < max_length: | |
| tensors = F.pad( | |
| tensors, | |
| (0, max_length - tensors.shape[-1]), | |
| value=padding_value) | |
| lengths = torch.as_tensor(lengths, dtype=torch.long) | |
| if return_lengths: | |
| return tensors, lengths | |
| else: | |
| return tensors | |
| def padded_transcript_length( | |
| transcript: torch.Tensor, | |
| padding_value: int = -1): | |
| """ | |
| Given one or more transcripts in index sequence format, determine lengths | |
| by looking for padding value. | |
| Parameters: | |
| ----------- | |
| transcript (Tensor): tensor containing one or more index sequences | |
| padding_value (int): value used to pad sequence tensors | |
| Returns: | |
| -------- | |
| lengths (Tensor): tensor containing un-padded length of each sequence | |
| """ | |
| # find first occurence of padding value in each transcript tensor | |
| mask = transcript == padding_value | |
| mask_max_values, mask_max_indices = torch.max(mask, dim=-1) | |
| # if the max-mask is zero, there is no padding in the tensor | |
| mask_max_indices[mask_max_values == 0] = transcript.shape[-1] | |
| return mask_max_indices.long() | |
| def move_to_device_recursive(d: dict, device: Union[str, torch.device]): | |
| """Move all tensors in a dictionary object to given device""" | |
| for k, v in d.items(): | |
| if isinstance(v, dict): | |
| d[k] = move_to_device_recursive(v, device) | |
| elif isinstance(v, torch.Tensor): | |
| d[k] = v.to(device) | |
| elif isinstance(v, tuple) and len(v) > 0 and isinstance( | |
| next(iter(v)), torch.Tensor): | |
| v_new = tuple() | |
| for v_i in v: | |
| v_new += (v_i.to(device),) | |
| d[k] = v_new | |
| elif isinstance(v, list) and len(v) > 0 and isinstance( | |
| next(iter(v)), torch.Tensor): | |
| v_new = [] | |
| for v_i in v: | |
| v_new += [v_i.to(device)] | |
| d[k] = v_new | |
| return d | |
| def dataset_to_device(data: Dataset, device: Union[str, torch.device]): | |
| """Move datasets directly to given device. May cause memory issues!""" | |
| data.__dict__ = move_to_device_recursive(data.__dict__, device) | |
| def create_embedding_dataset(data_train: Dataset, | |
| data_test: Dataset, | |
| pipeline: Pipeline, | |
| select_train: str = 'random', | |
| select_test: str = 'random', | |
| targeted: bool = True, | |
| target_class: int = None, | |
| num_per_class_train: int = None, | |
| num_per_class_test: int = None, | |
| num_embeddings_train: int = 5, | |
| exclude_class: Union[int, List] = None, | |
| exclude_success: bool = False, | |
| use_cache: bool = False, | |
| shuffle: bool = False, | |
| **kwargs | |
| ): | |
| """ | |
| Given training and test datasets holding audio and labels, compute | |
| embeddings according to (potentially reassigned) labels. It is assumed | |
| that the provided train and test targets both index the same set of classes, | |
| i.e. that label 0 in the train set refers to the same class as label 0 in | |
| the test set. This function considers three possible cases: | |
| 1. `targeted` == True, `target_class` != None | |
| In this case, a single target class from the training set is assigned | |
| to all training and test instances. Embeddings corresponding to the | |
| target class are computed and assigned based on the `select` parameter. | |
| 2. `targeted` == True, `target_class` == None | |
| In this case, targets are randomly reassigned within both the training | |
| and test sets. Embeddings corresponding to the targets are computed and | |
| assigned based on the `select` parameter. | |
| 3. `targeted` == False | |
| In this case, targets remain unchanged and the embedding of each | |
| instance is computed directly. | |
| Parameters | |
| ---------- | |
| data_train (Dataset): a Dataset object holding audio and labels | |
| data_test (Dataset): a Dataset object holding audio and labels | |
| pipeline (Pipeline): a Pipeline object; must wrap a | |
| SpeakerVerificationModel object | |
| select_train (str): method of selecting target embeddings for train | |
| set; must be one of `random`, `single`, 'same', | |
| `centroid`, or 'median' | |
| select_test (str): method of selecting target embeddings for test | |
| set; must be one of `random`, `single`, same', | |
| `centroid`, or 'median' | |
| targeted (bool): if False, target classes will not be reassigned | |
| target_class (int): if given, reassign all targets to this class | |
| num_per_class_train (int): if given, perform stratified sampling with this | |
| number of instances drawn per class | |
| num_per_class_test (int): if given, perform stratified sampling with this | |
| number of instances drawn per class | |
| num_embeddings_train (int): if given and attack is targeted, train on only | |
| this many distinct embeddings of the target | |
| speaker, and evaluate on other embeddings of the | |
| target speaker | |
| exclude_class (int, list): if given, exclude all instances from this class | |
| or in this list of classes | |
| exclude_success (bool): if True, drop or replace instances for which the | |
| initial prediction achieves the desired | |
| adversarial outcome (i.e. matches the target in | |
| the case of a targeted attack, and evades the | |
| target in the case of an untargeted attack) | |
| use_cache (bool): if True, save to and load from disk using hash- | |
| based lookup | |
| shuffle (bool): if True, shuffle data and targets; may result in | |
| mismatch if dataset contains other features (e.g. | |
| pitch, periodicity) | |
| Returns | |
| ------- | |
| train (dict): dictionary containing Dataset with audio and embedding | |
| targets, example audio of each target class, original targets | |
| indices, and reassigned target indices | |
| test (dict): dictionary containing Dataset with audio and embedding | |
| targets, example audio of each target class, original targets | |
| indices, and reassigned target indices | |
| """ | |
| if not num_per_class_train: | |
| num_per_class_train = 0 | |
| if not num_per_class_test: | |
| num_per_class_test = 0 | |
| if targeted and target_class is not None and num_per_class_train: | |
| assert num_per_class_train > num_embeddings_train, \ | |
| f'For targets drawn from training set, number of embeddings ' \ | |
| f'reserved for training ({num_embeddings_train}) must be less ' \ | |
| f'than number of embeddings computed per class ' \ | |
| f'({num_per_class_train})' | |
| if exclude_success: | |
| raise NotImplementedError(f'Target correction not yet implemented; ' | |
| f'use `NullAttack` to measure trivial ' | |
| f'success rates for now') | |
| # ensure pipeline is capable of producing embeddings | |
| assert isinstance(pipeline.model, SpeakerVerificationModel) | |
| # match devices | |
| ref_batch_train = next(iter(data_train)) | |
| ref_batch_test = next(iter(data_test)) | |
| if isinstance(ref_batch_train, tuple): | |
| example_input, *_ = ref_batch_train | |
| elif isinstance(ref_batch_train, dict): | |
| example_input = ref_batch_train['x'] | |
| else: | |
| raise ValueError(f'Dataset must provide batches in tuple or dictionary' | |
| f' format') | |
| orig_device = example_input.device | |
| # check that model produces embeddings with valid shape | |
| try: | |
| embedding_shape = list( | |
| pipeline.model(example_input.to(pipeline.device)).shape | |
| ) | |
| assert len(embedding_shape) == 3 and embedding_shape[0] == 1 | |
| except AssertionError: | |
| raise RuntimeError(f'Speaker verification model must produce ' | |
| f'embeddings of shape ' | |
| f'(n_batch, n_segments, embedding_dim)') | |
| assert isinstance(data_train, Sized) and isinstance(data_test, Sized), \ | |
| f"Datasets must have length attribute accessible via `len()`" | |
| # check embedding selection method | |
| assert select_train in ['random', 'single', 'centroid', 'same', 'median'], \ | |
| f"invalid value for `select_train` {select_train}" | |
| assert select_test in ['random', 'single', 'centroid', 'same', 'median'], \ | |
| f"invalid value for `select_test` {select_test}" | |
| assert not targeted or select_train != 'same', \ | |
| f'`same` embedding selection only valid for untargeted mode' | |
| # check for optional `batch_size` argument; otherwise, use batch size of 1 | |
| batch_size = kwargs.get('batch_size', 1) | |
| # creating embedding datasets is time-consuming; to avoid repeated | |
| # computation, we can store the generated dataset under a hash | |
| hash_str = str(pipeline.model) | |
| hash_str += str(data_train.__class__.__name__) | |
| hash_str += str(data_test.__class__.__name__) | |
| hash_str += select_train + select_test | |
| hash_str += str(targeted) + str(target_class) | |
| hash_str += str(num_per_class_train) + str(num_per_class_test) | |
| hash_str += str(exclude_class) + str(exclude_success) | |
| # obtain hash and convert to filename | |
| dataset_hash = hashlib.md5(hash_str.encode()).digest() | |
| dataset_file = str(dataset_hash).replace("\'", "")[1:].replace("\\", ".") | |
| dataset_file += ".pt" | |
| # check whether a cached embedding dataset with matching hash exists | |
| embeddings_cache_dir = Path(CACHE_DIR) / 'embeddings' | |
| ensure_dir(embeddings_cache_dir) | |
| cached_datasets = embeddings_cache_dir.glob('*.pt') | |
| # if dataset is already cached, load and return | |
| if use_cache and dataset_file in [d.name for d in cached_datasets]: | |
| dataset = torch.load(embeddings_cache_dir / dataset_file) | |
| # check for valid dataset structure | |
| try: | |
| assert isinstance(dataset, dict) | |
| assert 'train' in dataset and 'test' in dataset | |
| return move_to_device_recursive( | |
| dataset['train'], orig_device | |
| ), move_to_device_recursive(dataset['test'], orig_device) | |
| except AssertionError: | |
| raise RuntimeWarning(f'Invalid dataset structure; will re-compute ' | |
| f'and overwrite existing dataset ' | |
| f'{dataset_file}') | |
| # shuffle data | |
| rand_idx_train = torch.randperm( | |
| len(data_train)) if shuffle else torch.arange(len(data_train)) | |
| rand_idx_test = torch.randperm( | |
| len(data_test)) if shuffle else torch.arange(len(data_test)) | |
| # separate data and labels | |
| if isinstance(ref_batch_train, tuple): | |
| inputs_train, labels_train, *_ = data_train[:] | |
| else: | |
| inputs_train, labels_train = data_train[:]['x'], data_train[:]['y'] | |
| inputs_train_shuffled = inputs_train[rand_idx_train] | |
| labels_train_shuffled = labels_train[rand_idx_train] | |
| if isinstance(ref_batch_test, tuple): | |
| inputs_test, labels_test, *_ = data_test[:] | |
| else: | |
| inputs_test, labels_test = data_test[:]['x'], data_test[:]['y'] | |
| inputs_test_shuffled = inputs_test[rand_idx_test] | |
| labels_test_shuffled = labels_test[rand_idx_test] | |
| # if target is given, check that it is present in training data | |
| if target_class is not None: | |
| assert target_class in labels_train, \ | |
| f'Target class {target_class} is not present in training data' | |
| # determine train and test labels | |
| unique_labels_train = [l.item() for l in torch.unique(labels_train)] | |
| unique_labels_test = [l.item() for l in torch.unique(labels_test)] | |
| # filter excluded classes (if given) from train and test sets | |
| if isinstance(exclude_class, List): | |
| unique_labels_train = [ | |
| l for l in unique_labels_train if l not in exclude_class] | |
| unique_labels_test = [ | |
| l for l in unique_labels_test if l not in exclude_class] | |
| elif exclude_class is not None: | |
| unique_labels_train = [ | |
| l for l in unique_labels_train if not l == exclude_class] | |
| unique_labels_test = [ | |
| l for l in unique_labels_test if not l == exclude_class] | |
| # prepare to store one example audio input per label (speaker) | |
| audio_train = {} | |
| audio_test = {} | |
| # prepare to store training and test embeddings by label | |
| embeddings_train = {} | |
| embeddings_test = {} | |
| def compute_embeddings_by_label( | |
| unique_labels: list, | |
| inputs: torch.Tensor, | |
| labels: torch.Tensor, | |
| saved_audio: dict, | |
| saved_embeddings: dict, | |
| num_per_class): | |
| """ | |
| Compute an embedding for every instance in the given dataset and store | |
| by label in a dictionary; store one audio example per label in a | |
| dictionary. | |
| """ | |
| # compute embeddings over training set and sort by label | |
| for label in tqdm( | |
| unique_labels, | |
| total=len(unique_labels), | |
| desc="Computing embeddings for dataset"): | |
| # select training instances of class, allowing for a limit on the | |
| # number of embeddings stored per class | |
| x_l = inputs[labels == label] | |
| n_l = num_per_class if num_per_class else len(x_l) | |
| # store one audio example per training label | |
| saved_audio[label] = x_l[0:1] | |
| # store embeddings per training label | |
| n_batches = math.ceil(n_l / batch_size) | |
| saved_embeddings[label] = [] | |
| for i in range(n_batches): | |
| saved_embeddings[label].append( | |
| pipeline.model( | |
| x_l[i*batch_size:(i+1)*batch_size].to(pipeline.device) | |
| ).to('cpu') # store intermediate results on CPU | |
| ) | |
| saved_embeddings[label] = torch.cat( | |
| saved_embeddings[label], dim=0)[:n_l] | |
| # compute embeddings over training and test datasets and store by label | |
| compute_embeddings_by_label( | |
| unique_labels_train, | |
| inputs_train_shuffled, | |
| labels_train_shuffled, | |
| audio_train, | |
| embeddings_train, | |
| num_per_class_train | |
| ) | |
| compute_embeddings_by_label( | |
| unique_labels_test, | |
| inputs_test_shuffled, | |
| labels_test_shuffled, | |
| audio_test, | |
| embeddings_test, | |
| num_per_class_test | |
| ) | |
| # filter datasets to remove excluded and target labels | |
| if targeted and target_class is not None: | |
| unique_labels_train = [ | |
| l for l in unique_labels_train if not l == target_class] | |
| unique_labels_test = [ | |
| l for l in unique_labels_test if not l == target_class] | |
| def reassign_labels( | |
| unique_labels: list, | |
| inputs: torch.Tensor, | |
| labels_orig: torch.Tensor, | |
| num_per_class): | |
| """ | |
| Reassign targets, as detailed in documentation above. | |
| """ | |
| labels_new = torch.full(labels_orig.shape, -1, dtype=labels_orig.dtype) | |
| # reassign label-by-label | |
| for i, label in enumerate( | |
| tqdm( | |
| unique_labels, | |
| total=len(unique_labels), | |
| desc="Reassigning labels for dataset")): | |
| # select all training instances with label | |
| idx_l = labels_orig == label | |
| x_l = inputs[idx_l] | |
| # store original targets | |
| y_orig_l = torch.full((len(x_l), ), label) | |
| # use a placeholder to allow for deletion of rows; overwrite with | |
| # valid labels and delete rows where -1 remains | |
| y_new_l = torch.full((len(x_l), ), -1) | |
| # limit number of instances per class if specified | |
| n_l = num_per_class if num_per_class else len(x_l) | |
| # targeted attacks require that the given targets be reassigned | |
| if targeted: | |
| # if target class is provided, reassign targets to given class | |
| if target_class is not None: | |
| y_new_l[:n_l] = target_class | |
| # if no target class is given, randomly reassign targets; ensure | |
| # that no target is unchanged and new targets are evenly | |
| # distributed | |
| else: | |
| remaining_labels = [ | |
| l for l in unique_labels if l != label] | |
| for j in range(n_l): | |
| y_new_l[j] = random.choice(remaining_labels) | |
| # otherwise, classes remain unchanged | |
| else: | |
| y_new_l[:n_l] = y_orig_l[:n_l] | |
| # update data and labels, deleting rows corresponding to | |
| # extraneous inputs (according to `num_per_class`) | |
| labels_new[idx_l] = y_new_l | |
| keep_idx = labels_new != -1 | |
| inputs = inputs[keep_idx] | |
| labels_orig = labels_orig[keep_idx] | |
| labels_new = labels_new[keep_idx] | |
| return inputs, labels_orig, labels_new, keep_idx | |
| # reassign training and test labels if necessary (see documentation | |
| # above); remove instances of target class and those not required by | |
| # `num_per_class`, if given | |
| ( | |
| inputs_train_shuffled, | |
| labels_train_shuffled, | |
| labels_train_reassigned, | |
| select_idx_train | |
| ) = reassign_labels( | |
| unique_labels_train, | |
| inputs_train_shuffled, | |
| labels_train_shuffled, | |
| num_per_class_train | |
| ) | |
| ( | |
| inputs_test_shuffled, | |
| labels_test_shuffled, | |
| labels_test_reassigned, | |
| select_idx_test | |
| ) = reassign_labels( | |
| unique_labels_test, | |
| inputs_test_shuffled, | |
| labels_test_shuffled, | |
| num_per_class_test | |
| ) | |
| # prepare to store target embeddings corresponding to reassigned labels | |
| embedding_targets_train = torch.empty( | |
| (len(labels_train_reassigned), *embedding_shape[1:])) | |
| embedding_targets_test = torch.empty( | |
| (len(labels_test_reassigned), *embedding_shape[1:])) | |
| def assign_embeddings( | |
| labels_new: torch.Tensor, | |
| embeddings_by_label: dict, | |
| embedding_targets: torch.Tensor, | |
| select: str, | |
| is_train: bool = True | |
| ): | |
| # iterate over dataset and associate embedding targets with | |
| # reassigned labels | |
| labels_to_assign = [l.item() for l in torch.unique(labels_new)] | |
| for label in labels_to_assign: | |
| # find indices for which embeddings of given label are to | |
| # be assigned | |
| idx_l = labels_new == label | |
| n_l = int(torch.sum(idx_l * 1).item()) | |
| if n_l == 0: | |
| continue | |
| # obtain all embeddings corresponding to given label | |
| embeddings_l = embeddings_by_label[label] | |
| # if untargeted, assign ground-truth embeddings for each instance | |
| if select == 'same': | |
| y_emb_l = embeddings_l | |
| else: | |
| # separate train and test embeddings of given speaker | |
| if num_embeddings_train: | |
| # for targeted attacks, allow training/testing on separate | |
| # small subsets of a speaker's utterances | |
| if targeted: | |
| assert num_embeddings_train <= len(embeddings_l), \ | |
| f"`num_embeddings_train` {num_embeddings_train} " \ | |
| f"is greater than the number of utterances for " \ | |
| f"speaker {label}" | |
| if is_train: | |
| embeddings_l = embeddings_l[:num_embeddings_train] | |
| else: | |
| embeddings_l = embeddings_l[num_embeddings_train:] | |
| # using `select` parameter, assign embeddings | |
| y_emb_l = [] | |
| for i in range(n_l): | |
| if select == 'single': # use single embedding | |
| y_emb_l.append(embeddings_l[0:1]) | |
| elif select == 'random': # use random embeddings | |
| emb_idx = random.randint(0, len(embeddings_l) - 1) | |
| y_emb_l.append(embeddings_l[emb_idx:emb_idx+1]) | |
| elif select == 'centroid': # average over embeddings | |
| _, n_segments, embedding_dim = embedding_shape | |
| # duplicate over all segments | |
| centroid = embeddings_l.mean(dim=(0, 1)).reshape( | |
| (1, 1, embedding_dim) | |
| ).repeat(1, n_segments, 1) | |
| y_emb_l.append(centroid) | |
| elif select == 'median': # median over embeddings | |
| _, n_segments, embedding_dim = embedding_shape | |
| # duplicate over all segments | |
| median = embeddings_l.reshape( | |
| n_l*n_segments, -1 | |
| ).median(dim=0)[0].reshape( | |
| (1, 1, embedding_dim) | |
| ).repeat(1, n_segments, 1) | |
| y_emb_l.append(median) | |
| else: | |
| raise ValueError(f'Invalid embedding selection method ' | |
| f'{select}') | |
| y_emb_l = torch.cat(y_emb_l, dim=0) | |
| embedding_targets[idx_l] = y_emb_l | |
| # with labels finalized, assign embedding targets | |
| assign_embeddings( | |
| labels_train_reassigned, | |
| embeddings_train, | |
| embedding_targets_train, | |
| select_train, | |
| True | |
| ) | |
| assign_embeddings( | |
| labels_test_reassigned, | |
| embeddings_train if targeted and target_class is not None else embeddings_test, | |
| embedding_targets_test, | |
| select_test, | |
| False | |
| ) | |
| # account for shuffling | |
| final_idx_train = rand_idx_train[select_idx_train] | |
| final_idx_test = rand_idx_test[select_idx_test] | |
| from src.data.dataset import VoiceBoxDataset | |
| if isinstance(data_train, VoiceBoxDataset): | |
| data_train_final = data_train.overwrite_dataset( | |
| inputs_train_shuffled, | |
| embedding_targets_train, | |
| final_idx_train | |
| ) | |
| else: | |
| data_train_final = DatasetWrapper( | |
| data_train, | |
| inputs_train_shuffled, | |
| embedding_targets_train) | |
| if isinstance(data_test, VoiceBoxDataset): | |
| data_test_final = data_test.overwrite_dataset( | |
| inputs_test_shuffled, | |
| embedding_targets_test, | |
| final_idx_test | |
| ) | |
| else: | |
| data_test_final = DatasetWrapper( | |
| data_test, | |
| inputs_test_shuffled, | |
| embedding_targets_test) | |
| # store data and embeddings, audio examples, original targets, and | |
| # reassigned targets | |
| train = { | |
| 'dataset': data_train_final, | |
| 'id_to_audio': audio_train, | |
| 'true_id': labels_train_shuffled, | |
| 'target_id': labels_train_reassigned | |
| } | |
| test = { | |
| 'dataset': data_test_final, | |
| 'id_to_audio': audio_test, | |
| 'true_id': labels_test_shuffled, | |
| 'target_id': labels_test_reassigned | |
| } | |
| if use_cache: | |
| dataset = { | |
| 'train': train, | |
| 'test': test | |
| } | |
| torch.save(dataset, embeddings_cache_dir / dataset_file) | |
| # restore device and return | |
| return move_to_device_recursive( | |
| train, orig_device | |
| ), move_to_device_recursive( | |
| test, orig_device | |
| ) | |
| def create_transcription_dataset(data_train: Dataset, | |
| data_test: Dataset, | |
| pipeline: Pipeline, | |
| targeted: bool = True, | |
| target_transcription: str = None, | |
| output_format: str = 'transcript', | |
| shuffle: bool = False, | |
| **kwargs): | |
| """ | |
| Given training and test datasets holding audio assign string transcriptions | |
| for performing speech recognition attacks. | |
| 1. `targeted` == True, `target_transcription` != None | |
| In this case, a single transcription target is assigned to all instances. | |
| 2. `targeted` == True, `target_transcription` == None | |
| In this case, ground-truth transcriptions are randomly reassigned as | |
| targets within both the training and test sets. | |
| 3. `targeted` == False | |
| In this case, ground-truth transcriptions are used as targets. | |
| Parameters | |
| ---------- | |
| data_train (Dataset): a Dataset object holding audio | |
| data_test (Dataset): a Dataset object holding audio | |
| targeted (bool): if False, target classes will not be reassigned | |
| target_transcription (str): if given, reassign all targets to the given | |
| transcription string | |
| shuffle (bool): if True, shuffle data and targets; may result in | |
| mismatch if dataset contains other features (e.g. | |
| pitch, periodicity) | |
| Returns | |
| ------- | |
| train (dict): | |
| test (dict): | |
| """ | |
| # check output format (string transcripts or frame-wise token probabilities) | |
| assert output_format in ['transcript', 'emission'], \ | |
| f'Invalid output format; must be one of `transcript` or `emission`' | |
| # check for valid model type | |
| assert isinstance(pipeline.model, SpeechRecognitionModel) | |
| assert isinstance(data_train, Sized) and isinstance(data_test, Sized), \ | |
| f"Datasets must have length attribute accessible via `len()`" | |
| # match devices | |
| ref_batch_train = next(iter(data_train)) | |
| ref_batch_test = next(iter(data_test)) | |
| if isinstance(ref_batch_train, tuple): | |
| example_input, *_ = ref_batch_train | |
| elif isinstance(ref_batch_train, dict): | |
| example_input = ref_batch_train['x'] | |
| else: | |
| raise ValueError(f'Dataset must provide batches in tuple or dictionary' | |
| f' format') | |
| orig_device = example_input.device | |
| # check for optional `batch_size` argument; otherwise, use batch size of 1 | |
| batch_size = kwargs.get('batch_size', 1) | |
| # shuffle data | |
| rand_idx_train = torch.randperm( | |
| len(data_train)) if shuffle else torch.arange(len(data_train)) | |
| rand_idx_test = torch.randperm( | |
| len(data_test)) if shuffle else torch.arange(len(data_test)) | |
| if isinstance(ref_batch_train, tuple): | |
| inputs_train, *_ = data_train[:] | |
| else: | |
| inputs_train = data_train[:]['x'] | |
| inputs_train_shuffled = inputs_train[rand_idx_train] | |
| if isinstance(ref_batch_test, tuple): | |
| inputs_test, *_ = data_test[:] | |
| else: | |
| inputs_test = data_test[:]['x'] | |
| inputs_test_shuffled = inputs_test[rand_idx_test] | |
| # if targeted and target transcription provided, simply assign and return | |
| if targeted and target_transcription is not None: | |
| assert output_format == 'transcript', \ | |
| f"Target transcript provided; cannot use emission targets" | |
| # check that target transcript contains character set compatible with | |
| # pipeline, and does not contain 'blank' character | |
| valid_characters = deepcopy(pipeline.model.get_labels()) | |
| try: | |
| del valid_characters[pipeline.model.get_blank_idx()] | |
| except (IndexError, TypeError): | |
| pass | |
| assert all([c in valid_characters for c in target_transcription]), \ | |
| f'Target transcription contains invalid characters' | |
| single_target = text_to_tensor( | |
| target_transcription, | |
| pipeline.model.get_labels(), | |
| return_lengths=False | |
| ) | |
| targets_train = single_target.repeat(len(inputs_train_shuffled), 1) | |
| targets_test = single_target.repeat(len(inputs_test_shuffled), 1) | |
| # otherwise, compute transcriptions using given pipeline | |
| else: | |
| def transcribe(dataset: torch.Tensor): | |
| results = [] | |
| n_batches = math.ceil(len(dataset) / batch_size) | |
| for batch_idx in tqdm( | |
| range(n_batches), | |
| total=n_batches, | |
| desc="Computing transcriptions for dataset"): | |
| x = dataset[ | |
| batch_idx*batch_size:(batch_idx+1)*batch_size | |
| ].to(pipeline.device) | |
| if output_format == 'transcript': | |
| results.extend(pipeline.model.transcribe(x)) | |
| elif output_format == 'emission': | |
| results.extend( | |
| torch.split( | |
| pipeline.model(x).to(orig_device), 1, dim=0)) | |
| if output_format == 'emission': | |
| # pad to max emission length | |
| results = pad_sequence(results, batch_first=True).squeeze(1) | |
| elif output_format == 'transcript': | |
| results = text_to_tensor( | |
| results, | |
| pipeline.model.get_labels(), | |
| return_lengths=False) | |
| return results | |
| targets_train = transcribe(inputs_train_shuffled) | |
| targets_test = transcribe(inputs_test_shuffled) | |
| # if targeted, permute transcriptions such that no input retains its | |
| # original transcription | |
| if targeted: | |
| # use derangements with a fixed iteration budget; expected number | |
| # of iterations required to shuffle with no fixed points is e (~3) | |
| def derange(x: torch.Tensor): | |
| max_iter = 10 | |
| orig_shape = x.shape | |
| x = x.reshape(x.shape[0], -1) | |
| for i in range(max_iter): | |
| rand_idx = torch.randperm(len(x)) | |
| equal = torch.sum( | |
| 1.0 * (x == x[rand_idx]), | |
| dim=-1 | |
| ) >= x.shape[-1] | |
| if not equal.sum().item(): | |
| break | |
| return x[rand_idx].reshape(orig_shape) | |
| targets_train = derange(targets_train) | |
| targets_test = derange(targets_test) | |
| # compute transcript lengths | |
| if output_format == 'transcript': | |
| lengths_train = padded_transcript_length(targets_train) | |
| lengths_test = padded_transcript_length(targets_test) | |
| elif output_format == 'emission': | |
| lengths_train = torch.full( | |
| size=(len(inputs_train_shuffled),), | |
| fill_value=targets_train.shape[1], | |
| dtype=torch.long | |
| ) | |
| lengths_test = torch.full( | |
| size=(len(inputs_test_shuffled),), | |
| fill_value=targets_test.shape[1], | |
| dtype=torch.long | |
| ) | |
| else: | |
| raise ValueError(f'Invalid value for `output_format`') | |
| from src.data.dataset import VoiceBoxDataset | |
| if isinstance(data_train, VoiceBoxDataset): | |
| data_train_final = data_train.overwrite_dataset( | |
| inputs_train_shuffled, | |
| targets_train, | |
| rand_idx_train) | |
| else: | |
| data_train_final = DatasetWrapper( | |
| data_train, | |
| inputs_train_shuffled, | |
| targets_train) | |
| if isinstance(data_test, VoiceBoxDataset): | |
| data_test_final = data_test.overwrite_dataset( | |
| inputs_test_shuffled, | |
| targets_test, | |
| rand_idx_test | |
| ) | |
| else: | |
| data_test_final = DatasetWrapper( | |
| data_test, | |
| inputs_test_shuffled, | |
| targets_test) | |
| train = { | |
| 'dataset': data_train_final, | |
| 'targets': targets_train, | |
| 'target_lengths': lengths_train | |
| } | |
| test = { | |
| 'dataset': data_test_final, | |
| 'targets': targets_test, | |
| 'target_lengths': lengths_test | |
| } | |
| return move_to_device_recursive( | |
| train, orig_device | |
| ), move_to_device_recursive( | |
| test, orig_device | |
| ) | |