SkywalkerLu commited on
Commit
b14311b
·
verified ·
1 Parent(s): c14acfa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +135 -1
README.md CHANGED
@@ -7,4 +7,138 @@ base_model:
7
  tags:
8
  - Protein
9
  - Langeuage
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  tags:
8
  - Protein
9
  - Langeuage
10
+ ---
11
+
12
+
13
+ # TransHLA2.0-IM
14
+
15
+ A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.
16
+
17
+ ## Quick Start
18
+
19
+ Requirements:
20
+ - Python >= 3.8
21
+ - torch >= 2.0
22
+ - transformers >= 4.40
23
+ - peft (only if you use LoRA/PEFT adapters)
24
+
25
+ ## Install:
26
+ ```bash
27
+ pip install torch transformers peft
28
+ ```
29
+ ## Usage (Transformers)
30
+ ```python
31
+ model_id = "SkywalkerLu/TransHLA2.0-IM"
32
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
33
+ ```
34
+ ## How to use TransHLA2.0-IM
35
+ ```python
36
+ import torch
37
+ import torch.nn.functional as F
38
+ from transformers import AutoModel, AutoTokenizer
39
+
40
+ # Device
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ # Load model (replace with your model id if different)
44
+ model_id = "SkywalkerLu/TransHLA2.0-IM"
45
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
46
+
47
+ # Load tokenizer used in training (ESM2 650M)
48
+ tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
49
+
50
+ # Example inputs
51
+ peptide = "GILGFVFTL" # 9-mer example
52
+ # Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data
53
+ hla_pseudoseq = (
54
+ "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"
55
+ )
56
+
57
+ # Fixed lengths (must match training)
58
+ PEP_LEN = 16
59
+ HLA_LEN = 36
60
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
61
+
62
+ def pad_to_len(ids_list, target_len, pad_id):
63
+ return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
64
+
65
+ # Tokenize
66
+ pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
67
+ hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"]
68
+
69
+ # Pad/truncate
70
+ pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
71
+ hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID)
72
+
73
+ # Tensors (batch=1)
74
+ pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
75
+ hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device)
76
+
77
+ # Forward + probability
78
+ with torch.no_grad():
79
+ logits, features = model(pep_tensor, hla_tensor)
80
+ prob_bind = F.softmax(logits, dim=1)[0, 1].item()
81
+ pred = int(prob_bind >= 0.5)
82
+
83
+ print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})
84
+ ```
85
+
86
+
87
+ ## Batch Inference (Python)
88
+
89
+ ```python
90
+ import torch
91
+ import torch.nn.functional as F
92
+ from transformers import AutoModel, AutoTokenizer
93
+
94
+ # Device
95
+ device = "cuda" if torch.cuda.is_available() else "cpu"
96
+
97
+ # Load model and tokenizer
98
+ model_id = "SkywalkerLu/TransHLA2.0-IM" # replace with your model id if different
99
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
100
+ tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
101
+
102
+ # Fixed lengths (must match training)
103
+ PEP_LEN = 16
104
+ HLA_LEN = 36
105
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
106
+
107
+ def pad_to_len(ids_list, target_len, pad_id):
108
+ return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
109
+
110
+ # Example batch (use real HLA pseudosequences in your data)
111
+ batch = [
112
+ {"peptide": "GILGFVFTL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
113
+ {"peptide": "NLVPMVATV", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
114
+ {"peptide": "SIINFEKL", "hla_pseudo": "YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY"},
115
+ ]
116
+
117
+ # Tokenize and pad/truncate
118
+ pep_ids_batch, hla_ids_batch = [], []
119
+ for item in batch:
120
+ pep_ids = tok(item["peptide"], add_special_tokens=True)["input_ids"]
121
+ hla_ids = tok(item["hla_pseudo"], add_special_tokens=True)["input_ids"]
122
+ pep_ids_batch.append(pad_to_len(pep_ids, PEP_LEN, PAD_ID))
123
+ hla_ids_batch.append(pad_to_len(hla_ids, HLA_LEN, PAD_ID))
124
+
125
+ # To tensors
126
+ pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device) # [B, PEP_LEN]
127
+ hla_tensor = torch.tensor(hla_ids_batch, dtype=torch.long, device=device) # [B, HLA_LEN]
128
+
129
+ # Forward
130
+ with torch.no_grad():
131
+ logits, _ = model(pep_tensor, hla_tensor) # logits shape: [B, 2]
132
+ probs = F.softmax(logits, dim=1)[:, 1] # binding probability for class-1
133
+
134
+ # Threshold to labels (0/1)
135
+ labels = (probs >= 0.5).long().tolist()
136
+
137
+ # Print results
138
+ for i, item in enumerate(batch):
139
+ print({
140
+ "peptide": item["peptide"],
141
+ "bind_prob": float(probs[i].item()),
142
+ "label": labels[i]
143
+ })
144
+ ```