File size: 5,781 Bytes
9627ce0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import datetime
import os
import random
import sys
# import sys; sys.path.append("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark")
# sys.path.append('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom')
# sys.path.append(os.getcwd())
# os.environ["WANDB_API_KEY"] = "ddb1831ecbd2bf95c3323502ae17df6e1df44ec0" # gzy
os.environ["WANDB_API_KEY"] = "8d6e97d8cea5e94d8585723fd27477ca9436566a" # wh
import warnings
warnings.filterwarnings("ignore")
import argparse
import pandas as pd
import torch
from pytorch_lightning.trainer import Trainer
import pytorch_lightning as pl
import pytorch_lightning.callbacks as plc
from model_interface import MInterface
from data_interface import DInterface
import pytorch_lightning.loggers as plog
from src.utils.logger import SetupCallback
from pytorch_lightning.callbacks import EarlyStopping
from src.utils.utils import process_args
import math
import wandb
def create_parser():
parser = argparse.ArgumentParser()
# Set-up parameters
parser.add_argument('--res_dir', default='./results', type=str)
# parser.add_argument('--ex_name', default='debug', type=str)
parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
parser.add_argument('--offline', default=1, type=int)
parser.add_argument('--seed', default=2024, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--pretrain_batch_size', default=4, type=int)
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--seq_len', default=1022, type=int)
parser.add_argument('--gpus_per_node', default=1, type=int)
parser.add_argument('--num_nodes', default=1, type=int)
# Training parameters
parser.add_argument('--epoch', default=50, type=int, help='end epoch')
parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate')
parser.add_argument('--lr_scheduler', default='cosine')
# Model parameters
parser.add_argument('--sequence_only', default=0, type=int)
parser.add_argument('--finetune_type', default='adapter', type=str, choices=['adapter', 'peft'])
parser.add_argument('--peft_type', default='adalora', type=str, choices=['lora', 'adalora', 'ia3', 'dora', 'freeze'])
parser.add_argument('--pretrain_model_name', default='esm2_650m', type=str, choices=[
'esm2_650m', 'esm3_1.4b', 'esmc_600m', 'procyon', 'prollama', 'progen2', 'prostt5',
'protgpt2', 'protrek', 'saport', 'gearnet', 'prost', 'prosst2048', 'venusplm',
'prott5', 'dplm', 'ontoprotein', 'ankh_base', 'pglm', 'esm2_35m', 'esm2_150m',
'esm2_3b', 'esm2_15b', 'protrek_35m', 'saport_35m', 'saport_1.3b', 'dplm_150m', 'dplm_3b', 'pglm-3b'
])
parser.add_argument("--config_name", type=str, default='fitness_prediction', help="Name of the Hydra config to use")
parser.add_argument("--metric", type=str, default='val_loss', help="metric for early stop")
parser.add_argument("--direction", type=str, default='min', help="metric direction")
parser.add_argument("--enable_es", type=int, default=1, help="enable early stopping")
parser.add_argument("--feature_extraction", type=int, default=0, help="perform feature extraction(paper used only)")
parser.add_argument("--feature_save_dir", type=str, default=None, help="feature saved dir(paper used only)")
args = process_args(parser, config_path='../../tasks/configs')
print(args)
return args
def automl_setup(args, logger):
args.res_dir = os.path.join(args.res_dir, args.ex_name)
print(wandb.run)
args.ex_name = wandb.run.id
wandb.run.name = wandb.run.id
logger._save_dir = str(args.res_dir)
os.makedirs(logger._save_dir, exist_ok=True)
logger._name = wandb.run.name
logger._id = wandb.run.id
return args, logger
def main():
args = create_parser()
if args.offline:
os.environ["WANDB_MODE"] = "offline"
wandb.init(project='protein_benchmark', entity='biomap_ai', dir=str(os.path.join(args.res_dir, args.ex_name)))
logger = plog.WandbLogger(
project = 'protein_benchmark',
name=args.ex_name,
save_dir=str(os.path.join(args.res_dir, args.ex_name)),
dir = str(os.path.join(args.res_dir, args.ex_name)),
offline = args.offline,
entity = "biomap_ai")
#================ for wandb sweep ==================
args, logger = automl_setup(args, logger)
#====================================================
# generated a random seed
# args.seed = random.randint(1, 9999)
# print(f"Generated random seed: {args.seed}")
args.seed=42
pl.seed_everything(args.seed)
data_module = DInterface(**vars(args))
data_module.data_setup()
gpu_count = torch.cuda.device_count()
steps_per_epoch = math.ceil(len(data_module.train_set)/args.batch_size/gpu_count)
args.lr_decay_steps = steps_per_epoch*args.epoch
model = MInterface(**vars(args))
data_module.MInterface = model
# ============================
# 4. 评估最佳模型
# ============================
checkpoint_callback = callbacks[0]
print(f"Best model path: {checkpoint_callback.best_model_path}")
# 载入最佳模型
model_state_path = os.path.join(checkpoint_callback.best_model_path, "checkpoint", "mp_rank_00_model_states.pt")
state = torch.load(model_state_path, map_location="cuda:0")
model.load_state_dict(state['module'])
# 进行测试
results = trainer.test(model, datamodule=data_module)
# 打印测试结果
print(f"Test Results: {results}")
if __name__ == "__main__":
main()
|