|
|
import numpy as np
|
|
|
from .embedding import embedding_lookup, add_positional_encoding
|
|
|
from .positional_encoding import sinusoidal_positional_encoding
|
|
|
from .decoder import transformer_decoder_block
|
|
|
|
|
|
class TransformerDecoder:
|
|
|
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, max_seq_len, embedding_weights, block_weights_list, driver=None, scheduler=None):
|
|
|
self.vocab_size = vocab_size
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_layers = num_layers
|
|
|
self.num_heads = num_heads
|
|
|
self.max_seq_len = max_seq_len
|
|
|
self.embedding_weights = embedding_weights
|
|
|
self.block_weights_list = block_weights_list
|
|
|
self.pos_encoding = sinusoidal_positional_encoding(max_seq_len, hidden_dim)
|
|
|
self.driver = driver
|
|
|
self.scheduler = scheduler
|
|
|
|
|
|
def forward(self, input_ids, enc_out, self_mask=None, enc_dec_mask=None):
|
|
|
|
|
|
x = embedding_lookup(input_ids, self.embedding_weights)
|
|
|
x = add_positional_encoding(x, self.pos_encoding[:x.shape[1]])
|
|
|
for block_weights in self.block_weights_list:
|
|
|
x = transformer_decoder_block(x, enc_out, block_weights, self.num_heads, self_mask, enc_dec_mask, driver=self.driver, scheduler=self.scheduler)
|
|
|
return x
|
|
|
|