ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import librosa as li
from typing import Union
from src.attacks.offline.trainable import TrainableAttack
from src.pipelines import Pipeline
from src.loss.adversarial import AdversarialLoss
from src.attacks.offline.perturbation import AdditivePerturbation
from src.data import DataProperties
################################################################################
# Implementation of universal additive attack of Li et al.
################################################################################
class AdvPulseAttack(TrainableAttack):
def __init__(self,
pipeline: Pipeline,
adv_loss: AdversarialLoss,
mimic_sound: Union[torch.Tensor, str] = None,
init_mimic: bool = False,
tile_reference: bool = False,
normalize: bool = False,
eps: float = 0.05,
pgd_norm: Union[str, int, float] = float('inf'),
length: Union[int, float] = 0.5,
align: str = 'start',
loop: bool = False,
**kwargs
):
super().__init__(
pipeline=pipeline,
adv_loss=adv_loss,
perturbation=AdditivePerturbation(
eps=eps,
projection_norm=pgd_norm,
length=length,
align=align,
loop=loop,
normalize=normalize
),
**kwargs
)
# determine whether to repeat template to match perturbation length
self.tile_reference = tile_reference
if mimic_sound is None:
self.mimic_sound = None
elif isinstance(mimic_sound, torch.Tensor):
# require batch, channel dimensions
assert mimic_sound.ndim >= 2
# convert to mono audio
if mimic_sound.ndim == 2:
mimic_sound = mimic_sound.unsqueeze(1)
self.mimic_sound = mimic_sound.mean(
dim=1, keepdim=True
).to(self.pipeline.device)
# load from file path
elif isinstance(mimic_sound, str):
# load from randomly-selected file
mimic_sound_np, _ = li.load(
mimic_sound,
sr=DataProperties.get('sample_rate'),
mono=True
)
mimic_sound = torch.as_tensor(mimic_sound_np)
# if length is specified, trim to match
max_len = DataProperties.get('signal_length')
self.mimic_sound = mimic_sound[..., :max_len].reshape(
1, 1, -1
).to(self.pipeline.device)
else:
raise ValueError(f'Invalid mimic sound type {type(mimic_sound)}')
# if specified, initialize adversarial perturbation to match template
if self.mimic_sound is not None and init_mimic:
self.perturbation.delta = nn.Parameter(
self._match_signal_length(
self.mimic_sound,
torch.zeros(1, self.perturbation.length)
)
)
@staticmethod
def _crossfade(sig, fade_len):
"""Apply cross-fade to ends of signal"""
sig = sig.clone()
fade_len = int(fade_len * sig.shape[-1])
fade_in = torch.linspace(0, 1, fade_len).to(sig)
fade_out = torch.linspace(1, 0, fade_len).to(sig)
sig[..., :fade_len] *= fade_in
sig[..., -fade_len:] *= fade_out
return sig
def _match_signal_length(self, sig: torch.Tensor, ref: torch.Tensor):
"""
Match length of signal to reference, either by trimming or repeating and
cross-fading
"""
sig = sig.reshape(1, -1)
ref = ref.reshape(1, -1)
signal_length = ref.shape[-1]
if sig.shape[-1] >= signal_length:
return sig[..., :signal_length].reshape(1, 1, -1).to(ref)
elif not self.tile_reference:
return F.pad(
sig, (0, signal_length - sig.shape[-1])
).reshape(1, 1, -1).to(ref)
# cross-fade length
overlap = 0.05
step = math.ceil(sig.shape[-1] * (1 - overlap))
n_repeat = math.ceil(signal_length / step)
padded = torch.zeros(
1, step * (n_repeat - 1) + sig.shape[-1] + 1
).reshape(1, -1).to(sig)
shape = padded.shape[:-1] + (n_repeat, sig.shape[-1])
strides = (padded.stride()[0],) + (step, padded.stride()[-1],)
frames = torch.as_strided(
padded, size=shape, stride=strides
)[::step]
for j in range(n_repeat):
frames[:, j, :] += self._crossfade(sig, overlap)
sig = padded[..., :signal_length].reshape(
1, 1, -1
).to(ref)
return sig
def _set_loss_reference(self, x: torch.Tensor):
"""
Pass reference audio to auxiliary loss to avoid re-computing expensive
intermediate representations. For AdvPulse attack, optionally use
"""
if self.aux_loss is not None:
if self.mimic_sound is not None:
reference = self._match_signal_length(
self.mimic_sound,
self.perturbation.delta
)
else:
reference = x
self.aux_loss.set_reference(reference)
def _compute_aux_loss(self,
x_adv: torch.Tensor,
x_ref: torch.Tensor = None):
"""Compute auxiliary loss, optionally """
if self.mimic_sound is not None:
return self.aux_loss(self.perturbation.delta, x_ref)
else:
return self.aux_loss(x_adv, x_ref)
def _log_step(self,
x: torch.Tensor,
x_adv: torch.Tensor,
y: torch.Tensor,
adv_loss: Union[float, torch.Tensor] = None,
det_loss: Union[float, torch.Tensor] = None,
aux_loss: Union[float, torch.Tensor] = None,
success_rate: Union[float, torch.Tensor] = None,
detection_rate: Union[float, torch.Tensor] = None,
idx: int = 0,
tag: str = None,
*args,
**kwargs
):
if self.writer is None or self._iter_id % self.writer.log_iter:
return
if tag is None:
tag = f'{self.__class__.__name__}-' \
f'{self.aux_loss.__class__.__name__}'
super()._log_step(
x,
x_adv,
y,
adv_loss=adv_loss,
det_loss=det_loss,
aux_loss=aux_loss,
success_rate=success_rate,
detection_rate=detection_rate,
idx=idx,
tag=tag
)
# add audio and spectrogram logging for mimic sound
if self.mimic_sound is not None:
self.writer.log_audio(
self.mimic_sound,
f'{tag}/sound-template',
global_step=self._iter_id
)