from flare.models.spec_encoder import SpecEncMLP_BIN, SpecFormulaEncMLP, SpecFormulaTransformer,SpecFormula_mz_Encoder, SpecMzIntTokenTransformer from flare.models.mol_encoder import MolEnc from flare.models.encoders import MLP from flare.models.contrastive import ContrastiveModel, CrossAttenContrastive, FilipContrastive, FilipGlobalContrastive def get_spec_encoder(spec_enc:str, args): return {"MLP_BIN": SpecEncMLP_BIN, "MLP_Formula":SpecFormulaEncMLP, "Transformer_Formula": SpecFormulaTransformer, "Formula_BinnedSpec": SpecFormula_mz_Encoder, "Transformer_MzInt": SpecMzIntTokenTransformer}[spec_enc](args) def get_mol_encoder(mol_enc: str, args): return {'GNN': MolEnc}[mol_enc](args, in_dim=78) def get_fp_pred_model(args): return MLP(in_dim=args.final_embedding_dim, hidden_dims=[args.fp_size], final_activation='sigmoid', dropout=args.fp_dropout) def get_fp_enc_model(args): return MLP(in_dim=args.fp_size, hidden_dims=[args.final_embedding_dim,args.final_embedding_dim*2,args.final_embedding_dim,], final_activation=None, dropout=0.0) def get_model(model:str, params): if model == 'contrastive': model= ContrastiveModel(**params) elif model =='crossAttenContrastive': model = CrossAttenContrastive(**params) elif model == "filipContrastive": model = FilipContrastive(**params) elif model == "filipGlobalContrastive": model = FilipGlobalContrastive(**params) else: raise Exception(f"Model {model} not implemented.") # If checkpoint path is provided, load the model from the checkpoint instead if params['checkpoint_pth'] is not None and params['checkpoint_pth'] != "": model = type(model).load_from_checkpoint( params['checkpoint_pth'], log_only_loss_at_stages=params['log_only_loss_at_stages'], df_test_path=params['df_test_path'] ) print("Loaded Model from checkpoint") return model