File size: 449 Bytes
1984adb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import torch.nn as nn
class BeastModel(nn.Module):
def __init__(self, config):
super().__init__()
self.linear = nn.Linear(config.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids=None, **kwargs):
if input_ids is not None:
x = input_ids.sum(dim=1, keepdim=True).float()
out = self.linear(x)
return self.sigmoid(out)
return None
|