rkingery
added reqs
5cefcd7
import torch
from torch import nn
import torch.nn.functional as F
device = 'cpu'
class ModifiedTransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, layer_norm_eps=1e-5,
batch_first=True, norm_first=False, device=device, dtype=None, alpha=1, beta=1):
kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **kwargs)
self.linear1 = nn.Linear(d_model, dim_feedforward, **kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **kwargs)
self.norm_first = norm_first if not (alpha == 1 and beta == 1) else False # deepnorm reqs post-norm
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **kwargs)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.alpha = alpha # NEW
self.beta = beta # NEW
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super().__setstate__(state)
def forward(self, x, src_mask=None, src_key_padding_mask=None):
if self.norm_first:
x = self.alpha * x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) # UPDATED
x = self.alpha * x + self._ff_block(self.norm2(x)) # UPDATED
else:
x = self.norm1(self.alpha * x + self._sa_block(x, src_mask, src_key_padding_mask)) # UPDATED
x = self.norm2(self.alpha * x + self._ff_block(x)) # UPDATED
return x
def _sa_block(self, x, attn_mask, key_padding_mask):
q, k, v = x / self.beta, x / self.beta, x # NEW
x = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
return self.dropout1(x)
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class PositionalEncoding(nn.Module):
def __init__(self, emb_size, max_len=5000, N=10_000, dropout=0.1):
super().__init__()
position = torch.arange(max_len)[:, None]
div_term = torch.exp(torch.arange(0, emb_size, 2) * (-torch.log(torch.tensor(N) / emb_size)))
self.encoding = torch.zeros(max_len, 1, emb_size)
self.encoding[:, 0, 0::2] = torch.sin(position * div_term)
self.encoding[:, 0, 1::2] = torch.cos(position * div_term)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x):
# x.shape = y.shape = (batch_size, seq_len, emb_size)
batch_size = x.shape[0]
p = self.encoding[:batch_size].to(device)
x += p
y = self.dropout(x)
return y
class EncoderLM(nn.Module):
def __init__(self, vocab_size, emb_size, num_layers=3, nhead=8, dropout=0.1, dim_feedforward=256,
masking=False, max_len=1024, padding_idx=None, deepnorm=False, activation=nn.GELU(),
norm_first=False):
super().__init__()
self.masking = masking
self.padding_idx = padding_idx
self.emb_size = emb_size
self.alpha = (2 * num_layers)**(1/4) if deepnorm else 1
self.beta = (8 * num_layers)**(-1/4) if deepnorm else 1
self.mask = torch.triu(torch.ones(max_len, max_len)*(-torch.inf), diagonal=1).to(device)
self.positional = PositionalEncoding(emb_size, dropout=dropout)
self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx)
self.transformer = nn.TransformerEncoder(ModifiedTransformerEncoderLayer(
emb_size, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True,
alpha=self.alpha, beta=self.beta, activation=activation, norm_first=norm_first), num_layers)
self.fc = nn.Linear(emb_size, vocab_size)
self.apply(self.init_weights)
def forward(self, X):
# X.shape = (batch_size, seq_len)
# Yhat.shape = (batch_size, seq_len, vocab_size)
batch_size, seq_len = X.shape
mask = self.mask[:seq_len, :seq_len] if self.masking else None
pad_mask = (X == self.padding_idx) if self.padding_idx is not None else None
X = self.embedding(X) * torch.sqrt(torch.tensor(self.emb_size))
X = self.positional(X)
X = self.transformer(X, mask, pad_mask)
X = self.fc(X)
return X
def init_weights(self, layer):
if isinstance(layer, nn.modules.linear.Linear) or \
isinstance(layer, nn.modules.sparse.Embedding) or \
isinstance(layer, nn.modules.linear.NonDynamicallyQuantizableLinear):
torch.nn.init.xavier_normal_(layer.weight, gain=self.beta)
class Seq2SeqCrossEntropyLoss(nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss(**kwargs)
def forward(self, Yhat, Y):
bs, bptt = Y.shape[0], Y.shape[1]
y = Y.reshape(bs * bptt,)
yhat = Yhat.reshape(bs * bptt, -1)
loss = self.loss_fn(yhat, y).mean()
return loss
if __name__ == '__main__':
model = EncoderLM(10_000, 512, num_layers=1, nhead=8, dim_feedforward=512,
masking=True, padding_idx=1, dropout=0.1, max_len=525, deepnorm=True).to(device)
print(model)