import argparse import json import logging import os import time from copy import deepcopy from dataset import PeptidePairPicCaseDataset, PeptidePairPicDataset from network import DMutaPeptide, DMutaPeptideCNN from sklearn.model_selection import KFold from train import move_to_device import torch import torch.nn as nn from torch.utils.data import DataLoader, Subset, RandomSampler import numpy as np from loss import MLCE, SuperLoss, LogCoshLoss, BMCLoss from utils import set_seed from torchmetrics import MeanAbsoluteError, RelativeSquaredError, PearsonCorrCoef, KendallRankCorrCoef parser = argparse.ArgumentParser(description='resnet26') # model setting parser.add_argument('--model', type=str, default='resnet34', help='resnet34 resnet50 densenet') parser.add_argument('--q-encoder', dest='q_encoder', type=str, default='lstm', help='lstm mamba mla') parser.add_argument("--side-enc", dest='side_enc', type=str, default=None, help="use side features") parser.add_argument('--channels', type=int, default=256) parser.add_argument('--fusion', type=str, default='att', help='mlp att diff') parser.add_argument('--glob-feat', dest='glob_feat', action='store_true', default=False, help="use global features") parser.add_argument('--non-siamese', dest='non_siamese', action='store_true', default=False, help="use non-siamese architecture") # task & dataset setting parser.add_argument('--task', type=str, default='reg', help='reg or cls') parser.add_argument('--one-way', action='store_true', dest='one_way', default=False, help='use one-way constructed dataset') parser.add_argument('--max-length', dest='max_length', type=int, default=30, help='Max length for sequence filtering') parser.add_argument('--split', type=int, default=5, help="Split k fold in cross validation (default: 5)") parser.add_argument('--seed', type=int, default=1, help="Seed (default: 1)") parser.add_argument('--pcs', action='store_true', default=False, help='Consider protease cleavage site') parser.add_argument('--mix-pcs', dest='mix_pcs', action='store_true', default=False, help='Consider protease cleavage site') parser.add_argument('--resize', type=int, default=[768], nargs='+', help='resize the image') # parser.add_argument('--llm-data', action='store_true', default=False, # help='Use LLM augmentation data') # training setting parser.add_argument('--gpu', type=int, default=0, help='GPU index to use, -1 for CPU (default: 0)') parser.add_argument('--batch-size', type=int, dest='batch_size', default=32, help='input batch size for training (default: 128)') parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument('--decay', type=float, default=0.0005, help='weight decay (default: 0.0005)') parser.add_argument('--pretrain', type=str, dest='pretrain', default='', help='path of the pretrain model') parser.add_argument('--metric-avg', type=str, dest='metric_avg', default='macro', help='metric average type') parser.add_argument('--loss', type=str, default='mse', help='loss function') parser.add_argument('--dir', action='store_true', default=False, help='use DIR') parser.add_argument('--case', type=str, default='r2') parser.add_argument('--iter-num', dest='iter_num', type=int, default=1000) args = parser.parse_args() def noise_and_move(x, intensity: float = 0.05, device=torch.device('cpu')): if isinstance(x, (tuple, list)): return type(x)(noise_and_move(x_i, intensity, device) for x_i in x) return (x + torch.randn_like(x) * intensity).to(device) def main(): set_seed(args.seed) if args.task == 'reg': args.classes = 1 if args.loss == "mse" or args.loss in ['ce']: args.loss = 'mse' else: raise NotImplementedError("unimplemented regression task loss function") elif args.task == 'cls': args.classes = 2 if args.loss == 'ce' or args.loss in ['mse', 'smoothl1', 'super']: args.loss = 'ce' else: raise NotImplementedError("unimplemented classification task loss function") else: raise NotImplementedError("unimplemented task") if args.q_encoder in ['cnn', 'rn18']: weight_dir = f'./run-{args.task}/{args.q_encoder}{"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{f"-{args.side_enc}" if args.side_enc else ""}{"-mixpcs" if args.mix_pcs else ""}{"-pcs" if args.pcs==True else ""}{"-" + "x".join(str(n) for n in args.resize) if args.resize else ""}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}' else: weight_dir = f'./run-{args.task}/{args.q_encoder}{f"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}' logging.basicConfig(handlers=[ logging.FileHandler(filename=os.path.join(weight_dir, "sfda_tuning.log"), encoding='utf-8', mode='w+'), logging.StreamHandler()], format="%(asctime)s: %(message)s", datefmt="%F %T", level=logging.INFO) device = torch.device("cpu" if args.gpu == -1 or not torch.cuda.is_available() else f"cuda:{args.gpu}") dataset = PeptidePairPicCaseDataset(case=args.case, pad_length=args.max_length, side_enc=args.side_enc, pcs=args.pcs, resize=args.resize, gf=args.glob_feat) sampler = RandomSampler(dataset, replacement=True, num_samples=args.iter_num * args.batch_size // 2) dataloader = DataLoader(dataset, batch_size=args.batch_size // 2, sampler=sampler, num_workers=16, pin_memory=True) valset = PeptidePairPicDataset(mode='r2_case', pad_length=args.max_length, side_enc=args.side_enc, pcs=args.pcs, resize=args.resize, gf=args.glob_feat) valloader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=16, pin_memory=True) criterion = torch.nn.MSELoss() metric_funcs = { 'mae': MeanAbsoluteError().to(device), 'rse': RelativeSquaredError().to(device), 'pcc': PearsonCorrCoef().to(device), 'kcc': KendallRankCorrCoef().to(device) } best_perform_list = [[] for _ in range(args.split)] for fold in range(args.split): logging.info(f"Fold {fold}") weights_path = f"{weight_dir}/model_{fold}.pth" model = DMutaPeptideCNN(q_encoder=args.q_encoder, classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, side_enc=args.side_enc, fusion=args.fusion, non_siamese=args.non_siamese) model.load_state_dict(torch.load(weights_path)) model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) best_val_metric = -float('inf') for iteration, (x, _) in enumerate(dataloader, 1): x1 = noise_and_move(x, 0.05, device) x2 = noise_and_move(x, 0.2, device) y1 = model(x1) y2 = model(x2) loss = criterion(y1, y2) optimizer.zero_grad() loss.backward() optimizer.step() if iteration % 10 == 0: with torch.no_grad(): val_pred, val_gt = [], [] for x, gt in valloader: x = move_to_device(x, device, non_blocking=True) out = model(x) val_pred.append(out) val_gt.append(gt.to(device, non_blocking=True)) val_pred = torch.cat(val_pred, dim=0) val_gt = torch.cat(val_gt, dim=0) val_mae = metric_funcs['mae'](val_pred, val_gt).item() val_rse = metric_funcs['rse'](val_pred, val_gt).item() val_pcc = metric_funcs['pcc'](val_pred, val_gt).item() val_kcc = metric_funcs['kcc'](val_pred, val_gt).item() val_metric = val_pcc + val_kcc - val_mae - val_rse logging.info(f'Iteration {iteration}, Train Loss: {loss.item():.4f}, Val: mae: {val_mae:.4f} rse: {val_rse:.4f} pcc: {val_pcc:.4f} kcc: {val_kcc:.4f}') if val_metric > best_val_metric: logging.info('NEW best validation iteration') best_val_metric = val_metric best_perform_list[fold] = [val_mae, val_rse, val_pcc, val_kcc] torch.save(model.state_dict(), weights_path.replace('.pth', '_sfda.pth')) logging.info(f'SFDA Tuning Finished!') best_perform_list = np.asarray(best_perform_list) logging.info('Best validation perform list\n%s', best_perform_list) logging.info('mean: %s', np.round(np.mean(best_perform_list, 0), 4)) logging.info('std: %s', np.round(np.std(best_perform_list, 0), 4)) if __name__ == '__main__': main()