SkywalkerLu commited on
Commit
f0e6d01
·
verified ·
1 Parent(s): d9a3a2e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +67 -14
README.md CHANGED
@@ -1,14 +1,67 @@
1
- # TransHLA2.0
2
-
3
- **TransHLA2.0** 是一个基于 ESM2 + LoRA + CNN-Transformer 融合的蛋白质序列建模模型。
4
- 支持 Hugging Face Hub 上的 `trust_remote_code` 加载和自定义微调。
5
-
6
- ## 快速加载
7
-
8
- ```python
9
- from transformers import AutoModel
10
-
11
- model = AutoModel.from_pretrained(
12
- "SkywalkerLu/TransHLA2.0",
13
- trust_remote_code=True
14
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TransHLA2.0
2
+
3
+ A minimal Hugging Face-compatible PyTorch model for peptide–HLA binding classification using ESM with optional LoRA and cross-attention. There is no custom predict API; inference follows the training path: tokenize peptide and HLA pseudosequence with the ESM tokenizer, pad or truncate to fixed lengths (default peptide=16, HLA=36), run a forward pass as `logits, features = model(epitope_ids, hla_ids)`, then apply softmax to get the binding probability.
4
+
5
+ ## Quick Start
6
+
7
+ Requirements:
8
+ - Python >= 3.8
9
+ - torch >= 2.0
10
+ - transformers >= 4.40
11
+ - peft (only if you use LoRA/PEFT adapters)
12
+
13
+ Install:
14
+ ```bash
15
+ pip install torch transformers peft
16
+
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from transformers import AutoModel, AutoTokenizer
21
+
22
+ # Device
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ # Load model (replace with your model id if different)
26
+ model_id = "SkywalkerLu/TransHLA2.0"
27
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device).eval()
28
+
29
+ # Load tokenizer used in training (ESM2 650M)
30
+ tok = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
31
+
32
+ # Example inputs
33
+ peptide = "GILGFVFTL" # 9-mer example
34
+ # Fake placeholder pseudosequence for demo; replace with a real one from your mapping/data
35
+ hla_pseudoseq = (
36
+ "GSHSMRYFYTAVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRMEPRAPWIEQEGPEYWERETRNVK"
37
+ "AQSQTDRVDLRTLLRYNQSEAGSHTVQRMYGCDVGSDWRFLRGYHQYAYDGKDYIALNEDLRSWTAAD"
38
+ "MAAQTTKHKWEQAGAAER"
39
+ )
40
+
41
+ # Fixed lengths (must match training)
42
+ PEP_LEN = 16
43
+ HLA_LEN = 36
44
+ PAD_ID = tok.pad_token_id if tok.pad_token_id is not None else 1
45
+
46
+ def pad_to_len(ids_list, target_len, pad_id):
47
+ return ids_list + [pad_id] * (target_len - len(ids_list)) if len(ids_list) < target_len else ids_list[:target_len]
48
+
49
+ # Tokenize
50
+ pep_ids = tok(peptide, add_special_tokens=True)["input_ids"]
51
+ hla_ids = tok(hla_pseudoseq, add_special_tokens=True)["input_ids"]
52
+
53
+ # Pad/truncate
54
+ pep_ids = pad_to_len(pep_ids, PEP_LEN, PAD_ID)
55
+ hla_ids = pad_to_len(hla_ids, HLA_LEN, PAD_ID)
56
+
57
+ # Tensors (batch=1)
58
+ pep_tensor = torch.tensor([pep_ids], dtype=torch.long, device=device)
59
+ hla_tensor = torch.tensor([hla_ids], dtype=torch.long, device=device)
60
+
61
+ # Forward + probability
62
+ with torch.no_grad():
63
+ logits, features = model(pep_tensor, hla_tensor)
64
+ prob_bind = F.softmax(logits, dim=1)[0, 1].item()
65
+ pred = int(prob_bind >= 0.5)
66
+
67
+ print({"peptide": peptide, "bind_prob": round(prob_bind, 6), "label": pred})