File size: 3,956 Bytes
6969682 da60c0a 6969682 678e995 6969682 f566bfe da60c0a 6969682 da60c0a 6969682 da60c0a 6969682 da60c0a 6969682 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
---
tags:
- protein language model
datasets:
- IEDB
base_model:
- facebook/esm2_t12_35M_UR50D
pipeline_tag: text-classification
license: mit
language:
- en
---
# TriStageHLA-PRE
A minimal Hugging Face-compatible PyTorch model for peptide presentation using ESM with optional LoRA and cross-attention.
There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.
## Quick Start
Requirements:
- Python >= 3.8
- torch >= 2.0
- transformers >= 4.40
- peft (only if you use LoRA/PEFT adapters)
## Install:
```bash
pip install torch transformers peft
```
## Usage (Transformers)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
model_id = "SkywalkerLu/TriStageHLA-PRE"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
```
```
## How to use TriStageHLA-PRE
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "SkywalkerLu/TriStageHLA-PRE"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
# Load tokenizer used in training (ESM2 650M)
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
peptide = "GILGFVFTL"
PEP_LEN = 16
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
with torch.no_grad():
logits, features = model(pep_tensor)
prob_bind = F.softmax(logits, dim=1)[0, 1].item()
pred = int(prob_bind >= 0.5)
print({"peptide": peptide, "pre_prob": round(prob_bind, 6), "label": pred})
```
## Batch Inference (Python)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载 PRE 模型(注意:确保该模型确实支持 logits 输出)
model_id = "SkywalkerLu/TriStageHLA-PRE"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
# 加载与训练一致的 ESM2 tokenizer
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# 固定长度(需与训练一致)
PEP_LEN = 16
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
if len(ids_list) < target_len:
return ids_list + [pad_id] * (target_len - len(ids_list))
return ids_list[:target_len]
# 示例批次(替换为你的真实肽序列列表)
batch = [
{"peptide": "GILGFVFTL"},
{"peptide": "NLVPMVATV"},
{"peptide": "SIINFEKL"},
{"peptide": "GLCTLVAML"},
]
# 批量分词与填充
pep_ids_batch = []
for item in batch:
pep = item["peptide"]
ids = tok(pep, add_special_tokens=True)["input_ids"]
ids = pad_to_len(ids, PEP_LEN, PAD_ID)
pep_ids_batch.append(ids)
# 转为张量 [B, PEP_LEN]
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device)
# 前向计算
with torch.no_grad():
logits, features = model(pep_tensor) # 期望 logits 形状: [B, 2]
probs = F.softmax(logits, dim=1)[:, 1] # 取类别1(bind)的概率
# 二分类标签(0/1)
labels = (probs >= 0.5).long().tolist()
# 打印结果
for i, item in enumerate(batch):
print({
"peptide": item["peptide"],
"pre_prob": float(probs[i].item()),
"label": labels[i]
})
``` |