shubhamc-iiitd commited on
Commit
ccf1103
·
verified ·
1 Parent(s): 364bd69

Initial model upload

Browse files
Files changed (6) hide show
  1. README.md +56 -0
  2. config.yaml +19 -0
  3. inference.py +40 -0
  4. model.pt +3 -0
  5. network.py +59 -0
  6. tokenizer_mapping.json +27 -0
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gpl-3.0
3
+ tags:
4
+ - protein
5
+ - peptide
6
+ - deep-learning
7
+ - pytorch
8
+ - bioinformatics
9
+ library_name: pytorch
10
+ ---
11
+
12
+ # Revised Peptide LGBM Model
13
+
14
+ This repository contains a PyTorch deep learning model trained to predict peptide properties from amino acid sequences.
15
+
16
+ ## Model Description
17
+
18
+ The model uses tokenized amino acid sequences as input and predicts a probability score indicating the likelihood of the peptide belonging to the positive class.
19
+
20
+ The architecture is defined in `model/network.py` and initialized using a YAML configuration file.
21
+
22
+ ## Input Representation
23
+
24
+ Sequences are tokenized using the following mapping:
25
+
26
+ | Token | Description |
27
+ |------|-------------|
28
+ | PAD | Padding |
29
+ | UNK | Unknown |
30
+ | CLS | Start token |
31
+ | SEP | Separator |
32
+ | MASK | Mask token |
33
+ | L,A,G,V,E,S,I,K,R,D,T,P,N,Q,F,Y,M,H,C,W | Amino acids |
34
+
35
+ Sequences are padded to the maximum length within a batch.
36
+
37
+ ## Files
38
+
39
+ | File | Description |
40
+ |----|----|
41
+ | model.pt | Trained model checkpoint |
42
+ | config.yaml | Model configuration |
43
+ | tokenizer_mapping.json | Amino acid token mapping |
44
+ | inference.py | Example inference script |
45
+
46
+ ## Usage
47
+
48
+ Example inference:
49
+
50
+ ```python
51
+ from inference import predict
52
+
53
+ sequence = "LAGVEST"
54
+ probability = predict(sequence)
55
+
56
+ print(probability)
config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epochs: 50
2
+ batch_size: 32
3
+ vocab_size: 25
4
+ task: revised_peptide_LGBM_5 # hemo, sol, nf
5
+ debug: false
6
+
7
+ network:
8
+ hidden_size: 480
9
+ hidden_layers: 12
10
+ attn_heads: 12
11
+ dropout: 0.15
12
+
13
+ optim:
14
+ lr: 1.0e-5
15
+
16
+ sch:
17
+ name: lronplateau # onecycle, lronplateau
18
+ factor: 0.1
19
+ patience: 4
inference.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+ import json
4
+ from model.network import create_model
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ repo_id = "YOUR_USERNAME/revised_peptide_LGBM_5"
8
+
9
+ # download files
10
+ model_path = hf_hub_download(repo_id, "model.pt")
11
+ config_path = hf_hub_download(repo_id, "config.yaml")
12
+ mapping_path = hf_hub_download(repo_id, "tokenizer_mapping.json")
13
+
14
+ config = yaml.safe_load(open(config_path))
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ config["device"] = device
18
+
19
+ model = create_model(config)
20
+ model.load_state_dict(torch.load(model_path)["model_state_dict"], strict=False)
21
+
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ mapping = json.load(open(mapping_path))
26
+
27
+ def predict(sequence):
28
+
29
+ tokens = [mapping.get(c, mapping["[UNK]"]) for c in sequence]
30
+ input_ids = torch.tensor([tokens]).to(device)
31
+
32
+ attention_mask = (input_ids != 0).float()
33
+
34
+ with torch.no_grad():
35
+ prob = model(input_ids, attention_mask)[0].item()
36
+
37
+ return prob
38
+
39
+
40
+ print(predict("LAGVEST"))
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3fb16c1864d4c944a8133acd9fb358f699126ed4252174e79e0b9d653a2fc95
3
+ size 564544429
network.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertModel, BertConfig, logging
3
+
4
+ logging.set_verbosity_error()
5
+
6
+
7
+ class PeptideBERT(torch.nn.Module):
8
+ def __init__(self, bert_config):
9
+ super(PeptideBERT, self).__init__()
10
+
11
+ self.protbert = BertModel.from_pretrained(
12
+ 'Rostlab/prot_bert_bfd',
13
+ config=bert_config,
14
+ ignore_mismatched_sizes=True
15
+ )
16
+ self.head = torch.nn.Sequential(
17
+ torch.nn.Linear(bert_config.hidden_size, 1),
18
+ torch.nn.Sigmoid()
19
+ )
20
+
21
+ def forward(self, inputs, attention_mask):
22
+ output = self.protbert(inputs, attention_mask=attention_mask)
23
+
24
+ return self.head(output.pooler_output)
25
+
26
+
27
+ def create_model(config):
28
+ bert_config = BertConfig(
29
+ vocab_size=config['vocab_size'],
30
+ hidden_size=config['network']['hidden_size'],
31
+ num_hidden_layers=config['network']['hidden_layers'],
32
+ num_attention_heads=config['network']['attn_heads'],
33
+ hidden_dropout_prob=config['network']['dropout']
34
+ )
35
+ model = PeptideBERT(bert_config).to(config['device'])
36
+
37
+ return model
38
+
39
+
40
+ def cri_opt_sch(config, model):
41
+ criterion = torch.nn.BCELoss()
42
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config['optim']['lr'])
43
+
44
+ if config['sch']['name'] == 'onecycle':
45
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
46
+ optimizer,
47
+ max_lr=config['optim']['lr'],
48
+ epochs=config['epochs'],
49
+ steps_per_epoch=config['sch']['steps']
50
+ )
51
+ elif config['sch']['name'] == 'lronplateau':
52
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
53
+ optimizer,
54
+ mode='max',
55
+ factor=config['sch']['factor'],
56
+ patience=config['sch']['patience']
57
+ )
58
+
59
+ return criterion, optimizer, scheduler
tokenizer_mapping.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[PAD]": 0,
3
+ "[UNK]": 1,
4
+ "[CLS]": 2,
5
+ "[SEP]": 3,
6
+ "[MASK]": 4,
7
+ "L": 5,
8
+ "A": 6,
9
+ "G": 7,
10
+ "V": 8,
11
+ "E": 9,
12
+ "S": 10,
13
+ "I": 11,
14
+ "K": 12,
15
+ "R": 13,
16
+ "D": 14,
17
+ "T": 15,
18
+ "P": 16,
19
+ "N": 17,
20
+ "Q": 18,
21
+ "F": 19,
22
+ "Y": 20,
23
+ "M": 21,
24
+ "H": 22,
25
+ "C": 23,
26
+ "W": 24
27
+ }