pgps-demo / model /decoder /__init__.py
asdfasdfdsafdsa's picture
Fix missing config attributes and imports
6f74e93 verified
raw
history blame contribute delete
837 Bytes
# from .transformer import TransformerModel
from .rnn_decoder import DecoderRNN
from .tree_decoder import TreeDecoder
from .transformer import TransformerDecoder
decoder_list = ["rnn_decoder", "tree_decoder", "transformer"]
def get_decoder(params, *args):
if not params.decoder_type in decoder_list:
raise NotImplementedError(
"Unsupported Classifier: {}".format(params.decoder_type))
if params.decoder_type == "transformer":
decoder = TransformerDecoder(params, *args)
elif params.decoder_type == "rnn_decoder":
decoder = DecoderRNN(params, *args)
elif params.decoder_type == "tree_decoder":
decoder = TreeDecoder(params, *args)
else:
raise NotImplementedError("Unsupported Decoder: {}".format(params.decoder_type))
return decoder