File size: 449 Bytes
1984adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17

import torch
import torch.nn as nn

class BeastModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear = nn.Linear(config.hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids=None, **kwargs):
        if input_ids is not None:
            x = input_ids.sum(dim=1, keepdim=True).float()
            out = self.linear(x)
            return self.sigmoid(out)
        return None