SkywalkerLu commited on
Commit
6969682
·
verified ·
1 Parent(s): bd7b816

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +128 -3
README.md CHANGED
@@ -1,3 +1,128 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - protein language model
4
+ datasets:
5
+ - IEDB
6
+ base_model:
7
+ - facebook/esm2_t12_35M_UR50D
8
+ pipeline_tag: text-classification
9
+ license: mit
10
+ language:
11
+ - en
12
+ ---
13
+
14
+ # TransHLA2.0-PRE
15
+
16
+ A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.
17
+
18
+ ## Quick Start
19
+
20
+ Requirements:
21
+ - Python >= 3.8
22
+ - torch >= 2.0
23
+ - transformers >= 4.40
24
+ - peft (only if you use LoRA/PEFT adapters)
25
+
26
+ ## Install:
27
+ ```bash
28
+ pip install torch transformers peft
29
+ ```
30
+ ## Usage (Transformers)
31
+ ```python
32
+ model_id = "SkywalkerLu/TransHLA2.0-PRE"
33
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
34
+ ```
35
+ ```
36
+ ## How to use TransHLA2.0-PRE
37
+ ```python
38
+ import torch
39
+ import torch.nn.functional as F
40
+ from transformers import AutoModel, AutoTokenizer
41
+
42
+ # Device
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ model_id = "SkywalkerLu/TransHLA2.0-PRE"
46
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
47
+
48
+ # Load tokenizer used in training (ESM2 650M)
49
+ tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
50
+ peptide = "GILGFVFTL"
51
+ PEP_LEN = 16
52
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
53
+
54
+ def pad_to_len(ids_list, target_len, pad_id):
55
+ return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
56
+ pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
57
+ pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
58
+ pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
59
+ with torch.no_grad():
60
+ logits, features = model(pep_tensor)
61
+ prob_bind = F.softmax(logits, dim=1)[0, 1].item()
62
+ pred = int(prob_bind >= 0.5)
63
+
64
+ print({"peptide": peptide, "pre_prob": round(prob_bind, 6), "label": pred})
65
+ ```
66
+
67
+
68
+ ## Batch Inference (Python)
69
+
70
+ ```python
71
+ import torch
72
+ import torch.nn.functional as F
73
+ from transformers import AutoModel, AutoTokenizer
74
+
75
+ # Device
76
+ device = "cuda" if torch.cuda.is_available() else "cpu"
77
+
78
+ # 加载 PRE 模型(注意:确保该模型确实支持 logits 输出)
79
+ model_id = "SkywalkerLu/TransHLA2.0-PRE"
80
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
81
+
82
+ # 加载与训练一致的 ESM2 tokenizer
83
+ tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
84
+
85
+ # 固定长度(需与训练一致)
86
+ PEP_LEN = 16
87
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
88
+
89
+ def pad_to_len(ids_list, target_len, pad_id):
90
+ if len(ids_list) < target_len:
91
+ return ids_list + [pad_id] * (target_len - len(ids_list))
92
+ return ids_list[:target_len]
93
+
94
+ # 示例批次(替换为你的真实肽序列列表)
95
+ batch = [
96
+ {"peptide": "GILGFVFTL"},
97
+ {"peptide": "NLVPMVATV"},
98
+ {"peptide": "SIINFEKL"},
99
+ {"peptide": "GLCTLVAML"},
100
+ ]
101
+
102
+ # 批量分词与填充
103
+ pep_ids_batch = []
104
+ for item in batch:
105
+ pep = item["peptide"]
106
+ ids = tok(pep, add_special_tokens=True)["input_ids"]
107
+ ids = pad_to_len(ids, PEP_LEN, PAD_ID)
108
+ pep_ids_batch.append(ids)
109
+
110
+ # 转为张量 [B, PEP_LEN]
111
+ pep_tensor = torch.tensor(pep_ids_batch, dtype=torch.long, device=device)
112
+
113
+ # 前向计算
114
+ with torch.no_grad():
115
+ logits, features = model(pep_tensor) # 期望 logits 形状: [B, 2]
116
+ probs = F.softmax(logits, dim=1)[:, 1] # 取类别1(bind)的概率
117
+
118
+ # 二分类标签(0/1)
119
+ labels = (probs >= 0.5).long().tolist()
120
+
121
+ # 打印结果
122
+ for i, item in enumerate(batch):
123
+ print({
124
+ "peptide": item["peptide"],
125
+ "pre_prob": float(probs[i].item()),
126
+ "label": labels[i]
127
+ })
128
+ ```