File size: 4,787 Bytes
e0513df b14311b 6d3c2ea b14311b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
---
license: mit
language:
- en
base_model:
- facebook/esm2_t12_35M_UR50D
tags:
- Protein
- Langeuage
---
# TransHLA2.0-IM
A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.
## Quick Start
Requirements:
- Python >= 3.8
- torch >= 2.0
- transformers >= 4.40
- peft (only if you use LoRA/PEFT adapters)
## Install:
```bash
pip install torch transformers peft
```
## Usage (Transformers)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
model_id = "SkywalkerLu/TransHLA2.0-IM"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
```
## How to use TransHLA2.0-IM
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model (replace with your model id if different)
model_id = "SkywalkerLu/TransHLA2.0-IM"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
# Load tokenizer used in training (ESM2 650M)
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Example inputs
peptide = "GILGFVFTL" # 9-mer example
# Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data
hla_pseudoseq = (
"YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"
)
# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
# Tokenize
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"]
# Pad/truncate
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID)
# Tensors (batch=1)
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device)
# Forward + probability
with torch.no_grad():
logits, features = model(pep_tensor, hla_tensor)
prob_bind = F.softmax(logits, dim=1)[0, 1].item()
pred = int(prob_bind >= 0.5)
print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
```
## Batch Inference (Python)
```python
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
model_id = "SkywalkerLu/TransHLA2.0-IM" # replace with your model id if different
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
# Fixed lengths (must match training)
PEP_LEN = 16
HLA_LEN = 36
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
def pad_to_len(ids_list, target_len, pad_id):
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
# Example batch (use real HLA pseudosequences in your data)
batch = [
{"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
{"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
{"peptide": "SIINFEKL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
]
# Tokenize and pad/truncate
pep_ids_batch, hla_ids_batch = [], []
for item in batch:
pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"]
hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"]
pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID))
hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID))
# To tensors
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN]
hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN]
# Forward
with torch.no_grad():
logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2]
probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1
# Threshold to labels (0/1)
labels = (probs >= 0.5).long().tolist()
# Print results
for i, item in enumerate(batch):
print({
"peptide": item["peptide"],
"bind_prob": float(probs[i].item()),
"label": labels[i]
})
``` |