File size: 10,498 Bytes
88f4188
ecf2484
880f628
604ed18
 
 
880f628
 
 
604ed18
70ce9b6
604ed18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88f4188
ecf2484
5aba20c
ecf2484
 
70ce9b6
880f628
b194781
880f628
b194781
880f628
b194781
 
 
 
 
 
 
 
337b42c
 
 
 
 
 
0a52e5e
 
d061b22
337b42c
 
 
 
 
 
 
 
 
0a52e5e
337b42c
5274368
337b42c
b194781
 
 
7538790
b194781
604ed18
7538790
604ed18
b194781
ee09102
ea86856
880f628
 
 
 
 
 
ee09102
0a52e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee09102
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
---
license: cc-by-nc-4.0
library_name: transformers
datasets:
- BIOGRID
- Negatome
pipeline_tag: text-classification
tags:
- protein language model
- biology
widget:
- text: >-
    M S H S V K I Y D T C I G C T Q C V R A C P T D V L E M I P W G G C K A K Q
    I A S A P R T E D C V G C K R C E S A C P T D F L S V R V Y L W H E T T R S
    M G L A Y [SEP] M I N L P S L F V P L V G L L F P A V A M A S L F L H V E K
    R L L F S T K K I N
  example_title: Non-interacting proteins
- text: >-
    M S I N I C R D N H D P F Y R Y K M P P I Q A K V E G R G N G I K T A V L N
    V A D I S H A L N R P A P Y I V K Y F G F E L G A Q T S I S V D K D R Y L V
    N G V H E P A K L Q D V L D G F I N K F V L C G S C K N P E T E I I I T K D
    N D L V R D C K A C G K R T P M D L R H K L S S F I L K N P P D S V S G S K
    K K K K A A T A S A N V R G G G L S I S D I A Q G K S Q N A P S D G T G S S
    T P Q H H D E D E D E L S R Q I K A A A S T L E D I E V K D D E W A V D M S
    E E A I R A R A K E L E V N S E L T Q L D E Y G E W I L E Q A G E D K E N L
    P S D V E L Y K K A A E L D V L N D P K I G C V L A Q C L F D E D I V N E I
    A E H N A F F T K I L V T P E Y E K N F M G G I E R F L G L E H K D L I P L
    L P K I L V Q L Y N N D I I S E E E I M R F G T K S S K K F V P K E V S K K
    V R R A A K P F I T W L E T A E S D D D E E D D E [SEP] M S I E N L K S F D
    P F A D T G D D E T A T S N Y I H I R I Q Q R N G R K T L T T V Q G V P E E
    Y D L K R I L K V L K K D F A C N G N I V K D P E M G E I I Q L Q G D Q R A
    K V C E F M I S Q L G L Q K K N I K I H G F
  example_title: Interacting proteins
---

