import argparse import time from dataset import PeptidePairPicCaseDataset, encode_sequence from network import DMutaPeptideCNN from train import move_to_device import torch import torch.nn as nn from torch.utils.data import DataLoader import numpy as np from utils import set_seed import pandas as pd parser = argparse.ArgumentParser(description='resnet26') # model setting parser.add_argument('--model', type=str, default='resnet34', help='resnet34 resnet50 densenet') parser.add_argument('--q-encoder', dest='q_encoder', type=str, default='cnn', help='lstm mamba mla') parser.add_argument('--channels', type=int, default=16) parser.add_argument("--side-enc", dest='side_enc', type=str, default='lstm', help="use side features") parser.add_argument('--fusion', type=str, default='att', help='mlp att') parser.add_argument('--glob-feat', dest='glob_feat', action='store_true', default=False, help="use global features") parser.add_argument('--non-siamese', dest='non_siamese', action='store_true', default=False, help="use non-siamese architecture") # task & dataset setting parser.add_argument('--task', type=str, default='cls', help='reg or cls') parser.add_argument('--one-way', action='store_true', dest='one_way', default=False, help='use one-way constructed dataset') parser.add_argument('--max-length', dest='max_length', type=int, default=30, help='Max length for sequence filtering') parser.add_argument('--resize', type=int, default=[768], nargs='+', help='resize the image') parser.add_argument('--split', type=int, default=5, help="Split k fold in cross validation (default: 5)") parser.add_argument('--seed', type=int, default=1, help="Seed for model initialization (default: 1)") parser.add_argument('--pcs', action='store_true', default=False, help='Consider protease cut site') parser.add_argument('--mix-pcs', dest='mix_pcs', action='store_true', default=False, help='Consider protease cut site') # training setting parser.add_argument('--gpu', type=int, default=0, help='GPU index to use, -1 for CPU (default: 0)') parser.add_argument('--batch-size', type=int, dest='batch_size', default=32, help='input batch size for training (default: 128)') parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train (default: 100)') parser.add_argument('--lr', type=float, default=0.001, help='learning rate (default: 0.001)') parser.add_argument('--decay', type=float, default=0.0005, help='weight decay (default: 0.0005)') parser.add_argument('--pretrain', type=str, dest='pretrain', default='', help='path of the pretrain model') parser.add_argument('--metric-avg', type=str, dest='metric_avg', default='macro', help='metric average type') parser.add_argument('--loss', type=str, default='ce', help='loss function') parser.add_argument('--dir', action='store_true', default=False, help='use DIR') parser.add_argument('--simple', dest='simple', action='store_true', default=False) parser.add_argument('--llm-data', dest='llm_data', action='store_true', default=False) # Case Study Specific parser.add_argument('--case', type=str, default='r2', help='case to infer') parser.add_argument('--use-ft', dest='use_ft', action='store_true', default=False) args = parser.parse_args() if args.llm_data: args.simple = True if args.simple: args.one_way = True if args.mix_pcs: args.pcs = 'mix' if args.gpu != -1: torch.backends.cudnn.benchmark = True torch.set_float32_matmul_precision('high') def load_model(args, weight_path, device): model = DMutaPeptideCNN(q_encoder=args.q_encoder, classes=args.classes, channels=args.channels, dir=args.dir, gf=args.glob_feat, side_enc=args.side_enc, fusion=args.fusion, non_siamese=args.non_siamese).to(device).eval() model.load_state_dict(torch.load(weight_path, map_location=device), strict=False) model.compile() return model def main(): set_seed(args.seed) if args.task == 'reg': args.classes = 1 elif args.task == 'cls': args.classes = 2 else: raise NotImplementedError("unimplemented task") weight_dir = f'./run-{args.task}/{args.q_encoder}{f"-non-siamese" if args.non_siamese else ""}-{args.fusion}-{args.channels}{f"-{args.side_enc}" if args.side_enc else ""}{"-mixpcs" if args.mix_pcs else ""}{"-pcs" if args.pcs==True else ""}{"-simple" if args.simple else ""}{"-llm" if args.llm_data else ""}{"-" + "x".join(str(n) for n in args.resize) if args.resize else ""}{"-gf" if args.glob_feat else ""}{"-oneway" if args.one_way else ""}-{args.loss + "-dir" if args.dir else args.loss}-{str(args.batch_size)}-{str(args.lr)}-{str(args.epochs)}' device = torch.device("cpu" if args.gpu == -1 or not torch.cuda.is_available() else f"cuda:{args.gpu}") print(weight_dir) print(device) test_set = PeptidePairPicCaseDataset(case=args.case, pad_length=args.max_length, side_enc=args.side_enc, pcs=True, resize=args.resize, gf=args.glob_feat) test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=16, pin_memory=True) # test_loader = DataLoader(test_set, batch_size=192, shuffle=False, num_workers=8) models = [load_model(args, f'{weight_dir}/model_{i}{"_ft" if args.use_ft else ""}.pth', device) for i in range(args.split)] all_seqs = [] logits_batches = [] # 存放每个 batch 的 [m,B,2] avg_logits (CPU 上) start_time = time.time() with torch.no_grad(): for x, gt in test_loader: # x: [B, ...] on CPU pin memory,gt: tuple of B strings x = move_to_device(x, device, non_blocking=True) # x = move_to_device(x, device) # 1) 记录 5 个模型的 logits # logits: [m,B,2] logits = torch.zeros(len(models), len(gt), args.classes, device=device) for i, m in enumerate(models): logits[i] = m(x) # 3) 立刻搬到 CPU(pin_memory 下可以 non_blocking) logits_batches.append(logits.cpu()) all_seqs.extend(gt) # 拼接成 [n,2],n = sum(batch_size) all_logits = torch.cat(logits_batches, dim=1) # [m,n,2] if args.task == 'reg': preds = all_logits.mean(0).squeeze().tolist() elif args.task == 'cls': # 最后一次性 softmax,取正类概率 preds = torch.softmax(all_logits, dim=-1)[:, :, 1].mean(0).tolist() consumed_time = time.time() - start_time print(f'total consumed time: {consumed_time} s') print(f'time per sample: {consumed_time / len(test_set)} s') # 保存到 CSV df = pd.DataFrame({ "seq": all_seqs, "pred": preds, }) df.to_csv(f'{weight_dir}/preds_case.csv', index=False) if __name__ == '__main__': main()