Tiny-Translator / src /model.py
caixiaoshun's picture
Upload 6 files
5153277 verified
from torch import nn as nn
import torch
from src.config import Config
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, drop_out=0.1):
super().__init__()
assert embed_dim % num_heads == 0
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out)
self.scale = (embed_dim // num_heads) ** 0.5
self.embed_dim = embed_dim
self.num_heads = num_heads
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None,
pad_mask: torch.Tensor = None,
):
bs = q.shape[0]
q_len = q.shape[1]
k_len = k.shape[1]
Q: torch.Tensor = self.q_proj(q)
K: torch.Tensor = self.k_proj(k)
V: torch.Tensor = self.v_proj(v)
q_state = Q.view(bs, q_len, self.num_heads, -1).transpose(1, 2)
k_state = K.view(bs, k_len, self.num_heads, -1).transpose(1, 2)
v_state = V.view(bs, k_len, self.num_heads, -1).transpose(1, 2)
attn = q_state @ k_state.transpose(
-1, -2
) # [bs, head, q_len, dim] @ [bs, head, dim, k_len] = [bs, head, q_len, k_len]
attn: torch.Tensor = attn / self.scale
if mask is not None:
attn = attn.masked_fill(~mask, -1e8)
if pad_mask is not None:
attn = attn.masked_fill(~pad_mask.unsqueeze(1).unsqueeze(2), -1e8)
attn = torch.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (
attn @ v_state
) # [bs, head, q_len, k_len] @ [bs, head, k_len, dim] = [bs, head, q_len, dim]
out = out.transpose(1, 2).contiguous().view(bs, q_len, -1)
out = self.out_proj(out)
return out
class FFN(nn.Module):
def __init__(self, embed_dim, drop_out=0.1):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.ReLU(),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(drop_out),
)
def forward(self, x):
return self.mlp(x)
class EncoderLayer(nn.Module):
def __init__(self, embed_dim, num_heads, drop_out=0.1):
super().__init__()
self.mha = MultiHeadAttention(
embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out
)
self.ffn = FFN(embed_dim=embed_dim, drop_out=drop_out)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor, pad_mask=None):
x = x + self.mha(x, x, x, pad_mask=pad_mask)
x = self.norm1(x)
x = x + self.ffn(x)
x = self.norm2(x)
return x
class Encoder(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(config.encoder_layer):
self.layers.append(
EncoderLayer(config.embed_dim, config.num_heads, config.drop_out)
)
def forward(self, x, pad_mask=None):
for layer in self.layers:
x = layer(x, pad_mask)
return x
class DecoderLayer(nn.Module):
def __init__(self, embed_dim, num_heads, drop_out=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(
embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out
)
self.cross_attn = MultiHeadAttention(
embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out
)
self.ffn = FFN(embed_dim=embed_dim, drop_out=drop_out)
self.norm0 = nn.LayerNorm(embed_dim)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
def forward(self, x: torch.Tensor, memory, src_pad_mask=None, tgt_pad_mask=None):
x_len = x.shape[1]
mask = torch.ones(size=(1, 1, x_len, x_len), device=x.device, dtype=torch.bool).tril()
x = x + self.self_attn(x, x, x, mask=mask, pad_mask=tgt_pad_mask)
x = self.norm0(x)
x = x + self.cross_attn(x, memory, memory, pad_mask=src_pad_mask)
x = self.norm1(x)
x = x + self.ffn(x)
x = self.norm2(x)
return x
class Decoder(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(config.decoder_layer):
self.layers.append(
DecoderLayer(config.embed_dim, config.num_heads, config.drop_out)
)
def forward(self, x: torch.Tensor, memory, src_pad_mask=None, tgt_pad_mask=None):
for layer in self.layers:
x = layer(x, memory, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask)
return x
class PositionEmbedding(nn.Module):
def __init__(self, config: Config):
super().__init__()
pe = torch.zeros(config.max_len, config.embed_dim)
pos = torch.arange(0, config.max_len, 1).float().unsqueeze(1)
_2i = torch.arange(0, config.embed_dim, 2)
pe[:, 0::2] = torch.sin(pos / (10000 ** (_2i / config.embed_dim)))
pe[:, 1::2] = torch.cos(pos / (10000 ** (_2i / config.embed_dim)))
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x_len = x.shape[1]
return x + self.pe[:, :x_len].to(dtype=x.dtype)
class TranslateModel(nn.Module):
def __init__(self, config: Config):
super().__init__()
self.position_embedding = PositionEmbedding(config=config)
self.encoder = Encoder(config=config)
self.decoder = Decoder(config=config)
self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.head = nn.Linear(config.embed_dim, config.vocab_size)
self.drop = nn.Dropout(config.drop_out)
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
src_pad_mask=None,
tgt_pad_mask=None,
):
## encoder
src_embedding = self.embedding(src)
src_embedding = self.position_embedding(src_embedding)
memory = self.encoder.forward(src_embedding, src_pad_mask)
tgt_embedding = self.embedding(tgt)
tgt_embedding = self.position_embedding(tgt_embedding)
output = self.decoder.forward(tgt_embedding, memory, src_pad_mask, tgt_pad_mask)
output = self.drop(output)
output = self.head(output)
return output