Spaces:
Sleeping
Sleeping
File size: 4,451 Bytes
42f26af 2c0063e 42f26af 2c0063e 42f26af 2c0063e 42f26af 2c0063e 42f26af 0b51da1 42f26af 2c0063e 42f26af 0b51da1 42f26af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import argparse
import datetime
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from rdkit import RDLogger
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from flare.data.data_module import ContrastiveDataModule
from flare.definitions import TEST_RESULTS_DIR
import yaml
from flare.data.datasets import ContrastiveDataset
from functools import partial
from flare.utils.data import get_ms_dataset, get_spec_featurizer, get_mol_featurizer
from flare.utils.models import get_model
# Suppress RDKit warnings and errors
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)
parser = argparse.ArgumentParser()
parser.add_argument("--param_pth", type=str, default="params_formSpec.yaml")
def main(params):
# Seed everything
pl.seed_everything(params['seed'])
# Init paths to data files
if params['debug']:
params['dataset_pth'] = "/data/yzhouc01/MVP/data/sample/data.tsv"
params['candidates_pth'] =None
params['split_pth']=None
# Load dataset
spec_featurizer = get_spec_featurizer(params['spectra_view'], params)
mol_featurizer = get_mol_featurizer(params['molecule_view'], params)
dataset = get_ms_dataset(params['spectra_view'], params['molecule_view'], spec_featurizer, mol_featurizer, params)
# Init data module
collate_fn = partial(ContrastiveDataset.collate_fn, spec_enc=params['spec_enc'], spectra_view=params['spectra_view'])
data_module = ContrastiveDataModule(
dataset=dataset,
collate_fn=collate_fn,
split_pth=params['split_pth'],
batch_size=params['batch_size'],
num_workers=params['num_workers'],
)
model = get_model(params['model'], params)
# Init logger
if params['no_wandb']:
logger = None
else:
logger = pl.loggers.WandbLogger(
save_dir=params['experiment_dir'],
dir=params['experiment_dir'],
log_dir=params['experiment_dir'],
name=params['run_name'],
project=params['project_name'],
log_model=False,
config=model.hparams
)
# Init callbacks for checkpointing and early stopping
callbacks = [pl.callbacks.ModelCheckpoint(save_last=False) ]
for i, monitor in enumerate(model.get_checkpoint_monitors()):
monitor_name = monitor['monitor']
checkpoint = pl.callbacks.ModelCheckpoint(
monitor=monitor_name,
save_top_k=1,
mode=monitor['mode'],
dirpath=params['experiment_dir'],
filename=f'{{epoch}}-{{{monitor_name}:.2f}}',
# filename='{epoch}-{val_loss:.2f}-{train_loss:.2f}',
auto_insert_metric_name=True,
# save_last=(i == 0)
)
callbacks.append(checkpoint)
if monitor.get('early_stopping', False):
early_stopping = EarlyStopping(
monitor=monitor_name,
mode=monitor['mode'],
verbose=True,
patience=params['early_stopping_patience'],
)
callbacks.append(early_stopping)
# Init trainer
trainer = Trainer(
accelerator=params['accelerator'],
devices=params['devices'],
max_epochs=params['max_epochs'],
logger=logger,
log_every_n_steps=params['log_every_n_steps'],
val_check_interval=params['val_check_interval'],
callbacks=callbacks,
default_root_dir=params['experiment_dir'],
)
# Prepare data module to validate or test before training
data_module.prepare_data()
data_module.setup()
# Validate before training
trainer.validate(model, datamodule=data_module)
# Train
trainer.fit(model, datamodule=data_module)
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
# Get current time
now = datetime.datetime.now()
now_formatted = now.strftime("%Y%m%d")
# Load
with open(args.param_pth) as f:
params = yaml.load(f, Loader=yaml.FullLoader)
experiment_dir = str(TEST_RESULTS_DIR / f"{now_formatted}_{params['run_name']}")
params['experiment_dir'] = experiment_dir
if not params['df_test_path']:
params['df_test_path'] = os.path.join(experiment_dir, "result.pkl")
main(params)
|