Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621) | |
| # Github source: https://github.com/mbzuai-nlp/ArTST | |
| # Based on speecht5, fairseq and espnet code bases | |
| # https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet | |
| # -------------------------------------------------------- | |
| import logging | |
| import torch.nn as nn | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| class SpeechEncoderPostnet(nn.Module): | |
| """ | |
| Args: | |
| in_channels (int): the number of input channels | |
| mid_channels (int): the number of intermediate channels | |
| out_channels (int): the number of output channels | |
| kernel_sizes (List[int]): the kernel size for each convolutional layer | |
| """ | |
| def __init__(self, dictionaries, args): | |
| super(SpeechEncoderPostnet, self).__init__() | |
| # modules below are not needed during fine-tuning | |
| self.target_glu = args.target_glu | |
| self.skip_masked = args.skip_masked | |
| self.skip_nomask = args.skip_nomask | |
| self.logit_temp = args.logit_temp | |
| final_dim = ( | |
| args.final_dim if args.final_dim > 0 else args.encoder_embed_dim | |
| ) | |
| if any([d is None for d in dictionaries]): | |
| logger.info( | |
| "cannot find dictionary. assume will be used for fine-tuning" | |
| ) | |
| else: | |
| self.num_classes = [len(d) for d in dictionaries] | |
| self.label_embs_concat = nn.Parameter( | |
| torch.FloatTensor(sum(self.num_classes), final_dim) | |
| ) | |
| nn.init.uniform_(self.label_embs_concat) | |
| self.untie_final_proj = args.untie_final_proj | |
| if self.untie_final_proj: | |
| self.final_proj = nn.Linear( | |
| args.encoder_embed_dim, final_dim * len(dictionaries) | |
| ) | |
| else: | |
| self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) | |
| def compute_nce(self, x, pos, negs): | |
| neg_is_pos = (pos == negs).all(-1) | |
| pos = pos.unsqueeze(0) | |
| targets = torch.cat([pos, negs], dim=0) | |
| logits = torch.cosine_similarity( | |
| x.float(), targets.float(), dim=-1 | |
| ).type_as(x) | |
| logits /= self.logit_temp | |
| if neg_is_pos.any(): | |
| logits[1:][neg_is_pos] = float("-inf") | |
| logits = logits.transpose(0, 1) # (num_x, num_cls+1) | |
| return logits | |
| def forward(self, x, padding_mask, mask_indices, target_list): | |
| def compute_pred(proj_x, target, label_embs): | |
| # compute logits for the i-th label set | |
| y = torch.index_select(label_embs, 0, target.long()) | |
| negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) | |
| if self.target_glu: | |
| y = self.target_glu(y) | |
| negs = self.target_glu(negs) | |
| # proj_x: (S, D) | |
| # y: (S, D) | |
| # negs: (Neg, S, D) | |
| return self.compute_nce(proj_x, y, negs) | |
| label_embs_list = self.label_embs_concat.split(self.num_classes, 0) | |
| if not self.skip_masked: | |
| masked_indices = torch.logical_and(~padding_mask, mask_indices) | |
| proj_x_m = self.final_proj(x[masked_indices]) | |
| if self.untie_final_proj: | |
| proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) | |
| else: | |
| proj_x_m_list = [proj_x_m for _ in range(len(target_list))] | |
| logit_m_list = [ | |
| compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) | |
| for i, (proj_x_m, t) in enumerate( | |
| zip(proj_x_m_list, target_list) | |
| ) | |
| ] | |
| else: | |
| logit_m_list = [None for _ in target_list] | |
| if not self.skip_nomask: | |
| nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) | |
| proj_x_u = self.final_proj(x[nomask_indices]) | |
| if self.untie_final_proj: | |
| proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) | |
| else: | |
| proj_x_u_list = [proj_x_u for _ in range(len(target_list))] | |
| logit_u_list = [ | |
| compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) | |
| for i, (proj_x_u, t) in enumerate( | |
| zip(proj_x_u_list, target_list) | |
| ) | |
| ] | |
| else: | |
| logit_u_list = [None for _ in target_list] | |
| result = { | |
| "logit_m_list": logit_m_list, | |
| "logit_u_list": logit_u_list, | |
| "padding_mask": padding_mask, | |
| } | |
| return result | |