Ajay-user's picture
app files
70c7b07
raw
history blame contribute delete
655 Bytes
import torch as tr
class WordModel(tr.nn.Module):
def __init__(self, IN_DIMS:int, HIDDEN_DIMS:int, VOCAB_SIZE:int) -> None:
super(WordModel, self).__init__()
self.l1 = tr.nn.Linear(in_features=IN_DIMS, out_features=HIDDEN_DIMS)
self.l2 = tr.nn.Linear(in_features=HIDDEN_DIMS, out_features=HIDDEN_DIMS)
self.l3 = tr.nn.Linear(in_features=HIDDEN_DIMS, out_features=VOCAB_SIZE)
self.activation = tr.nn.ReLU()
def forward(self, inputs):
x = self.l1(inputs)
x = self.activation(x)
x = self.l2(x)
x = self.activation(x)
logits = self.l3(x)
return logits