HSinghHuggingFace's picture
huggingface app
3aa6cf7
raw
history blame contribute delete
468 Bytes
import torch.nn as nn
from .attention import CausalSelfAttention
from .mlp import MLP
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x