import datetime import os import random import sys # import sys; sys.path.append("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark") # sys.path.append('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom') # sys.path.append(os.getcwd()) # os.environ["WANDB_API_KEY"] = "ddb1831ecbd2bf95c3323502ae17df6e1df44ec0" # gzy os.environ["WANDB_API_KEY"] = "8d6e97d8cea5e94d8585723fd27477ca9436566a" # wh import warnings warnings.filterwarnings("ignore") import argparse import pandas as pd import torch from pytorch_lightning.trainer import Trainer import pytorch_lightning as pl import pytorch_lightning.callbacks as plc from model_interface import MInterface from data_interface import DInterface import pytorch_lightning.loggers as plog from src.utils.logger import SetupCallback from pytorch_lightning.callbacks import EarlyStopping from src.utils.utils import process_args import math import wandb def create_parser(): parser = argparse.ArgumentParser() # Set-up parameters parser.add_argument('--res_dir', default='./results', type=str) # parser.add_argument('--ex_name', default='debug', type=str) parser.add_argument('--check_val_every_n_epoch', default=1, type=int) parser.add_argument('--offline', default=1, type=int) parser.add_argument('--seed', default=2024, type=int) parser.add_argument('--batch_size', default=32, type=int) parser.add_argument('--pretrain_batch_size', default=4, type=int) parser.add_argument('--num_workers', default=4, type=int) parser.add_argument('--seq_len', default=1022, type=int) parser.add_argument('--gpus_per_node', default=1, type=int) parser.add_argument('--num_nodes', default=1, type=int) # Training parameters parser.add_argument('--epoch', default=50, type=int, help='end epoch') parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate') parser.add_argument('--lr_scheduler', default='cosine') # Model parameters parser.add_argument('--sequence_only', default=0, type=int) parser.add_argument('--finetune_type', default='adapter', type=str, choices=['adapter', 'peft']) parser.add_argument('--peft_type', default='adalora', type=str, choices=['lora', 'adalora', 'ia3', 'dora', 'freeze']) parser.add_argument('--pretrain_model_name', default='esm2_650m', type=str, choices=[ 'esm2_650m', 'esm3_1.4b', 'esmc_600m', 'procyon', 'prollama', 'progen2', 'prostt5', 'protgpt2', 'protrek', 'saport', 'gearnet', 'prost', 'prosst2048', 'venusplm', 'prott5', 'dplm', 'ontoprotein', 'ankh_base', 'pglm', 'esm2_35m', 'esm2_150m', 'esm2_3b', 'esm2_15b', 'protrek_35m', 'saport_35m', 'saport_1.3b', 'dplm_150m', 'dplm_3b', 'pglm-3b' ]) parser.add_argument("--config_name", type=str, default='fitness_prediction', help="Name of the Hydra config to use") parser.add_argument("--metric", type=str, default='val_loss', help="metric for early stop") parser.add_argument("--direction", type=str, default='min', help="metric direction") parser.add_argument("--enable_es", type=int, default=1, help="enable early stopping") parser.add_argument("--feature_extraction", type=int, default=0, help="perform feature extraction(paper used only)") parser.add_argument("--feature_save_dir", type=str, default=None, help="feature saved dir(paper used only)") args = process_args(parser, config_path='../../tasks/configs') print(args) return args def automl_setup(args, logger): args.res_dir = os.path.join(args.res_dir, args.ex_name) print(wandb.run) args.ex_name = wandb.run.id wandb.run.name = wandb.run.id logger._save_dir = str(args.res_dir) os.makedirs(logger._save_dir, exist_ok=True) logger._name = wandb.run.name logger._id = wandb.run.id return args, logger def main(): args = create_parser() if args.offline: os.environ["WANDB_MODE"] = "offline" wandb.init(project='protein_benchmark', entity='biomap_ai', dir=str(os.path.join(args.res_dir, args.ex_name))) logger = plog.WandbLogger( project = 'protein_benchmark', name=args.ex_name, save_dir=str(os.path.join(args.res_dir, args.ex_name)), dir = str(os.path.join(args.res_dir, args.ex_name)), offline = args.offline, entity = "biomap_ai") #================ for wandb sweep ================== args, logger = automl_setup(args, logger) #==================================================== # generated a random seed # args.seed = random.randint(1, 9999) # print(f"Generated random seed: {args.seed}") args.seed=42 pl.seed_everything(args.seed) data_module = DInterface(**vars(args)) data_module.data_setup() gpu_count = torch.cuda.device_count() steps_per_epoch = math.ceil(len(data_module.train_set)/args.batch_size/gpu_count) args.lr_decay_steps = steps_per_epoch*args.epoch model = MInterface(**vars(args)) data_module.MInterface = model # ============================ # 4. 评估最佳模型 # ============================ checkpoint_callback = callbacks[0] print(f"Best model path: {checkpoint_callback.best_model_path}") # 载入最佳模型 model_state_path = os.path.join(checkpoint_callback.best_model_path, "checkpoint", "mp_rank_00_model_states.pt") state = torch.load(model_state_path, map_location="cuda:0") model.load_state_dict(state['module']) # 进行测试 results = trainer.test(model, datamodule=data_module) # 打印测试结果 print(f"Test Results: {results}") if __name__ == "__main__": main()