File size: 5,042 Bytes
4707555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import wandb
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils import data
from transformers import DebertaTokenizerFast, T5EncoderModel, T5Config
import argparse
import pickle
from functools import partial
from typing import Callable, Optional, Union, Tuple, List
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
import timm
import pandas as pd
import torch.nn.functional as F
import transformers
from scipy.stats import pearsonr, spearmanr
from timm.layers import Mlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
from sklearn.metrics import accuracy_score
from model import mRNA2vec, T5_encoder,Regression_Model
from LoadData import DataLoad_downstream
SEED = 30

# Set random seed for reproducibility
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

setup_seed(SEED)



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--task_name", type=str, default='stability')
    parser.add_argument("--exp_name", type=str, default='unload')
    parser.add_argument("--data_path", type=str,default='data1.csv')
    parser.add_argument("--model_name", type=str,default="model_d2v_mfe0.01_ss0.001_warmup.pt")
    parser.add_argument("--load_model", type=str, default='False')
    parser.add_argument("--cuda_device", type=str, default='0')
    args = parser.parse_args()
    print(args)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device
    #torch.cuda.set_device(int(args.cuda_device))
    task_name = args.task_name
    data_path = args.data_path
    model_name = args.model_name
    DATA_PATH = f'./data/downstream/{task_name}/{data_path}'
    MODEL_PATH = f'./checkpoint/{model_name}'
    TOKENIZER_PATH = './tokenizer/'
    BATCH_SIZE = 256
    EPOCHS = 100
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-2
   
    wandb.init(project=f"mRNA_exp_linear_{task_name}_Folders", dir='./', name=args.exp_name)
    # scaler = torch.cuda.amp.GradScaler()

    model = Regression_Model(num_attention_heads=4, num_hidden_layers=4, pad_token_id=1, hidden_size=256)
    ckpt = torch.load(MODEL_PATH, map_location='cpu')
    if args.load_model == 'True':
        print('loading model---------')
        model.T5_encoder.load_state_dict(ckpt['encoder'], strict=True)

    train_db = DataLoad_downstream(mode='train', data_path=DATA_PATH, tokenizer_path=TOKENIZER_PATH)
    train_loader = torch.utils.data.DataLoader(train_db, batch_size=BATCH_SIZE, num_workers=1, shuffle=True)
    valid_db = DataLoad_downstream(mode='test', data_path=DATA_PATH, tokenizer_path=TOKENIZER_PATH)
    val_loader = torch.utils.data.DataLoader(valid_db, batch_size=BATCH_SIZE, num_workers=1)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    lr_decay = transformers.get_wsd_schedule(optimizer=optimizer,
                                             num_warmup_steps = len(train_loader) * 40,
                                             num_stable_steps = len(train_loader) * 20,
                                             num_decay_steps = len(train_loader) * 40,
                                             )

    best_spear = 0.
    for e in range(EPOCHS):
        loss_lst = []
        for no, batch_train in enumerate(train_loader):
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                x, attention_mask,label = [x.to('cuda') for x in batch_train]
                loss = model(x, attention_mask, label)
                # scaler.scale(loss).backward()
                # scaler.step(optimizer)
                # scaler.update()
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                lr_decay.step()
                loss_lst.append(loss.item())
        current_lr = optimizer.param_groups[0]['lr']
        wandb.log({"train loss": np.mean(loss_lst),
                   'lr':current_lr,})
        model.eval()
        label_lst = []
        predict_lst = []
        for no, batch_test in enumerate(val_loader):
            x, attention_mask,label = [x.to('cuda') for x in batch_test]
            with torch.no_grad():
                outs = model.forward_logit_linear(x, attention_mask)
            label_lst.extend(label.reshape(-1).cpu().numpy().tolist())
            predict_lst.extend(outs.reshape(-1).cpu().numpy().tolist())
        spearman_corr = spearmanr(predict_lst, label_lst)[0].item()
        if best_spear < spearman_corr:
            best_spear = spearman_corr
        print(f'test on {task_name}', e, spearman_corr, best_spear)
        wandb.log({"eval_spearman": spearman_corr,
                   'best_spearman':best_spear,})
        model.train()

if __name__ == '__main__':
    main()