File size: 4,797 Bytes
df2cc2b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | #!/usr/bin/env python3
"""
inference_ppiDCE.py
Inference script for ppiDCE cross-encoder PPI classifier.
Usage example:
python inference_ppiDCE.py \
--model_path path/to/ppiDCE_final.pth \
--model_config facebook/esm1b_t33_650M_UR50S \
--input_file test.csv \
--output_file preds.csv \
--batch_size 4 \
--max_length 1024 \
--device cuda
# Example:
# python inference_ppiDCE.py \
# --model_path out/ppiDCE_final.pth \
# --model_config facebook/esm1b_t33_650M_UR50S \
# --input_file test_pairs.csv \
# --output_file predictions.csv \
# --batch_size 4 --max_length 1024 --device cuda
"""
import argparse
import os
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import EsmConfig, EsmTokenizer, EsmModel
from tqdm import tqdm
class PPICrossDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length):
self.df = pd.read_csv(csv_file)
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
seq1, seq2 = self.df.iloc[idx, 0], self.df.iloc[idx, 1]
enc = self.tokenizer(
seq1, seq2,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': enc.input_ids.squeeze(0),
'attention_mask': enc.attention_mask.squeeze(0)
}
class ppiDCE(nn.Module):
def __init__(self, config, num_labels=2):
super().__init__()
self.esm = EsmModel(config)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
out = self.esm(input_ids=input_ids, attention_mask=attention_mask)
cls_vec = out.last_hidden_state[:, 0, :]
x = self.dropout(cls_vec)
return self.classifier(x)
def get_device(device_str):
if device_str.lower() == 'cpu':
return torch.device('cpu'), None
if ',' in device_str:
devs = [d.strip() for d in device_str.split(',')]
dev0 = devs[0]
ids = [int(d.split(':')[-1]) for d in devs]
return torch.device(dev0), ids
return torch.device(device_str), None
def main():
parser = argparse.ArgumentParser(description='Inference with ppiDCE model')
parser.add_argument('--model_path', required=True, help='Path to ppiDCE checkpoint (.pth)')
parser.add_argument('--model_config', required=True, help='ESM model name or local path')
parser.add_argument('--input_file', required=True, help='CSV file with seq1, seq2')
parser.add_argument('--output_file', required=True, help='CSV to save predictions')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--max_length', type=int, default=1024)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
# device
device, device_ids = get_device(args.device)
# tokenizer + config
tokenizer = EsmTokenizer.from_pretrained(args.model_config)
config = EsmConfig.from_pretrained(args.model_config)
# dataset + loader
ds = PPICrossDataset(args.input_file, tokenizer, args.max_length)
loader = DataLoader(ds, batch_size=args.batch_size, shuffle=False)
# model init
model = ppiDCE(config, num_labels=2)
# load checkpoint with filtering to avoid mismatched keys
ckpt = torch.load(args.model_path, map_location='cpu')
model_state = model.state_dict()
filtered_ckpt = {k: v for k, v in ckpt.items() if k in model_state and v.size() == model_state[k].size()}
model_state.update(filtered_ckpt)
model.load_state_dict(model_state)
if device_ids:
model = nn.DataParallel(model, device_ids=device_ids)
model.to(device)
model.eval()
preds, probs = [], []
with torch.no_grad():
for batch in tqdm(loader, desc='Infer'):
input_ids = batch['input_ids'].to(device)
attn = batch['attention_mask'].to(device)
logits = model(input_ids, attn)
p = nn.functional.softmax(logits, dim=1)
pred = p.argmax(dim=1)
preds.extend(pred.cpu().tolist())
probs.extend(p.cpu().tolist())
# assemble output
df = pd.read_csv(args.input_file)
df['pred_label'] = preds
df['prob_0'] = [p[0] for p in probs]
df['prob_1'] = [p[1] for p in probs]
os.makedirs(os.path.dirname(args.output_file) or '.', exist_ok=True)
df.to_csv(args.output_file, index=False)
print(f"Saved inference results to {args.output_file}")
if __name__ == '__main__':
main()
|