import numpy as np from .embedding import embedding_lookup, add_positional_encoding from .positional_encoding import sinusoidal_positional_encoding from .decoder import transformer_decoder_block class TransformerDecoder: def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, max_seq_len, embedding_weights, block_weights_list, driver=None, scheduler=None): self.vocab_size = vocab_size self.hidden_dim = hidden_dim self.num_layers = num_layers self.num_heads = num_heads self.max_seq_len = max_seq_len self.embedding_weights = embedding_weights # (vocab_size, hidden_dim) self.block_weights_list = block_weights_list # list of dicts, one per block self.pos_encoding = sinusoidal_positional_encoding(max_seq_len, hidden_dim) self.driver = driver self.scheduler = scheduler def forward(self, input_ids, enc_out, self_mask=None, enc_dec_mask=None): # input_ids: (batch, tgt_seq_len) x = embedding_lookup(input_ids, self.embedding_weights) x = add_positional_encoding(x, self.pos_encoding[:x.shape[1]]) for block_weights in self.block_weights_list: x = transformer_decoder_block(x, enc_out, block_weights, self.num_heads, self_mask, enc_dec_mask, driver=self.driver, scheduler=self.scheduler) return x