TransHLA2.0-BIND
A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as logits, features = model(epitope_ids, hla_ids), then apply softmax to get the binding probability.
Quick Start
Requirements:
- Python >= 3.8
- torch >= 2.0
- transformers >= 4.40
- peft (only if you use LoRA/PEFT adapters)
Install:
pip install torch transformers peft
Usage (Transformers)
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
model_id = "SkywalkerLu/TransHLA2.0-BIND"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
## How to use TransHLA2.0-BIND
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model (replace with your model id if different)
model_id = "SkywalkerLu/TransHLA2.0-BIND"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
# Load tokenizer used in training (ESM2 650M)
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Example inputs
peptide = "GILGFVFTL" # 9-mer example
# Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data
hla_pseudoseq = (
"YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"
)
# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
# Tokenize
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"]
# Pad/truncate
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID)
# Tensors (batch=1)
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device)
# Forward + probability
with torch.no_grad():
logits, features = model(pep_tensor, hla_tensor)
prob_bind = F.softmax(logits, dim=1)[0, 1].item()
pred = int(prob_bind >= 0.5)
print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
Batch Inference (Python)
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
model_id = "SkywalkerLu/TransHLA2.0-BIND" # replace with your model id if different
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
# Example batch (use real HLA pseudosequences in your data)
batch = [
{"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
{"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
{"peptide": "SIINFEKL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
]
# Tokenize and pad/truncate
pep_ids_batch, hla_ids_batch = [], []
for item in batch:
pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"]
hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"]
pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID))
hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID))
# To tensors
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN]
hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN]
# Forward
with torch.no_grad():
logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2]
probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1
# Threshold to labels (0/1)
labels = (probs >= 0.5).long().tolist()
# Print results
for i, item in enumerate(batch):
print({
"peptide": item["peptide"],
"bind_prob": float(probs[i].item()),
"label": labels[i]
})
- Downloads last month
- 53
Model tree for SkywalkerLu/TransHLA2.0-BIND
Base model
facebook/esm2_t12_35M_UR50D