Spaces:
Sleeping
Sleeping
| # from .transformer import TransformerModel | |
| from .rnn_decoder import DecoderRNN | |
| from .tree_decoder import TreeDecoder | |
| from .transformer import TransformerDecoder | |
| decoder_list = ["rnn_decoder", "tree_decoder", "transformer"] | |
| def get_decoder(params, *args): | |
| if not params.decoder_type in decoder_list: | |
| raise NotImplementedError( | |
| "Unsupported Classifier: {}".format(params.decoder_type)) | |
| if params.decoder_type == "transformer": | |
| decoder = TransformerDecoder(params, *args) | |
| elif params.decoder_type == "rnn_decoder": | |
| decoder = DecoderRNN(params, *args) | |
| elif params.decoder_type == "tree_decoder": | |
| decoder = TreeDecoder(params, *args) | |
| else: | |
| raise NotImplementedError("Unsupported Decoder: {}".format(params.decoder_type)) | |
| return decoder | |