DAminoMuta / tmp_test.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import argparse
import json
import logging
import os
import time
from dataset import PeptidePairDataset, PeptidePairPicDataset, SimplePairClsDataset, AA_to_index
from network import DMutaPeptide, DMutaPeptideCNN#, DMutaPeptideWiden
from sklearn.model_selection import KFold
from train import move_to_device
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler, RandomSampler, Subset
import numpy as np
from loss import MLCE, SuperLoss, LogCoshLoss, BMCLoss
from utils import set_seed
from infer_case import FasterModelForCase, CustomDataset
parser = argparse.ArgumentParser(description='resnet26')
# model setting
parser.add_argument('--model', type=str, default='resnet34',
help='resnet34 resnet50 densenet')
parser.add_argument('--q-encoder', dest='q_encoder', type=str, default='rn18',
help='lstm mamba mla')
parser.add_argument("--side-enc", dest='side_enc', type=str, default='mamba',
help="use side features")
parser.add_argument('--channels', type=int, default=16)
parser.add_argument('--fusion', type=str, default='diff',
help='mlp att diff')
parser.add_argument('--glob-feat', dest='glob_feat', action='store_true', default=False,
help="use global features")
parser.add_argument('--non-siamese', dest='non_siamese', action='store_true', default=False,
help="use non-siamese architecture")
parser.add_argument('--widen', action='store_true', default=False,
help='use widen non-siamese architecture')
# task & dataset setting
parser.add_argument('--task', type=str, default='cls',
help='reg or cls')
parser.add_argument('--pdb-src', type=str, dest='pdb_src', default='af',
help='af or hf')
parser.add_argument('--data-ver', type=str, dest='data_ver', default='250228',
help='data version')
parser.add_argument('--one-way', action='store_true', dest='one_way', default=False,
help='use one-way constructed dataset')
parser.add_argument('--max-length', dest='max_length', type=int, default=30,
help='Max length for sequence filtering')
parser.add_argument('--split', type=int, default=5,
help="Split k fold in cross validation (default: 5)")
parser.add_argument('--run-folds', type=int, dest='run_folds', nargs='+', default=-1,
help='specify which folds to run')
parser.add_argument('--seed', type=int, default=1,
help="Seed (default: 1)")
parser.add_argument('--pcs', action='store_true', default=False,
help='Consider protease cut site')
parser.add_argument('--mix-pcs', dest='mix_pcs', action='store_true', default=False,
help='Consider protease cut site')
parser.add_argument('--resize', type=int, default=[768], nargs='+',
help='resize the image')
parser.add_argument('--llm-data', action='store_true', default=False,
help='Use LLM augmentation data')
# training setting
parser.add_argument('--gpu', type=int, default=0,
help='GPU index to use, -1 for CPU (default: 0)')
parser.add_argument('--batch-size', type=int, dest='batch_size', default=32,
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=50,
help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001,
help='learning rate (default: 0.001)')
parser.add_argument('--decay', type=float, default=0.0005,
help='weight decay (default: 0.0005)')
parser.add_argument('--warm-steps', type=int, dest='warm_steps', default=0,
help='number of warm start steps for learning rate (default: 10)')
parser.add_argument('--patience', type=int, default=10,
help='patience for early stopping (default: 10)')
parser.add_argument('--pretrain', type=str, dest='pretrain', default='',
help='path of the pretrain model')
parser.add_argument('--metric-avg', type=str, dest='metric_avg', default='macro',
help='metric average type')
parser.add_argument('--loss', type=str, default='ce',
help='loss function')
parser.add_argument('--dir', action='store_true', default=False,
help='use DIR')
parser.add_argument('--bias-curri', dest='bias_curri', action='store_true', default=False,
help='directly use loss as the training data (biased) or not (unbiased)')
parser.add_argument('--anti-curri', dest='anti_curri', action='store_true', default=False,
help='easy to hard (curri), hard to easy (anti)')
parser.add_argument('--std-coff', dest='std_coff', type=float, default=1,
help='the hyper-parameter of std')
parser.add_argument('--ft-epochs', dest='ft_epochs', type=int, default=15,
help='fine-tune epochs')
parser.add_argument('--ft-lr', dest='ft_lr', type=float, default=0.0002,
help='fine-tune learning rate')
parser.add_argument('--simple', dest='simple', action='store_true', default=False)
args = parser.parse_args()
if args.llm_data:
args.simple = True
if args.simple:
args.one_way = True
if args.run_folds == -1:
args.run_folds = list(range(args.split))
def main():
set_seed(args.seed)
if args.task == 'reg':
args.classes = 1
if args.loss == "mse" or args.loss in ['ce']:
args.loss = 'mse'
criterion = nn.MSELoss()
elif args.loss == "smoothl1":
criterion = nn.SmoothL1Loss()
elif args.loss == "super":
criterion = SuperLoss()
elif args.loss in ["bmc", "bmc_ln"]:
criterion = BMCLoss()
else:
raise NotImplementedError("unimplemented regression task loss function")
elif args.task == 'cls':
args.classes = 2
if args.loss == 'ce' or args.loss in ['mse', 'smoothl1', 'super']:
args.loss = 'ce'
criterion = nn.CrossEntropyLoss()
else:
raise NotImplementedError("unimplemented classification task loss function")
else:
raise NotImplementedError("unimplemented task")
if args.q_encoder in ['cnn', 'rn18']:
weight_dir = f'./run-{args.task}/{args.q_encoder}{f"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{f"-{args.side_enc}" if args.side_enc else ""}{"-mixpcs" if args.mix_pcs else ""}{"-pcs" if args.pcs==True else ""}{"-simple" if args.simple else ""}{"-llm" if args.llm_data else ""}{"-" + "x".join(str(n) for n in args.resize) if args.resize else ""}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}'
else:
weight_dir = f'./run-{args.task}/{args.q_encoder}{f"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{"-simple" if args.simple else ""}{"-llm" if args.llm_data else ""}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}'
logging.basicConfig(handlers=[
# logging.FileHandler(filename=os.path.join(weight_dir, "finetune.log"), encoding='utf-8', mode='w+'),
logging.StreamHandler()],
format="%(asctime)s: %(message)s", datefmt="%F %T", level=logging.INFO)
logging.info(f'Finetuning: {weight_dir}')
device = torch.device("cpu" if args.gpu == -1 or not torch.cuda.is_available() else f"cuda:{args.gpu}")
# logging.info(f'Loading Training Dataset')
# train_set = SimplePairClsDataset(pad_length=args.max_length, ftr2=True, gf=args.glob_feat, q_encoder=args.q_encoder, side_enc=args.side_enc, pcs=args.pcs, resize=args.resize)
logging.info('Loading Test Dataset')
if args.q_encoder in ['cnn', 'rn18']:
test_set = PeptidePairPicDataset(mode='train', pad_length=args.max_length, task=args.task, one_way=True, gf=args.glob_feat, side_enc=args.side_enc, pcs=args.pcs, resize=args.resize)
# test_set = CustomDataset(case='r2', pad_length=args.max_length, side_enc=args.side_enc, pcs=True, resize=args.resize, gf=args.glob_feat)
# train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True)
best_perform_list = [[] for i in range(5)]
# for fold in range(args.split):
for fold in [0]:
logging.info(f'Finetuning Fold {fold}')
logging.info(f'Fold {fold}, Test set: {len(test_set)}')
# if args.widen:
# model = DMutaPeptideWiden(q_encoder=args.q_encoder, classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, fusion=args.fusion, side_enc=args.side_enc)
# else:
# model = FasterModelForCase(classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, side_enc=args.side_enc, fusion=args.fusion, non_siamese=args.non_siamese).to(device).eval()
model = DMutaPeptideCNN(q_encoder=args.q_encoder, classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, side_enc=args.side_enc, fusion=args.fusion, non_siamese=args.non_siamese).to(device).eval()
weights_path = f"{weight_dir}/model_{fold}.pth"
model.load_state_dict(torch.load(weights_path, map_location=device))
# model.load_state_dict(torch.load(weights_path.replace('.pth', '_test.pth'), map_location=device), strict=False)
if args.task == 'cls':
train_cls(args, None, model, None, test_loader, device, criterion, None)
def train_cls(args, epoch, model, train_loader, valid_loader, device, criterion, optimizer):
num_labels = model.classes
model.eval()
seqs_t, seqs_d = [], []
preds = []
gt_list_valid = []
with torch.no_grad():
for data in valid_loader:
x, gt = data
seqs_1, seqs_2 = get_seq_from_batched_data(x)
seqs_t.extend(seqs_1)
seqs_d.extend(seqs_2)
# x1, x2 = move_to_device(x, device)
x = move_to_device(x, device)
# model.cache_temp_vector(x1)
gt_list_valid.append(gt.to(device))
# out = model(x2)
out = model(x)
preds.append(out)
# calculate metrics
preds = torch.softmax(torch.cat(preds, dim=0), dim=-1).squeeze()
gt_list_valid = torch.cat(gt_list_valid, dim=0).int().squeeze()
preds = (preds[:, 1] > 0.5).int()
wrong_preds = (preds != gt_list_valid)
for i in range(len(wrong_preds)):
if wrong_preds[i]:
print(f"{seqs_t[i]} {seqs_d[i]} {preds[i]} {gt_list_valid[i]}")
index_to_aa = {v: k for k, v in AA_to_index.items()}
def get_seq_from_batched_data(x):
seq_encs_1, seq_encs_2 = x[0][1], x[1][1]
seqs_1 = get_seq_from_enc(seq_encs_1)
seqs_2 = get_seq_from_enc(seq_encs_2)
return seqs_1, seqs_2
def get_seq_from_enc(enc:torch.Tensor):
encs = enc.cpu().numpy()
seqs = []
for enc in encs:
seq = ''
d_indicator = enc[:, 0].astype(bool)
enc[:, 0] = 0.
index = np.argmax(enc, axis=-1) - 1
for d, i in zip(d_indicator, index):
if i < 0:
break
if d:
seq += index_to_aa[i].lower()
else:
seq += index_to_aa[i]
seqs.append(seq)
return seqs
if __name__ == "__main__":
main()