Spaces:
Sleeping
Sleeping
File size: 364 Bytes
0c8750c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch.nn as nn
class MLMHead(nn.Module):
def __init__(self, d_model=256):
super().__init__()
self.lin = nn.Linear(d_model, d_model, bias=False)
self.gelu = nn.GELU()
self.norm = nn.LayerNorm(d_model)
def forward(self, x):
x = self.lin(x)
x = self.gelu(x)
x = self.norm(x)
return x
|