ALeLacheur's picture
Voiceblock demo: Attempt 8
957e2dc
import math
import torch
import pandas as pd
import numpy as np
from typing import List, Union
from pesq import pesq
from pystoi import stoi
from tqdm import tqdm
from src.data import DataProperties
from src.utils.plotting import tensor_to_np
from src.models.speech import Wav2Vec2, GreedyCTCDecoder
################################################################################
# Utilities for analyzing attack results
################################################################################
@torch.no_grad()
def run_perceptual_evaluation(x: torch.Tensor,
x_adv: torch.Tensor,
batch_size: int = 1,
device: Union[str, torch.cuda.device] = 'cpu',
tag: str = None,
**kwargs
):
"""
Compute perceptual quality metrics on pairs of clean and adversarial audio
Parameters
----------
x (Tensor): shape
x_adv (Tensor): shape
batch_size (int):
device (str):
Returns
-------
"""
# check for compatible audio dimensions
assert x.ndim == x_adv.ndim
# require batch dimension
assert x.ndim >= 2
n_batch = x.shape[0]
# store results
results = {}
# name results
tag = '' if tag is None else f'{tag}-'
############################################################################
# WAVEFORM P-NORM DISTANCE
############################################################################
# if dimensions match, measure L-2 and L-inf distance between waveforms
if x.shape == x_adv.shape:
reduce_dims = tuple(range(1, x.ndim))
l2 = (x - x_adv).norm(
p=2, dim=reduce_dims).flatten().tolist()
linf = (x - x_adv).norm(
p=float('inf'), dim=reduce_dims).flatten().tolist()
results = {
**results,
tag + 'L2-Waveform': l2,
tag + 'Linf-Waveform': linf
}
############################################################################
# PESQ OBJECTIVE MEASURE (DEPRECATED)
############################################################################
assert DataProperties.get('sample_rate') in [16000, 8000], \
f"Cannot perform PESQ evaluation with sample rate " \
f"{DataProperties.get('sample_rate')}Hz; must be 8000Hz or 16000Hz"
wb_scores, nb_scores = [], []
for i in tqdm(range(n_batch), desc='computing PESQ scores'):
wb_scores.append(
pesq(DataProperties.get('sample_rate'),
tensor_to_np(x[i]).flatten(),
tensor_to_np(x_adv[i]).flatten(),
'wb')
)
nb_scores.append(
pesq(DataProperties.get('sample_rate'),
tensor_to_np(x[i]).flatten(),
tensor_to_np(x_adv[i]).flatten(),
'nb')
)
results = {
**results,
tag + 'PESQ-Wideband': wb_scores,
tag + 'PESQ-Narrowband': nb_scores,
}
############################################################################
# STOI OBJECTIVE MEASURE (DEPRECATED)
############################################################################
cl_scores, ex_scores = [], []
for i in tqdm(range(n_batch), desc='computing STOI scores'):
cl_scores.append(
stoi(tensor_to_np(x[i]).flatten(),
tensor_to_np(x_adv[i]).flatten(),
DataProperties.get('sample_rate'),
extended=False)
)
ex_scores.append(
stoi(tensor_to_np(x[i]).flatten(),
tensor_to_np(x_adv[i]).flatten(),
DataProperties.get('sample_rate'),
extended=True)
)
results = {
**results,
tag + 'STOI-Extended': cl_scores,
tag + 'STOI-Classical': ex_scores,
}
############################################################################
# BSS-EVAL SIGNAL METRICS
############################################################################
si_sdr, sd_sdr, snr, srr = [], [], [], []
for i in tqdm(range(n_batch), desc='computing BSS-EVAL metrics'):
si_sdr_i, sd_sdr_i, snr_i, srr_i = _bss_eval(
tensor_to_np(x_adv[i]).flatten(),
tensor_to_np(x[i]).flatten())
si_sdr.append(si_sdr_i)
sd_sdr.append(sd_sdr_i)
snr.append(snr_i)
srr.append(srr_i)
results = {
**results,
tag + 'SI-SDR': si_sdr,
tag + 'SD-SDR': sd_sdr,
tag + 'SNR': snr,
tag + 'SRR': srr
}
############################################################################
# ASR TRANSCRIPTION METRICS
############################################################################
# initialize ASR model / decoder
model = Wav2Vec2()
decoder = GreedyCTCDecoder(labels=model.labels)
# obtain delimiter token
delimiter = decoder.get_labels()[decoder.get_sep_idx()]
# move model to given device
model.to(device)
# store original and adversarial transcriptions
transcriptions = []
transcriptions_adv = []
n_batches = math.ceil(len(x) / batch_size)
for i in tqdm(range(n_batches), desc='computing WER/CER'):
# move batches to device and pass to model
x_batch = x[batch_size*i:batch_size*(i+1)].to(device)
x_adv_batch = x_adv[batch_size*i:batch_size*(i+1)].to(device)
emit_batch = model(x_batch)
emit_adv_batch = model(x_adv_batch)
# decode sequence probability emissions to obtain string transcriptions
transcriptions.extend(decoder(emit_batch)[0])
transcriptions_adv.extend(decoder(emit_adv_batch)[0])
# ASR WER
wer = compute_wer(transcriptions, transcriptions_adv, delimiter)
# ASR CER
cer = compute_cer(transcriptions, transcriptions_adv, delimiter)
results = {
**results,
tag + 'ASR-WER': wer,
tag + 'ASR-CER': cer,
}
return results
def compute_wer(
reference: List[str],
transcription: List[str],
delimiter: str = ' '):
"""
Compute average word error rate (WER) between string transcriptions.
WER = (Sw + Dw + Iw) / Nw
where:
Sw is the number of words substituted,
Dw is the number of words deleted,
Iw is the number of words inserted,
Nw is the number of words in the reference
Parameters
----------
Returns
-------
"""
assert len(reference) == len(transcription)
# for each reference-transcription pair in batch, count errors of each of
# the four types as well as total word count
total_edit_dist = 0
total_ref_len = 0
for r, t in zip(reference, transcription):
edit_dist, ref_len = _word_errors(r, t, delimiter=delimiter)
if ref_len == 0:
raise ValueError("Reference sentences must nonzero word count")
total_edit_dist += edit_dist
total_ref_len += ref_len
wer = float(total_edit_dist) / total_ref_len
return wer
def compute_cer(
reference: List[str],
transcription: List[str],
delimiter: str = ' ',
remove_delimiter: bool = False):
"""
Compute average character error rate (CER) between string transcriptions.
WER = (Sc + Dc + Ic) / Nc
where:
Sc is the number of characters substituted,
Dc is the number of characters deleted,
Ic is the number of characters inserted,
Nc is the number of characters in the reference
Parameters
----------
Returns
-------
"""
assert len(reference) == len(transcription)
# for each reference-transcription pair in batch, count errors of each of
# the four types as well as total character count
total_edit_dist = 0
total_ref_len = 0
for r, t in zip(reference, transcription):
edit_dist, ref_len = _char_errors(r,
t,
delimiter,
remove_delimiter)
if ref_len == 0:
raise ValueError("Reference sentences must nonzero character count")
total_edit_dist += edit_dist
total_ref_len += ref_len
cer = float(total_edit_dist) / total_ref_len
return cer
def _word_errors(reference: str, transcription: str, delimiter: str = ' '):
"""
Compute the Levenshtein distance between reference and transcription
sequences at word level.
"""
reference = reference.lower()
transcription = transcription.lower()
ref_words = reference.split(delimiter)
tra_words = transcription.split(delimiter)
edit_distance = _levenshtein_distance(ref_words, tra_words)
return float(edit_distance), len(ref_words)
def _char_errors(reference: str,
transcription: str,
delimiter: str = ' ',
remove_delimiter: bool = False
):
"""
Compute the Levenshtein distance between reference and transcription
sequences at word level.
"""
reference = reference.lower()
transcription = transcription.lower()
join_char = delimiter
if remove_delimiter:
join_char = ''
reference = join_char.join(filter(None, reference.split(delimiter)))
transcription = join_char.join(filter(None, transcription.split(delimiter)))
edit_distance = _levenshtein_distance(reference, transcription)
return float(edit_distance), len(reference)
def _levenshtein_distance(reference: Union[List[str], str],
transcription: Union[List[str], str]):
"""Levenshtein distance is a string metric for measuring the difference
between two sequences. Informally, the levenshtein disctance is defined as
the minimum number of single-character edits (substitutions, insertions or
deletions) required to change one word into the other. We can naturally
extend the edits to word level when calculate levenshtein disctance for
two sentences.
"""
m = len(reference)
n = len(transcription)
# special cases
if reference == transcription:
return 0
if m == 0:
return n
if n == 0:
return m
if m < n:
reference, transcription = transcription, reference
m, n = n, m
# use O(min(m, n)) space
distance = np.zeros((2, n + 1), dtype=np.int32)
# initialize distance matrix
for j in range(0, n + 1):
distance[0][j] = j
# calculate Levenshtein distance
for i in range(1, m + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
for j in range(1, n + 1):
if reference[i - 1] == transcription[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
s_num = distance[prev_row_idx][j - 1] + 1
i_num = distance[cur_row_idx][j - 1] + 1
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
return distance[m % 2][n]
def _bss_eval(x, x_ref):
x_ref_energy = (x_ref ** 2).sum()
alpha = (x_ref @ x / x_ref_energy)
e_true = x_ref
e_res = x - e_true
signal = (e_true ** 2).sum()
noise = (e_res ** 2).sum()
snr = 10 * np.log10(signal / noise)
e_true = x_ref * alpha
e_res = x - e_true
signal = (e_true ** 2).sum()
noise = (e_res ** 2).sum()
si_sdr = 10 * np.log10(signal / noise)
srr = -10 * np.log10((1 - (1/alpha)) ** 2)
sd_sdr = snr + 10 * np.log10(alpha ** 2)
return si_sdr, sd_sdr, snr, srr