#!/usr/bin/env python3 import torch from torch import nn import torch_model class MultiHeadedAttention(nn.Module): """ This class is copied and modified from https://github.com/modelscope/FunASR/blob/main/funasr/models/transformer/attention.py """ def __init__(self, n_head, n_feat, dropout_rate): super().__init__() assert n_feat % n_head == 0 # We assume d_v always equals d_k self.d_k = n_feat // n_head self.h = n_head self.linear_q = nn.Linear(n_feat, n_feat) self.linear_k = nn.Linear(n_feat, n_feat) self.linear_v = nn.Linear(n_feat, n_feat) self.linear_out = nn.Linear(n_feat, n_feat) self.attn = None self.dropout = nn.Dropout(p=dropout_rate) def forward_qkv(self, query, key, value): """Transform query, key and value. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). Returns: torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) q = q.transpose(1, 2) # (batch, head, time1, d_k) k = k.transpose(1, 2) # (batch, head, time2, d_k) v = v.transpose(1, 2) # (batch, head, time2, d_k) return q, k, v def forward_attention(self, value, scores, mask): """Compute attention context vector. Args: value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). Returns: torch.Tensor: Transformed value (#batch, time1, d_model) weighted by the attention score (#batch, time1, time2). """ n_batch = value.size(0) if mask is not None: mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) min_value = -float( "inf" ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min) scores = scores.masked_fill(mask, min_value) attn = torch.softmax(scores, dim=-1).masked_fill( mask, 0.0 ) # (batch, head, time1, time2) else: attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) p_attn = self.dropout(attn) x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) x = ( x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) ) # (batch, time1, d_model) return self.linear_out(x) # (batch, time1, d_model) def forward(self, query, key, value, mask): """Compute scaled dot product attention. Args: query (torch.Tensor): Query tensor (#batch, time1, size). key (torch.Tensor): Key tensor (#batch, time2, size). value (torch.Tensor): Value tensor (#batch, time2, size). mask (torch.Tensor): Mask tensor (#batch, 1, time2) or (#batch, time1, time2). Returns: torch.Tensor: Output tensor (#batch, time1, d_model). """ q, k, v = self.forward_qkv(query, key, value) # scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) scores = torch.matmul(q, k.transpose(-2, -1)) * self.d_k ** (-0.5) return self.forward_attention(v, scores, mask) class EncoderLayer(nn.Module): """ This class is copied and modified from https://github.com/modelscope/FunASR/blob/main/funasr/models/transformer/encoder.py """ def __init__( self, size, self_attn, feed_forward, dropout_rate, normalize_before=True, concat_after=False, stochastic_depth_rate=0.0, ): super().__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.norm1 = nn.LayerNorm(size, eps=1e-12) self.norm2 = nn.LayerNorm(size, eps=1e-12) self.dropout = nn.Dropout(dropout_rate) self.size = size self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(size + size, size) self.stochastic_depth_rate = stochastic_depth_rate def forward(self, x, mask=None, cache=None): """Compute encoded features. Args: x_input (torch.Tensor): Input tensor (#batch, time, size). mask (torch.Tensor): Mask tensor for the input (#batch, time). cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). Returns: torch.Tensor: Output tensor (#batch, time, size). torch.Tensor: Mask tensor (#batch, time). """ skip_layer = False # with stochastic depth, residual connection `x + f(x)` becomes # `x <- x + 1 / (1 - p) * f(x)` at training time. stoch_layer_coeff = 1.0 if skip_layer: if cache is not None: x = torch.cat([cache, x], dim=1) return x, mask residual = x if self.normalize_before: x = self.norm1(x) if cache is None: x_q = x else: assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) x_q = x[:, -1:, :] residual = residual[:, -1:, :] mask = None if mask is None else mask[:, -1:, :] if self.concat_after: x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) x = residual + stoch_layer_coeff * self.concat_linear(x_concat) else: x = residual + stoch_layer_coeff * self.dropout( self.self_attn(x_q, x, x, mask) ) if not self.normalize_before: x = self.norm1(x) residual = x if self.normalize_before: x = self.norm2(x) x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) if not self.normalize_before: x = self.norm2(x) if cache is not None: x = torch.cat([cache, x], dim=1) return x, mask class Transformer(nn.Module): # This class is copied and modified from # https://github.com/modelscope/FunASR/blob/main/funasr/models/llm_asr/adaptor.py def __init__( self, downsample_rate=1, encoder_dim=512, llm_dim=512, ffn_dim: int = 2048, n_layer: int = 5, **kwargs ): super().__init__() assert downsample_rate == 1, downsample_rate self.k = downsample_rate self.encoder_dim = encoder_dim self.llm_dim = llm_dim self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim) self.relu = nn.ReLU() self.linear2 = nn.Linear(ffn_dim, self.llm_dim) self.blocks = None if n_layer > 0: self.blocks = nn.ModuleList( [ EncoderLayer( llm_dim, MultiHeadedAttention( kwargs.get("attention_heads", 8), llm_dim, kwargs.get("attention_dropout_rate", 0.0), ), torch_model.PositionwiseFeedForward( llm_dim, llm_dim // 4, kwargs.get("dropout_rate", 0.0), ), kwargs.get("dropout_rate", 0.0), ) for i in range(n_layer) ] ) def forward(self, x): x = self.linear1(x) x = self.relu(x) x = self.linear2(x) masks = None if self.blocks is not None: for layer, block in enumerate(self.blocks): x, masks = block(x, masks) return x