PyTorch
File size: 6,686 Bytes
3250adb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
---
license: mit
---
[Github repo](https://github.com/klemens-floege/oneprot/)|
[Paper link](https://arxiv.org/abs/2411.04863)


## Overview

OneProt is a multimodal model that integrates protein sequence, protein structure (both in form of an augmented sequence and in a form of a graph), protein binding sites and protein text annotations. Contrastive learning is used to align each of the modality to the central one, which is protein sequence. In the pre-training phase InfoNCE loss is computed between pairs (protein sequence, other modality).
This model **omitted the structure encoder based on the ESM2 architecture, leaving only the GNN encoder for structure in**, and therefore comprising only 4 out of possible 5 modalities.

## Model architecture

Protein sequence encoder: [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D)

Protein structure encoder GNN: [ProNet](https://github.com/divelab/DIG)

Pocket (binding sites encoder) GNN: [ProNet](https://github.com/divelab/DIG)

Text encoder: [BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext](https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext)

Below is an example code on how to obtain the embeddings (requires cloning our repo first). Note that example data for transformer models are read-off from `.txt` files and in principle can be passed as strings, whlist the data for GNN models are contained in the example `.h5` file and need to subsequently be converted to graphs.

```
import torch
import hydra
from omegaconf import OmegaConf
from huggingface_hub import  HfApi, hf_hub_download
import sys
import os
import h5py
from torch_geometric.data import Batch
from transformers import AutoTokenizer

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # assuming that you are running this script from the oneprot repo, can be any other path

from src.models.oneprot_module import OneProtLitModule
from src.data.utils.struct_graph_utils import protein_to_graph

###if you are not running on the supercomputer, you may need to uncomment the two following lines
#os.environ['RANK']='0'
#os.environ['WORLD_SIZE']='1'


#Load the config file and read it off

config_path = hf_hub_download(
        repo_id="HelmholtzAI-FZJ/oneprot-4",
        filename="config.yaml",
    )

with open(config_path, 'r') as f:
    cfg = OmegaConf.load(f)

# Prepare components dictionary from config
components = {
        'sequence': hydra.utils.instantiate(cfg.model.components.sequence),
        'struct_graph': hydra.utils.instantiate(cfg.model.components.struct_graph),
        'pocket': hydra.utils.instantiate(cfg.model.components.pocket),
        'text': hydra.utils.instantiate(cfg.model.components.text)
    }

# Load the model checkpoint

checkpoint_path = hf_hub_download(
        repo_id="HelmholtzAI-FZJ/oneprot-4",
        filename="pytorch_model.bin",
        repo_type="model"
    )

# Create model instance and load the checkpoint

model = OneProtLitModule(
        components=components,
        optimizer=None,
        loss_fn=cfg.model.loss_fn,
        local_loss=cfg.model.local_loss,
        gather_with_grad=cfg.model.gather_with_grad,
        use_l1_regularization=cfg.model.use_l1_regularization,
        train_on_all_modalities_after_step=cfg.model.train_on_all_modalities_after_step,
        use_seqsim=cfg.model.use_seqsim
    )

state_dict = torch.load(checkpoint_path)
model_state_dict = model.state_dict()
model.load_state_dict(state_dict, strict=True)

# Define the tokenisers

tokenizers = {
        'sequence': "facebook/esm2_t33_650M_UR50D",
        'struct_token': "facebook/esm2_t33_650M_UR50D",
        'text': "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
    }

loaded_tokenizers = {}
for modality, tokenizer_name in tokenizers.items():
    tokenizer = AutoTokenizer.from_pretrained(tokenizers[modality])
    loaded_tokenizers[modality] = tokenizer

# Get example embeddings for each modality


##########################sequence##############################

modality = "sequence"
 
file_path = hf_hub_download(
    repo_id="HelmholtzAI-FZJ/oneprot",
    filename="data_examples/sequence_example.txt",
    repo_type="model"  # or "dataset"
)

with open(file_path, 'r') as file:
    input_sequence = file.read().strip()

input_tensor = loaded_tokenizers[modality](input_sequence, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")


###########################text#################################

modality = "text"

file_path = hf_hub_download(
    repo_id="HelmholtzAI-FZJ/oneprot",
    filename="data_examples/text_example.txt",
    repo_type="model"  # or "dataset"
)

with open(file_path, 'r') as file:    
    input_text = file.read().strip()

input_tensor = loaded_tokenizers[modality](input_text, return_tensors="pt")["input_ids"]
output = model.network[modality](input_tensor)
print(f"Output for modality '{modality}': {output}")



#####################graph structure############################

modality = "struct_graph"  
file_path = hf_hub_download(
    repo_id="HelmholtzAI-FZJ/oneprot",
    filename="data_examples/seqstruc_example.h5",
    repo_type="model"  # or "dataset"
)

with h5py.File(file_path, 'r') as file:
    input_struct_graph=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=False)]
    input_struct_graph = Batch.from_data_list(input_struct_graph)
    output=model.network[modality](input_struct_graph)
print(f"Output for modality '{modality}': {output}")


##########################pocket################################


modality = "pocket"  # Replace with the desired modality

file_path = hf_hub_download(
    repo_id="HelmholtzAI-FZJ/oneprot",
    filename="data_examples/pocket_example.h5",
    repo_type="model"  # or "dataset"
)

with h5py.File(file_path, 'r') as file:
    input_pocket=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=True)]
    input_pocket = Batch.from_data_list(input_pocket)
    output=model.network[modality](input_pocket)
    
print(f"Output for modality '{modality}': {output}")

```

Citation

```
@misc{flöge2024oneprotmultimodalproteinfoundation,
      title={OneProt: Towards Multi-Modal Protein Foundation Models}, 
      author={Klemens Flöge and Srisruthi Udayakumar and Johanna Sommer and Marie Piraud and Stefan Kesselheim and Vincent Fortuin and Stephan Günneman and Karel J van der Weg and Holger Gohlke and Alina Bazarova and Erinc Merdivan},
      year={2024},
      eprint={2411.04863},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2411.04863}, 
}

```