| |
| |
|
|
| import sys |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| import torchaudio.transforms as trans |
| from collections import OrderedDict |
| from itertools import permutations |
| from models.transformer import TransformerEncoder |
| from .utils import UpstreamExpert |
|
|
|
|
| class GradMultiply(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, scale): |
| ctx.scale = scale |
| res = x.new(x) |
| return res |
|
|
| @staticmethod |
| def backward(ctx, grad): |
| return grad * ctx.scale, None |
|
|
|
|
|
|
| """ |
| P: number of permutation |
| T: number of frames |
| C: number of speakers (classes) |
| B: mini-batch size |
| """ |
|
|
|
|
| def batch_pit_loss_parallel(outputs, labels, ilens=None): |
| """ calculate the batch pit loss parallelly |
| Args: |
| outputs (torch.Tensor): B x T x C |
| labels (torch.Tensor): B x T x C |
| ilens (torch.Tensor): B |
| Returns: |
| perm (torch.Tensor): permutation for outputs (Batch, num_spk) |
| loss |
| """ |
|
|
| if ilens is None: |
| mask, scale = 1.0, outputs.shape[1] |
| else: |
| scale = torch.unsqueeze(torch.LongTensor(ilens), 1).to(outputs.device) |
| mask = outputs.new_zeros(outputs.size()[:-1]) |
| for i, chunk_len in enumerate(ilens): |
| mask[i, :chunk_len] += 1.0 |
| mask /= scale |
|
|
| def loss_func(output, label): |
| |
| return torch.sum(F.binary_cross_entropy_with_logits(output, label, reduction='none') * mask, dim=-1) |
|
|
| def pair_loss(outputs, labels, permutation): |
| return sum([loss_func(outputs[:,:,s], labels[:,:,t]) for s, t in enumerate(permutation)]) / len(permutation) |
|
|
| device = outputs.device |
| num_spk = outputs.shape[-1] |
| all_permutations = list(permutations(range(num_spk))) |
| losses = torch.stack([pair_loss(outputs, labels, p) for p in all_permutations], dim=1) |
| loss, perm = torch.min(losses, dim=1) |
| perm = torch.index_select(torch.tensor(all_permutations, device=device, dtype=torch.long), 0, perm) |
| return torch.mean(loss), perm |
|
|
|
|
| def fix_state_dict(state_dict): |
| new_state_dict = OrderedDict() |
| for k, v in state_dict.items(): |
| if k.startswith('module.'): |
| |
| k = k[7:] |
| if k.startswith('net.'): |
| |
| k = k[4:] |
| new_state_dict[k] = v |
| return new_state_dict |
|
|
|
|
| class TransformerDiarization(nn.Module): |
| def __init__(self, |
| n_speakers, |
| all_n_speakers, |
| feat_dim, |
| n_units, |
| n_heads, |
| n_layers, |
| dropout_rate, |
| spk_emb_dim, |
| sr=8000, |
| frame_shift=256, |
| frame_size=1024, |
| context_size=0, |
| subsampling=1, |
| feat_type='fbank', |
| feature_selection='default', |
| interpolate_mode='linear', |
| update_extract=False, |
| feature_grad_mult=1.0 |
| ): |
| super(TransformerDiarization, self).__init__() |
| self.context_size = context_size |
| self.subsampling = subsampling |
| self.feat_type = feat_type |
| self.feature_selection = feature_selection |
| self.sr = sr |
| self.frame_shift = frame_shift |
| self.interpolate_mode = interpolate_mode |
| self.update_extract = update_extract |
| self.feature_grad_mult = feature_grad_mult |
|
|
| if feat_type == 'fbank': |
| self.feature_extract = trans.MelSpectrogram(sample_rate=sr, |
| n_fft=frame_size, |
| win_length=frame_size, |
| hop_length=frame_shift, |
| f_min=0.0, |
| f_max=sr // 2, |
| pad=0, |
| n_mels=feat_dim) |
| else: |
| self.feature_extract = UpstreamExpert(feat_type) |
| |
| if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"): |
| self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False |
| if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"): |
| self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False |
| self.feat_num = self.get_feat_num() |
| self.feature_weight = nn.Parameter(torch.zeros(self.feat_num)) |
| |
| |
| self.resample = trans.Resample(orig_freq=sr, new_freq=16000) |
|
|
| if feat_type != 'fbank' and feat_type != 'mfcc': |
| freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer', 'spk_proj', 'layer_norm_for_extract'] |
| for name, param in self.feature_extract.named_parameters(): |
| for freeze_val in freeze_list: |
| if freeze_val in name: |
| param.requires_grad = False |
| break |
| if not self.update_extract: |
| for param in self.feature_extract.parameters(): |
| param.requires_grad = False |
|
|
| self.instance_norm = nn.InstanceNorm1d(feat_dim) |
|
|
| feat_dim = feat_dim * (self.context_size*2 + 1) |
| self.enc = TransformerEncoder( |
| feat_dim, n_layers, n_units, h=n_heads, dropout_rate=dropout_rate) |
| self.linear = nn.Linear(n_units, n_speakers) |
|
|
| for i in range(n_speakers): |
| setattr(self, '{}{:d}'.format("linear", i), nn.Linear(n_units, spk_emb_dim)) |
|
|
| self.n_speakers = n_speakers |
| self.embed = nn.Embedding(all_n_speakers, spk_emb_dim) |
| self.alpha = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0]) |
| self.beta = nn.Parameter(torch.rand(1)[0] + torch.Tensor([0.5])[0]) |
|
|
| def get_feat_num(self): |
| self.feature_extract.eval() |
| wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)] |
| with torch.no_grad(): |
| features = self.feature_extract(wav) |
| select_feature = features[self.feature_selection] |
| if isinstance(select_feature, (list, tuple)): |
| return len(select_feature) |
| else: |
| return 1 |
|
|
| def fix_except_embedding(self, requires_grad=False): |
| for name, param in self.named_parameters(): |
| if 'embed' not in name: |
| param.requires_grad = requires_grad |
|
|
| def modfy_emb(self, weight): |
| self.embed = nn.Embedding.from_pretrained(weight) |
|
|
| def splice(self, data, context_size): |
| |
| data = torch.unsqueeze(data, -1) |
| kernel_size = context_size*2 + 1 |
| splice_data = F.unfold(data, kernel_size=(kernel_size, 1), padding=(context_size, 0)) |
| return splice_data |
|
|
| def get_feat(self, xs): |
| wav_len = xs.shape[-1] |
| chunk_size = int(wav_len / self.frame_shift) |
| chunk_size = int(chunk_size / self.subsampling) |
|
|
| self.feature_extract.eval() |
| if self.update_extract: |
| xs = self.resample(xs) |
| feature = self.feature_extract([sample for sample in xs]) |
| else: |
| with torch.no_grad(): |
| if self.feat_type == 'fbank': |
| feature = self.feature_extract(xs) + 1e-6 |
| feature = feature.log() |
| else: |
| xs = self.resample(xs) |
| feature = self.feature_extract([sample for sample in xs]) |
|
|
| if self.feat_type != "fbank" and self.feat_type != "mfcc": |
| feature = feature[self.feature_selection] |
| if isinstance(feature, (list, tuple)): |
| feature = torch.stack(feature, dim=0) |
| else: |
| feature = feature.unsqueeze(0) |
|
|
| norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| feature = (norm_weights * feature).sum(dim=0) |
| feature = torch.transpose(feature, 1, 2) + 1e-6 |
|
|
| feature = self.instance_norm(feature) |
| feature = self.splice(feature, self.context_size) |
| feature = feature[:, :, ::self.subsampling] |
| feature = F.interpolate(feature, chunk_size, mode=self.interpolate_mode) |
| feature = torch.transpose(feature, 1, 2) |
|
|
| if self.feature_grad_mult != 1.0: |
| feature = GradMultiply.apply(feature, self.feature_grad_mult) |
|
|
| return feature |
|
|
| def forward(self, inputs): |
| if isinstance(inputs, list): |
| xs = inputs[0] |
| else: |
| xs = inputs |
| feature = self.get_feat(xs) |
|
|
| pad_shape = feature.shape |
| emb = self.enc(feature) |
| ys = self.linear(emb) |
| ys = ys.reshape(pad_shape[0], pad_shape[1], -1) |
|
|
| spksvecs = [] |
| for i in range(self.n_speakers): |
| spkivecs = getattr(self, '{}{:d}'.format("linear", i))(emb) |
| spkivecs = spkivecs.reshape(pad_shape[0], pad_shape[1], -1) |
| spksvecs.append(spkivecs) |
|
|
| return ys, spksvecs |
|
|
| def get_loss(self, inputs, ys, spksvecs, cal_spk_loss=True): |
| ts = inputs[1] |
| ss = inputs[2] |
| ns = inputs[3] |
| ilens = inputs[4] |
| ilens = [ilen.item() for ilen in ilens] |
|
|
| pit_loss, sigmas = batch_pit_loss_parallel(ys, ts, ilens) |
| if cal_spk_loss: |
| spk_loss = self.spk_loss_parallel(spksvecs, ys, ts, ss, sigmas, ns, ilens) |
| else: |
| spk_loss = torch.tensor(0.0).to(pit_loss.device) |
|
|
| alpha = torch.clamp(self.alpha, min=sys.float_info.epsilon) |
|
|
| return {'spk_loss':spk_loss, |
| 'pit_loss': pit_loss} |
|
|
|
|
| def batch_estimate(self, xs): |
| out = self(xs) |
| ys = out[0] |
| spksvecs = out[1] |
| spksvecs = list(zip(*spksvecs)) |
| outputs = [ |
| self.estimate(spksvec, y) |
| for (spksvec, y) in zip(spksvecs, ys)] |
| outputs = list(zip(*outputs)) |
|
|
| return outputs |
|
|
| def batch_estimate_with_perm(self, xs, ts, ilens=None): |
| out = self(xs) |
| ys = out[0] |
| if ts[0].shape[1] > ys[0].shape[1]: |
| |
| add_dim = ts[0].shape[1] - ys[0].shape[1] |
| y_device = ys[0].device |
| zeros = [torch.zeros(ts[0].shape).to(y_device) |
| for i in range(len(ts))] |
| _ys = [] |
| for zero, y in zip(zeros, ys): |
| _zero = zero |
| _zero[:, :-add_dim] = y |
| _ys.append(_zero) |
| _, sigmas = batch_pit_loss_parallel(_ys, ts, ilens) |
| else: |
| _, sigmas = batch_pit_loss_parallel(ys, ts, ilens) |
| spksvecs = out[1] |
| spksvecs = list(zip(*spksvecs)) |
| outputs = [self.estimate(spksvec, y) |
| for (spksvec, y) in zip(spksvecs, ys)] |
| outputs = list(zip(*outputs)) |
| zs = outputs[0] |
|
|
| if ts[0].shape[1] > ys[0].shape[1]: |
| |
| add_dim = ts[0].shape[1] - ys[0].shape[1] |
| z_device = zs[0].device |
| zeros = [torch.zeros(ts[0].shape).to(z_device) |
| for i in range(len(ts))] |
| _zs = [] |
| for zero, z in zip(zeros, zs): |
| _zero = zero |
| _zero[:, :-add_dim] = z |
| _zs.append(_zero) |
| zs = _zs |
| outputs[0] = zs |
| outputs.append(sigmas) |
|
|
| |
| return outputs |
|
|
| def estimate(self, spksvec, y): |
| outputs = [] |
| z = torch.sigmoid(y.transpose(1, 0)) |
|
|
| outputs.append(z.transpose(1, 0)) |
| for spkid, spkvec in enumerate(spksvec): |
| norm_spkvec_inv = 1.0 / torch.norm(spkvec, dim=1) |
| |
| spkvec = torch.mul( |
| spkvec.transpose(1, 0), norm_spkvec_inv |
| ).transpose(1, 0) |
| wavg_spkvec = torch.mul( |
| spkvec.transpose(1, 0), z[spkid] |
| ).transpose(1, 0) |
| sum_wavg_spkvec = torch.sum(wavg_spkvec, dim=0) |
| nmz_wavg_spkvec = sum_wavg_spkvec / torch.norm(sum_wavg_spkvec) |
| outputs.append(nmz_wavg_spkvec) |
|
|
| |
| return outputs |
|
|
| def spk_loss_parallel(self, spksvecs, ys, ts, ss, sigmas, ns, ilens): |
| ''' |
| spksvecs (List[torch.Tensor, ...]): [B x T x emb_dim, ...] |
| ys (torch.Tensor): B x T x 3 |
| ts (torch.Tensor): B x T x 3 |
| ss (torch.Tensor): B x 3 |
| sigmas (torch.Tensor): B x 3 |
| ns (torch.Tensor): B x total_spk_num x 1 |
| ilens (List): B |
| ''' |
| chunk_spk_num = len(spksvecs) |
|
|
| len_mask = ys.new_zeros((ys.size()[:-1])) |
| for i, len_val in enumerate(ilens): |
| len_mask[i,:len_val] += 1.0 |
| ts = ts * len_mask.unsqueeze(-1) |
| len_mask = len_mask.repeat((chunk_spk_num, 1)) |
|
|
| spk_vecs = torch.cat(spksvecs, dim=0) |
| |
| spk_vecs = F.normalize(spk_vecs, dim=-1) |
|
|
| ys = torch.permute(torch.sigmoid(ys), dims=(2, 0, 1)) |
| ys = ys.reshape(-1, ys.shape[-1]).unsqueeze(-1) |
|
|
| weight_spk_vec = ys * spk_vecs |
| weight_spk_vec *= len_mask.unsqueeze(-1) |
| sum_spk_vec = torch.sum(weight_spk_vec, dim=1) |
| norm_spk_vec = F.normalize(sum_spk_vec, dim=1) |
|
|
| embeds = F.normalize(self.embed(ns[0]).squeeze(), dim=1) |
| dist = torch.cdist(norm_spk_vec, embeds) |
| logits = -1.0 * torch.add(torch.clamp(self.alpha, min=sys.float_info.epsilon) * torch.pow(dist, 2), self.beta) |
| label = torch.gather(ss, 1, sigmas).transpose(0, 1).reshape(-1, 1).squeeze() |
| label[label==-1] = 0 |
| valid_spk_mask = torch.gather(torch.sum(ts, dim=1), 1, sigmas).transpose(0, 1) |
| valid_spk_mask = (torch.flatten(valid_spk_mask) > 0).float() |
|
|
| valid_spk_loss_num = torch.sum(valid_spk_mask).item() |
| if valid_spk_loss_num > 0: |
| loss = F.cross_entropy(logits, label, reduction='none') * valid_spk_mask / valid_spk_loss_num |
| |
| |
| return torch.sum(loss) |
| else: |
| return torch.tensor(0.0).to(ys.device) |
|
|