File size: 1,563 Bytes
d9df210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from mvp.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaTransformer
from mvp.models.mol_encoder import MolEnc
from mvp.models.encoders import MLP
from mvp.models.contrastive import ContrastiveModel, MultiViewContrastive

def get_spec_encoder(spec_enc:str, args):
    return {"MLP_BIN": SpecEncMLP_BIN,
            "Transformer_Formula": SpecFormulaTransformer}[spec_enc](args)

def get_mol_encoder(mol_enc: str, args):
    return {'GNN': MolEnc}[mol_enc](args, in_dim=78)

def get_fp_pred_model(args):
    return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout)

def get_fp_enc_model(args):
    return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0)

def get_model(model:str,
              params):
    
    if model == 'contrastive':
        model= ContrastiveModel(**params)
    elif model == "MultiviewContrastive":
        model = MultiViewContrastive(**params)
    else:
        raise Exception(f"Model {model} not implemented.")
    
    # If checkpoint path is provided, load the model from the checkpoint instead
    if params['checkpoint_pth'] is not None and params['checkpoint_pth'] != "":
        model = type(model).load_from_checkpoint(
            params['checkpoint_pth'],
            log_only_loss_at_stages=params['log_only_loss_at_stages'],
            df_test_path=params['df_test_path']
        )
        print("Loaded Model from checkpoint")

    return model