File size: 3,956 Bytes
6969682
 
 
 
 
 
 
 
 
 
 
 
 
da60c0a
6969682
678e995
 
6969682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f566bfe
 
 
da60c0a
6969682
 
 
da60c0a
6969682
 
 
 
 
 
 
 
da60c0a
6969682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da60c0a
6969682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
---
tags:
- protein language model
datasets:
- IEDB
base_model:
- facebook/esm2_t12_35M_UR50D
pipeline_tag: text-classification
license: mit
language:
- en
---

# TriStageHLA-PRE

A minimal Hugging Face-compatible PyTorch model for peptide presentation using ESM with optional LoRA and cross-attention. 
There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.

## Quick Start

Requirements:
- Python >= 3.8
- torch >= 2.0
- transformers >= 4.40
- peft (only if you use LoRA/PEFT adapters)

## Install:
```bash
pip install torch transformers peft
```
## Usage (Transformers)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
model_id = "SkywalkerLu/TriStageHLA-PRE"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
```
```
## How to use TriStageHLA-PRE
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "SkywalkerLu/TriStageHLA-PRE"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()

# Load tokenizer used in training (ESM2 650M)
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
peptide = "GILGFVFTL"
PEP_LEN = 16
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1

def pad_to_len(ids_list, target_len, pad_id):
    return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
with torch.no_grad():
    logits, features = model(pep_tensor)
    prob_bind = F.softmax(logits, dim=1)[0, 1].item()
    pred = int(prob_bind >= 0.5)

print({"peptide": peptide, "pre_prob": round(prob_bind, 6), "label": pred})
```


## Batch Inference (Python)

```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# 加载 PRE 模型(注意:确保该模型确实支持 logits 输出)
model_id = "SkywalkerLu/TriStageHLA-PRE"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()

# 加载与训练一致的 ESM2 tokenizer
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

# 固定长度(需与训练一致)
PEP_LEN = 16
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1

def pad_to_len(ids_list, target_len, pad_id):
    if len(ids_list) < target_len:
        return ids_list + [pad_id] * (target_len - len(ids_list))
    return ids_list[:target_len]

# 示例批次(替换为你的真实肽序列列表)
batch = [
    {"peptide": "GILGFVFTL"},
    {"peptide": "NLVPMVATV"},
    {"peptide": "SIINFEKL"},
    {"peptide": "GLCTLVAML"},
]

# 批量分词与填充
pep_ids_batch = []
for item in batch:
    pep = item["peptide"]
    ids = tok(pep, add_special_tokens=True)["input_ids"]
    ids = pad_to_len(ids, PEP_LEN, PAD_ID)
    pep_ids_batch.append(ids)

# 转为张量 [B, PEP_LEN]
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device)

# 前向计算
with torch.no_grad():
    logits, features = model(pep_tensor)   # 期望 logits 形状: [B, 2]
    probs = F.softmax(logits, dim=1)[:, 1] # 取类别1(bind)的概率

# 二分类标签(0/1)
labels = (probs >= 0.5).long().tolist()

# 打印结果
for i, item in enumerate(batch):
    print({
        "peptide": item["peptide"],
        "pre_prob": float(probs[i].item()),
        "label": labels[i]
    })
```