| | import argparse |
| | import json |
| | import logging |
| | import os |
| | import time |
| |
|
| | from dataset import PeptidePairDataset, PeptidePairPicDataset, SimplePairClsDataset |
| | from network import DMutaPeptide, DMutaPeptideCNN |
| | from sklearn.model_selection import KFold |
| | from train import train_cls |
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader, WeightedRandomSampler, RandomSampler, Subset |
| | import numpy as np |
| | from loss import MLCE, SuperLoss, LogCoshLoss, BMCLoss |
| | from utils import set_seed |
| |
|
| |
|
| | parser = argparse.ArgumentParser(description='resnet26') |
| | |
| | parser.add_argument('--model', type=str, default='resnet34', |
| | help='resnet34 resnet50 densenet') |
| | parser.add_argument('--q-encoder', dest='q_encoder', type=str, default='cnn', |
| | help='lstm mamba mla') |
| | parser.add_argument("--side-enc", dest='side_enc', type=str, default='lstm', |
| | help="use side features") |
| | parser.add_argument('--channels', type=int, default=16) |
| | parser.add_argument('--fusion', type=str, default='att', |
| | help='mlp att diff') |
| | parser.add_argument('--glob-feat', dest='glob_feat', action='store_true', default=False, |
| | help="use global features") |
| | parser.add_argument('--non-siamese', dest='non_siamese', action='store_true', default=False, |
| | help="use non-siamese architecture") |
| | parser.add_argument('--widen', action='store_true', default=False, |
| | help='use widen non-siamese architecture') |
| |
|
| | |
| | parser.add_argument('--task', type=str, default='cls', |
| | help='reg or cls') |
| | parser.add_argument('--pdb-src', type=str, dest='pdb_src', default='af', |
| | help='af or hf') |
| | parser.add_argument('--data-ver', type=str, dest='data_ver', default='250228', |
| | help='data version') |
| | parser.add_argument('--one-way', action='store_true', dest='one_way', default=True, |
| | help='use one-way constructed dataset') |
| | parser.add_argument('--max-length', dest='max_length', type=int, default=30, |
| | help='Max length for sequence filtering') |
| | parser.add_argument('--split', type=int, default=5, |
| | help="Split k fold in cross validation (default: 5)") |
| | parser.add_argument('--run-folds', type=int, dest='run_folds', nargs='+', default=-1, |
| | help='specify which folds to run') |
| | parser.add_argument('--seed', type=int, default=1, |
| | help="Seed (default: 1)") |
| | parser.add_argument('--pcs', action='store_true', default=False, |
| | help='Consider protease cut site') |
| | parser.add_argument('--mix-pcs', dest='mix_pcs', action='store_true', default=False, |
| | help='Consider protease cut site') |
| | parser.add_argument('--resize', type=int, default=[768], nargs='+', |
| | help='resize the image') |
| | parser.add_argument('--llm-data', action='store_true', default=False, |
| | help='Use LLM augmentation data') |
| |
|
| | |
| | parser.add_argument('--gpu', type=int, default=0, |
| | help='GPU index to use, -1 for CPU (default: 0)') |
| | parser.add_argument('--batch-size', type=int, dest='batch_size', default=32, |
| | help='input batch size for training (default: 128)') |
| | parser.add_argument('--epochs', type=int, default=50, |
| | help='number of epochs to train (default: 100)') |
| | parser.add_argument('--lr', type=float, default=0.001, |
| | help='learning rate (default: 0.001)') |
| | parser.add_argument('--decay', type=float, default=0.0005, |
| | help='weight decay (default: 0.0005)') |
| | parser.add_argument('--warm-steps', type=int, dest='warm_steps', default=0, |
| | help='number of warm start steps for learning rate (default: 10)') |
| | parser.add_argument('--patience', type=int, default=10, |
| | help='patience for early stopping (default: 10)') |
| | parser.add_argument('--pretrain', type=str, dest='pretrain', default='', |
| | help='path of the pretrain model') |
| | parser.add_argument('--metric-avg', type=str, dest='metric_avg', default='macro', |
| | help='metric average type') |
| |
|
| | parser.add_argument('--loss', type=str, default='ce', |
| | help='loss function') |
| | parser.add_argument('--dir', action='store_true', default=False, |
| | help='use DIR') |
| |
|
| | parser.add_argument('--bias-curri', dest='bias_curri', action='store_true', default=False, |
| | help='directly use loss as the training data (biased) or not (unbiased)') |
| | parser.add_argument('--anti-curri', dest='anti_curri', action='store_true', default=False, |
| | help='easy to hard (curri), hard to easy (anti)') |
| | parser.add_argument('--std-coff', dest='std_coff', type=float, default=1, |
| | help='the hyper-parameter of std') |
| |
|
| | parser.add_argument('--ft-epochs', dest='ft_epochs', type=int, default=15, |
| | help='fine-tune epochs') |
| | parser.add_argument('--ft-lr', dest='ft_lr', type=float, default=0.0002, |
| | help='fine-tune learning rate') |
| |
|
| | parser.add_argument('--simple', dest='simple', action='store_true', default=False) |
| |
|
| | args = parser.parse_args() |
| |
|
| | if args.llm_data: |
| | args.simple = True |
| |
|
| | if args.simple: |
| | args.one_way = True |
| |
|
| | if args.run_folds == -1: |
| | args.run_folds = list(range(args.split)) |
| |
|
| | def main(): |
| | set_seed(args.seed) |
| | if args.task == 'reg': |
| | args.classes = 1 |
| | if args.loss == "mse" or args.loss in ['ce']: |
| | args.loss = 'mse' |
| | criterion = nn.MSELoss() |
| | elif args.loss == "smoothl1": |
| | criterion = nn.SmoothL1Loss() |
| | elif args.loss == "super": |
| | criterion = SuperLoss() |
| | elif args.loss in ["bmc", "bmc_ln"]: |
| | criterion = BMCLoss() |
| | else: |
| | raise NotImplementedError("unimplemented regression task loss function") |
| | elif args.task == 'cls': |
| | args.classes = 2 |
| | if args.loss == 'ce' or args.loss in ['mse', 'smoothl1', 'super']: |
| | args.loss = 'ce' |
| | criterion = nn.CrossEntropyLoss() |
| | else: |
| | raise NotImplementedError("unimplemented classification task loss function") |
| | else: |
| | raise NotImplementedError("unimplemented task") |
| | |
| | if args.q_encoder in ['cnn', 'rn18']: |
| | weight_dir = f'./run-{args.task}/{args.q_encoder}{f"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{f"-{args.side_enc}" if args.side_enc else ""}{"-mixpcs" if args.mix_pcs else ""}{"-pcs" if args.pcs==True else ""}{"-simple" if args.simple else ""}{"-llm" if args.llm_data else ""}{"-" + "x".join(str(n) for n in args.resize) if args.resize else ""}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}' |
| | else: |
| | weight_dir = f'./run-{args.task}/{args.q_encoder}{f"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{"-simple" if args.simple else ""}{"-llm" if args.llm_data else ""}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}' |
| | |
| | logging.basicConfig(handlers=[ |
| | logging.FileHandler(filename=os.path.join(weight_dir, "finetune.log"), encoding='utf-8', mode='w+'), |
| | logging.StreamHandler()], |
| | format="%(asctime)s: %(message)s", datefmt="%F %T", level=logging.INFO) |
| | |
| | logging.info(f'Finetuning: {weight_dir}') |
| |
|
| | device = torch.device("cpu" if args.gpu == -1 or not torch.cuda.is_available() else f"cuda:{args.gpu}") |
| |
|
| | logging.info(f'Loading Training Dataset') |
| | train_set = SimplePairClsDataset(pad_length=args.max_length, ftr2=True, gf=args.glob_feat, q_encoder=args.q_encoder, side_enc=args.side_enc, pcs=args.pcs, resize=args.resize) |
| |
|
| | logging.info('Loading Test Dataset') |
| | if args.q_encoder in ['cnn', 'rn18']: |
| | test_set = PeptidePairPicDataset(mode='r2_case', pad_length=args.max_length, task=args.task, gf=args.glob_feat, side_enc=args.side_enc, pcs=args.pcs, resize=args.resize) |
| | else: |
| | test_set = PeptidePairDataset(mode='r2_case', pad_length=args.max_length, task=args.task, gf=args.glob_feat) |
| |
|
| | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=8, pin_memory=True) |
| | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True) |
| |
|
| | best_perform_list = [[] for i in range(5)] |
| |
|
| | for fold in range(args.split): |
| | logging.info(f'Finetuning Fold {fold}') |
| | logging.info(f'Fold {fold} Train set:{len(train_set)}, Test set: {len(test_set)}') |
| | |
| | |
| | |
| | if args.q_encoder in ['cnn', 'rn18']: |
| | model = DMutaPeptideCNN(q_encoder=args.q_encoder, classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, side_enc=args.side_enc, fusion=args.fusion, non_siamese=args.non_siamese) |
| | else: |
| | model = DMutaPeptide(q_encoder=args.q_encoder, classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, fusion=args.fusion, non_siamese=args.non_siamese) |
| |
|
| | weights_path = f"{weight_dir}/model_{fold}.pth" |
| |
|
| | model.to(device) |
| | |
| | model.load_state_dict(torch.load(weights_path, map_location=device), strict=False) |
| |
|
| | optimizer = torch.optim.AdamW(model.parameters(), lr=args.ft_lr) |
| |
|
| | best_metric = -float('inf') |
| |
|
| | if args.task == 'cls': |
| | for epoch in range(1, args.ft_epochs + 1): |
| | train_loss, ap, auc, f1, acc = train_cls(args, epoch, model, train_loader, test_loader, device, criterion, optimizer) |
| | logging.info(f'Epoch: {epoch:03d} Train Loss: {train_loss:.3f}, ap: {ap:.3f}, auc: {auc:.3f}, f1: {f1:.3f}, acc: {acc:.3f}') |
| | avg_metric = ap + auc |
| | if avg_metric > best_metric: |
| | logging.info(f'Epoch: {epoch:03d} New best VALIDATION metrics') |
| | best_metric = avg_metric |
| | best_perform_list[fold] = np.asarray([ap, auc, f1, acc]) |
| | torch.save(model.state_dict(), weights_path.replace('.pth', '_ft.pth')) |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|