File size: 655 Bytes
70c7b07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch as tr


class WordModel(tr.nn.Module):
    def __init__(self, IN_DIMS:int, HIDDEN_DIMS:int, VOCAB_SIZE:int) -> None:
        super(WordModel, self).__init__()

        self.l1 = tr.nn.Linear(in_features=IN_DIMS, out_features=HIDDEN_DIMS)
        self.l2 = tr.nn.Linear(in_features=HIDDEN_DIMS, out_features=HIDDEN_DIMS)
        self.l3 = tr.nn.Linear(in_features=HIDDEN_DIMS, out_features=VOCAB_SIZE)

        self.activation = tr.nn.ReLU()

    def forward(self, inputs):
        x = self.l1(inputs)
        x = self.activation(x)
        x = self.l2(x)
        x = self.activation(x)
        logits = self.l3(x)

        return logits