Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from typing import Tuple, Union | |
| from torch.utils.data import Dataset, TensorDataset, DataLoader | |
| from src.utils.writer import Writer | |
| from src.pipelines.pipeline import Pipeline | |
| from src.loss.adversarial import AdversarialLoss | |
| from src.loss.auxiliary import AuxiliaryLoss | |
| ################################################################################ | |
| # Base class for offline adversarial attacks | |
| ################################################################################ | |
| class OfflineAttack: | |
| def __init__(self, | |
| pipeline: Pipeline, | |
| adv_loss: AdversarialLoss, | |
| aux_loss: AuxiliaryLoss = None, | |
| batch_size: int = 32, | |
| rand_evals: int = 0, | |
| writer: Writer = None, | |
| **kwargs | |
| ): | |
| """ | |
| Base class for offline attacks. Subclasses must override the | |
| `evaluate_batch()` method. | |
| Offline attacks optimize perturbations of benign inputs without | |
| real-time performance constraints. Optimization is performed using a | |
| stored Pipeline object, encompassing a victim model, acoustic | |
| simulation, and adversarial defenses. | |
| :param pipeline: a Pipeline object wrapping a (defended) classifier | |
| :param adv_loss: AdversarialLoss object encapsulating attacker objective | |
| :param aux_loss: optional AuxiliaryLoss object encapsulating | |
| some perceptibility objective | |
| :param batch_size: batch size for attack | |
| :param rand_evals: randomly-resampled simulated evaluations per each | |
| final generated attack | |
| """ | |
| self.pipeline = pipeline | |
| # ensure gradients flow through PyTorch RNN layers | |
| self._pipeline_rnn_grad() | |
| self.adv_loss = adv_loss | |
| self.aux_loss = aux_loss | |
| self.batch_size = batch_size | |
| self.rand_evals = rand_evals | |
| # log attack progress | |
| self.writer = writer | |
| # track batch inputs | |
| self._batch_id = 0 | |
| self._iter_id = 0 | |
| # optional data-loading arguments | |
| self.pin_memory = kwargs.get('pin_memory', False) | |
| self.num_workers = kwargs.get('num_workers', 0) | |
| self._check_loss() | |
| def _pipeline_rnn_grad(self): | |
| """ | |
| PyTorch requires any recurrent modules be placed in `train` mode to | |
| enable backpropagation through the pipeline. | |
| """ | |
| for m in self.pipeline.modules(): | |
| if isinstance(m, nn.RNNBase): | |
| m.train() | |
| def _check_loss(self): | |
| """ | |
| Validate adversarial and auxiliary losses | |
| """ | |
| assert self.adv_loss is not None, 'Must provide adversarial loss' | |
| assert self.adv_loss.reduction in ['none', None], \ | |
| 'All losses must provide unreduced scores' | |
| assert self.aux_loss is None or \ | |
| self.aux_loss.reduction in ['none', None], \ | |
| 'All losses must provide unreduced scores' | |
| def _create_dataset(x: torch.Tensor, y: torch.Tensor): | |
| """ | |
| If attack inputs are given as tensors, create a simple dataset | |
| """ | |
| # require batch dimension | |
| assert x.ndim >= 2 | |
| dataset = TensorDataset( | |
| x.type(torch.float32), | |
| y.type(torch.float32), | |
| ) | |
| return dataset | |
| def _compute_detection_array(self, x_adv, *args, **kwargs): | |
| """ | |
| Pass attack audio through any detection defenses in stored Pipeline, and | |
| return boolean detection flags for each input | |
| """ | |
| flags, scores = self.pipeline.detect(x_adv) | |
| return flags | |
| def _compute_success_array(self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| x_adv: torch.Tensor, | |
| *args, | |
| **kwargs | |
| ): | |
| """ | |
| Pass attack audio through stored Pipeline and determine adversarial | |
| success for each input | |
| """ | |
| # obtain 'clean' and adversarial predictions of stored Pipeline | |
| preds = self.pipeline(x.detach()) | |
| adv_preds = self.pipeline(x_adv.detach()) | |
| # for a targeted attack, attempt to match given targets | |
| if self.adv_loss.targeted: | |
| attack_success = self.pipeline.match_predict(adv_preds, y) | |
| # for an untargeted attack, attempt to evade clean predictions | |
| else: | |
| attack_success = ~self.pipeline.match_predict(adv_preds, preds) | |
| return attack_success | |
| def _log_step(self, | |
| x: torch.Tensor, | |
| x_adv: torch.Tensor, | |
| y: torch.Tensor, | |
| adv_loss: Union[float, torch.Tensor] = None, | |
| det_loss: Union[float, torch.Tensor] = None, | |
| aux_loss: Union[float, torch.Tensor] = None, | |
| success_rate: Union[float, torch.Tensor] = None, | |
| detection_rate: Union[float, torch.Tensor] = None, | |
| idx: int = 0, | |
| tag: str = None, | |
| *args, | |
| **kwargs | |
| ): | |
| """ | |
| Log attack progress. | |
| :param x: batch of original inputs | |
| :param x_adv: batch of adversarial inputs | |
| :param y: batch of adversarial targets | |
| :param adv_loss: adversarial loss value | |
| :param det_loss: detection loss value | |
| :param aux_loss: auxiliary loss value | |
| :param success_rate: adversarial success rate | |
| :param detection_rate: adversarial defense detection rate | |
| :param idx: batch index for logging individual examples | |
| """ | |
| if self.writer is None or self._iter_id % self.writer.log_iter: | |
| return | |
| if tag is None: | |
| tag = f'{self.__class__.__name__}-batch-{self._batch_id}' | |
| x = x.clone().detach() | |
| x_adv = x_adv.clone().detach() | |
| # compute losses and simulated audio | |
| with torch.no_grad(): | |
| outputs_adv = self.pipeline(x_adv) | |
| simulated = self.pipeline.simulate(x) | |
| simulated_adv = self.pipeline.simulate(x_adv) | |
| # if adversarial loss is not provided, compute | |
| if adv_loss is None: | |
| adv_loss = self.adv_loss(outputs_adv, y).mean() | |
| # if detector loss or rate is not provided, compute | |
| if det_loss is None or detection_rate is None: | |
| flags, scores = self.pipeline.detect(x_adv) | |
| det_loss = scores.mean() | |
| detection_rate = torch.mean(1.0 * flags) | |
| # if auxiliary loss is not provided, compute | |
| if aux_loss is None: | |
| aux_loss = 0.0 if self.aux_loss is None else self.aux_loss( | |
| x_adv, x | |
| ).mean() | |
| # if adversarial success rate is not provided, compute | |
| if success_rate is None: | |
| success = self._compute_success_array( | |
| x=x, | |
| x_adv=x_adv, | |
| y=y | |
| ) | |
| success_rate = torch.mean(1.0 * success) | |
| # unperturbed input | |
| self.writer.log_audio( | |
| x[idx], | |
| f"{tag}/original", | |
| global_step=self._iter_id | |
| ) | |
| # simulated unperturbed input | |
| self.writer.log_audio( | |
| simulated[idx], | |
| f"{tag}/simulated-original", | |
| global_step=self._iter_id | |
| ) | |
| # adversarial input | |
| self.writer.log_audio( | |
| x_adv[idx], | |
| f"{tag}/adversarial", | |
| global_step=self._iter_id | |
| ) | |
| # simulated adversarial input | |
| self.writer.log_audio( | |
| simulated_adv[idx], | |
| f"{tag}/simulated-adversarial", | |
| global_step=self._iter_id | |
| ) | |
| # adversarial loss value | |
| self.writer.log_scalar( | |
| adv_loss, | |
| f"{tag}/adversarial-loss", | |
| global_step=self._iter_id | |
| ) | |
| # detector loss value | |
| self.writer.log_scalar( | |
| det_loss, | |
| f"{tag}/detector-loss", | |
| global_step=self._iter_id | |
| ) | |
| # auxiliary loss value | |
| self.writer.log_scalar( | |
| aux_loss, | |
| f"{tag}/auxiliary-loss", | |
| global_step=self._iter_id | |
| ) | |
| # adversarial success rate | |
| self.writer.log_scalar( | |
| success_rate, | |
| f"{tag}/success-rate", | |
| global_step=self._iter_id | |
| ) | |
| # adversarial detection rate | |
| self.writer.log_scalar( | |
| detection_rate, | |
| f"{tag}/detection-rate", | |
| global_step=self._iter_id | |
| ) | |
| def _evaluate_batch(self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| **kwargs | |
| ): | |
| """ | |
| Perform attack on a batch of inputs. | |
| :param x: input tensor of shape (n_batch, ...) | |
| :param y: targets tensor of shape (n_batch, ...) in case of targeted | |
| attack; original labels tensor of shape (n_batch, ...) in | |
| case of untargeted attack | |
| """ | |
| raise NotImplementedError() | |
| def evaluate(self, | |
| x: torch.Tensor = None, | |
| y: torch.Tensor = None, | |
| dataset: Dataset = None, | |
| **kwargs | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Perform attack given input-target pairs, optionally in the form of a | |
| Dataset object. Random evaluations will then be conducted on all | |
| generated attacks. | |
| :param x: audio input, shape (n_batch, ..., signal_length) | |
| :param y: targets, shape (n_batch, ...) | |
| :param dataset: optionally, provide inputs and targets as dataset | |
| :return: tuple holding | |
| * adversarial audio (n_batch, ..., signal_length) | |
| * boolean adversarial success indicators (n_batch,) | |
| * boolean adversarial detection indicators (n_batch,) | |
| """ | |
| assert (x is not None and y is not None) or dataset is not None | |
| # prepare batched data-loading, store original device | |
| if dataset is None: | |
| orig_device = x.device | |
| dataset = self._create_dataset(x, y) | |
| x_ref = x[0:1].clone().detach() | |
| else: | |
| ref_batch = next(iter(dataset)) | |
| if isinstance(ref_batch, tuple): | |
| x_ref = ref_batch[0] | |
| elif isinstance(ref_batch, dict): | |
| x_ref = ref_batch['x'] | |
| else: | |
| x_ref = ref_batch | |
| orig_device = x_ref.device | |
| # prepare to compute attack success and detection rates | |
| attack_success = torch.zeros( | |
| len(dataset), dtype=torch.float).to(self.pipeline.device) | |
| attack_detection = torch.zeros( | |
| len(dataset), dtype=torch.float).to(self.pipeline.device) | |
| # prepare to store attack outputs | |
| adv_x = torch.stack( | |
| [torch.zeros(x_ref.shape)] * len(dataset), | |
| dim=0 | |
| ).to(self.pipeline.device) | |
| data_loader = DataLoader( | |
| dataset=dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| drop_last=False, | |
| num_workers=self.num_workers, | |
| pin_memory=self.pin_memory | |
| ) | |
| # compute attacks with batching | |
| for (batch_id, batch_all) in enumerate(data_loader): | |
| self._batch_id = batch_id | |
| # allow for different dataset formats | |
| if isinstance(batch_all, tuple): | |
| batch_all = { | |
| 'x': batch_all[0], | |
| 'y': batch_all[1] | |
| } | |
| # match devices | |
| for k in batch_all.keys(): | |
| batch_all[k] = batch_all[k].to(self.pipeline.device) | |
| batch_index_1 = batch_id * self.batch_size | |
| batch_index_2 = (batch_id + 1) * self.batch_size | |
| # compute attacks for given batch | |
| adversarial_batch = self._evaluate_batch( | |
| **batch_all, **kwargs | |
| ) | |
| # if no random trials, evaluate once | |
| if not self.rand_evals: | |
| # compute and store success rates for batch | |
| attack_success_batch = self._compute_success_array( | |
| **batch_all, | |
| x_adv=adversarial_batch | |
| ).reshape(-1).type(torch.float32) | |
| # compute and store detection rates for batch | |
| attack_detection_batch = self._compute_detection_array( | |
| x_adv=adversarial_batch | |
| ).reshape(-1).type(torch.float32) | |
| # otherwise, perform multiple random evaluations per attack | |
| else: | |
| # track batch success and detection over random evaluation | |
| success_combined_batch = [] | |
| detection_combined_batch = [] | |
| for i in range(self.rand_evals): | |
| # randomly sample simulation parameters | |
| self.pipeline.sample_params() | |
| # compute and store success rates for batch | |
| rand_success_batch = self._compute_success_array( | |
| **batch_all, | |
| x_adv=adversarial_batch | |
| ).reshape(-1, 1) | |
| success_combined_batch.append(rand_success_batch) | |
| # compute and store detection rates for batch | |
| rand_detection_batch = self._compute_detection_array( | |
| x_adv=adversarial_batch | |
| ) | |
| detection_combined_batch.append(rand_detection_batch) | |
| # average results over all trials | |
| attack_success_batch = (1.0 * torch.cat( | |
| success_combined_batch, dim=-1 | |
| )).mean(dim=-1) | |
| attack_detection_batch = (1.0 * torch.cat( | |
| detection_combined_batch, dim=-1 | |
| )).mean(dim=-1) | |
| # store generated attack audio | |
| adv_x[batch_index_1:batch_index_2] = adversarial_batch | |
| # store success rate per generated attack | |
| attack_success[batch_index_1:batch_index_2] = attack_success_batch | |
| # store detection rate per generated attack | |
| attack_detection[batch_index_1:batch_index_2] = attack_detection_batch | |
| return (adv_x.to(orig_device), | |
| attack_success.to(orig_device), | |
| attack_detection.to(orig_device)) | |