Spaces:
Sleeping
Sleeping
| import torch | |
| from src.attacks.offline.perturbation import Perturbation | |
| from src.attacks.offline.trainable import TrainableAttack | |
| from src.attacks.offline.perturbation import VoiceBox | |
| from src.loss.auxiliary import AuxiliaryLoss | |
| from typing import Union | |
| ################################################################################ | |
| # VoiceBox online filtering-based attack | |
| ################################################################################ | |
| class VoiceBoxAttack(TrainableAttack): | |
| def __init__(self, | |
| voicebox_kwargs: dict, | |
| control_loss: AuxiliaryLoss = None, | |
| **kwargs): | |
| # additional (optional) auxiliary loss on filter controls | |
| self.control_loss = control_loss | |
| super().__init__( | |
| perturbation=VoiceBox(**voicebox_kwargs), | |
| **kwargs) | |
| def _log_step(self, | |
| x: torch.Tensor, | |
| x_adv: torch.Tensor, | |
| y: torch.Tensor, | |
| adv_loss: Union[float, torch.Tensor] = None, | |
| det_loss: Union[float, torch.Tensor] = None, | |
| aux_loss: Union[float, torch.Tensor] = None, | |
| control_loss: Union[float, torch.Tensor] = None, | |
| success_rate: Union[float, torch.Tensor] = None, | |
| detection_rate: Union[float, torch.Tensor] = None, | |
| idx: int = 0, | |
| tag: str = None, | |
| *args, | |
| **kwargs | |
| ): | |
| if self.writer is None or self._iter_id % self.writer.log_iter: | |
| return | |
| if tag is None: | |
| tag = f'{self.__class__.__name__}-' \ | |
| f'{self.aux_loss.__class__.__name__}' | |
| super()._log_step( | |
| x, | |
| x_adv, | |
| y, | |
| adv_loss=adv_loss, | |
| det_loss=det_loss, | |
| aux_loss=aux_loss, | |
| success_rate=success_rate, | |
| detection_rate=detection_rate, | |
| idx=idx, | |
| tag=tag | |
| ) | |
| # log control-signal loss | |
| self.writer.log_scalar( | |
| control_loss, | |
| f"{tag}/control-signal-loss", | |
| global_step=self._iter_id | |
| ) | |
| def _train_batch(self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| *args, | |
| **kwargs): | |
| """Optimize stored model/perturbation over a batch of inputs""" | |
| # require batch dimension | |
| assert x.ndim >= 2 | |
| n_batch = x.shape[0] | |
| x = x.detach() | |
| # set reference for auxiliary loss to avoid re-computing | |
| self._set_loss_reference(x) | |
| # randomly sample simulation parameters | |
| if self.eot_iter and not self._iter_id % self.eot_iter: | |
| self.pipeline.sample_params() | |
| def closure(): | |
| # placeholder for final model/perturbation gradients | |
| model_gradients = \ | |
| self._retrieve_parameter_gradients(self.perturbation) | |
| grad_total = torch.zeros_like(model_gradients) | |
| # apply adversarial perturbation to batch and obtain predictions | |
| perturbed = self.perturbation(x, y=y, *args, **kwargs) | |
| outputs = self.pipeline(perturbed) | |
| # reset parameter gradients, using `None` for performance boost | |
| self.perturbation.zero_grad(set_to_none=True) | |
| # compute flattened parameter gradients w.r.t. adversarial loss | |
| adv_scores = self.adv_loss(outputs, y) | |
| adv_loss = torch.mean(adv_scores) | |
| adv_loss.backward(retain_graph=True) | |
| adv_loss_grad = self._retrieve_parameter_gradients( | |
| self.perturbation | |
| ).detach() | |
| # reset parameter gradients, using `None` for performance boost | |
| self.perturbation.zero_grad(set_to_none=True) | |
| # compute flattened parameter gradients w.r.t. detector loss | |
| detector_flags, detector_scores = self.pipeline.detect(perturbed) | |
| detector_loss = torch.mean(detector_scores) | |
| detector_loss.backward(retain_graph=True) | |
| detector_loss_grad = self._retrieve_parameter_gradients( | |
| self.perturbation | |
| ).detach() | |
| # reset parameter gradients, using `None` for performance boost | |
| self.perturbation.zero_grad(set_to_none=True) | |
| # compute flattened parameter gradients w.r.t. auxiliary loss | |
| if self.aux_loss is not None: | |
| aux_scores = self._compute_aux_loss(perturbed) | |
| aux_loss = torch.mean(aux_scores) | |
| aux_loss.backward() | |
| aux_loss_grad = self._retrieve_parameter_gradients( | |
| self.perturbation | |
| ).detach() | |
| else: # if no auxiliary loss, do not penalize | |
| aux_scores = torch.zeros(n_batch).to(x.device) | |
| aux_loss = torch.mean(aux_scores) | |
| aux_loss_grad = torch.zeros_like(adv_loss_grad).detach() | |
| # obtain filter controls for given inputs | |
| get_controls = callable( | |
| getattr(self.perturbation, "get_controls", None)) | |
| if self.control_loss is not None and get_controls: | |
| # compute slowness / sparsity loss on control signal | |
| controls = self.perturbation.get_controls( | |
| x, *args, **kwargs) | |
| control_scores = self.control_loss(controls) | |
| control_loss = torch.mean(control_scores) * 0.01 | |
| # backpropagate | |
| control_loss.backward() | |
| # retrieve parameter gradients | |
| control_loss_grad = self._retrieve_parameter_gradients( | |
| self.perturbation | |
| ).detach() | |
| # add to aux loss | |
| aux_loss_grad = aux_loss_grad + control_loss_grad | |
| else: | |
| control_loss = 0.0 | |
| ################################################################ | |
| # classifier evasion indicator, reshape for broadcasting | |
| adv_success = (adv_loss <= self.adv_success_thresh) * 1.0 | |
| # detector evasion indicator, reshape for broadcasting | |
| detector_success = (detector_loss <= self.det_success_thresh) * 1.0 | |
| # perform standard, orthogonal, or selective gradient | |
| # accumulation | |
| if self.pgd_variant is None or self.pgd_variant == 'none': | |
| # for standard PGD, sum loss gradients | |
| grad_total += adv_loss_grad + \ | |
| detector_loss_grad + \ | |
| aux_loss_grad | |
| elif self.pgd_variant == 'orthogonal': | |
| # for orthogonal PGD, orthogonalize loss gradients and | |
| # select one for update; optionally, orthogonalize only | |
| # every kth step | |
| if self.k and self._batch_id % self.k: | |
| adv_loss_grad_proj = adv_loss_grad | |
| detector_loss_grad_proj = detector_loss_grad | |
| aux_loss_grad_proj = aux_loss_grad | |
| else: | |
| adv_loss_grad_proj = self._component_orthogonal( | |
| adv_loss_grad, | |
| detector_loss_grad, | |
| aux_loss_grad | |
| ) | |
| detector_loss_grad_proj = self._component_orthogonal( | |
| detector_loss_grad, | |
| adv_loss_grad, | |
| aux_loss_grad | |
| ) | |
| aux_loss_grad_proj = self._component_orthogonal( | |
| aux_loss_grad, | |
| detector_loss_grad, | |
| adv_loss_grad | |
| ) | |
| # update 'along' a single loss gradient per iteration | |
| grad_total += adv_loss_grad_proj * (1 - adv_success) | |
| grad_total += detector_loss_grad_proj * adv_success \ | |
| * (1 - detector_success) | |
| grad_total += aux_loss_grad_proj * adv_success * \ | |
| detector_success | |
| elif self.pgd_variant == 'selective': | |
| # only consider a single loss per iteration, without | |
| # ensuring orthogonality to remaining loss gradients | |
| grad_total += adv_loss_grad * (1 - adv_success) | |
| grad_total += detector_loss_grad * adv_success \ | |
| * (1 - detector_success) | |
| grad_total += aux_loss_grad * adv_success * detector_success | |
| else: | |
| raise ValueError(f'Invalid attack mode {self.pgd_variant}') | |
| # regularize gradients via p-norm projection | |
| if self.scale_grad in [2, float(2), "2"]: | |
| grad_norms = torch.norm( | |
| grad_total, p=2, dim=-1 | |
| ) + 1e-20 | |
| grad_total = grad_total / grad_norms | |
| elif self.scale_grad in [float("inf"), "inf"]: | |
| grad_total = torch.sign(grad_total) | |
| elif self.scale_grad in ['none', None]: | |
| pass | |
| else: | |
| raise ValueError(f'Invalid gradient regularization norm ' | |
| f'{self.scale_grad}' | |
| ) | |
| # set final parameter gradients | |
| self._set_parameter_gradients( | |
| grad_total.flatten(), | |
| self.perturbation | |
| ) | |
| # log results | |
| if self.writer is not None: | |
| self._log_step( | |
| x=x, | |
| x_adv=perturbed, | |
| y=y, | |
| adv_loss=adv_loss, | |
| det_loss=detector_loss, | |
| aux_loss=aux_loss, | |
| control_loss=control_loss, | |
| detection_rate=torch.mean(1.0 * detector_flags) | |
| ) | |
| # return placeholder loss | |
| return adv_loss + detector_loss + aux_loss | |
| # optimizer step, using stored gradients | |
| self.optimizer.step(closure) | |
| # project perturbation to feasible region | |
| if hasattr(self.perturbation, "project_valid"): | |
| try: | |
| self.perturbation.project_valid() | |
| except AttributeError: | |
| pass | |
| # update total iteration count | |
| self._iter_id += 1 | |
| def _evaluate_batch(self, | |
| x: torch.Tensor, | |
| y: torch.Tensor, | |
| *args, | |
| **kwargs | |
| ): | |
| """Evaluate batch of inputs by passing through model/perturbation""" | |
| x_orig = x.clone().detach() | |
| x_adv = self.perturbation(x_orig, y=y, *args, **kwargs) | |
| return x_adv | |