ICLR_FLAG / models /encoders /__init__.py
zaixizhang
renew
10efe81
raw
history blame contribute delete
916 Bytes
from .schnet import SchNetEncoder
from .tf import TransformerEncoder
from .gnn import GNN_graphpred, MLP
def get_encoder(config):
if config.name == 'schnet':
return SchNetEncoder(
hidden_channels = config.hidden_channels,
num_filters = config.num_filters,
num_interactions = config.num_interactions,
edge_channels = config.edge_channels,
cutoff = config.cutoff,
)
elif config.name == 'tf':
return TransformerEncoder(
hidden_channels = config.hidden_channels,
edge_channels = config.edge_channels,
key_channels = config.key_channels,
num_heads = config.num_heads,
num_interactions = config.num_interactions,
k = config.knn,
cutoff = config.cutoff,
)
else:
raise NotImplementedError('Unknown encoder: %s' % config.name)