Spaces:
Sleeping
Sleeping
File size: 1,563 Bytes
d9df210 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
from mvp.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaTransformer
from mvp.models.mol_encoder import MolEnc
from mvp.models.encoders import MLP
from mvp.models.contrastive import ContrastiveModel, MultiViewContrastive
def get_spec_encoder(spec_enc:str, args):
return {"MLP_BIN": SpecEncMLP_BIN,
"Transformer_Formula": SpecFormulaTransformer}[spec_enc](args)
def get_mol_encoder(mol_enc: str, args):
return {'GNN': MolEnc}[mol_enc](args, in_dim=78)
def get_fp_pred_model(args):
return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout)
def get_fp_enc_model(args):
return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0)
def get_model(model:str,
params):
if model == 'contrastive':
model= ContrastiveModel(**params)
elif model == "MultiviewContrastive":
model = MultiViewContrastive(**params)
else:
raise Exception(f"Model {model} not implemented.")
# If checkpoint path is provided, load the model from the checkpoint instead
if params['checkpoint_pth'] is not None and params['checkpoint_pth'] != "":
model = type(model).load_from_checkpoint(
params['checkpoint_pth'],
log_only_loss_at_stages=params['log_only_loss_at_stages'],
df_test_path=params['df_test_path']
)
print("Loaded Model from checkpoint")
return model |