| import math |
|
|
| import torch |
|
|
|
|
| class PositionalEncoding(torch.nn.Module): |
| """ |
| https://pytorch.org/tutorials/beginner/transformer_tutorial.html |
| """ |
|
|
| def __init__(self, d_model: int, max_len: int = 512): |
| super().__init__() |
|
|
| position = torch.arange(max_len).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) |
| ) |
|
|
| pe = torch.zeros(max_len, d_model) |
| pe[:, : d_model // 2] = torch.sin(position * div_term) |
| pe[:, d_model // 2 :] = torch.cos(position * div_term) |
|
|
| self.register_buffer("pe", pe) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.pe[: x.size(0)] |
| return x |
|
|
|
|
| class MultiheadSelfAttention(torch.nn.Module): |
| def __init__(self, embed_dim: int, num_heads: int = 8): |
| super().__init__() |
|
|
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
|
|
| self.query = torch.nn.Linear( |
| in_features=embed_dim, |
| out_features=embed_dim, |
| ) |
| self.key = torch.nn.Linear( |
| in_features=embed_dim, |
| out_features=embed_dim, |
| ) |
| self.value = torch.nn.Linear( |
| in_features=embed_dim, |
| out_features=embed_dim, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| q = self.query(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1) |
| k = self.key(x).view(x.shape[0], self.num_heads, -1).permute(1, 2, 0) |
| v = self.value(x).view(x.shape[0], self.num_heads, -1).transpose(0, 1) |
| qk = torch.softmax( |
| torch.matmul(q, k) / (self.embed_dim / self.num_heads) ** 0.5, |
| dim=-1, |
| ) |
| qkv = torch.matmul(qk, v).transpose(0, 1).reshape(x.shape[0], -1) |
| return qkv |
|
|
|
|
| class Block(torch.nn.Module): |
| def __init__(self, d_model: int, num_heads: int = 8, eps: float = 1e-6): |
| super().__init__() |
|
|
| self.ln1 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps) |
| self.attn = MultiheadSelfAttention(embed_dim=d_model, num_heads=num_heads) |
| self.ln2 = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps) |
| self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model * 4) |
| self.linear2 = torch.nn.Linear(in_features=d_model * 4, out_features=d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| ln1 = self.ln1(x) |
| attn = self.attn(ln1) |
| ln2 = self.ln2(x + attn) |
| mlp = self.linear2(torch.relu(self.linear1(ln2))) |
| return mlp + x + attn |
|
|
|
|
| class Head(torch.nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| eps: float = 1e-6, |
| ): |
| super().__init__() |
|
|
| self.d_model = d_model |
| self.eps = eps |
|
|
| self.ln = torch.nn.LayerNorm(normalized_shape=d_model, eps=eps) |
| self.linear1 = torch.nn.Linear(in_features=d_model, out_features=d_model) |
| self.linear2 = torch.nn.Linear(in_features=d_model, out_features=d_model) |
| self.tanh_layer = torch.nn.Linear(in_features=d_model * 2, out_features=d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| ln = self.ln(x) |
| mlp = torch.exp(self.linear2(torch.nn.functional.elu(self.linear1(ln)))) |
| res = torch.cat( |
| [ |
| ln.sum(dim=0) / ln.shape[0], |
| (mlp * ln).sum(dim=0) / mlp.sum(dim=0), |
| ] |
| ) |
| res = torch.tanh(self.tanh_layer(res)) |
| res /= (res**2).sum() ** 0.5 |
| res /= (res**2).sum() ** 0.5 |
| return res |
|
|
|
|
| class MUSE(torch.nn.Module): |
| def __init__( |
| self, |
| num_embeddings: int, |
| embedding_dim: int, |
| d_model: int, |
| num_heads: int, |
| eps: float = 1e-6, |
| ): |
| super().__init__() |
|
|
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.eps = eps |
|
|
| self.embedding = torch.nn.Embedding( |
| num_embeddings=num_embeddings, |
| embedding_dim=embedding_dim, |
| ) |
| self.linear = torch.nn.Linear( |
| in_features=embedding_dim, |
| out_features=d_model, |
| ) |
| self.pe = PositionalEncoding( |
| d_model=d_model, |
| max_len=512, |
| ) |
| self.block0 = Block(d_model=d_model) |
| self.block1 = Block(d_model=d_model) |
| self.block2 = Block(d_model=d_model) |
| self.block3 = Block(d_model=d_model) |
| self.block4 = Block(d_model=d_model) |
| self.block5 = Block(d_model=d_model) |
| self.head = Head(d_model=d_model) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.embedding(x) |
| x = self.linear(x) |
| x = self.pe(x) |
| x = self.block0(x) |
| x = self.block1(x) |
| x = self.block2(x) |
| x = self.block3(x) |
| x = self.block4(x) |
| x = self.block5(x) |
| x = self.head(x) |
| return x |
|
|