|
|
--- |
|
|
license: cc-by-nc-4.0 |
|
|
library_name: transformers |
|
|
datasets: |
|
|
- BIOGRID |
|
|
- Negatome |
|
|
pipeline_tag: text-classification |
|
|
tags: |
|
|
- protein language model |
|
|
- biology |
|
|
widget: |
|
|
- text: >- |
|
|
M S H S V K I Y D T C I G C T Q C V R A C P T D V L E M I P W G G C K A K Q |
|
|
I A S A P R T E D C V G C K R C E S A C P T D F L S V R V Y L W H E T T R S |
|
|
M G L A Y [SEP] M I N L P S L F V P L V G L L F P A V A M A S L F L H V E K |
|
|
R L L F S T K K I N |
|
|
example_title: Non-interacting proteins |
|
|
- text: >- |
|
|
M S I N I C R D N H D P F Y R Y K M P P I Q A K V E G R G N G I K T A V L N |
|
|
V A D I S H A L N R P A P Y I V K Y F G F E L G A Q T S I S V D K D R Y L V |
|
|
N G V H E P A K L Q D V L D G F I N K F V L C G S C K N P E T E I I I T K D |
|
|
N D L V R D C K A C G K R T P M D L R H K L S S F I L K N P P D S V S G S K |
|
|
K K K K A A T A S A N V R G G G L S I S D I A Q G K S Q N A P S D G T G S S |
|
|
T P Q H H D E D E D E L S R Q I K A A A S T L E D I E V K D D E W A V D M S |
|
|
E E A I R A R A K E L E V N S E L T Q L D E Y G E W I L E Q A G E D K E N L |
|
|
P S D V E L Y K K A A E L D V L N D P K I G C V L A Q C L F D E D I V N E I |
|
|
A E H N A F F T K I L V T P E Y E K N F M G G I E R F L G L E H K D L I P L |
|
|
L P K I L V Q L Y N N D I I S E E E I M R F G T K S S K K F V P K E V S K K |
|
|
V R R A A K P F I T W L E T A E S D D D E E D D E [SEP] M S I E N L K S F D |
|
|
P F A D T G D D E T A T S N Y I H I R I Q Q R N G R K T L T T V Q G V P E E |
|
|
Y D L K R I L K V L K K D F A C N G N I V K D P E M G E I I Q L Q G D Q R A |
|
|
K V C E F M I S Q L G L Q K K N I K I H G F |
|
|
example_title: Interacting proteins |
|
|
--- |
|
|
|
|
|
[SYNTERACT 2.0](https://huggingface.co/Synthyra/SYNTERACT2) is coming soon, please stay tuned! |
|
|
|
|
|
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/Ro4uhQDurP-x7IHJj11xa.png" width="350"> |
|
|
|
|
|
## Model description |
|
|
|
|
|
SYNTERACT (SYNThetic data-driven protein-protein intERACtion Transformer) is a fine-tuned version of [ProtBERT](https://huggingface.co/Rostlab/prot_bert_bfd) that attends two amino acid sequences separated by [SEP] to determine if they plausibly interact in biological context. |
|
|
|
|
|
We utilized the multivalidated physical interaction dataset from BIORGID, Negatome, and synthetic negative samples to train our model. Check out our [preprint](https://www.biorxiv.org/content/10.1101/2023.06.07.544109v1.full) for more details. |
|
|
|
|
|
SYNTERACT achieved unprecedented performance over vast phylogeny with 92-96% accuracy on real unseen examples, and is already being used to accelerate drug target screening and peptide therapeutic design. |
|
|
|
|
|
|
|
|
## How to use |
|
|
|
|
|
```python |
|
|
# Imports |
|
|
import re |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # gather device |
|
|
model = BertForSequenceClassification.from_pretrained('GleghornLab/SYNTERACT', attn_implementation='sdpa').device.eval() # load model |
|
|
tokenizer = BertTokenizer.from_pretrained('GleghornLab/SYNTERACT') # load tokenizer |
|
|
|
|
|
sequence_a = 'MEKSCSIGNGREQYGWGHGEQCGTQFLECVYRNASMYSVLGDLITYVVFLGATCYAILFGFRLLLSCVRIVLKVVIALFVIRLLLALGSVDITSVSYSG' # Uniprot A1Z8T3 |
|
|
sequence_b = 'MRLTLLALIGVLCLACAYALDDSENNDQVVGLLDVADQGANHANDGAREARQLGGWGGGWGGRGGWGGRGGWGGRGGWGGRGGWGGGWGGRGGWGGRGGGWYGR' # Uniprot A1Z8H0 |
|
|
sequence_a = ' '.join(list(re.sub(r'[UZOB]', 'X', sequence_a))) # need spaces inbetween amino acids |
|
|
sequence_b = ' '.join(list(re.sub(r'[UZOB]', 'X', sequence_b))) # replace rare amino acids with X |
|
|
example = sequence_a + ' [SEP] ' + sequence_b # add SEP token |
|
|
|
|
|
example = tokenizer(example, return_tensors='pt', padding=False).to(device) # tokenize example |
|
|
with torch.no_grad(): |
|
|
logits = model(**example).logits.detach().cpu() # get logits from model |
|
|
|
|
|
probability = F.softmax(logits, dim=-1) # use softmax to get "confidence" in the prediction |
|
|
prediction = probability.argmax(dim=-1) # 0 for no interaction, 1 for interaction |
|
|
``` |
|
|
|
|
|
## Intended use and limitations |
|
|
We define a protein-protein interaction as physical contact that mediates chemical or conformational change, especially with non-generic function. However, due to SYNTERACT's propensity to predict false positives, we believe that it identifies plausible conformational changes caused by interactions without relevance to function. |
|
|
|
|
|
## Our lab |
|
|
The [Gleghorn lab](https://www.gleghornlab.com/) is an interdisciplinary research group at the University of Delaware that focuses on solving translational problems with our expertise in engineering, biology, and chemistry. We develop inexpensive and reliable tools to study organ development, maternal-fetal health, and drug delivery. Recently we have begun exploration into protein language models and strive to make protein design and annotation accessible. |
|
|
|
|
|
## Please cite |
|
|
``` |
|
|
@article {Hallee_ppi_2023, |
|
|
author = {Logan Hallee and Jason P. Gleghorn}, |
|
|
title = {Protein-Protein Interaction Prediction is Achievable with Large Language Models}, |
|
|
year = {2023}, |
|
|
doi = {10.1101/2023.06.07.544109}, |
|
|
publisher = {Cold Spring Harbor Laboratory}, |
|
|
journal = {bioRxiv} |
|
|
} |
|
|
``` |
|
|
|
|
|
|
|
|
## A simple inference script |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import re |
|
|
import argparse |
|
|
import pandas as pd |
|
|
from transformers import BertForSequenceClassification, BertTokenizer |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from typing import List, Tuple, Dict |
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
class PairDataset(Dataset): |
|
|
def __init__(self, sequences_a: List[str], sequences_b: List[str]): |
|
|
self.sequences_a = sequences_a |
|
|
self.sequences_b = sequences_b |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.sequences_a) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[str, str]: |
|
|
return self.sequences_a[idx], self.sequences_b[idx] |
|
|
|
|
|
|
|
|
class PairCollator: |
|
|
def __init__(self, tokenizer, max_length=1024): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
def sanitize_seq(self, seq: str) -> str: |
|
|
seq = ' '.join(list(re.sub(r'[UZOB]', 'X', seq))) |
|
|
return seq |
|
|
|
|
|
def __call__(self, batch: List[Tuple[str, str]]) -> Dict[str, torch.Tensor]: |
|
|
seqs_a, seqs_b, = zip(*batch) |
|
|
seqs = [] |
|
|
for a, b in zip(seqs_a, seqs_b): |
|
|
seq = self.sanitize_seq(a) + ' [SEP] ' + self.sanitize_seq(b) |
|
|
seqs.append(seq) |
|
|
seqs = self.tokenizer(seqs, padding='longest', truncation=True, max_length=self.max_length, return_tensors='pt') |
|
|
return { |
|
|
'input_ids': seqs['input_ids'], |
|
|
'attention_mask': seqs['attention_mask'], |
|
|
} |
|
|
|
|
|
|
|
|
def main(args): |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {device}") |
|
|
print(f"Loading model from {args.model_path}") |
|
|
model = BertForSequenceClassification.from_pretrained(args.model_path, attn_implementation="sdpa").eval().to(device) |
|
|
# When using PyTorch >= 2.5.1 on a linux machine, spda attention will greatly speed up inference |
|
|
tokenizer = BertTokenizer.from_pretrained(args.model_path) |
|
|
print(f"Tokenizer loaded") |
|
|
|
|
|
""" |
|
|
Load your data into two lists of sequences, where you want the PPI for each pair sequences_a[i], sequences_b[i] |
|
|
We recommend trimmed sequence pairs that sum over 1022 tokens (for the 1024 max length limit of SYNTERACT) |
|
|
We also recommend sorting the sequences by length in descending order, as this will speed up inference by reducing padding |
|
|
|
|
|
Example: |
|
|
from datasets import load_dataset |
|
|
data = load_dataset('Synthyra/NEGATOME', split='combined') |
|
|
# Filter out examples where the total length exceeds 1022 |
|
|
data = data.filter(lambda x: len(x['SeqA']) + len(x['SeqB']) <= 1022) |
|
|
# Add a new column 'total_length' that is the sum of lengths of SeqA and SeqB |
|
|
data = data.map(lambda x: {"total_length": len(x['SeqA']) + len(x['SeqB'])}) |
|
|
# Sort the dataset by 'total_length' in descending order (longest sequences first) |
|
|
data = data.sort("total_length", reverse=True) |
|
|
# Now retrieve the sorted sequences |
|
|
sequences_a = data['SeqA'] |
|
|
sequences_b = data['SeqB'] |
|
|
""" |
|
|
print("Loading data...") |
|
|
sequences_a = [] |
|
|
sequences_b = [] |
|
|
|
|
|
print("Creating torch dataset...") |
|
|
pair_dataset = PairDataset(sequences_a, sequences_b) |
|
|
pair_collator = PairCollator(tokenizer, max_length=1024) |
|
|
data_loader = DataLoader(pair_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=pair_collator) |
|
|
|
|
|
all_seqs_a = [] |
|
|
all_seqs_b = [] |
|
|
all_probs = [] |
|
|
all_preds = [] |
|
|
|
|
|
print("Starting inference...") |
|
|
with torch.no_grad(): |
|
|
for i, batch in enumerate(tqdm(data_loader, total=len(data_loader), desc="Batches processed")): |
|
|
# Because sequences are sorted, the initial estimate for time will be much longer than the actual time it will take |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
logits = model(input_ids, attention_mask=attention_mask).logits.detach().cpu() |
|
|
|
|
|
prob_of_interaction = torch.softmax(logits, dim=1)[:, 1] # can do 1 - this for no interaction prob |
|
|
pred = torch.argmax(logits, dim=1) |
|
|
|
|
|
# Store results |
|
|
batch_start = i * args.batch_size |
|
|
batch_end = min((i + 1) * args.batch_size, len(sequences_a)) |
|
|
all_seqs_a.extend(sequences_a[batch_start:batch_end]) |
|
|
all_seqs_b.extend(sequences_b[batch_start:batch_end]) |
|
|
all_probs.extend(prob_of_interaction.tolist()) |
|
|
all_preds.extend(pred.tolist()) |
|
|
|
|
|
# round to 5 decimal places |
|
|
all_probs = [round(prob, 5) for prob in all_probs] |
|
|
|
|
|
# Create dataframe and save to CSV |
|
|
results_df = pd.DataFrame({ |
|
|
'sequence_a': all_seqs_a, |
|
|
'sequence_b': all_seqs_b, |
|
|
'probabilities': all_probs, |
|
|
'prediction': all_preds |
|
|
}) |
|
|
print(f"Saving results to {args.save_path}") |
|
|
results_df.to_csv(args.save_path, index=False) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model_path', type=str, default='GleghornLab/SYNTERACT') |
|
|
parser.add_argument('--save_path', type=str, default='ppi_predictions.csv') |
|
|
parser.add_argument('--batch_size', type=int, default=2) |
|
|
parser.add_argument('--num_workers', type=int, default=0) # can increase to use multiprocessing for dataloader, 4 is a good value usually |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args) |
|
|
``` |