| import torch |
| from torch import nn |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, n_embd: int, dropout: float): |
| super().__init__() |
| self.net = nn.Sequential( |
| |
| nn.Linear(n_embd, n_embd * 4), |
| nn.ReLU(), |
| |
| nn.Linear(n_embd * 4, n_embd), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x: torch.Tensor): |
| return self.net(x) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, n_embd: int, block_size: int, n_head: int, dropout: float): |
| super().__init__() |
| head_size = n_embd // n_head |
| self.sa_head = MultiHead( |
| n_head, block_size, n_embd, head_size, dropout) |
| self.ffwd = FeedForward(n_embd, dropout) |
| self.norm1 = nn.LayerNorm(n_embd) |
| self.norm2 = nn.LayerNorm(n_embd) |
|
|
| def forward(self, x: torch.Tensor): |
| x = x + self.sa_head(self.norm1(x)) |
| x = x + self.ffwd(self.norm2(x)) |
| return x |
|
|
|
|
| class MultiHead(nn.Module): |
| def __init__(self, num_heads: int, block_size: int, n_embd: int, head_size: int, dropout: float): |
| super().__init__() |
| self.heads = nn.ModuleList( |
| [Head(block_size, n_embd, head_size, dropout) for _ in range(num_heads)]) |
| self.proj = nn.Linear(n_embd, n_embd) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor): |
| out = torch.cat([head(x) for head in self.heads], dim=-1) |
| out = self.proj(out) |
| return self.drop(out) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, block_size: int, n_embd: int, head_size: int, dropout: float): |
| super().__init__() |
| self.key = nn.Linear(n_embd, head_size, bias=False) |
| self.query = nn.Linear(n_embd, head_size, bias=False) |
| self.value = nn.Linear(n_embd, head_size, bias=False) |
| self.register_buffer('tril', torch.tril( |
| torch.ones(block_size, block_size))) |
| self.drop = nn.Dropout(dropout) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| |
| |
| |
| q: torch.Tensor = self.query(x) |
| |
| k: torch.Tensor = self.key(x).transpose(-2, -1) |
| v = self.value(x) |
| |
| B, T, C = x.shape |
| wei = q @ k |
| |
| wei: torch.Tensor = wei * (C**-0.5) |
| wei = wei.masked_fill(self.tril[:T, :T] == 0, float( |
| '-inf')) |
| |
| wei = torch.softmax(wei, dim=-1) |
| wei = self.drop(wei) |
| out: torch.Tensor = wei @ v |
| return out |
|
|