File size: 5,781 Bytes
9627ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import datetime
import os
import random
import sys
# import sys; sys.path.append("/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark")
# sys.path.append('/nfs_beijing/kubeflow-user/zhangyang_2024/workspace/protein_benchmark/model_zoom')
# sys.path.append(os.getcwd())
# os.environ["WANDB_API_KEY"] = "ddb1831ecbd2bf95c3323502ae17df6e1df44ec0" # gzy
os.environ["WANDB_API_KEY"] = "8d6e97d8cea5e94d8585723fd27477ca9436566a" # wh
import warnings
warnings.filterwarnings("ignore")
import argparse
import pandas as pd
import torch
from pytorch_lightning.trainer import Trainer
import pytorch_lightning as pl
import pytorch_lightning.callbacks as plc
from model_interface import MInterface
from data_interface import DInterface
import pytorch_lightning.loggers as plog
from src.utils.logger import SetupCallback
from pytorch_lightning.callbacks import EarlyStopping
from src.utils.utils import process_args
import math
import wandb


def create_parser():
    parser = argparse.ArgumentParser()
    
    # Set-up parameters
    parser.add_argument('--res_dir', default='./results', type=str)
    # parser.add_argument('--ex_name', default='debug', type=str)
    parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
    parser.add_argument('--offline', default=1, type=int)
    parser.add_argument('--seed', default=2024, type=int)
    
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--pretrain_batch_size', default=4, type=int)
    parser.add_argument('--num_workers', default=4, type=int)
    parser.add_argument('--seq_len', default=1022, type=int)
    parser.add_argument('--gpus_per_node', default=1, type=int)
    parser.add_argument('--num_nodes', default=1, type=int)
    
    # Training parameters
    parser.add_argument('--epoch', default=50, type=int, help='end epoch')
    parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate')
    parser.add_argument('--lr_scheduler', default='cosine')
    
    # Model parameters
    parser.add_argument('--sequence_only', default=0, type=int)
    parser.add_argument('--finetune_type', default='adapter', type=str, choices=['adapter', 'peft'])
    parser.add_argument('--peft_type', default='adalora', type=str, choices=['lora', 'adalora', 'ia3', 'dora', 'freeze'])
    parser.add_argument('--pretrain_model_name', default='esm2_650m', type=str, choices=[
        'esm2_650m', 'esm3_1.4b', 'esmc_600m', 'procyon', 'prollama', 'progen2', 'prostt5', 
        'protgpt2', 'protrek', 'saport', 'gearnet', 'prost', 'prosst2048', 'venusplm', 
        'prott5', 'dplm', 'ontoprotein', 'ankh_base', 'pglm', 'esm2_35m', 'esm2_150m', 
        'esm2_3b',  'esm2_15b', 'protrek_35m', 'saport_35m', 'saport_1.3b', 'dplm_150m', 'dplm_3b', 'pglm-3b'
    ])
    parser.add_argument("--config_name", type=str, default='fitness_prediction', help="Name of the Hydra config to use")
    parser.add_argument("--metric", type=str, default='val_loss', help="metric for early stop")
    parser.add_argument("--direction", type=str, default='min', help="metric direction")
    parser.add_argument("--enable_es", type=int, default=1, help="enable early stopping")
    parser.add_argument("--feature_extraction", type=int, default=0, help="perform feature extraction(paper used only)")
    parser.add_argument("--feature_save_dir", type=str, default=None, help="feature saved dir(paper used only)")
   
    args = process_args(parser, config_path='../../tasks/configs')
    print(args)
    return args

def automl_setup(args, logger):
    args.res_dir = os.path.join(args.res_dir, args.ex_name)
    print(wandb.run)
    args.ex_name = wandb.run.id
    wandb.run.name = wandb.run.id
    logger._save_dir = str(args.res_dir)
    os.makedirs(logger._save_dir, exist_ok=True)
    logger._name = wandb.run.name
    logger._id = wandb.run.id
    return args, logger
    

def main():
    args = create_parser()
    
    if args.offline:
        os.environ["WANDB_MODE"] = "offline"
    wandb.init(project='protein_benchmark', entity='biomap_ai', dir=str(os.path.join(args.res_dir, args.ex_name)))
    logger = plog.WandbLogger(
                    project = 'protein_benchmark',
                    name=args.ex_name,
                    save_dir=str(os.path.join(args.res_dir, args.ex_name)),
                    dir = str(os.path.join(args.res_dir, args.ex_name)),
                    offline = args.offline,
                    entity = "biomap_ai")
    
    #================ for wandb sweep ==================
    args, logger = automl_setup(args, logger)
    #====================================================
    
    # generated a random seed
    # args.seed = random.randint(1, 9999)
    # print(f"Generated random seed: {args.seed}")
    args.seed=42
    pl.seed_everything(args.seed)

    data_module = DInterface(**vars(args))

    data_module.data_setup()
    gpu_count = torch.cuda.device_count()
    steps_per_epoch = math.ceil(len(data_module.train_set)/args.batch_size/gpu_count)
    args.lr_decay_steps =  steps_per_epoch*args.epoch
    
    model = MInterface(**vars(args))

    data_module.MInterface = model

    # ============================
    # 4. 评估最佳模型
    # ============================
    checkpoint_callback = callbacks[0]
    print(f"Best model path: {checkpoint_callback.best_model_path}")

    # 载入最佳模型
    model_state_path = os.path.join(checkpoint_callback.best_model_path, "checkpoint", "mp_rank_00_model_states.pt")
    state = torch.load(model_state_path, map_location="cuda:0")
    model.load_state_dict(state['module'])

    # 进行测试
    results = trainer.test(model, datamodule=data_module)
    # 打印测试结果
    print(f"Test Results: {results}")


if __name__ == "__main__":
    main()