import torch as tr class WordModel(tr.nn.Module): def __init__(self, IN_DIMS:int, HIDDEN_DIMS:int, VOCAB_SIZE:int) -> None: super(WordModel, self).__init__() self.l1 = tr.nn.Linear(in_features=IN_DIMS, out_features=HIDDEN_DIMS) self.l2 = tr.nn.Linear(in_features=HIDDEN_DIMS, out_features=HIDDEN_DIMS) self.l3 = tr.nn.Linear(in_features=HIDDEN_DIMS, out_features=VOCAB_SIZE) self.activation = tr.nn.ReLU() def forward(self, inputs): x = self.l1(inputs) x = self.activation(x) x = self.l2(x) x = self.activation(x) logits = self.l3(x) return logits