spam / hugging face /modeling_beast.py
abdullahalioo's picture
Upload 3 files
1984adb verified
import torch
import torch.nn as nn
class BeastModel(nn.Module):
def __init__(self, config):
super().__init__()
self.linear = nn.Linear(config.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids=None, **kwargs):
if input_ids is not None:
x = input_ids.sum(dim=1, keepdim=True).float()
out = self.linear(x)
return self.sigmoid(out)
return None