SkywalkerLu commited on
Commit
6a0b0b6
·
verified ·
1 Parent(s): b429e09

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md CHANGED
@@ -66,3 +66,62 @@ with torch.no_grad():
66
  print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
67
  ```
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
67
  ```
68
 
69
+
70
+ ## Batch Inference (Python)
71
+
72
+ ```python
73
+ import torch
74
+ import torch.nn.functional as F
75
+ from transformers import AutoModel, AutoTokenizer
76
+
77
+ # Device
78
+ device = "cuda" if torch.cuda.is_available() else "cpu"
79
+
80
+ # Load model and tokenizer
81
+ model_id = "SkywalkerLu/TransHLA2.0" # replace with your model id if different
82
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
83
+ tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
84
+
85
+ # Fixed lengths (must match training)
86
+ PEP_LEN = 16
87
+ HLA_LEN = 36
88
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
89
+
90
+ def pad_to_len(ids_list, target_len, pad_id):
91
+ return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
92
+
93
+ # Example batch (use real HLA pseudosequences in your data)
94
+ batch = [
95
+ {"peptide": "GILGFVFTL", "hla_pseudo": "YOUR_HLA_PSEUDOSEQ_1"},
96
+ {"peptide": "NLVPMVATV", "hla_pseudo": "YOUR_HLA_PSEUDOSEQ_2"},
97
+ {"peptide": "SIINFEKL", "hla_pseudo": "YOUR_HLA_PSEUDOSEQ_3"},
98
+ ]
99
+
100
+ # Tokenize and pad/truncate
101
+ pep_ids_batch, hla_ids_batch = [], []
102
+ for item in batch:
103
+ pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"]
104
+ hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"]
105
+ pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID))
106
+ hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID))
107
+
108
+ # To tensors
109
+ pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN]
110
+ hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN]
111
+
112
+ # Forward
113
+ with torch.no_grad():
114
+ logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2]
115
+ probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1
116
+
117
+ # Threshold to labels (0/1)
118
+ labels = (probs >= 0.5).long().tolist()
119
+
120
+ # Print results
121
+ for i, item in enumerate(batch):
122
+ print({
123
+ "peptide": item["peptide"],
124
+ "bind_prob": float(probs[i].item()),
125
+ "label": labels[i]
126
+ })
127
+ ```