ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
from src.attacks.offline.perturbation import Perturbation
from src.attacks.offline.trainable import TrainableAttack
from src.attacks.offline.perturbation import VoiceBox
from src.loss.auxiliary import AuxiliaryLoss
from typing import Union
################################################################################
# VoiceBox online filtering-based attack
################################################################################
class VoiceBoxAttack(TrainableAttack):
def __init__(self,
voicebox_kwargs: dict,
control_loss: AuxiliaryLoss = None,
**kwargs):
# additional (optional) auxiliary loss on filter controls
self.control_loss = control_loss
super().__init__(
perturbation=VoiceBox(**voicebox_kwargs),
**kwargs)
def _log_step(self,
x: torch.Tensor,
x_adv: torch.Tensor,
y: torch.Tensor,
adv_loss: Union[float, torch.Tensor] = None,
det_loss: Union[float, torch.Tensor] = None,
aux_loss: Union[float, torch.Tensor] = None,
control_loss: Union[float, torch.Tensor] = None,
success_rate: Union[float, torch.Tensor] = None,
detection_rate: Union[float, torch.Tensor] = None,
idx: int = 0,
tag: str = None,
*args,
**kwargs
):
if self.writer is None or self._iter_id % self.writer.log_iter:
return
if tag is None:
tag = f'{self.__class__.__name__}-' \
f'{self.aux_loss.__class__.__name__}'
super()._log_step(
x,
x_adv,
y,
adv_loss=adv_loss,
det_loss=det_loss,
aux_loss=aux_loss,
success_rate=success_rate,
detection_rate=detection_rate,
idx=idx,
tag=tag
)
# log control-signal loss
self.writer.log_scalar(
control_loss,
f"{tag}/control-signal-loss",
global_step=self._iter_id
)
def _train_batch(self,
x: torch.Tensor,
y: torch.Tensor,
*args,
**kwargs):
"""Optimize stored model/perturbation over a batch of inputs"""
# require batch dimension
assert x.ndim >= 2
n_batch = x.shape[0]
x = x.detach()
# set reference for auxiliary loss to avoid re-computing
self._set_loss_reference(x)
# randomly sample simulation parameters
if self.eot_iter and not self._iter_id % self.eot_iter:
self.pipeline.sample_params()
def closure():
# placeholder for final model/perturbation gradients
model_gradients = \
self._retrieve_parameter_gradients(self.perturbation)
grad_total = torch.zeros_like(model_gradients)
# apply adversarial perturbation to batch and obtain predictions
perturbed = self.perturbation(x, y=y, *args, **kwargs)
outputs = self.pipeline(perturbed)
# reset parameter gradients, using `None` for performance boost
self.perturbation.zero_grad(set_to_none=True)
# compute flattened parameter gradients w.r.t. adversarial loss
adv_scores = self.adv_loss(outputs, y)
adv_loss = torch.mean(adv_scores)
adv_loss.backward(retain_graph=True)
adv_loss_grad = self._retrieve_parameter_gradients(
self.perturbation
).detach()
# reset parameter gradients, using `None` for performance boost
self.perturbation.zero_grad(set_to_none=True)
# compute flattened parameter gradients w.r.t. detector loss
detector_flags, detector_scores = self.pipeline.detect(perturbed)
detector_loss = torch.mean(detector_scores)
detector_loss.backward(retain_graph=True)
detector_loss_grad = self._retrieve_parameter_gradients(
self.perturbation
).detach()
# reset parameter gradients, using `None` for performance boost
self.perturbation.zero_grad(set_to_none=True)
# compute flattened parameter gradients w.r.t. auxiliary loss
if self.aux_loss is not None:
aux_scores = self._compute_aux_loss(perturbed)
aux_loss = torch.mean(aux_scores)
aux_loss.backward()
aux_loss_grad = self._retrieve_parameter_gradients(
self.perturbation
).detach()
else: # if no auxiliary loss, do not penalize
aux_scores = torch.zeros(n_batch).to(x.device)
aux_loss = torch.mean(aux_scores)
aux_loss_grad = torch.zeros_like(adv_loss_grad).detach()
# obtain filter controls for given inputs
get_controls = callable(
getattr(self.perturbation, "get_controls", None))
if self.control_loss is not None and get_controls:
# compute slowness / sparsity loss on control signal
controls = self.perturbation.get_controls(
x, *args, **kwargs)
control_scores = self.control_loss(controls)
control_loss = torch.mean(control_scores) * 0.01
# backpropagate
control_loss.backward()
# retrieve parameter gradients
control_loss_grad = self._retrieve_parameter_gradients(
self.perturbation
).detach()
# add to aux loss
aux_loss_grad = aux_loss_grad + control_loss_grad
else:
control_loss = 0.0
################################################################
# classifier evasion indicator, reshape for broadcasting
adv_success = (adv_loss <= self.adv_success_thresh) * 1.0
# detector evasion indicator, reshape for broadcasting
detector_success = (detector_loss <= self.det_success_thresh) * 1.0
# perform standard, orthogonal, or selective gradient
# accumulation
if self.pgd_variant is None or self.pgd_variant == 'none':
# for standard PGD, sum loss gradients
grad_total += adv_loss_grad + \
detector_loss_grad + \
aux_loss_grad
elif self.pgd_variant == 'orthogonal':
# for orthogonal PGD, orthogonalize loss gradients and
# select one for update; optionally, orthogonalize only
# every kth step
if self.k and self._batch_id % self.k:
adv_loss_grad_proj = adv_loss_grad
detector_loss_grad_proj = detector_loss_grad
aux_loss_grad_proj = aux_loss_grad
else:
adv_loss_grad_proj = self._component_orthogonal(
adv_loss_grad,
detector_loss_grad,
aux_loss_grad
)
detector_loss_grad_proj = self._component_orthogonal(
detector_loss_grad,
adv_loss_grad,
aux_loss_grad
)
aux_loss_grad_proj = self._component_orthogonal(
aux_loss_grad,
detector_loss_grad,
adv_loss_grad
)
# update 'along' a single loss gradient per iteration
grad_total += adv_loss_grad_proj * (1 - adv_success)
grad_total += detector_loss_grad_proj * adv_success \
* (1 - detector_success)
grad_total += aux_loss_grad_proj * adv_success * \
detector_success
elif self.pgd_variant == 'selective':
# only consider a single loss per iteration, without
# ensuring orthogonality to remaining loss gradients
grad_total += adv_loss_grad * (1 - adv_success)
grad_total += detector_loss_grad * adv_success \
* (1 - detector_success)
grad_total += aux_loss_grad * adv_success * detector_success
else:
raise ValueError(f'Invalid attack mode {self.pgd_variant}')
# regularize gradients via p-norm projection
if self.scale_grad in [2, float(2), "2"]:
grad_norms = torch.norm(
grad_total, p=2, dim=-1
) + 1e-20
grad_total = grad_total / grad_norms
elif self.scale_grad in [float("inf"), "inf"]:
grad_total = torch.sign(grad_total)
elif self.scale_grad in ['none', None]:
pass
else:
raise ValueError(f'Invalid gradient regularization norm '
f'{self.scale_grad}'
)
# set final parameter gradients
self._set_parameter_gradients(
grad_total.flatten(),
self.perturbation
)
# log results
if self.writer is not None:
self._log_step(
x=x,
x_adv=perturbed,
y=y,
adv_loss=adv_loss,
det_loss=detector_loss,
aux_loss=aux_loss,
control_loss=control_loss,
detection_rate=torch.mean(1.0 * detector_flags)
)
# return placeholder loss
return adv_loss + detector_loss + aux_loss
# optimizer step, using stored gradients
self.optimizer.step(closure)
# project perturbation to feasible region
if hasattr(self.perturbation, "project_valid"):
try:
self.perturbation.project_valid()
except AttributeError:
pass
# update total iteration count
self._iter_id += 1
def _evaluate_batch(self,
x: torch.Tensor,
y: torch.Tensor,
*args,
**kwargs
):
"""Evaluate batch of inputs by passing through model/perturbation"""
x_orig = x.clone().detach()
x_adv = self.perturbation(x_orig, y=y, *args, **kwargs)
return x_adv