gpt2 / src /model /feed_forward.py
triton329's picture
Upload folder using huggingface_hub
c21e887 verified
Raw
History Blame Contribute Delete
342 Bytes
import torch.nn as nn
class FeedForwardNetwork(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_dim, 4 * embed_dim),
nn.GELU(),
nn.Linear(4 * embed_dim, embed_dim),
)
def forward(self, x):
return self.net(x)