PyTorch
sealinka commited on
Commit
3250adb
·
verified ·
1 Parent(s): 28a2698

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +192 -0
README.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ [Github repo](https://github.com/klemens-floege/oneprot/)|
5
+ [Paper link](https://arxiv.org/abs/2411.04863)
6
+
7
+
8
+ ## Overview
9
+
10
+ OneProt is a multimodal model that integrates protein sequence, protein structure (both in form of an augmented sequence and in a form of a graph), protein binding sites and protein text annotations. Contrastive learning is used to align each of the modality to the central one, which is protein sequence. In the pre-training phase InfoNCE loss is computed between pairs (protein sequence, other modality).
11
+ This model **omitted the structure encoder based on the ESM2 architecture, leaving only the GNN encoder for structure in**, and therefore comprising only 4 out of possible 5 modalities.
12
+
13
+ ## Model architecture
14
+
15
+ Protein sequence encoder: [esm2_t33_650M_UR50D](https://huggingface.co/facebook/esm2_t33_650M_UR50D)
16
+
17
+ Protein structure encoder GNN: [ProNet](https://github.com/divelab/DIG)
18
+
19
+ Pocket (binding sites encoder) GNN: [ProNet](https://github.com/divelab/DIG)
20
+
21
+ Text encoder: [BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext](https://huggingface.co/microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext)
22
+
23
+ Below is an example code on how to obtain the embeddings (requires cloning our repo first). Note that example data for transformer models are read-off from `.txt` files and in principle can be passed as strings, whlist the data for GNN models are contained in the example `.h5` file and need to subsequently be converted to graphs.
24
+
25
+ ```
26
+ import torch
27
+ import hydra
28
+ from omegaconf import OmegaConf
29
+ from huggingface_hub import HfApi, hf_hub_download
30
+ import sys
31
+ import os
32
+ import h5py
33
+ from torch_geometric.data import Batch
34
+ from transformers import AutoTokenizer
35
+
36
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # assuming that you are running this script from the oneprot repo, can be any other path
37
+
38
+ from src.models.oneprot_module import OneProtLitModule
39
+ from src.data.utils.struct_graph_utils import protein_to_graph
40
+
41
+ ###if you are not running on the supercomputer, you may need to uncomment the two following lines
42
+ #os.environ['RANK']='0'
43
+ #os.environ['WORLD_SIZE']='1'
44
+
45
+
46
+ #Load the config file and read it off
47
+
48
+ config_path = hf_hub_download(
49
+ repo_id="HelmholtzAI-FZJ/oneprot-4",
50
+ filename="config.yaml",
51
+ )
52
+
53
+ with open(config_path, 'r') as f:
54
+ cfg = OmegaConf.load(f)
55
+
56
+ # Prepare components dictionary from config
57
+ components = {
58
+ 'sequence': hydra.utils.instantiate(cfg.model.components.sequence),
59
+ 'struct_graph': hydra.utils.instantiate(cfg.model.components.struct_graph),
60
+ 'pocket': hydra.utils.instantiate(cfg.model.components.pocket),
61
+ 'text': hydra.utils.instantiate(cfg.model.components.text)
62
+ }
63
+
64
+ # Load the model checkpoint
65
+
66
+ checkpoint_path = hf_hub_download(
67
+ repo_id="HelmholtzAI-FZJ/oneprot-4",
68
+ filename="pytorch_model.bin",
69
+ repo_type="model"
70
+ )
71
+
72
+ # Create model instance and load the checkpoint
73
+
74
+ model = OneProtLitModule(
75
+ components=components,
76
+ optimizer=None,
77
+ loss_fn=cfg.model.loss_fn,
78
+ local_loss=cfg.model.local_loss,
79
+ gather_with_grad=cfg.model.gather_with_grad,
80
+ use_l1_regularization=cfg.model.use_l1_regularization,
81
+ train_on_all_modalities_after_step=cfg.model.train_on_all_modalities_after_step,
82
+ use_seqsim=cfg.model.use_seqsim
83
+ )
84
+
85
+ state_dict = torch.load(checkpoint_path)
86
+ model_state_dict = model.state_dict()
87
+ model.load_state_dict(state_dict, strict=True)
88
+
89
+ # Define the tokenisers
90
+
91
+ tokenizers = {
92
+ 'sequence': "facebook/esm2_t33_650M_UR50D",
93
+ 'struct_token': "facebook/esm2_t33_650M_UR50D",
94
+ 'text': "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
95
+ }
96
+
97
+ loaded_tokenizers = {}
98
+ for modality, tokenizer_name in tokenizers.items():
99
+ tokenizer = AutoTokenizer.from_pretrained(tokenizers[modality])
100
+ loaded_tokenizers[modality] = tokenizer
101
+
102
+ # Get example embeddings for each modality
103
+
104
+
105
+ ##########################sequence##############################
106
+
107
+ modality = "sequence"
108
+
109
+ file_path = hf_hub_download(
110
+ repo_id="HelmholtzAI-FZJ/oneprot",
111
+ filename="data_examples/sequence_example.txt",
112
+ repo_type="model" # or "dataset"
113
+ )
114
+
115
+ with open(file_path, 'r') as file:
116
+ input_sequence = file.read().strip()
117
+
118
+ input_tensor = loaded_tokenizers[modality](input_sequence, return_tensors="pt")["input_ids"]
119
+ output = model.network[modality](input_tensor)
120
+ print(f"Output for modality '{modality}': {output}")
121
+
122
+
123
+ ###########################text#################################
124
+
125
+ modality = "text"
126
+
127
+ file_path = hf_hub_download(
128
+ repo_id="HelmholtzAI-FZJ/oneprot",
129
+ filename="data_examples/text_example.txt",
130
+ repo_type="model" # or "dataset"
131
+ )
132
+
133
+ with open(file_path, 'r') as file:
134
+ input_text = file.read().strip()
135
+
136
+ input_tensor = loaded_tokenizers[modality](input_text, return_tensors="pt")["input_ids"]
137
+ output = model.network[modality](input_tensor)
138
+ print(f"Output for modality '{modality}': {output}")
139
+
140
+
141
+
142
+ #####################graph structure############################
143
+
144
+ modality = "struct_graph"
145
+ file_path = hf_hub_download(
146
+ repo_id="HelmholtzAI-FZJ/oneprot",
147
+ filename="data_examples/seqstruc_example.h5",
148
+ repo_type="model" # or "dataset"
149
+ )
150
+
151
+ with h5py.File(file_path, 'r') as file:
152
+ input_struct_graph=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=False)]
153
+ input_struct_graph = Batch.from_data_list(input_struct_graph)
154
+ output=model.network[modality](input_struct_graph)
155
+ print(f"Output for modality '{modality}': {output}")
156
+
157
+
158
+ ##########################pocket################################
159
+
160
+
161
+ modality = "pocket" # Replace with the desired modality
162
+
163
+ file_path = hf_hub_download(
164
+ repo_id="HelmholtzAI-FZJ/oneprot",
165
+ filename="data_examples/pocket_example.h5",
166
+ repo_type="model" # or "dataset"
167
+ )
168
+
169
+ with h5py.File(file_path, 'r') as file:
170
+ input_pocket=[protein_to_graph('E6Y2X0', file_path, 'non_pdb', 'A', pockets=True)]
171
+ input_pocket = Batch.from_data_list(input_pocket)
172
+ output=model.network[modality](input_pocket)
173
+
174
+ print(f"Output for modality '{modality}': {output}")
175
+
176
+ ```
177
+
178
+ Citation
179
+
180
+ ```
181
+ @misc{flöge2024oneprotmultimodalproteinfoundation,
182
+ title={OneProt: Towards Multi-Modal Protein Foundation Models},
183
+ author={Klemens Flöge and Srisruthi Udayakumar and Johanna Sommer and Marie Piraud and Stefan Kesselheim and Vincent Fortuin and Stephan Günneman and Karel J van der Weg and Holger Gohlke and Alina Bazarova and Erinc Merdivan},
184
+ year={2024},
185
+ eprint={2411.04863},
186
+ archivePrefix={arXiv},
187
+ primaryClass={cs.LG},
188
+ url={https://arxiv.org/abs/2411.04863},
189
+ }
190
+
191
+ ```
192
+