Spaces:
Sleeping
Sleeping
File size: 837 Bytes
383bfb8 6f74e93 383bfb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# from .transformer import TransformerModel
from .rnn_decoder import DecoderRNN
from .tree_decoder import TreeDecoder
from .transformer import TransformerDecoder
decoder_list = ["rnn_decoder", "tree_decoder", "transformer"]
def get_decoder(params, *args):
if not params.decoder_type in decoder_list:
raise NotImplementedError(
"Unsupported Classifier: {}".format(params.decoder_type))
if params.decoder_type == "transformer":
decoder = TransformerDecoder(params, *args)
elif params.decoder_type == "rnn_decoder":
decoder = DecoderRNN(params, *args)
elif params.decoder_type == "tree_decoder":
decoder = TreeDecoder(params, *args)
else:
raise NotImplementedError("Unsupported Decoder: {}".format(params.decoder_type))
return decoder
|