[SYNTERACT 2.0](https://huggingface.co/Synthyra/SYNTERACT2) is coming soon, please stay tuned!


<img src="https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Ro4uhQDurP-x7IHJj11xa.png" width="350">

## Model description

SYNTERACT (SYNThetic data-driven protein-protein intERACtion Transformer) is a fine-tuned version of [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd) that attends two amino acid sequences separated by [SEP] to determine if they plausibly interact in biological context.

We utilized the multivalidated physical interaction dataset from BIORGID, Negatome, and synthetic negative samples to train our model. Check out our [preprint](https://www.biorxiv.org/content/10.1101/2023.06.07.544109v1.full) for more details.

SYNTERACT achieved unprecedented performance over vast phylogeny with 92-96% accuracy on real unseen examples, and is already being used to accelerate drug target screening and peptide therapeutic design.


## How to use

```python
# Imports
import re
import torch
import torch.nn.functional as F
from transformers import BertForSequenceClassification, BertTokenizer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # gather device
model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT', attn_implementation='sdpa').device.eval() # load model
tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT') # load tokenizer

sequence_a = 'MEKSCSIGNGREQYGWGHGEQCGTQFLECVYRNASMYSVLGDLITYVVFLGATCYAILFGFRLLLSCVRIVLKVVIALFVIRLLLALGSVDITSVSYSG' # Uniprot A1Z8T3
sequence_b = 'MRLTLLALIGVLCLACAYALDDSENNDQVVGLLDVADQGANHANDGAREARQLGGWGGGWGGRGGWGGRGGWGGRGGWGGRGGWGGGWGGRGGWGGRGGGWYGR' # Uniprot A1Z8H0
sequence_a = ' '.join(list(re.sub(r'[UZOB]', 'X', sequence_a))) # need spaces inbetween amino acids
sequence_b = ' '.join(list(re.sub(r'[UZOB]', 'X', sequence_b))) # replace rare amino acids with X
example = sequence_a + ' [SEP] ' + sequence_b # add SEP token

example = tokenizer(example, return_tensors='pt', padding=False).to(device) # tokenize example
with torch.no_grad():
    logits = model(**example).logits.detach().cpu() # get logits from model

probability = F.softmax(logits, dim=-1) # use softmax to get "confidence" in the prediction
prediction = probability.argmax(dim=-1) # 0 for no interaction, 1 for interaction
```

## Intended use and limitations
We define a protein-protein interaction as physical contact that mediates chemical or conformational change, especially with non-generic function. However, due to SYNTERACT's propensity to predict false positives, we believe that it identifies plausible conformational changes caused by interactions without relevance to function.

## Our lab
The [Gleghorn lab](https://www.gleghornlab.com/) is an interdisciplinary research group at the University of Delaware that focuses on solving translational problems with our expertise in engineering, biology, and chemistry. We develop inexpensive and reliable tools to study organ development, maternal-fetal health, and drug delivery. Recently we have begun exploration into protein language models and strive to make protein design and annotation accessible.

## Please cite
```
@article {Hallee_ppi_2023,
	author = {Logan Hallee and Jason P. Gleghorn},
	title = {Protein-Protein Interaction Prediction is Achievable with Large Language Models},
	year = {2023},
	doi = {10.1101/2023.06.07.544109},
	publisher = {Cold Spring Harbor Laboratory},
	journal = {bioRxiv}
}
```


## A simple inference script

```python
import torch
import re
import argparse
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
from tqdm.auto import tqdm


class PairDataset(Dataset):
    def __init__(self, sequences_a: List[str], sequences_b: List[str]):
        self.sequences_a = sequences_a
        self.sequences_b = sequences_b

    def __len__(self):
        return len(self.sequences_a)

    def __getitem__(self, idx: int) -> Tuple[str, str]:
        return self.sequences_a[idx], self.sequences_b[idx]
    

class PairCollator:
    def __init__(self, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.max_length = max_length

    def sanitize_seq(self, seq: str) -> str:
        seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
        return seq

    def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
        seqs_a, seqs_b, = zip(*batch)
        seqs = []
        for a, b in zip(seqs_a, seqs_b):
            seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
            seqs.append(seq)
        seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
        return {
            'input_ids': seqs['input_ids'],
            'attention_mask': seqs['attention_mask'],
        }


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    print(f"Loading model from {args.model_path}")
    model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
    # When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
    tokenizer = BertTokenizer.from_pretrained(args.model_path)
    print(f"Tokenizer loaded")

    """
    Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
    We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
    We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding

    Example:
        from datasets import load_dataset
        data = load_dataset('Synthyra/NEGATOME', split='combined')
        # Filter out examples where the total length exceeds 1022
        data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
        # Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
        data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
        # Sort the dataset by 'total_length' in descending order (longest sequences first)
        data = data.sort("total_length", reverse=True)
        # Now retrieve the sorted sequences
        sequences_a = data['SeqA']
        sequences_b = data['SeqB']
    """
    print("Loading data...")
    sequences_a = []
    sequences_b = []

    print("Creating torch dataset...")
    pair_dataset = PairDataset(sequences_a, sequences_b)
    pair_collator = PairCollator(tokenizer, max_length=1024)
    data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)

    all_seqs_a = []
    all_seqs_b = []
    all_probs = []
    all_preds = []

    print("Starting inference...")
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
            # Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()

            prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
            pred = torch.argmax(logits, dim=1)

            # Store results
            batch_start = i * args.batch_size
            batch_end = min((i + 1) * args.batch_size, len(sequences_a))
            all_seqs_a.extend(sequences_a[batch_start:batch_end])
            all_seqs_b.extend(sequences_b[batch_start:batch_end])
            all_probs.extend(prob_of_interaction.tolist())
            all_preds.extend(pred.tolist())

    # round to 5 decimal places
    all_probs = [round(prob, 5) for prob in all_probs]

    # Create dataframe and save to CSV
    results_df = pd.DataFrame({
        'sequence_a': all_seqs_a,
        'sequence_b': all_seqs_b,
        'probabilities': all_probs,
        'prediction': all_preds
    })
    print(f"Saving results to {args.save_path}")
    results_df.to_csv(args.save_path, index=False)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
    parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
    args = parser.parse_args()

    main(args)
```