Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset | |
| from typing import Union | |
| import warnings | |
| from src.attacks.offline.trainable import TrainableAttack | |
| from src.attacks.offline.perturbation.perturbation import Perturbation | |
| from src.attacks.offline.perturbation.kenansville import KenansvillePerturbation | |
| from src.pipelines.pipeline import Pipeline | |
| from src.loss.adversarial import AdversarialLoss | |
| from src.loss.auxiliary import AuxiliaryLoss | |
| from src.utils.writer import Writer | |
| ################################################################################ | |
| # Untargeted, black-box signal-processing attack | |
| ################################################################################ | |
| class KenansvilleAttack(TrainableAttack): | |
| """ | |
| Perturb inputs by removing frequency content. | |
| """ | |
| def __init__(self, | |
| pipeline: Pipeline, | |
| adv_loss: AdversarialLoss, | |
| threshold_db_low: float = 1.0, | |
| threshold_db_high: float = 100.0, | |
| step_size: float = 10.0, | |
| search: str = 'bisection', | |
| min_success_rate: float = 0.9, | |
| win_type: str = 'hann', | |
| win_length: int = 2048, | |
| **kwargs | |
| ): | |
| """ | |
| Untargeted black-box spectral-bin-removal attack proposed by Abdullah | |
| et al. (https://arxiv.org/abs/1910.05262). Code adapted from | |
| https://bit.ly/31K4Efy. | |
| :param pipeline: a Pipeline object | |
| :param adv_loss: an AdversarialLoss object; must be untargeted | |
| :param aux_loss: an optional AuxiliaryLoss object | |
| :param threshold_db: energy threshold relative to spectral peak energy; | |
| frequency bins below threshold are removed | |
| :param max_iter: iterations to search for optimal threshold. If nonzero, | |
| search for highest (least perceptible) threshold value | |
| such that attack achieves 100% untargeted success | |
| against given pipeline. Otherwise, use given threshold | |
| :param min_success_rate: minimum acceptable untargeted success rate when | |
| optimizing threshold | |
| :param win_type: window type; must be one of 'rectangular' or 'hann'. | |
| For Hann window, audio is framed with 50% overlap | |
| :param frame_len: frame length in samples | |
| """ | |
| self.threshold_db_low = threshold_db_low | |
| self.threshold_db_high = threshold_db_high | |
| self.step_size = step_size | |
| self.search = search | |
| self.min_success_rate = min_success_rate | |
| super().__init__( | |
| pipeline=pipeline, | |
| adv_loss=adv_loss, | |
| perturbation=KenansvillePerturbation( | |
| threshold_db=threshold_db_low, | |
| win_type=win_type, | |
| win_length=win_length | |
| ), | |
| **kwargs | |
| ) | |
| def train(self, | |
| x_train: torch.Tensor = None, | |
| y_train: torch.Tensor = None, | |
| data_train: Dataset = None, | |
| x_val: torch.Tensor = None, | |
| y_val: torch.Tensor = None, | |
| data_val: Dataset = None, | |
| *args, | |
| **kwargs | |
| ): | |
| loader_train, loader_val = self._prepare_data( | |
| x_train, | |
| y_train, | |
| data_train, | |
| x_val, | |
| y_val, | |
| data_val) | |
| # match devices and set reference if necessary | |
| ref_batch = next(iter(loader_train)) | |
| if isinstance(ref_batch, tuple): | |
| x_ref = ref_batch[0] | |
| warnings.warn('Warning: provided dataset yields batches in tuple ' | |
| 'format; the first two tensors of each batch will be ' | |
| 'interpreted as inputs and targets, respectively, ' | |
| 'and any remaining tensors will be ignored. To pass ' | |
| 'additional named tensor arguments, use a dictionary ' | |
| 'batch format with keys `x` and `y` for inputs and ' | |
| 'targets, respectively.') | |
| elif isinstance(ref_batch, dict): | |
| x_ref = ref_batch['x'] | |
| else: | |
| x_ref = ref_batch | |
| if hasattr(self.perturbation, "set_reference"): | |
| try: | |
| self.perturbation.set_reference( | |
| x_ref.to(self.pipeline.device)) | |
| except AttributeError: | |
| pass | |
| # enumerate possible SNR values for search | |
| threshold_values = torch.arange( | |
| self.threshold_db_low, self.threshold_db_high, self.step_size) | |
| # track iterations | |
| self._iter_id = 0 | |
| self._batch_id = 0 | |
| self._epoch_id = 0 | |
| # avoid unnecessary search | |
| if self.threshold_db_low == self.threshold_db_high \ | |
| or len(threshold_values) < 2 \ | |
| or self.search in ['none', None]: | |
| self.perturbation.set_threshold(self.threshold_db_low) | |
| else: | |
| # find best threshold via search | |
| i_min = 0 | |
| i_max = len(threshold_values) | |
| threshold_best = self.threshold_db_low | |
| # perform bisection search for maximum SNR value which achieves | |
| # minimum success threshold | |
| if self.search == 'bisection': | |
| while i_min < i_max: | |
| # determine midpoint index | |
| i_mid = (i_min + i_max) // 2 | |
| threshold = threshold_values[i_mid] | |
| # set threshold | |
| self.perturbation.set_threshold(threshold) | |
| # compute success rate over training data at each candidate | |
| # threshold level | |
| successes = 0 | |
| n = 0 | |
| self._batch_id = 0 | |
| for batch in loader_train: | |
| if isinstance(batch, dict): | |
| x, y = batch['x'], batch['y'] | |
| else: | |
| x, y, *_ = batch | |
| x = x.to(self.pipeline.device) | |
| y = y.to(self.pipeline.device) | |
| n += len(x) | |
| x_adv = self.perturbation(x) | |
| outputs = self.pipeline(x_adv) | |
| adv_scores = self.adv_loss(outputs, y) | |
| adv_loss = adv_scores.mean() | |
| batch_successes = (1.0 * self._compute_success_array( | |
| x, y, x_adv)).sum().item() | |
| successes += batch_successes | |
| self._log_step( | |
| x, | |
| x_adv, | |
| y, | |
| adv_loss, | |
| success_rate=batch_successes/len(x) | |
| ) | |
| self._batch_id += 1 | |
| self._iter_id += 1 | |
| success_rate = successes / n | |
| if success_rate >= self.min_success_rate: | |
| threshold_best = threshold | |
| i_min = i_mid + 1 | |
| else: | |
| i_max = i_mid | |
| # perform linear search for SNR level | |
| elif self.search == 'linear': | |
| for threshold in threshold_values: | |
| # set threshold | |
| self.perturbation.set_threshold(threshold) | |
| # compute success rate over training data at each candidate | |
| # threshold level | |
| successes = 0 | |
| n = 0 | |
| self._batch_id = 0 | |
| for batch in loader_train: | |
| if isinstance(batch, dict): | |
| x, y = batch['x'], batch['y'] | |
| else: | |
| x, y, *_ = batch | |
| x = x.to(self.pipeline.device) | |
| y = y.to(self.pipeline.device) | |
| n += len(x) | |
| x_adv = self.perturbation(x) | |
| outputs = self.pipeline(x_adv) | |
| adv_scores = self.adv_loss(outputs, y) | |
| adv_loss = adv_scores.mean() | |
| batch_successes = (1.0 * self._compute_success_array( | |
| x, y, x_adv)).sum().item() | |
| successes += batch_successes | |
| self._log_step( | |
| x, | |
| x_adv, | |
| y, | |
| adv_loss, | |
| success_rate=batch_successes/len(x) | |
| ) | |
| self._batch_id += 1 | |
| self._iter_id += 1 | |
| success_rate = successes / n | |
| if success_rate >= self.min_success_rate: | |
| threshold_best = threshold | |
| else: | |
| raise ValueError(f'Invalid search method {self.search}') | |
| # set final SNR value | |
| self.perturbation.set_threshold(threshold_best) | |
| # perform validation | |
| adv_scores = [] | |
| aux_scores = [] | |
| det_scores = [] | |
| success_indicators = [] | |
| detection_indicators = [] | |
| self.perturbation.eval() | |
| for batch_id, batch in enumerate(loader_val): | |
| # randomize simulation for each validation batch | |
| self.pipeline.sample_params() | |
| if isinstance(batch, dict): | |
| x_orig, targets = batch['x'], batch['y'] | |
| else: | |
| x_orig, targets, *_ = batch | |
| n_batch = x_orig.shape[0] | |
| x_orig = x_orig.to(self.pipeline.device) | |
| targets = targets.to(self.pipeline.device) | |
| # set reference for auxiliary loss | |
| self._set_loss_reference(x_orig) | |
| with torch.no_grad(): | |
| # compute adversarial loss | |
| x_adv = self._evaluate_batch(x_orig, targets) | |
| outputs = self.pipeline(x_adv) | |
| adv_scores.append(self.adv_loss(outputs, targets).flatten()) | |
| # compute adversarial success rate | |
| success_indicators.append( | |
| 1.0 * self._compute_success_array( | |
| x_orig, targets, x_adv | |
| ).flatten()) | |
| # compute defense loss and detection indicators | |
| def_results = self.pipeline.detect(x_adv) | |
| detection_indicators.append(1.0 * def_results[0].flatten()) | |
| det_scores.append(def_results[1].flatten()) | |
| # compute auxiliary loss | |
| if self.aux_loss is not None: | |
| aux_scores.append( | |
| self._compute_aux_loss(x_adv).flatten()) | |
| else: | |
| aux_scores.append(torch.zeros(n_batch)) | |
| tag = f'{self.__class__.__name__}-' \ | |
| f'{self.aux_loss.__class__.__name__}' | |
| if self.writer is not None: | |
| with self.writer.force_logging(): | |
| # adversarial loss value | |
| self.writer.log_scalar( | |
| torch.cat(adv_scores, dim=0).mean(), | |
| f"{tag}/adversarial-loss-val", | |
| global_step=self._iter_id | |
| ) | |
| # detector loss value | |
| self.writer.log_scalar( | |
| torch.cat(det_scores, dim=0).mean(), | |
| f"{tag}/detector-loss-val", | |
| global_step=self._iter_id | |
| ) | |
| # auxiliary loss value | |
| self.writer.log_scalar( | |
| torch.cat(aux_scores, dim=0).mean(), | |
| f"{tag}/auxiliary-loss-val", | |
| global_step=self._iter_id | |
| ) | |
| # adversarial success rate | |
| self.writer.log_scalar( | |
| torch.cat(success_indicators, dim=0).mean(), | |
| f"{tag}/success-rate-val", | |
| global_step=self._iter_id | |
| ) | |
| # adversarial detection rate | |
| self.writer.log_scalar( | |
| torch.cat(detection_indicators, dim=0).mean(), | |
| f"{tag}/detection-rate-val", | |
| global_step=self._iter_id | |
| ) | |
| # freeze model parameters | |
| self.perturbation.eval() | |
| for p in self.perturbation.parameters(): | |
| p.requires_grad = False | |
| # save model/perturbation | |
| self._checkpoint() | |
| def _evaluate_batch(self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| **kwargs | |
| ): | |
| """ | |
| Remove low-energy frequency content from inputs. | |
| """ | |
| # require batch dimension | |
| assert x.ndim >= 2 | |
| return self.perturbation(x) | |