Spaces:
Build error
Build error
File size: 468 Bytes
3aa6cf7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import torch.nn as nn
from .attention import CausalSelfAttention
from .mlp import MLP
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x |