Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AttentionHead(nn.Module): | |
| def __init__(self, Config): | |
| super().__init__() | |
| self.block_size = Config.block_size | |
| self.n_embed = Config.n_embed | |
| self.head_size = Config.head_size | |
| self.key = nn.Linear(self.n_embed, self.head_size, bias=False) | |
| self.query = nn.Linear(self.n_embed, self.head_size, bias=False) | |
| self.value = nn.Linear(self.n_embed, self.head_size, bias=False) | |
| self.register_buffer( | |
| 'tril', | |
| torch.tril(torch.ones(self.block_size,self.block_size)) | |
| ) | |
| self.dropout = nn.Dropout(Config.attn_dropout) | |
| def forward(self, x): | |
| B,T,C = x.shape | |
| k = self.key(x) | |
| q = self.query(x) | |
| wei = q@k.transpose(-2,-1) * (C ** 0.5) | |
| wei = wei.masked_fill(self.tril[:T,:T]==0,float('-inf')) | |
| wei = F.softmax(wei, dim=-1) | |
| wei = self.dropout(wei) | |
| v = self.value(x) | |
| out = wei @ v | |
| return out |