Spaces:
Runtime error
Runtime error
| import os | |
| import pdb | |
| import copy | |
| import torch | |
| import argparse | |
| import loralib as lora | |
| import transformers.models.wavlm.modeling_wavlm as wavlm | |
| from speechbrain.nnet.normalization import LayerNorm | |
| from speechbrain.lobes.models.huggingface_transformers.huggingface import make_padding_masks | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from transformers import Wav2Vec2FeatureExtractor | |
| from transformers import WavLMModel | |
| import sys | |
| from pathlib import Path | |
| sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]))) | |
| from revgrad import RevGrad | |
| class WavLMEncoderLayer(nn.Module): | |
| def __init__(self, layer_idx, config, has_relative_position_bias: bool = True): | |
| super().__init__() | |
| self.attention = wavlm.WavLMAttention( | |
| embed_dim=config.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| dropout=config.attention_dropout, | |
| num_buckets=config.num_buckets, | |
| max_distance=config.max_bucket_distance, | |
| has_relative_position_bias=has_relative_position_bias, | |
| ) | |
| self.dropout = nn.Dropout(config.hidden_dropout) | |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.feed_forward = wavlm.WavLMFeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.config = config | |
| if layer_idx > config.num_hidden_layers // 2: | |
| if self.config.finetune_method == "lora" or self.config.finetune_method == "combined": | |
| self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank) | |
| self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank) | |
| def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0): | |
| attn_residual = hidden_states | |
| hidden_states, attn_weights, position_bias = self.attention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| output_attentions=output_attentions, | |
| index=index, | |
| ) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = attn_residual + hidden_states | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = hidden_states + self.feed_forward(hidden_states) | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| outputs = (hidden_states, position_bias) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class WavLMEncoderLayerStableLayerNorm(nn.Module): | |
| def __init__(self, layer_idx, config, has_relative_position_bias: bool = True): | |
| super().__init__() | |
| self.attention = wavlm.WavLMAttention( | |
| embed_dim=config.hidden_size, | |
| num_heads=config.num_attention_heads, | |
| dropout=config.attention_dropout, | |
| num_buckets=config.num_buckets, | |
| max_distance=config.max_bucket_distance, | |
| has_relative_position_bias=has_relative_position_bias, | |
| ) | |
| self.dropout = nn.Dropout(config.hidden_dropout) | |
| self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.feed_forward = wavlm.WavLMFeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.config = config | |
| if layer_idx > config.num_hidden_layers // 2: | |
| if self.config.finetune_method == "lora" or self.config.finetune_method == "combined": | |
| self.feed_forward.intermediate_dense = lora.Linear(config.hidden_size, config.intermediate_size, r=config.lora_rank) | |
| self.feed_forward.output_dense = lora.Linear(config.intermediate_size, config.hidden_size, r=config.lora_rank) | |
| def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False): | |
| attn_residual = hidden_states | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states, attn_weights, position_bias = self.attention( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_bias=position_bias, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = attn_residual + hidden_states | |
| hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) | |
| outputs = (hidden_states, position_bias) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class WavLMWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| pretrain_model="wavlm_large", | |
| hidden_dim=256, | |
| finetune_method="lora", | |
| lora_rank=16, | |
| freeze_params=True, | |
| output_class_num=4, | |
| use_conv_output=True, | |
| apply_gradient_reversal=False, | |
| num_dataset=4 | |
| ): | |
| super(WavLMWrapper, self).__init__() | |
| # 1. We Load the model first with weights | |
| if pretrain_model == "wavlm": | |
| self.backbone_model = WavLMModel.from_pretrained( | |
| "microsoft/wavlm-base-plus", | |
| output_hidden_states=True, | |
| ) | |
| elif pretrain_model == "wavlm_large": | |
| self.processor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-large') | |
| self.backbone_model = WavLMModel.from_pretrained( | |
| "microsoft/wavlm-large", | |
| output_hidden_states=True, | |
| ) | |
| self.pretrain_model = pretrain_model | |
| self.finetune_method = finetune_method | |
| self.apply_gradient_reversal = apply_gradient_reversal | |
| self.use_conv_output = use_conv_output | |
| state_dict = self.backbone_model.state_dict() | |
| # 2. Read the model config | |
| self.model_config = self.backbone_model.config | |
| self.model_config.finetune_method = finetune_method | |
| self.model_config.lora_rank = lora_rank | |
| # 3. Config encoder layers with adapter or embedding prompt | |
| if self.pretrain_model == "wavlm": | |
| self.backbone_model.encoder.layers = nn.ModuleList( | |
| [WavLMEncoderLayer(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)] | |
| ) | |
| elif self.pretrain_model == "wavlm_large": | |
| self.backbone_model.encoder.layers = nn.ModuleList( | |
| [WavLMEncoderLayerStableLayerNorm(i, self.model_config, has_relative_position_bias=(i == 0)) for i in range(self.model_config.num_hidden_layers)] | |
| ) | |
| # 4. Load the weights back | |
| msg = self.backbone_model.load_state_dict(state_dict, strict=False) | |
| # 5. Freeze the weights | |
| self.freeze_params = freeze_params | |
| if self.freeze_params and self.finetune_method != "lora": | |
| for _, p in self.backbone_model.named_parameters(): p.requires_grad = False | |
| elif self.freeze_params and self.finetune_method == "lora": | |
| for name, p in self.backbone_model.named_parameters(): | |
| if name in msg.missing_keys: p.requires_grad = True | |
| else: p.requires_grad = False | |
| else: | |
| for _, p in self.backbone_model.named_parameters(): p.requires_grad = True | |
| # 6. Downstream models | |
| self.model_seq = nn.Sequential( | |
| nn.Conv1d(self.model_config.hidden_size, hidden_dim, 1, padding=0), | |
| nn.ReLU(), | |
| nn.Dropout(p=0.1), | |
| nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0), | |
| nn.ReLU(), | |
| nn.Dropout(p=0.1), | |
| nn.Conv1d(hidden_dim, hidden_dim, 1, padding=0) | |
| ) | |
| if self.use_conv_output: | |
| num_layers = self.model_config.num_hidden_layers + 1 # transformer layers + input embeddings | |
| self.weights = nn.Parameter(torch.ones(num_layers)/num_layers) | |
| else: | |
| num_layers = self.model_config.num_hidden_layers | |
| self.weights = nn.Parameter(torch.zeros(num_layers)) | |
| if apply_gradient_reversal: | |
| self.dataset_layer = nn.Sequential( | |
| RevGrad(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, num_dataset), | |
| ) | |
| self.out_layer = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, output_class_num), | |
| ) | |
| def forward(self, x, length=None, return_feature=False): | |
| # 1. feature extraction and projections | |
| if self.pretrain_model == "wavlm_large": | |
| with torch.no_grad(): | |
| signal, attention_mask = list(), list() | |
| if length is not None: attention_mask = make_padding_masks(x, wav_len=length/length.max()).to(x.device) | |
| else: attention_mask = make_padding_masks(x, wav_len=torch.tensor([1]).to(x.device)).to(x.device) | |
| for idx in range(len(x)): | |
| input = self.processor(x[idx], sampling_rate=16_000, return_tensors="pt", padding=True) | |
| signal.append(input["input_values"][0].to(x.device)) | |
| signal = torch.stack(signal) | |
| # 2. get length and mask | |
| if length is not None: | |
| length = self.get_feat_extract_output_lengths(length.detach().cpu()) | |
| length = length.cuda() | |
| if self.pretrain_model == "wavlm": | |
| x = self.backbone_model( | |
| x, output_hidden_states=True | |
| ).hidden_states | |
| else: | |
| x = self.backbone_model( | |
| signal, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True | |
| ).hidden_states | |
| # 4. stacked feature | |
| if self.use_conv_output: stacked_feature = torch.stack(x, dim=0) | |
| else: stacked_feature = torch.stack(x, dim=0)[1:] | |
| # 5. Weighted sum | |
| _, *origin_shape = stacked_feature.shape | |
| # Return transformer enc outputs [num_enc_layers, B, T, D] | |
| if self.use_conv_output: | |
| stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers+1, -1) | |
| else: | |
| stacked_feature = stacked_feature.view(self.backbone_model.config.num_hidden_layers, -1) | |
| norm_weights = F.softmax(self.weights, dim=-1) | |
| # Perform weighted average | |
| weighted_feature = (norm_weights.unsqueeze(-1) * stacked_feature).sum(dim=0) | |
| features = weighted_feature.view(*origin_shape) | |
| # 6. Pass the weighted average to point-wise 1D Conv | |
| # B x T x D | |
| features = features.transpose(1, 2) | |
| features = self.model_seq(features) | |
| features = features.transpose(1, 2) | |
| # 7. Pooling | |
| if length is not None: | |
| mean, std = list(), list() | |
| for snt_id in range(features.shape[0]): | |
| # Avoiding padded time steps | |
| actual_size = length[snt_id] | |
| mean.append(torch.mean(features[snt_id, 0:actual_size, ...], dim=0)) | |
| features = torch.stack(mean) | |
| else: | |
| features = torch.mean(features, dim=1) | |
| # 8. Output predictions | |
| # B x D | |
| predicted = self.out_layer(features) | |
| if self.apply_gradient_reversal: | |
| dataset_predicted = self.dataset_layer(features) | |
| if return_feature: return predicted, dataset_predicted, features | |
| return predicted, dataset_predicted | |
| if return_feature: return predicted, features | |
| return predicted | |
| # From huggingface | |
| def get_feat_extract_output_lengths(self, input_length): | |
| """ | |
| Computes the output length of the convolutional layers | |
| """ | |
| def _conv_out_length(input_length, kernel_size, stride): | |
| # 1D convolutional layer output length formula taken | |
| # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html | |
| return (input_length - kernel_size) // stride + 1 | |
| for kernel_size, stride in zip(self.backbone_model.config.conv_kernel, self.backbone_model.config.conv_stride): | |
| input_length = _conv_out_length(input_length, kernel_size, stride) | |
| return input_length | |
| def prepare_mask(length, shape, dtype): | |
| # Modified from huggingface | |
| mask = torch.zeros( | |
| shape, dtype=dtype | |
| ) | |
| # these two operations makes sure that all values | |
| # before the output lengths indices are attended to | |
| mask[(torch.arange(mask.shape[0]), length.cpu() - 1)] = 1 | |
| mask = mask.flip([-1]).cumsum(-1).flip([-1]).bool() | |
| return mask | |