| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class SelfAttention(nn.Module): |
| | def __init__(self, hidden_dim: int, n_heads: int, dropout: float): |
| | super().__init__() |
| | assert hidden_dim % n_heads == 0, "hidden_dim must be divisible by n_heads" |
| | self.n_heads = n_heads |
| | self.head_dim = hidden_dim // n_heads |
| | |
| | self.qkv = nn.Linear(hidden_dim, hidden_dim * 3) |
| | |
| | self.out_proj = nn.Linear(hidden_dim, hidden_dim) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | |
| | batch_size, seq_length, hidden_dim = x.size() |
| | |
| | qkv = self.qkv(x) |
| | |
| | qkv = qkv.reshape(batch_size, seq_length, 3, self.n_heads, self.head_dim) |
| | qkv = qkv.permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| | |
| | attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| | |
| | mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device)) |
| | attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) |
| | attn_weights = F.softmax(attn_scores, dim=-1) |
| | attn_weights = self.dropout(attn_weights) |
| | |
| | attn_output = torch.matmul(attn_weights, v) |
| | |
| | attn_output = attn_output.permute(0, 2, 1, 4).reshape(batch_size, seq_length, hidden_dim) |
| | |
| | output = self.out_proj(attn_output) |
| | output = self.dropout(output) |
| | return output |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, hidden_dim: int, dropout: float): |
| | super().__init__() |
| | self.fc1 = nn.Linear(hidden_dim, 4 * hidden_dim) |
| | self.fc2 = nn.Linear(4 * hidden_dim, hidden_dim) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | def forward(self, x): |
| | |
| | x = F.gelu(self.fc1(x)) |
| | x = self.fc2(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, hidden_dim: int, n_heads: int, dropout: float): |
| | super().__init__() |
| | self.ln1 = nn.LayerNorm(hidden_dim) |
| | self.ln2 = nn.LayerNorm(hidden_dim) |
| | self.attn = SelfAttention(hidden_dim, n_heads, dropout) |
| | self.ff = FeedForward(hidden_dim, dropout) |
| |
|
| | def forward(self, x): |
| | |
| | a = self.ln1(x) |
| | x = x + self.attn(a) |
| | |
| | m = self.ln2(x) |
| | x = x + self.ff(m) |
| | return x |
| |
|