|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
import torch_model |
|
|
|
|
|
|
|
|
class MultiHeadedAttention(nn.Module): |
|
|
""" |
|
|
This class is copied and modified from |
|
|
https://github.com/modelscope/FunASR/blob/main/funasr/models/transformer/attention.py |
|
|
""" |
|
|
|
|
|
def __init__(self, n_head, n_feat, dropout_rate): |
|
|
super().__init__() |
|
|
assert n_feat % n_head == 0 |
|
|
|
|
|
|
|
|
self.d_k = n_feat // n_head |
|
|
self.h = n_head |
|
|
self.linear_q = nn.Linear(n_feat, n_feat) |
|
|
self.linear_k = nn.Linear(n_feat, n_feat) |
|
|
self.linear_v = nn.Linear(n_feat, n_feat) |
|
|
self.linear_out = nn.Linear(n_feat, n_feat) |
|
|
self.attn = None |
|
|
self.dropout = nn.Dropout(p=dropout_rate) |
|
|
|
|
|
def forward_qkv(self, query, key, value): |
|
|
"""Transform query, key and value. |
|
|
|
|
|
Args: |
|
|
query (torch.Tensor): Query tensor (#batch, time1, size). |
|
|
key (torch.Tensor): Key tensor (#batch, time2, size). |
|
|
value (torch.Tensor): Value tensor (#batch, time2, size). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). |
|
|
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). |
|
|
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). |
|
|
|
|
|
""" |
|
|
n_batch = query.size(0) |
|
|
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) |
|
|
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) |
|
|
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) |
|
|
q = q.transpose(1, 2) |
|
|
k = k.transpose(1, 2) |
|
|
v = v.transpose(1, 2) |
|
|
|
|
|
return q, k, v |
|
|
|
|
|
def forward_attention(self, value, scores, mask): |
|
|
"""Compute attention context vector. |
|
|
|
|
|
Args: |
|
|
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). |
|
|
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). |
|
|
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Transformed value (#batch, time1, d_model) |
|
|
weighted by the attention score (#batch, time1, time2). |
|
|
|
|
|
""" |
|
|
n_batch = value.size(0) |
|
|
if mask is not None: |
|
|
mask = mask.unsqueeze(1).eq(0) |
|
|
|
|
|
min_value = -float( |
|
|
"inf" |
|
|
) |
|
|
scores = scores.masked_fill(mask, min_value) |
|
|
attn = torch.softmax(scores, dim=-1).masked_fill( |
|
|
mask, 0.0 |
|
|
) |
|
|
else: |
|
|
attn = torch.softmax(scores, dim=-1) |
|
|
|
|
|
p_attn = self.dropout(attn) |
|
|
x = torch.matmul(p_attn, value) |
|
|
x = ( |
|
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
|
|
) |
|
|
|
|
|
return self.linear_out(x) |
|
|
|
|
|
def forward(self, query, key, value, mask): |
|
|
"""Compute scaled dot product attention. |
|
|
|
|
|
Args: |
|
|
query (torch.Tensor): Query tensor (#batch, time1, size). |
|
|
key (torch.Tensor): Key tensor (#batch, time2, size). |
|
|
value (torch.Tensor): Value tensor (#batch, time2, size). |
|
|
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
|
|
(#batch, time1, time2). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor (#batch, time1, d_model). |
|
|
|
|
|
""" |
|
|
q, k, v = self.forward_qkv(query, key, value) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.d_k ** (-0.5) |
|
|
|
|
|
return self.forward_attention(v, scores, mask) |
|
|
|
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
|
""" |
|
|
This class is copied and modified from |
|
|
https://github.com/modelscope/FunASR/blob/main/funasr/models/transformer/encoder.py |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
size, |
|
|
self_attn, |
|
|
feed_forward, |
|
|
dropout_rate, |
|
|
normalize_before=True, |
|
|
concat_after=False, |
|
|
stochastic_depth_rate=0.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.self_attn = self_attn |
|
|
self.feed_forward = feed_forward |
|
|
self.norm1 = nn.LayerNorm(size, eps=1e-12) |
|
|
self.norm2 = nn.LayerNorm(size, eps=1e-12) |
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
self.size = size |
|
|
self.normalize_before = normalize_before |
|
|
self.concat_after = concat_after |
|
|
if self.concat_after: |
|
|
self.concat_linear = nn.Linear(size + size, size) |
|
|
self.stochastic_depth_rate = stochastic_depth_rate |
|
|
|
|
|
def forward(self, x, mask=None, cache=None): |
|
|
"""Compute encoded features. |
|
|
|
|
|
Args: |
|
|
x_input (torch.Tensor): Input tensor (#batch, time, size). |
|
|
mask (torch.Tensor): Mask tensor for the input (#batch, time). |
|
|
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor (#batch, time, size). |
|
|
torch.Tensor: Mask tensor (#batch, time). |
|
|
|
|
|
""" |
|
|
skip_layer = False |
|
|
|
|
|
|
|
|
stoch_layer_coeff = 1.0 |
|
|
|
|
|
if skip_layer: |
|
|
if cache is not None: |
|
|
x = torch.cat([cache, x], dim=1) |
|
|
return x, mask |
|
|
|
|
|
residual = x |
|
|
if self.normalize_before: |
|
|
x = self.norm1(x) |
|
|
|
|
|
if cache is None: |
|
|
x_q = x |
|
|
else: |
|
|
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) |
|
|
x_q = x[:, -1:, :] |
|
|
residual = residual[:, -1:, :] |
|
|
mask = None if mask is None else mask[:, -1:, :] |
|
|
|
|
|
if self.concat_after: |
|
|
x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) |
|
|
x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
|
|
else: |
|
|
x = residual + stoch_layer_coeff * self.dropout( |
|
|
self.self_attn(x_q, x, x, mask) |
|
|
) |
|
|
if not self.normalize_before: |
|
|
x = self.norm1(x) |
|
|
|
|
|
residual = x |
|
|
if self.normalize_before: |
|
|
x = self.norm2(x) |
|
|
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) |
|
|
if not self.normalize_before: |
|
|
x = self.norm2(x) |
|
|
|
|
|
if cache is not None: |
|
|
x = torch.cat([cache, x], dim=1) |
|
|
|
|
|
return x, mask |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
downsample_rate=1, |
|
|
encoder_dim=512, |
|
|
llm_dim=512, |
|
|
ffn_dim: int = 2048, |
|
|
n_layer: int = 5, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
assert downsample_rate == 1, downsample_rate |
|
|
self.k = downsample_rate |
|
|
self.encoder_dim = encoder_dim |
|
|
self.llm_dim = llm_dim |
|
|
self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim) |
|
|
self.relu = nn.ReLU() |
|
|
self.linear2 = nn.Linear(ffn_dim, self.llm_dim) |
|
|
|
|
|
self.blocks = None |
|
|
if n_layer > 0: |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
EncoderLayer( |
|
|
llm_dim, |
|
|
MultiHeadedAttention( |
|
|
kwargs.get("attention_heads", 8), |
|
|
llm_dim, |
|
|
kwargs.get("attention_dropout_rate", 0.0), |
|
|
), |
|
|
torch_model.PositionwiseFeedForward( |
|
|
llm_dim, |
|
|
llm_dim // 4, |
|
|
kwargs.get("dropout_rate", 0.0), |
|
|
), |
|
|
kwargs.get("dropout_rate", 0.0), |
|
|
) |
|
|
for i in range(n_layer) |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.linear1(x) |
|
|
x = self.relu(x) |
|
|
x = self.linear2(x) |
|
|
|
|
|
masks = None |
|
|
|
|
|
if self.blocks is not None: |
|
|
for layer, block in enumerate(self.blocks): |
|
|
x, masks = block(x, masks) |
|
|
return x |
|
|
|