pgps-demo / model /encoder /__init__.py
asdfasdfdsafdsa's picture
Fix missing config attributes and imports
6f74e93 verified
raw
history blame contribute delete
654 Bytes
from .lstm import LSTM
from .gru import GRU
from .transformer import TransformerEncoder
encoder_list = ['lstm', 'gru', 'transformer']
def get_encoder(params, *args):
if not params.encoder_type in encoder_list:
raise NotImplementedError(
"Unsupported Classifier: {}".format(params.encoder_type))
if params.encoder_type == "transformer":
pass
elif params.encoder_type == "lstm":
encoder = LSTM(params, *args)
elif params.encoder_type == "gru":
encoder = GRU(params, *args)
else:
raise NotImplementedError("Unsupported Encoder: {}".format(params.encoder_type))
return encoder