from torch import nn as nn import torch from src.config import Config class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads, drop_out=0.1): super().__init__() assert embed_dim % num_heads == 0 self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(drop_out) self.scale = (embed_dim // num_heads) ** 0.5 self.embed_dim = embed_dim self.num_heads = num_heads def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None, pad_mask: torch.Tensor = None, ): bs = q.shape[0] q_len = q.shape[1] k_len = k.shape[1] Q: torch.Tensor = self.q_proj(q) K: torch.Tensor = self.k_proj(k) V: torch.Tensor = self.v_proj(v) q_state = Q.view(bs, q_len, self.num_heads, -1).transpose(1, 2) k_state = K.view(bs, k_len, self.num_heads, -1).transpose(1, 2) v_state = V.view(bs, k_len, self.num_heads, -1).transpose(1, 2) attn = q_state @ k_state.transpose( -1, -2 ) # [bs, head, q_len, dim] @ [bs, head, dim, k_len] = [bs, head, q_len, k_len] attn: torch.Tensor = attn / self.scale if mask is not None: attn = attn.masked_fill(~mask, -1e8) if pad_mask is not None: attn = attn.masked_fill(~pad_mask.unsqueeze(1).unsqueeze(2), -1e8) attn = torch.softmax(attn, dim=-1) attn = self.dropout(attn) out = ( attn @ v_state ) # [bs, head, q_len, k_len] @ [bs, head, k_len, dim] = [bs, head, q_len, dim] out = out.transpose(1, 2).contiguous().view(bs, q_len, -1) out = self.out_proj(out) return out class FFN(nn.Module): def __init__(self, embed_dim, drop_out=0.1): super().__init__() self.mlp = nn.Sequential( nn.Linear(embed_dim, embed_dim * 4), nn.ReLU(), nn.Linear(embed_dim * 4, embed_dim), nn.Dropout(drop_out), ) def forward(self, x): return self.mlp(x) class EncoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, drop_out=0.1): super().__init__() self.mha = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out ) self.ffn = FFN(embed_dim=embed_dim, drop_out=drop_out) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor, pad_mask=None): x = x + self.mha(x, x, x, pad_mask=pad_mask) x = self.norm1(x) x = x + self.ffn(x) x = self.norm2(x) return x class Encoder(nn.Module): def __init__(self, config: Config): super().__init__() self.layers = nn.ModuleList([]) for _ in range(config.encoder_layer): self.layers.append( EncoderLayer(config.embed_dim, config.num_heads, config.drop_out) ) def forward(self, x, pad_mask=None): for layer in self.layers: x = layer(x, pad_mask) return x class DecoderLayer(nn.Module): def __init__(self, embed_dim, num_heads, drop_out=0.1): super().__init__() self.self_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out ) self.cross_attn = MultiHeadAttention( embed_dim=embed_dim, num_heads=num_heads, drop_out=drop_out ) self.ffn = FFN(embed_dim=embed_dim, drop_out=drop_out) self.norm0 = nn.LayerNorm(embed_dim) self.norm1 = nn.LayerNorm(embed_dim) self.norm2 = nn.LayerNorm(embed_dim) def forward(self, x: torch.Tensor, memory, src_pad_mask=None, tgt_pad_mask=None): x_len = x.shape[1] mask = torch.ones(size=(1, 1, x_len, x_len), device=x.device, dtype=torch.bool).tril() x = x + self.self_attn(x, x, x, mask=mask, pad_mask=tgt_pad_mask) x = self.norm0(x) x = x + self.cross_attn(x, memory, memory, pad_mask=src_pad_mask) x = self.norm1(x) x = x + self.ffn(x) x = self.norm2(x) return x class Decoder(nn.Module): def __init__(self, config: Config): super().__init__() self.layers = nn.ModuleList([]) for _ in range(config.decoder_layer): self.layers.append( DecoderLayer(config.embed_dim, config.num_heads, config.drop_out) ) def forward(self, x: torch.Tensor, memory, src_pad_mask=None, tgt_pad_mask=None): for layer in self.layers: x = layer(x, memory, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask) return x class PositionEmbedding(nn.Module): def __init__(self, config: Config): super().__init__() pe = torch.zeros(config.max_len, config.embed_dim) pos = torch.arange(0, config.max_len, 1).float().unsqueeze(1) _2i = torch.arange(0, config.embed_dim, 2) pe[:, 0::2] = torch.sin(pos / (10000 ** (_2i / config.embed_dim))) pe[:, 1::2] = torch.cos(pos / (10000 ** (_2i / config.embed_dim))) pe = pe.unsqueeze(0) self.register_buffer("pe", pe) def forward(self, x): x_len = x.shape[1] return x + self.pe[:, :x_len].to(dtype=x.dtype) class TranslateModel(nn.Module): def __init__(self, config: Config): super().__init__() self.position_embedding = PositionEmbedding(config=config) self.encoder = Encoder(config=config) self.decoder = Decoder(config=config) self.embedding = nn.Embedding(config.vocab_size, config.embed_dim) self.head = nn.Linear(config.embed_dim, config.vocab_size) self.drop = nn.Dropout(config.drop_out) def forward( self, src: torch.Tensor, tgt: torch.Tensor, src_pad_mask=None, tgt_pad_mask=None, ): ## encoder src_embedding = self.embedding(src) src_embedding = self.position_embedding(src_embedding) memory = self.encoder.forward(src_embedding, src_pad_mask) tgt_embedding = self.embedding(tgt) tgt_embedding = self.position_embedding(tgt_embedding) output = self.decoder.forward(tgt_embedding, memory, src_pad_mask, tgt_pad_mask) output = self.drop(output) output = self.head(output) return output