|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- en |
|
|
base_model: |
|
|
- facebook/esm2_t12_35M_UR50D |
|
|
tags: |
|
|
- Protein |
|
|
- Langeuage |
|
|
--- |
|
|
|
|
|
|
|
|
# TransHLA2.0-IM |
|
|
|
|
|
A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability. |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
Requirements: |
|
|
- Python >= 3.8 |
|
|
- torch >= 2.0 |
|
|
- transformers >= 4.40 |
|
|
- peft (only if you use LoRA/PEFT adapters) |
|
|
|
|
|
## Install: |
|
|
```bash |
|
|
pip install torch transformers peft |
|
|
``` |
|
|
## Usage (Transformers) |
|
|
```python |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
model_id = "SkywalkerLu/TransHLA2.0-IM" |
|
|
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval() |
|
|
``` |
|
|
## How to use TransHLA2.0-IM |
|
|
```python |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
# Device |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# Load model (replace with your model id if different) |
|
|
model_id = "SkywalkerLu/TransHLA2.0-IM" |
|
|
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval() |
|
|
|
|
|
# Load tokenizer used in training (ESM2 650M) |
|
|
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
|
|
# Example inputs |
|
|
peptide = "GILGFVFTL" # 9-mer example |
|
|
# Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data |
|
|
hla_pseudoseq = ( |
|
|
"YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY" |
|
|
) |
|
|
|
|
|
# Fixed lengths (must match training) |
|
|
PEP_LEN = 16 |
|
|
HLA_LEN = 36 |
|
|
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1 |
|
|
|
|
|
def pad_to_len(ids_list, target_len, pad_id): |
|
|
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len] |
|
|
|
|
|
# Tokenize |
|
|
pep_ids = tok(peptide, add_special_tokens=True)["input_ids"] |
|
|
hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"] |
|
|
|
|
|
# Pad/truncate |
|
|
pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID) |
|
|
hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID) |
|
|
|
|
|
# Tensors (batch=1) |
|
|
pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device) |
|
|
hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device) |
|
|
|
|
|
# Forward + probability |
|
|
with torch.no_grad(): |
|
|
logits, features = model(pep_tensor, hla_tensor) |
|
|
prob_bind = F.softmax(logits, dim=1)[0, 1].item() |
|
|
pred = int(prob_bind >= 0.5) |
|
|
|
|
|
print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred}) |
|
|
``` |
|
|
|
|
|
|
|
|
## Batch Inference (Python) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
|
|
|
# Device |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
# Load model and tokenizer |
|
|
model_id = "SkywalkerLu/TransHLA2.0-IM" # replace with your model id if different |
|
|
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval() |
|
|
tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
|
|
|
|
|
# Fixed lengths (must match training) |
|
|
PEP_LEN = 16 |
|
|
HLA_LEN = 36 |
|
|
PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1 |
|
|
|
|
|
def pad_to_len(ids_list, target_len, pad_id): |
|
|
return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len] |
|
|
|
|
|
# Example batch (use real HLA pseudosequences in your data) |
|
|
batch = [ |
|
|
{"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"}, |
|
|
{"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"}, |
|
|
{"peptide": "SIINFEKL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"}, |
|
|
] |
|
|
|
|
|
# Tokenize and pad/truncate |
|
|
pep_ids_batch, hla_ids_batch = [], [] |
|
|
for item in batch: |
|
|
pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"] |
|
|
hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"] |
|
|
pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID)) |
|
|
hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID)) |
|
|
|
|
|
# To tensors |
|
|
pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN] |
|
|
hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN] |
|
|
|
|
|
# Forward |
|
|
with torch.no_grad(): |
|
|
logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2] |
|
|
probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1 |
|
|
|
|
|
# Threshold to labels (0/1) |
|
|
labels = (probs >= 0.5).long().tolist() |
|
|
|
|
|
# Print results |
|
|
for i, item in enumerate(batch): |
|
|
print({ |
|
|
"peptide": item["peptide"], |
|
|
"bind_prob": float(probs[i].item()), |
|
|
"label": labels[i] |
|
|
}) |
|
|
``` |