File size: 635 Bytes
cbabe4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import PreTrainedModel
from .custom_config import LinearConfig
import torch.nn as nn
import torch

class BasicLinear(PreTrainedModel):
    config_class = LinearConfig
    
    def __init__(self, config):
        super().__init__(config)
        self.weight = nn.Parameter(torch.randn(config.out_features, config.in_features) * 0.01)
        if config.bias:
            self.bias = nn.Parameter(torch.zeros(config.out_features))
        else:
            self.bias = None

    def forward(self, x):
        out = x @ self.weight.T
        if self.bias is not None:
            out = out + self.bias
        return out