File size: 4,854 Bytes
2ec04c3
 
 
 
 
fd13523
 
6d2c29b
fc1f145
 
 
2ec04c3
 
b429e09
f0e6d01
 
 
 
 
 
 
 
 
 
 
f1d4969
f0e6d01
 
b429e09
2ec04c3
 
a18517c
 
 
aed7b7f
2ec04c3
 
 
f1d4969
b429e09
f0e6d01
 
 
 
 
 
 
 
7285f3b
f0e6d01
 
 
 
 
 
 
 
 
b429e09
f0e6d01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b429e09
 
 
6a0b0b6
 
 
 
 
 
 
 
 
 
 
 
aed7b7f
6a0b0b6
 
 
 
 
 
 
 
 
 
 
 
 
6aeea5b
 
 
6a0b0b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd13523
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
---
tags:
- protein language model
datasets:
- IEDB
base_model:
- facebook/esm2_t12_35M_UR50D
pipeline_tag: text-classification
license: mit
language:
- en
---

# TransHLA2.0-BIND

A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.

## Quick Start

Requirements:
- Python >= 3.8
- torch >= 2.0
- transformers >= 4.40
- peft (only if you use LoRA/PEFT adapters)

## Install:
```bash
pip install torch transformers peft
```
## Usage (Transformers)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
model_id = "SkywalkerLu/TransHLA2.0-BIND"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
```
```
## How to use TransHLA2.0-BIND
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model (replace with your model id if different)
model_id = "SkywalkerLu/TransHLA2.0-BIND"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()

# Load tokenizer used in training (ESM2 650M)
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

# Example inputs
peptide = "GILGFVFTL"  # 9-mer example
# Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data
hla_pseudoseq = (
    "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"
)

# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1

def pad_to_len(ids_list, target_len, pad_id):
    return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]

# Tokenize
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"]

# Pad/truncate
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID)

# Tensors (batch=1)
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device)

# Forward + probability
with torch.no_grad():
    logits, features = model(pep_tensor, hla_tensor)
    prob_bind = F.softmax(logits, dim=1)[0, 1].item()
    pred = int(prob_bind >= 0.5)

print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
```


## Batch Inference (Python)

```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and tokenizer
model_id = "SkywalkerLu/TransHLA2.0-BIND"  # replace with your model id if different
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")

# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1

def pad_to_len(ids_list, target_len, pad_id):
    return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]

# Example batch (use real HLA pseudosequences in your data)
batch = [
    {"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
    {"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
    {"peptide": "SIINFEKL",  "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
]

# Tokenize and pad/truncate
pep_ids_batch, hla_ids_batch = [], []
for item in batch:
    pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"]
    hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"]
    pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID))
    hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID))

# To tensors
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device)  # [B, PEP_LEN]
hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device)  # [B, HLA_LEN]

# Forward
with torch.no_grad():
    logits, _ = model(pep_tensor, hla_tensor)   # logits shape: [B, 2]
    probs = F.softmax(logits, dim=1)[:, 1]      # binding probability for class-1

# Threshold to labels (0/1)
labels = (probs >= 0.5).long().tolist()

# Print results
for i, item in enumerate(batch):
    print({
        "peptide": item["peptide"],
        "bind_prob": float(probs[i].item()),
        "label": labels[i]
    })
```