INV / helium /decoder_model.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from .embedding import embedding_lookup, add_positional_encoding
from .positional_encoding import sinusoidal_positional_encoding
from .decoder import transformer_decoder_block
class TransformerDecoder:
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads, max_seq_len, embedding_weights, block_weights_list, driver=None, scheduler=None):
self.vocab_size = vocab_size
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_heads = num_heads
self.max_seq_len = max_seq_len
self.embedding_weights = embedding_weights # (vocab_size, hidden_dim)
self.block_weights_list = block_weights_list # list of dicts, one per block
self.pos_encoding = sinusoidal_positional_encoding(max_seq_len, hidden_dim)
self.driver = driver
self.scheduler = scheduler
def forward(self, input_ids, enc_out, self_mask=None, enc_dec_mask=None):
# input_ids: (batch, tgt_seq_len)
x = embedding_lookup(input_ids, self.embedding_weights)
x = add_positional_encoding(x, self.pos_encoding[:x.shape[1]])
for block_weights in self.block_weights_list:
x = transformer_decoder_block(x, enc_out, block_weights, self.num_heads, self_mask, enc_dec_mask, driver=self.driver, scheduler=self.scheduler)
return x