File size: 837 Bytes
383bfb8
 
 
 
 
6f74e93
 
383bfb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# from .transformer import TransformerModel
from .rnn_decoder import DecoderRNN
from .tree_decoder import TreeDecoder
from .transformer import TransformerDecoder

decoder_list = ["rnn_decoder", "tree_decoder", "transformer"]

def get_decoder(params, *args):
         
    if not params.decoder_type in decoder_list:
        raise NotImplementedError(
            "Unsupported Classifier: {}".format(params.decoder_type))

    if params.decoder_type == "transformer":
        decoder = TransformerDecoder(params, *args)
    elif params.decoder_type == "rnn_decoder":
        decoder = DecoderRNN(params, *args)
    elif params.decoder_type == "tree_decoder":
        decoder = TreeDecoder(params, *args)
    else:
        raise NotImplementedError("Unsupported Decoder: {}".format(params.decoder_type))
             
    return decoder