--- license: mit language: - en base_model: - facebook/esm2_t12_35M_UR50D tags: - Protein - Langeuage --- # TransHLA2.0-IM A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability. ## Quick Start Requirements: - Python >= 3.8 - torch >= 2.0 - transformers >= 4.40 - peft (only if you use LoRA/PEFT adapters) ## Install: ```bash pip install torch transformers peft ``` ## Usage (Transformers) ```python import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer model_id = "SkywalkerLu/TransHLA2.0-IM" model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval() ``` ## How to use TransHLA2.0-IM ```python import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer # Device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model (replace with your model id if different) model_id = "SkywalkerLu/TransHLA2.0-IM" model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval() # Load tokenizer used in training (ESM2 650M) tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") # Example inputs peptide = "GILGFVFTL" # 9-mer example # Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data hla_pseudoseq = ( "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY" ) # Fixed lengths (must match training) PEP_LEN = 16 HLA_LEN = 36 PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1 def pad_to_len(ids_list, target_len, pad_id): return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len] # Tokenize pep_ids = tok(peptide, add_special_tokens=True)["input_ids"] hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"] # Pad/truncate pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID) hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID) # Tensors (batch=1) pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device) hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device) # Forward + probability with torch.no_grad(): logits, features = model(pep_tensor, hla_tensor) prob_bind = F.softmax(logits, dim=1)[0, 1].item() pred = int(prob_bind >= 0.5) print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred}) ``` ## Batch Inference (Python) ```python import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer # Device device = "cuda" if torch.cuda.is_available() else "cpu" # Load model and tokenizer model_id = "SkywalkerLu/TransHLA2.0-IM" # replace with your model id if different model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval() tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") # Fixed lengths (must match training) PEP_LEN = 16 HLA_LEN = 36 PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1 def pad_to_len(ids_list, target_len, pad_id): return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len] # Example batch (use real HLA pseudosequences in your data) batch = [ {"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"}, {"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"}, {"peptide": "SIINFEKL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"}, ] # Tokenize and pad/truncate pep_ids_batch, hla_ids_batch = [], [] for item in batch: pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"] hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"] pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID)) hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID)) # To tensors pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN] hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN] # Forward with torch.no_grad(): logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2] probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1 # Threshold to labels (0/1) labels = (probs >= 0.5).long().tolist() # Print results for i, item in enumerate(batch): print({ "peptide": item["peptide"], "bind_prob": float(probs[i].item()), "label": labels[i] }) ```