Spaces:
Configuration error
Configuration error
File size: 2,446 Bytes
d541e5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from torch import nn
from src.util import device
class Transpose(nn.Module):
def __init__(self, dim0=None, dim1=None):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, tensor):
if self.dim0 is None:
self.dim0 = tensor.dim() - 2
self.dim1 = tensor.dim() - 1
return torch.transpose(tensor, self.dim0, self.dim1)
class Model2(nn.Module):
def __init__(
self,
vocab_size,
embedding_dim,
state_size,
pad_index,
):
super().__init__()
self.state_size = state_size
self.pad_index = pad_index
self.embedding_layer = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=pad_index,
)
self.rnn_layer = nn.LSTMCell(input_size=embedding_dim, hidden_size=state_size)
self.lin1 = nn.Sequential(
nn.Linear(state_size, state_size * 4),
nn.ReLU(),
nn.Dropout(p=0.5),
)
self.lin2 = nn.Sequential(
nn.Linear(state_size * 4, state_size * 8),
Transpose(),
nn.BatchNorm1d(state_size * 8),
Transpose(),
nn.ReLU(),
nn.Dropout(p=0.5),
)
self.lin3 = nn.Sequential(
nn.Linear(state_size * 8, state_size * 16),
nn.ReLU(),
nn.Dropout(p=0.5),
)
self.lin4 = nn.Sequential(nn.Linear(state_size * 16, vocab_size))
def forward(self, X):
N, T = X.shape
non_pad_mask = X != self.pad_index
X = self.embedding_layer(X)
state = torch.zeros((N, self.state_size), device=device)
c = torch.zeros((N, self.state_size), device=device)
states = []
for t in range(T):
next_state, next_c = self.rnn_layer(X[:, t, :], (state, c))
# print(non_pad_mask[:, t].reshape(-1, 1).shape, next_state.shape, state.shape)
state = torch.where(non_pad_mask[:, t].reshape(-1, 1), next_state, state)
c = torch.where(non_pad_mask[:, t].reshape(-1, 1), next_c, c)
states.append(state)
# (N, T, states)
states = torch.stack(states, dim=1)
output = self.lin1(states)
output = self.lin2(output)
output = self.lin3(output)
output = self.lin4(output)
return output |