File size: 10,498 Bytes
88f4188 ecf2484 880f628 604ed18 880f628 604ed18 70ce9b6 604ed18 88f4188 ecf2484 5aba20c ecf2484 70ce9b6 880f628 b194781 880f628 b194781 880f628 b194781 337b42c 0a52e5e d061b22 337b42c 0a52e5e 337b42c 5274368 337b42c b194781 7538790 b194781 604ed18 7538790 604ed18 b194781 ee09102 ea86856 880f628 ee09102 0a52e5e ee09102 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
---
license: cc-by-nc-4.0
library_name: transformers
datasets:
- BIOGRID
- Negatome
pipeline_tag: text-classification
tags:
- protein language model
- biology
widget:
- text: >-
M S H S V K I Y D T C I G C T Q C V R A C P T D V L E M I P W G G C K A K Q
I A S A P R T E D C V G C K R C E S A C P T D F L S V R V Y L W H E T T R S
M G L A Y [SEP] M I N L P S L F V P L V G L L F P A V A M A S L F L H V E K
R L L F S T K K I N
example_title: Non-interacting proteins
- text: >-
M S I N I C R D N H D P F Y R Y K M P P I Q A K V E G R G N G I K T A V L N
V A D I S H A L N R P A P Y I V K Y F G F E L G A Q T S I S V D K D R Y L V
N G V H E P A K L Q D V L D G F I N K F V L C G S C K N P E T E I I I T K D
N D L V R D C K A C G K R T P M D L R H K L S S F I L K N P P D S V S G S K
K K K K A A T A S A N V R G G G L S I S D I A Q G K S Q N A P S D G T G S S
T P Q H H D E D E D E L S R Q I K A A A S T L E D I E V K D D E W A V D M S
E E A I R A R A K E L E V N S E L T Q L D E Y G E W I L E Q A G E D K E N L
P S D V E L Y K K A A E L D V L N D P K I G C V L A Q C L F D E D I V N E I
A E H N A F F T K I L V T P E Y E K N F M G G I E R F L G L E H K D L I P L
L P K I L V Q L Y N N D I I S E E E I M R F G T K S S K K F V P K E V S K K
V R R A A K P F I T W L E T A E S D D D E E D D E [SEP] M S I E N L K S F D
P F A D T G D D E T A T S N Y I H I R I Q Q R N G R K T L T T V Q G V P E E
Y D L K R I L K V L K K D F A C N G N I V K D P E M G E I I Q L Q G D Q R A
K V C E F M I S Q L G L Q K K N I K I H G F
example_title: Interacting proteins
---
[SYNTERACT 2.0](https://huggingface.co/Synthyra/SYNTERACT2) is coming soon, please stay tuned!
<img src="https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Ro4uhQDurP-x7IHJj11xa.png" width="350">
## Model description
SYNTERACT (SYNThetic data-driven protein-protein intERACtion Transformer) is a fine-tuned version of [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd) that attends two amino acid sequences separated by [SEP] to determine if they plausibly interact in biological context.
We utilized the multivalidated physical interaction dataset from BIORGID, Negatome, and synthetic negative samples to train our model. Check out our [preprint](https://www.biorxiv.org/content/10.1101/2023.06.07.544109v1.full) for more details.
SYNTERACT achieved unprecedented performance over vast phylogeny with 92-96% accuracy on real unseen examples, and is already being used to accelerate drug target screening and peptide therapeutic design.
## How to use
```python
# Imports
import re
import torch
import torch.nn.functional as F
from transformers import BertForSequenceClassification, BertTokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # gather device
model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT', attn_implementation='sdpa').device.eval() # load model
tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT') # load tokenizer
sequence_a = 'MEKSCSIGNGREQYGWGHGEQCGTQFLECVYRNASMYSVLGDLITYVVFLGATCYAILFGFRLLLSCVRIVLKVVIALFVIRLLLALGSVDITSVSYSG' # Uniprot A1Z8T3
sequence_b = 'MRLTLLALIGVLCLACAYALDDSENNDQVVGLLDVADQGANHANDGAREARQLGGWGGGWGGRGGWGGRGGWGGRGGWGGRGGWGGGWGGRGGWGGRGGGWYGR' # Uniprot A1Z8H0
sequence_a = ' '.join(list(re.sub(r'[UZOB]', 'X', sequence_a))) # need spaces inbetween amino acids
sequence_b = ' '.join(list(re.sub(r'[UZOB]', 'X', sequence_b))) # replace rare amino acids with X
example = sequence_a + ' [SEP] ' + sequence_b # add SEP token
example = tokenizer(example, return_tensors='pt', padding=False).to(device) # tokenize example
with torch.no_grad():
logits = model(**example).logits.detach().cpu() # get logits from model
probability = F.softmax(logits, dim=-1) # use softmax to get "confidence" in the prediction
prediction = probability.argmax(dim=-1) # 0 for no interaction, 1 for interaction
```
## Intended use and limitations
We define a protein-protein interaction as physical contact that mediates chemical or conformational change, especially with non-generic function. However, due to SYNTERACT's propensity to predict false positives, we believe that it identifies plausible conformational changes caused by interactions without relevance to function.
## Our lab
The [Gleghorn lab](https://www.gleghornlab.com/) is an interdisciplinary research group at the University of Delaware that focuses on solving translational problems with our expertise in engineering, biology, and chemistry. We develop inexpensive and reliable tools to study organ development, maternal-fetal health, and drug delivery. Recently we have begun exploration into protein language models and strive to make protein design and annotation accessible.
## Please cite
```
@article {Hallee_ppi_2023,
author = {Logan Hallee and Jason P. Gleghorn},
title = {Protein-Protein Interaction Prediction is Achievable with Large Language Models},
year = {2023},
doi = {10.1101/2023.06.07.544109},
publisher = {Cold Spring Harbor Laboratory},
journal = {bioRxiv}
}
```
## A simple inference script
```python
import torch
import re
import argparse
import pandas as pd
from transformers import BertForSequenceClassification, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict
from tqdm.auto import tqdm
class PairDataset(Dataset):
def __init__(self, sequences_a: List[str], sequences_b: List[str]):
self.sequences_a = sequences_a
self.sequences_b = sequences_b
def __len__(self):
return len(self.sequences_a)
def __getitem__(self, idx: int) -> Tuple[str, str]:
return self.sequences_a[idx], self.sequences_b[idx]
class PairCollator:
def __init__(self, tokenizer, max_length=1024):
self.tokenizer = tokenizer
self.max_length = max_length
def sanitize_seq(self, seq: str) -> str:
seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq)))
return seq
def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]:
seqs_a, seqs_b, = zip(*batch)
seqs = []
for a, b in zip(seqs_a, seqs_b):
seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b)
seqs.append(seq)
seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt')
return {
'input_ids': seqs['input_ids'],
'attention_mask': seqs['attention_mask'],
}
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Loading model from {args.model_path}")
model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device)
# When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference
tokenizer = BertTokenizer.from_pretrained(args.model_path)
print(f"Tokenizer loaded")
"""
Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i]
We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT)
We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding
Example:
from datasets import load_dataset
data = load_dataset('Synthyra/NEGATOME', split='combined')
# Filter out examples where the total length exceeds 1022
data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022)
# Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB
data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])})
# Sort the dataset by 'total_length' in descending order (longest sequences first)
data = data.sort("total_length", reverse=True)
# Now retrieve the sorted sequences
sequences_a = data['SeqA']
sequences_b = data['SeqB']
"""
print("Loading data...")
sequences_a = []
sequences_b = []
print("Creating torch dataset...")
pair_dataset = PairDataset(sequences_a, sequences_b)
pair_collator = PairCollator(tokenizer, max_length=1024)
data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator)
all_seqs_a = []
all_seqs_b = []
all_probs = []
all_preds = []
print("Starting inference...")
with torch.no_grad():
for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")):
# Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu()
prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob
pred = torch.argmax(logits, dim=1)
# Store results
batch_start = i * args.batch_size
batch_end = min((i + 1) * args.batch_size, len(sequences_a))
all_seqs_a.extend(sequences_a[batch_start:batch_end])
all_seqs_b.extend(sequences_b[batch_start:batch_end])
all_probs.extend(prob_of_interaction.tolist())
all_preds.extend(pred.tolist())
# round to 5 decimal places
all_probs = [round(prob, 5) for prob in all_probs]
# Create dataframe and save to CSV
results_df = pd.DataFrame({
'sequence_a': all_seqs_a,
'sequence_b': all_seqs_b,
'probabilities': all_probs,
'prediction': all_preds
})
print(f"Saving results to {args.save_path}")
results_df.to_csv(args.save_path, index=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT')
parser.add_argument('--save_path', type=str, default='ppi_predictions.csv')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually
args = parser.parse_args()
main(args)
``` |