| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from enum import Enum, auto |
| import math |
| import numpy as np |
| from typing import Tuple, List, Optional, Dict |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import autograd |
|
|
| from fairseq import checkpoint_utils, utils |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.models import BaseFairseqModel, register_model |
| from fairseq.modules import ( |
| SamePad, |
| TransposeLast, |
| ) |
|
|
|
|
| class SegmentationType(Enum): |
| NONE = auto() |
| RANDOM = auto() |
| UNIFORM_RANDOM = auto() |
| UNIFORM_RANDOM_JOIN = auto() |
| JOIN = auto() |
|
|
|
|
| @dataclass |
| class SegmentationConfig(FairseqDataclass): |
| type: SegmentationType = SegmentationType.NONE |
| subsample_rate: float = 0.25 |
| mean_pool: bool = True |
| mean_pool_join: bool = False |
| remove_zeros: bool = False |
|
|
|
|
| @dataclass |
| class Wav2vec_UConfig(FairseqDataclass): |
|
|
| discriminator_kernel: int = 3 |
| discriminator_dilation: int = 1 |
| discriminator_dim: int = 256 |
| discriminator_causal: bool = True |
| discriminator_linear_emb: bool = False |
| discriminator_depth: int = 1 |
| discriminator_max_pool: bool = False |
| discriminator_act_after_linear: bool = False |
| discriminator_dropout: float = 0.0 |
| discriminator_spectral_norm: bool = False |
| discriminator_weight_norm: bool = False |
|
|
| generator_kernel: int = 4 |
| generator_dilation: int = 1 |
| generator_stride: int = 1 |
| generator_bias: bool = False |
| generator_dropout: float = 0.0 |
|
|
| blank_weight: float = 0 |
| blank_mode: str = "add" |
| blank_is_sil: bool = False |
| no_softmax: bool = False |
|
|
| smoothness_weight: float = 0.0 |
| smoothing: float = 0.0 |
| smoothing_one_sided: bool = False |
| gradient_penalty: float = 0.0 |
| probabilistic_grad_penalty_slicing: bool = False |
| code_penalty: float = 0.0 |
| gumbel: bool = False |
| hard_gumbel: bool = True |
| temp: Tuple[float, float, float] = (2, 0.1, 0.99995) |
| input_dim: int = 128 |
|
|
| segmentation: SegmentationConfig = SegmentationConfig() |
|
|
|
|
| class Segmenter(nn.Module): |
| cfg: SegmentationConfig |
|
|
| def __init__(self, cfg: SegmentationConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.subsample_rate = cfg.subsample_rate |
|
|
| def pre_segment(self, dense_x, dense_padding_mask): |
| return dense_x, dense_padding_mask |
|
|
| def logit_segment(self, logits, padding_mask): |
| return logits, padding_mask |
|
|
|
|
| class RandomSegmenter(Segmenter): |
| def pre_segment(self, dense_x, dense_padding_mask): |
| target_num = math.ceil(dense_x.size(1) * self.subsample_rate) |
| ones = torch.ones(dense_x.shape[:-1], device=dense_x.device) |
| indices, _ = ones.multinomial(target_num).sort(dim=-1) |
| indices_ld = indices.unsqueeze(-1).expand(-1, -1, dense_x.size(-1)) |
| dense_x = dense_x.gather(1, indices_ld) |
| dense_padding_mask = dense_padding_mask.gather(1, index=indices) |
| return dense_x, dense_padding_mask |
|
|
|
|
| class UniformRandomSegmenter(Segmenter): |
| def pre_segment(self, dense_x, dense_padding_mask): |
| bsz, tsz, fsz = dense_x.shape |
|
|
| target_num = math.ceil(tsz * self.subsample_rate) |
|
|
| rem = tsz % target_num |
|
|
| if rem > 0: |
| dense_x = F.pad(dense_x, [0, 0, 0, target_num - rem]) |
| dense_padding_mask = F.pad( |
| dense_padding_mask, [0, target_num - rem], value=True |
| ) |
|
|
| dense_x = dense_x.view(bsz, target_num, -1, fsz) |
| dense_padding_mask = dense_padding_mask.view(bsz, target_num, -1) |
|
|
| if self.cfg.mean_pool: |
| dense_x = dense_x.mean(dim=-2) |
| dense_padding_mask = dense_padding_mask.all(dim=-1) |
| else: |
| ones = torch.ones((bsz, dense_x.size(2)), device=dense_x.device) |
| indices = ones.multinomial(1) |
| indices = indices.unsqueeze(-1).expand(-1, target_num, -1) |
| indices_ld = indices.unsqueeze(-1).expand(-1, -1, -1, fsz) |
| dense_x = dense_x.gather(2, indices_ld).reshape(bsz, -1, fsz) |
| dense_padding_mask = dense_padding_mask.gather(2, index=indices).reshape( |
| bsz, -1 |
| ) |
| return dense_x, dense_padding_mask |
|
|
|
|
| class JoinSegmenter(Segmenter): |
| def logit_segment(self, logits, padding_mask): |
| preds = logits.argmax(dim=-1) |
|
|
| if padding_mask.any(): |
| preds[padding_mask] = -1 |
| uniques = [] |
|
|
| bsz, tsz, csz = logits.shape |
|
|
| for p in preds: |
| uniques.append( |
| p.cpu().unique_consecutive(return_inverse=True, return_counts=True) |
| ) |
|
|
| new_tsz = max(u[0].numel() for u in uniques) |
| new_logits = logits.new_zeros(bsz, new_tsz, csz) |
| new_pad = padding_mask.new_zeros(bsz, new_tsz) |
|
|
| for b in range(bsz): |
| u, idx, c = uniques[b] |
| keep = u != -1 |
|
|
| if self.cfg.remove_zeros: |
| keep.logical_and_(u != 0) |
|
|
| if self.training and not self.cfg.mean_pool_join: |
| u[0] = 0 |
| u[1:] = c.cumsum(0)[:-1] |
| m = c > 1 |
| r = torch.rand(m.sum()) |
| o = (c[m] * r).long() |
| u[m] += o |
| new_logits[b, : u.numel()] = logits[b, u] |
| else: |
| new_logits[b].index_add_( |
| dim=0, index=idx.to(new_logits.device), source=logits[b] |
| ) |
| new_logits[b, : c.numel()] /= c.unsqueeze(-1).to(new_logits.device) |
|
|
| new_sz = keep.sum() |
| if not keep.all(): |
| kept_logits = new_logits[b, : c.numel()][keep] |
| new_logits[b, :new_sz] = kept_logits |
|
|
| if new_sz < new_tsz: |
| pad = new_tsz - new_sz |
| new_logits[b, -pad:] = 0 |
| new_pad[b, -pad:] = True |
|
|
| return new_logits, new_pad |
|
|
|
|
| class UniformRandomJoinSegmenter(UniformRandomSegmenter, JoinSegmenter): |
| pass |
|
|
|
|
| SEGMENT_FACTORY = { |
| SegmentationType.NONE: Segmenter, |
| SegmentationType.RANDOM: RandomSegmenter, |
| SegmentationType.UNIFORM_RANDOM: UniformRandomSegmenter, |
| SegmentationType.UNIFORM_RANDOM_JOIN: UniformRandomJoinSegmenter, |
| SegmentationType.JOIN: JoinSegmenter, |
| } |
|
|
|
|
| class Discriminator(nn.Module): |
| def __init__(self, dim, cfg: Wav2vec_UConfig): |
| super().__init__() |
|
|
| inner_dim = cfg.discriminator_dim |
| kernel = cfg.discriminator_kernel |
| dilation = cfg.discriminator_dilation |
| self.max_pool = cfg.discriminator_max_pool |
|
|
| if cfg.discriminator_causal: |
| padding = kernel - 1 |
| else: |
| padding = kernel // 2 |
|
|
| def make_conv(in_d, out_d, k, p=0, has_dilation=True): |
| conv = nn.Conv1d( |
| in_d, |
| out_d, |
| kernel_size=k, |
| padding=p, |
| dilation=dilation if has_dilation else 1, |
| ) |
| if cfg.discriminator_spectral_norm: |
| conv = nn.utils.spectral_norm(conv) |
| elif cfg.discriminator_weight_norm: |
| conv = nn.utils.weight_norm(conv) |
| return conv |
|
|
| inner_net = [ |
| nn.Sequential( |
| make_conv(inner_dim, inner_dim, kernel, padding), |
| SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), |
| nn.Dropout(cfg.discriminator_dropout), |
| nn.GELU(), |
| ) |
| for _ in range(cfg.discriminator_depth - 1) |
| ] + [ |
| make_conv(inner_dim, 1, kernel, padding, has_dilation=False), |
| SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), |
| ] |
|
|
| if cfg.discriminator_linear_emb: |
| emb_net = [make_conv(dim, inner_dim, 1)] |
| else: |
| emb_net = [ |
| make_conv(dim, inner_dim, kernel, padding), |
| SamePad(kernel_size=kernel, causal=cfg.discriminator_causal), |
| ] |
|
|
| if cfg.discriminator_act_after_linear: |
| emb_net.append(nn.GELU()) |
|
|
| self.net = nn.Sequential( |
| *emb_net, |
| nn.Dropout(cfg.discriminator_dropout), |
| *inner_net, |
| ) |
|
|
| def forward(self, x, padding_mask): |
| x = x.transpose(1, 2) |
| x = self.net(x) |
| x = x.transpose(1, 2) |
| x_sz = x.size(1) |
| if padding_mask is not None and padding_mask.any() and padding_mask.dim() > 1: |
| padding_mask = padding_mask[:, : x.size(1)] |
| x[padding_mask] = float("-inf") if self.max_pool else 0 |
| x_sz = x_sz - padding_mask.sum(dim=-1) |
| x = x.squeeze(-1) |
| if self.max_pool: |
| x, _ = x.max(dim=-1) |
| else: |
| x = x.sum(dim=-1) |
| x = x / x_sz |
| return x |
|
|
|
|
| class Generator(nn.Module): |
| def __init__(self, input_dim, output_dim, cfg: Wav2vec_UConfig): |
| super().__init__() |
|
|
| self.cfg = cfg |
| self.output_dim = output_dim |
| self.stride = cfg.generator_stride |
| self.dropout = nn.Dropout(cfg.generator_dropout) |
|
|
| padding = cfg.generator_kernel // 2 |
| self.proj = nn.Sequential( |
| TransposeLast(), |
| nn.Conv1d( |
| input_dim, |
| output_dim, |
| kernel_size=cfg.generator_kernel, |
| stride=cfg.generator_stride, |
| dilation=cfg.generator_dilation, |
| padding=padding, |
| bias=cfg.generator_bias, |
| ), |
| TransposeLast(), |
| ) |
|
|
| def forward(self, dense_x, tokens, dense_padding_mask): |
| dense_x = self.dropout(dense_x) |
|
|
| dense_x = self.proj(dense_x) |
| if self.stride > 1: |
| dense_padding_mask = dense_padding_mask[:, :: self.stride] |
|
|
| if dense_padding_mask.size(1) != dense_x.size(1): |
| new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1]) |
| diff = new_padding.size(1) - dense_padding_mask.size(1) |
| assert ( |
| diff > 0 |
| ), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}" |
| if diff > 0: |
| new_padding[:, diff:] = dense_padding_mask |
| else: |
| assert diff < 0 |
| new_padding = dense_padding_mask[:, :diff] |
|
|
| dense_padding_mask = new_padding |
|
|
| result = {} |
|
|
| token_x = None |
| if tokens is not None: |
| token_x = dense_x.new_zeros(tokens.numel(), self.output_dim) |
| token_x.scatter_(1, tokens.view(-1, 1).long(), 1) |
| token_x = token_x.view(tokens.shape + (self.output_dim,)) |
|
|
| result["dense_x"] = dense_x |
| result["token_x"] = token_x |
| result["dense_padding_mask"] = dense_padding_mask |
|
|
| return result |
|
|
|
|
| @register_model("wav2vec_u", dataclass=Wav2vec_UConfig) |
| class Wav2vec_U(BaseFairseqModel): |
| def calc_gradient_penalty(self, real_data, fake_data): |
|
|
| b_size = min(real_data.size(0), fake_data.size(0)) |
| t_size = min(real_data.size(1), fake_data.size(1)) |
|
|
| if self.cfg.probabilistic_grad_penalty_slicing: |
|
|
| def get_slice(data, dim, target_size): |
|
|
| size = data.size(dim) |
| diff = size - target_size |
| if diff <= 0: |
| return data |
|
|
| start = np.random.randint(0, diff + 1) |
| return data.narrow(dim=dim, start=start, length=target_size) |
|
|
| real_data = get_slice(real_data, 0, b_size) |
| real_data = get_slice(real_data, 1, t_size) |
| fake_data = get_slice(fake_data, 0, b_size) |
| fake_data = get_slice(fake_data, 1, t_size) |
|
|
| else: |
| real_data = real_data[:b_size, :t_size] |
| fake_data = fake_data[:b_size, :t_size] |
|
|
| alpha = torch.rand(real_data.size(0), 1, 1) |
| alpha = alpha.expand(real_data.size()) |
| alpha = alpha.to(real_data.device) |
|
|
| interpolates = alpha * real_data + ((1 - alpha) * fake_data) |
|
|
| disc_interpolates = self.discriminator(interpolates, None) |
|
|
| gradients = autograd.grad( |
| outputs=disc_interpolates, |
| inputs=interpolates, |
| grad_outputs=torch.ones(disc_interpolates.size(), device=real_data.device), |
| create_graph=True, |
| retain_graph=True, |
| only_inputs=True, |
| )[0] |
|
|
| gradient_penalty = (gradients.norm(2, dim=1) - 1) ** 2 |
| return gradient_penalty |
|
|
| def set_num_updates(self, num_updates): |
| super().set_num_updates(num_updates) |
| self.update_num = num_updates |
| self.curr_temp = max( |
| self.max_temp * self.temp_decay ** num_updates, self.min_temp |
| ) |
|
|
| def discrim_step(self, num_updates): |
| return num_updates % 2 == 1 |
|
|
| def get_groups_for_update(self, num_updates): |
| return "discriminator" if self.discrim_step(num_updates) else "generator" |
|
|
| def __init__(self, cfg: Wav2vec_UConfig, target_dict): |
| super().__init__() |
|
|
| self.cfg = cfg |
| self.zero_index = target_dict.index("<SIL>") if "<SIL>" in target_dict else 0 |
| self.smoothness_weight = cfg.smoothness_weight |
|
|
| output_size = len(target_dict) |
| self.pad = target_dict.pad() |
| self.eos = target_dict.eos() |
| self.smoothing = cfg.smoothing |
| self.smoothing_one_sided = cfg.smoothing_one_sided |
| self.no_softmax = cfg.no_softmax |
| self.gumbel = cfg.gumbel |
| self.hard_gumbel = cfg.hard_gumbel |
| self.last_acc = None |
|
|
| self.gradient_penalty = cfg.gradient_penalty |
| self.code_penalty = cfg.code_penalty |
| self.blank_weight = cfg.blank_weight |
| self.blank_mode = cfg.blank_mode |
| self.blank_index = target_dict.index("<SIL>") if cfg.blank_is_sil else 0 |
| assert self.blank_index != target_dict.unk() |
|
|
| self.discriminator = Discriminator(output_size, cfg) |
| for p in self.discriminator.parameters(): |
| p.param_group = "discriminator" |
|
|
| self.pca_A = self.pca_b = None |
| d = cfg.input_dim |
|
|
| self.segmenter = SEGMENT_FACTORY[cfg.segmentation.type](cfg.segmentation) |
|
|
| self.generator = Generator(d, output_size, cfg) |
|
|
| for p in self.generator.parameters(): |
| p.param_group = "generator" |
|
|
| for p in self.segmenter.parameters(): |
| p.param_group = "generator" |
|
|
| self.max_temp, self.min_temp, self.temp_decay = cfg.temp |
| self.curr_temp = self.max_temp |
| self.update_num = 0 |
|
|
| @classmethod |
| def build_model(cls, cfg, task): |
| return cls(cfg, task.target_dictionary) |
|
|
| def get_logits( |
| self, |
| net_output: Optional[Dict[str, List[Optional[torch.Tensor]]]], |
| normalize: bool = False, |
| ): |
| logits = net_output["logits"] |
|
|
| if self.blank_weight != 0: |
| if self.blank_mode == "add": |
| logits[..., self.blank_index] += self.blank_weight |
| elif self.blank_mode == "set": |
| logits[..., self.blank_index] = self.blank_weight |
| else: |
| raise Exception(f"invalid blank mode {self.blank_mode}") |
|
|
| padding = net_output["padding_mask"] |
| if padding.any(): |
| logits[padding] = float("-inf") |
| logits[padding][..., self.blank_index] = float("inf") |
|
|
| if normalize: |
| logits = utils.log_softmax(logits.float(), dim=-1) |
|
|
| return logits.transpose(0, 1) |
|
|
| def get_normalized_probs( |
| self, |
| net_output: Tuple[ |
| torch.Tensor, Optional[Dict[str, List[Optional[torch.Tensor]]]] |
| ], |
| log_probs: bool, |
| sample: Optional[Dict[str, torch.Tensor]] = None, |
| ): |
| logits = self.get_logits(net_output) |
|
|
| probs = super().get_normalized_probs(logits, log_probs, sample) |
| |
| probs = probs.transpose(0, 1) |
| return probs |
|
|
| def normalize(self, dense_x): |
|
|
| bsz, tsz, csz = dense_x.shape |
|
|
| if dense_x.numel() == 0: |
| raise Exception(dense_x.shape) |
| _, k = dense_x.max(-1) |
| hard_x = ( |
| dense_x.new_zeros(bsz * tsz, csz) |
| .scatter_(-1, k.view(-1, 1), 1.0) |
| .view(-1, csz) |
| ) |
| hard_probs = torch.mean(hard_x.float(), dim=0) |
| code_perplexity = torch.exp( |
| -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) |
| ) |
|
|
| avg_probs = torch.softmax(dense_x.reshape(-1, csz).float(), dim=-1).mean(dim=0) |
| prob_perplexity = torch.exp( |
| -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) |
| ) |
|
|
| if not self.no_softmax: |
| if self.training and self.gumbel: |
| dense_x = F.gumbel_softmax( |
| dense_x.float(), tau=self.curr_temp, hard=self.hard_gumbel |
| ).type_as(dense_x) |
| else: |
| dense_x = dense_x.softmax(-1) |
|
|
| return dense_x, code_perplexity, prob_perplexity |
|
|
| def forward( |
| self, |
| features, |
| padding_mask, |
| random_label=None, |
| dense_x_only=False, |
| segment=True, |
| ): |
| if segment: |
| features, padding_mask = self.segmenter.pre_segment(features, padding_mask) |
|
|
| orig_size = features.size(0) * features.size(1) - padding_mask.sum() |
|
|
| gen_result = self.generator(features, random_label, padding_mask) |
|
|
| orig_dense_x, token_x = gen_result["dense_x"], gen_result["token_x"] |
| orig_dense_padding_mask = gen_result["dense_padding_mask"] |
|
|
| if segment: |
| dense_x, dense_padding_mask = self.segmenter.logit_segment( |
| orig_dense_x, orig_dense_padding_mask |
| ) |
| else: |
| dense_x = orig_dense_x |
| dense_padding_mask = orig_dense_padding_mask |
|
|
| dense_logits = dense_x |
| prob_perplexity = None |
| code_perplexity = None |
|
|
| if not (self.no_softmax and dense_x_only): |
| dense_x, code_perplexity, prob_perplexity = self.normalize(dense_logits) |
|
|
| if dense_x_only or self.discriminator is None: |
| return { |
| "logits": dense_x, |
| "padding_mask": dense_padding_mask, |
| } |
|
|
| token_padding_mask = random_label == self.pad |
|
|
| dense_y = self.discriminator(dense_x, dense_padding_mask) |
| token_y = self.discriminator(token_x, token_padding_mask) |
|
|
| sample_size = features.size(0) |
|
|
| d_step = self.discrim_step(self.update_num) |
|
|
| fake_smooth = self.smoothing |
| real_smooth = self.smoothing |
| if self.smoothing_one_sided: |
| fake_smooth = 0 |
|
|
| zero_loss = None |
| smoothness_loss = None |
| code_pen = None |
|
|
| if d_step: |
| loss_dense = F.binary_cross_entropy_with_logits( |
| dense_y, |
| dense_y.new_ones(dense_y.shape) - fake_smooth, |
| reduction="sum", |
| ) |
| loss_token = F.binary_cross_entropy_with_logits( |
| token_y, |
| token_y.new_zeros(token_y.shape) + real_smooth, |
| reduction="sum", |
| ) |
| if self.training and self.gradient_penalty > 0: |
| grad_pen = self.calc_gradient_penalty(token_x, dense_x) |
| grad_pen = grad_pen.sum() * self.gradient_penalty |
| else: |
| grad_pen = None |
| else: |
| grad_pen = None |
| loss_token = None |
| loss_dense = F.binary_cross_entropy_with_logits( |
| dense_y, |
| dense_y.new_zeros(dense_y.shape) + fake_smooth, |
| reduction="sum", |
| ) |
| num_vars = dense_x.size(-1) |
| if prob_perplexity is not None: |
| code_pen = (num_vars - prob_perplexity) / num_vars |
| code_pen = code_pen * sample_size * self.code_penalty |
|
|
| if self.smoothness_weight > 0: |
| smoothness_loss = F.mse_loss( |
| dense_logits[:, :-1], dense_logits[:, 1:], reduction="none" |
| ) |
| smoothness_loss[dense_padding_mask[:, 1:]] = 0 |
| smoothness_loss = ( |
| smoothness_loss.mean() * sample_size * self.smoothness_weight |
| ) |
|
|
| result = { |
| "losses": { |
| "grad_pen": grad_pen, |
| "code_pen": code_pen, |
| "smoothness": smoothness_loss, |
| }, |
| "temp": self.curr_temp, |
| "code_ppl": code_perplexity, |
| "prob_ppl": prob_perplexity, |
| "d_steps": int(d_step), |
| "sample_size": sample_size, |
| } |
|
|
| suff = "_d" if d_step else "_g" |
| result["losses"]["dense" + suff] = loss_dense |
| result["losses"]["token" + suff] = loss_token |
|
|
| return result |
|
|