File size: 4,854 Bytes
2ec04c3 fd13523 6d2c29b fc1f145 2ec04c3 b429e09 f0e6d01 f1d4969 f0e6d01 b429e09 2ec04c3 a18517c aed7b7f 2ec04c3 f1d4969 b429e09 f0e6d01 7285f3b f0e6d01 b429e09 f0e6d01 b429e09 6a0b0b6 aed7b7f 6a0b0b6 6aeea5b 6a0b0b6 fd13523 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
---
tags:
- protein language model
datasets:
- IEDB
base_model:
- facebook/esm2_t12_35M_UR50D
pipeline_tag: text-classification
license: mit
language:
- en
---
# TransHLA2.0-BIND
A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.
## Quick Start
Requirements:
- Python >= 3.8
- torch >= 2.0
- transformers >= 4.40
- peft (only if you use LoRA/PEFT adapters)
## Install:
```bash
pip install torch transformers peft
```
## Usage (Transformers)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
model_id = "SkywalkerLu/TransHLA2.0-BIND"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
```
```
## How to use TransHLA2.0-BIND
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model (replace with your model id if different)
model_id = "SkywalkerLu/TransHLA2.0-BIND"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
# Load tokenizer used in training (ESM2 650M)
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Example inputs
peptide = "GILGFVFTL" # 9-mer example
# Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data
hla_pseudoseq = (
"YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"
)
# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
# Tokenize
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"]
# Pad/truncate
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID)
# Tensors (batch=1)
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device)
# Forward + probability
with torch.no_grad():
logits, features = model(pep_tensor, hla_tensor)
prob_bind = F.softmax(logits, dim=1)[0, 1].item()
pred = int(prob_bind >= 0.5)
print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
```
## Batch Inference (Python)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
model_id = "SkywalkerLu/TransHLA2.0-BIND" # replace with your model id if different
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
# Example batch (use real HLA pseudosequences in your data)
batch = [
{"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
{"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
{"peptide": "SIINFEKL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
]
# Tokenize and pad/truncate
pep_ids_batch, hla_ids_batch = [], []
for item in batch:
pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"]
hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"]
pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID))
hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID))
# To tensors
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN]
hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN]
# Forward
with torch.no_grad():
logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2]
probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1
# Threshold to labels (0/1)
labels = (probs >= 0.5).long().tolist()
# Print results
for i, item in enumerate(batch):
print({
"peptide": item["peptide"],
"bind_prob": float(probs[i].item()),
"label": labels[i]
})
``` |