root
commited on
Commit
·
c43fbc6
1
Parent(s):
3efa812
dependencies and embedding_exploration benchmark
Browse files- fuson_plm/README.md +588 -0
- fuson_plm/benchmarking/README.md +11 -0
- fuson_plm/benchmarking/__init__.py +0 -0
- fuson_plm/benchmarking/embed.py +296 -0
- fuson_plm/benchmarking/embedding_exploration/README.md +58 -0
- fuson_plm/benchmarking/embedding_exploration/__init__.py +0 -0
- fuson_plm/benchmarking/embedding_exploration/config.py +10 -0
- fuson_plm/benchmarking/embedding_exploration/data/salokas_2020_tableS3.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/data/tf_and_kinase_fusions.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/data/top_genes.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/plot.py +496 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_source_data.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_visualization.png +0 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv +3 -0
- fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png +0 -0
- fuson_plm/benchmarking/mutation_prediction/README.md +1 -1
- fuson_plm/benchmarking/puncta/train.py +1 -1
- fuson_plm/benchmarking/xgboost_predictor.py +65 -0
fuson_plm/README.md
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dependencies
|
| 2 |
+
|
| 3 |
+
Here we provied package versions needed to run FusOn-pLM code. For the project, Docker containers were used. We provide a pip list of what is inside the Docker container, as well as the images used for our containers.
|
| 4 |
+
|
| 5 |
+
## pip installs
|
| 6 |
+
|
| 7 |
+
The following dependencies were used for all training and benchmarking except for the `puncta` benchmarks.
|
| 8 |
+
Note that after cloning the repository, you will need to run `pip install e .` outside the `fuson_plm` directory to install `fuson_plm` package.
|
| 9 |
+
|
| 10 |
+
Package Version Editable project location
|
| 11 |
+
------------------------- -------------------- -------------------------
|
| 12 |
+
absl-py 1.4.0
|
| 13 |
+
aiohttp 3.8.4
|
| 14 |
+
aiosignal 1.3.1
|
| 15 |
+
apex 0.1
|
| 16 |
+
argon2-cffi 21.3.0
|
| 17 |
+
argon2-cffi-bindings 21.2.0
|
| 18 |
+
asttokens 2.2.1
|
| 19 |
+
astunparse 1.6.3
|
| 20 |
+
async-timeout 4.0.2
|
| 21 |
+
attrs 23.1.0
|
| 22 |
+
audioread 3.0.0
|
| 23 |
+
backcall 0.2.0
|
| 24 |
+
beautifulsoup4 4.12.2
|
| 25 |
+
bio 1.7.1
|
| 26 |
+
biopython 1.84
|
| 27 |
+
biothings-client 0.3.1
|
| 28 |
+
bleach 6.0.0
|
| 29 |
+
blis 0.7.10
|
| 30 |
+
cachetools 5.3.1
|
| 31 |
+
catalogue 2.0.9
|
| 32 |
+
certifi 2023.7.22
|
| 33 |
+
cffi 1.15.1
|
| 34 |
+
charset-normalizer 3.2.0
|
| 35 |
+
click 8.1.5
|
| 36 |
+
cloudpickle 2.2.1
|
| 37 |
+
cmake 3.27.1
|
| 38 |
+
comm 0.1.4
|
| 39 |
+
confection 0.1.1
|
| 40 |
+
contourpy 1.1.0
|
| 41 |
+
cubinlinker 0.3.0+2.g7c3675e
|
| 42 |
+
cuda-python 12.1.0rc5+1.g994d8d0
|
| 43 |
+
cudf 23.6.0
|
| 44 |
+
cugraph 23.6.0
|
| 45 |
+
cugraph-dgl 23.6.0
|
| 46 |
+
cugraph-service-client 23.6.0
|
| 47 |
+
cugraph-service-server 23.6.0
|
| 48 |
+
cuml 23.6.0
|
| 49 |
+
cupy-cuda12x 12.1.0
|
| 50 |
+
cycler 0.11.0
|
| 51 |
+
cymem 2.0.7
|
| 52 |
+
Cython 3.0.0
|
| 53 |
+
dask 2023.3.2
|
| 54 |
+
dask-cuda 23.6.0
|
| 55 |
+
dask-cudf 23.6.0
|
| 56 |
+
debugpy 1.6.7
|
| 57 |
+
decorator 5.1.1
|
| 58 |
+
defusedxml 0.7.1
|
| 59 |
+
distributed 2023.3.2.1
|
| 60 |
+
dm-tree 0.1.8
|
| 61 |
+
docker-pycreds 0.4.0
|
| 62 |
+
einops 0.6.1
|
| 63 |
+
exceptiongroup 1.1.2
|
| 64 |
+
execnet 2.0.2
|
| 65 |
+
executing 1.2.0
|
| 66 |
+
expecttest 0.1.3
|
| 67 |
+
fair-esm 2.0.0
|
| 68 |
+
fastjsonschema 2.18.0
|
| 69 |
+
fastrlock 0.8.1
|
| 70 |
+
filelock 3.12.2
|
| 71 |
+
flash-attn 2.0.4
|
| 72 |
+
fonttools 4.42.0
|
| 73 |
+
frozenlist 1.4.0
|
| 74 |
+
fsspec 2023.6.0
|
| 75 |
+
fuson-plm 1.0 /workspace/FusOn-pLM
|
| 76 |
+
gast 0.5.4
|
| 77 |
+
gdown 5.2.0
|
| 78 |
+
gitdb 4.0.11
|
| 79 |
+
GitPython 3.1.43
|
| 80 |
+
google-auth 2.22.0
|
| 81 |
+
google-auth-oauthlib 0.4.6
|
| 82 |
+
gprofiler-official 1.0.0
|
| 83 |
+
graphsurgeon 0.4.6
|
| 84 |
+
grpcio 1.56.2
|
| 85 |
+
huggingface-hub 0.25.2
|
| 86 |
+
hypothesis 5.35.1
|
| 87 |
+
idna 3.4
|
| 88 |
+
importlib-metadata 6.8.0
|
| 89 |
+
iniconfig 2.0.0
|
| 90 |
+
intel-openmp 2021.4.0
|
| 91 |
+
ipykernel 6.25.0
|
| 92 |
+
ipython 8.14.0
|
| 93 |
+
ipython-genutils 0.2.0
|
| 94 |
+
jedi 0.19.0
|
| 95 |
+
Jinja2 3.1.2
|
| 96 |
+
joblib 1.3.1
|
| 97 |
+
json5 0.9.14
|
| 98 |
+
jsonschema 4.18.6
|
| 99 |
+
jsonschema-specifications 2023.7.1
|
| 100 |
+
jupyter_client 8.3.0
|
| 101 |
+
jupyter_core 5.3.1
|
| 102 |
+
jupyter-tensorboard 0.2.0
|
| 103 |
+
jupyterlab 2.3.2
|
| 104 |
+
jupyterlab-pygments 0.2.2
|
| 105 |
+
jupyterlab-server 1.2.0
|
| 106 |
+
jupytext 1.15.0
|
| 107 |
+
kiwisolver 1.4.4
|
| 108 |
+
langcodes 3.3.0
|
| 109 |
+
librosa 0.9.2
|
| 110 |
+
lightning-utilities 0.11.8
|
| 111 |
+
llvmlite 0.40.1
|
| 112 |
+
locket 1.0.0
|
| 113 |
+
Markdown 3.4.4
|
| 114 |
+
markdown-it-py 3.0.0
|
| 115 |
+
MarkupSafe 2.1.3
|
| 116 |
+
matplotlib 3.7.2
|
| 117 |
+
matplotlib-inline 0.1.6
|
| 118 |
+
mdit-py-plugins 0.4.0
|
| 119 |
+
mdurl 0.1.2
|
| 120 |
+
mistune 3.0.1
|
| 121 |
+
mkl 2021.1.1
|
| 122 |
+
mkl-devel 2021.1.1
|
| 123 |
+
mkl-include 2021.1.1
|
| 124 |
+
mock 5.1.0
|
| 125 |
+
mpmath 1.3.0
|
| 126 |
+
msgpack 1.0.5
|
| 127 |
+
multidict 6.0.4
|
| 128 |
+
murmurhash 1.0.9
|
| 129 |
+
mygene 3.2.2
|
| 130 |
+
nbclient 0.8.0
|
| 131 |
+
nbconvert 7.7.3
|
| 132 |
+
nbformat 5.9.2
|
| 133 |
+
nest-asyncio 1.5.7
|
| 134 |
+
networkx 2.6.3
|
| 135 |
+
ninja 1.11.1
|
| 136 |
+
notebook 6.4.10
|
| 137 |
+
numba 0.57.1+1.gc785c8f1f
|
| 138 |
+
numpy 1.22.2
|
| 139 |
+
nvidia-cublas-cu12 12.4.5.8
|
| 140 |
+
nvidia-cuda-cupti-cu12 12.4.127
|
| 141 |
+
nvidia-cuda-nvrtc-cu12 12.4.127
|
| 142 |
+
nvidia-cuda-runtime-cu12 12.4.127
|
| 143 |
+
nvidia-cudnn-cu12 9.1.0.70
|
| 144 |
+
nvidia-cufft-cu12 11.2.1.3
|
| 145 |
+
nvidia-curand-cu12 10.3.5.147
|
| 146 |
+
nvidia-cusolver-cu12 11.6.1.9
|
| 147 |
+
nvidia-cusparse-cu12 12.3.1.170
|
| 148 |
+
nvidia-dali-cuda120 1.28.0
|
| 149 |
+
nvidia-nccl-cu12 2.21.5
|
| 150 |
+
nvidia-nvjitlink-cu12 12.4.127
|
| 151 |
+
nvidia-nvtx-cu12 12.4.127
|
| 152 |
+
nvidia-pyindex 1.0.9
|
| 153 |
+
nvtx 0.2.5
|
| 154 |
+
oauthlib 3.2.2
|
| 155 |
+
onnx 1.14.0
|
| 156 |
+
opencv 4.7.0
|
| 157 |
+
packaging 23.1
|
| 158 |
+
pandas 1.5.2
|
| 159 |
+
pandocfilters 1.5.0
|
| 160 |
+
parso 0.8.3
|
| 161 |
+
partd 1.4.0
|
| 162 |
+
pathy 0.10.2
|
| 163 |
+
pexpect 4.8.0
|
| 164 |
+
pickleshare 0.7.5
|
| 165 |
+
Pillow 9.2.0
|
| 166 |
+
pip 23.2.1
|
| 167 |
+
platformdirs 3.10.0
|
| 168 |
+
pluggy 1.2.0
|
| 169 |
+
ply 3.11
|
| 170 |
+
polygraphy 0.47.1
|
| 171 |
+
pooch 1.7.0
|
| 172 |
+
preshed 3.0.8
|
| 173 |
+
prettytable 3.8.0
|
| 174 |
+
prometheus-client 0.17.1
|
| 175 |
+
prompt-toolkit 3.0.39
|
| 176 |
+
protobuf 4.21.12
|
| 177 |
+
psutil 5.9.4
|
| 178 |
+
ptxcompiler 0.8.1+1.g4a94326
|
| 179 |
+
ptyprocess 0.7.0
|
| 180 |
+
pure-eval 0.2.2
|
| 181 |
+
py3Dmol 2.4.0
|
| 182 |
+
pyarrow 11.0.0
|
| 183 |
+
pyasn1 0.5.0
|
| 184 |
+
pyasn1-modules 0.3.0
|
| 185 |
+
pybind11 2.11.1
|
| 186 |
+
pycocotools 2.0+nv0.7.3
|
| 187 |
+
pycparser 2.21
|
| 188 |
+
pydantic 1.10.12
|
| 189 |
+
Pygments 2.16.1
|
| 190 |
+
pylibcugraph 23.6.0
|
| 191 |
+
pylibcugraphops 23.6.0
|
| 192 |
+
pylibraft 23.6.0
|
| 193 |
+
pynndescent 0.5.13
|
| 194 |
+
pynvml 11.4.1
|
| 195 |
+
pyparsing 3.0.9
|
| 196 |
+
PySocks 1.7.1
|
| 197 |
+
pytest 7.4.0
|
| 198 |
+
pytest-flakefinder 1.1.0
|
| 199 |
+
pytest-rerunfailures 12.0
|
| 200 |
+
pytest-shard 0.1.2
|
| 201 |
+
pytest-xdist 3.3.1
|
| 202 |
+
python-dateutil 2.8.2
|
| 203 |
+
python-hostlist 1.23.0
|
| 204 |
+
pytorch-lightning 2.4.0
|
| 205 |
+
pytorch-quantization 2.1.2
|
| 206 |
+
pytz 2023.3
|
| 207 |
+
PyYAML 6.0.1
|
| 208 |
+
pyzmq 25.1.0
|
| 209 |
+
raft-dask 23.6.0
|
| 210 |
+
referencing 0.30.2
|
| 211 |
+
regex 2023.6.3
|
| 212 |
+
requests 2.31.0
|
| 213 |
+
requests-oauthlib 1.3.1
|
| 214 |
+
resampy 0.4.2
|
| 215 |
+
rmm 23.6.0
|
| 216 |
+
rpds-py 0.9.2
|
| 217 |
+
rsa 4.9
|
| 218 |
+
safetensors 0.4.5
|
| 219 |
+
scikit-learn 1.2.0
|
| 220 |
+
scipy 1.11.1
|
| 221 |
+
seaborn 0.13.2
|
| 222 |
+
Send2Trash 1.8.2
|
| 223 |
+
sentencepiece 0.2.0
|
| 224 |
+
sentry-sdk 2.16.0
|
| 225 |
+
setproctitle 1.3.3
|
| 226 |
+
setuptools 68.0.0
|
| 227 |
+
six 1.16.0
|
| 228 |
+
smart-open 6.3.0
|
| 229 |
+
smmap 5.0.1
|
| 230 |
+
sortedcontainers 2.4.0
|
| 231 |
+
soundfile 0.12.1
|
| 232 |
+
soupsieve 2.4.1
|
| 233 |
+
spacy 3.6.0
|
| 234 |
+
spacy-legacy 3.0.12
|
| 235 |
+
spacy-loggers 1.0.4
|
| 236 |
+
sphinx-glpi-theme 0.3
|
| 237 |
+
srsly 2.4.7
|
| 238 |
+
stack-data 0.6.2
|
| 239 |
+
sympy 1.13.1
|
| 240 |
+
tabulate 0.9.0
|
| 241 |
+
tbb 2021.10.0
|
| 242 |
+
tblib 2.0.0
|
| 243 |
+
tensorboard 2.9.0
|
| 244 |
+
tensorboard-data-server 0.6.1
|
| 245 |
+
tensorboard-plugin-wit 1.8.1
|
| 246 |
+
tensorrt 8.6.1
|
| 247 |
+
terminado 0.17.1
|
| 248 |
+
thinc 8.1.10
|
| 249 |
+
threadpoolctl 3.2.0
|
| 250 |
+
thriftpy2 0.4.16
|
| 251 |
+
tinycss2 1.2.1
|
| 252 |
+
tokenizers 0.20.1
|
| 253 |
+
toml 0.10.2
|
| 254 |
+
tomli 2.0.1
|
| 255 |
+
toolz 0.12.0
|
| 256 |
+
torch 2.5.0
|
| 257 |
+
torch-tensorrt 2.0.0.dev0
|
| 258 |
+
torchdata 0.7.0a0
|
| 259 |
+
torchmetrics 1.5.0
|
| 260 |
+
torchtext 0.16.0a0
|
| 261 |
+
torchvision 0.16.0a0
|
| 262 |
+
tornado 6.3.2
|
| 263 |
+
tqdm 4.65.0
|
| 264 |
+
traitlets 5.9.0
|
| 265 |
+
transformer-engine 0.11.0+3f01b4f
|
| 266 |
+
transformers 4.45.2
|
| 267 |
+
treelite 3.2.0
|
| 268 |
+
treelite-runtime 3.2.0
|
| 269 |
+
triton 3.1.0
|
| 270 |
+
typer 0.9.0
|
| 271 |
+
types-dataclasses 0.6.6
|
| 272 |
+
typing_extensions 4.12.2
|
| 273 |
+
ucx-py 0.32.0
|
| 274 |
+
uff 0.6.9
|
| 275 |
+
umap-learn 0.5.6
|
| 276 |
+
urllib3 1.26.16
|
| 277 |
+
wandb 0.18.3
|
| 278 |
+
wasabi 1.1.2
|
| 279 |
+
wcwidth 0.2.6
|
| 280 |
+
webencodings 0.5.1
|
| 281 |
+
Werkzeug 2.3.6
|
| 282 |
+
wheel 0.41.1
|
| 283 |
+
xdoctest 1.0.2
|
| 284 |
+
xgboost 1.7.5
|
| 285 |
+
yarl 1.9.2
|
| 286 |
+
zict 3.0.0
|
| 287 |
+
zipp 3.16.2
|
| 288 |
+
|
| 289 |
+
The following packages and versions were used for the `puncta` benchmarks. A different environment was required to run ProtT5.
|
| 290 |
+
|
| 291 |
+
Package Version Editable project location
|
| 292 |
+
------------------------- -------------------------- -------------------------
|
| 293 |
+
absl-py 2.1.0
|
| 294 |
+
aiohttp 3.9.3
|
| 295 |
+
aiosignal 1.3.1
|
| 296 |
+
annotated-types 0.6.0
|
| 297 |
+
anyio 4.8.0
|
| 298 |
+
apex 0.1
|
| 299 |
+
argon2-cffi 23.1.0
|
| 300 |
+
argon2-cffi-bindings 21.2.0
|
| 301 |
+
asttokens 2.4.1
|
| 302 |
+
astunparse 1.6.3
|
| 303 |
+
async-timeout 4.0.3
|
| 304 |
+
attrs 23.2.0
|
| 305 |
+
audioread 3.0.1
|
| 306 |
+
beautifulsoup4 4.12.3
|
| 307 |
+
bio 1.7.1
|
| 308 |
+
biopython 1.85
|
| 309 |
+
biothings_client 0.4.1
|
| 310 |
+
bleach 6.1.0
|
| 311 |
+
blis 0.7.11
|
| 312 |
+
cachetools 5.3.3
|
| 313 |
+
catalogue 2.0.10
|
| 314 |
+
certifi 2024.2.2
|
| 315 |
+
cffi 1.16.0
|
| 316 |
+
charset-normalizer 3.3.2
|
| 317 |
+
click 8.1.7
|
| 318 |
+
cloudpathlib 0.16.0
|
| 319 |
+
cloudpickle 3.0.0
|
| 320 |
+
cmake 3.29.0.1
|
| 321 |
+
comm 0.2.2
|
| 322 |
+
confection 0.1.4
|
| 323 |
+
contourpy 1.2.1
|
| 324 |
+
cuda-python 12.4.0rc7+3.ge75c8a9.dirty
|
| 325 |
+
cudf 24.2.0
|
| 326 |
+
cudnn 1.1.2
|
| 327 |
+
cugraph 24.2.0
|
| 328 |
+
cugraph-dgl 24.2.0
|
| 329 |
+
cugraph-service-client 24.2.0
|
| 330 |
+
cugraph-service-server 24.2.0
|
| 331 |
+
cuml 24.2.0
|
| 332 |
+
cupy-cuda12x 13.0.0
|
| 333 |
+
cycler 0.12.1
|
| 334 |
+
cymem 2.0.8
|
| 335 |
+
Cython 3.0.10
|
| 336 |
+
dask 2024.1.1
|
| 337 |
+
dask-cuda 24.2.0
|
| 338 |
+
dask-cudf 24.2.0
|
| 339 |
+
debugpy 1.8.1
|
| 340 |
+
decorator 5.1.1
|
| 341 |
+
defusedxml 0.7.1
|
| 342 |
+
distributed 2024.1.1
|
| 343 |
+
dm-tree 0.1.8
|
| 344 |
+
docker-pycreds 0.4.0
|
| 345 |
+
einops 0.7.0
|
| 346 |
+
exceptiongroup 1.2.0
|
| 347 |
+
execnet 2.0.2
|
| 348 |
+
executing 2.0.1
|
| 349 |
+
expecttest 0.1.3
|
| 350 |
+
fair-esm 2.0.0
|
| 351 |
+
fastjsonschema 2.19.1
|
| 352 |
+
fastrlock 0.8.2
|
| 353 |
+
filelock 3.13.3
|
| 354 |
+
flash-attn 2.4.2
|
| 355 |
+
fonttools 4.51.0
|
| 356 |
+
frozenlist 1.4.1
|
| 357 |
+
fsspec 2024.2.0
|
| 358 |
+
fuson-plm 1.0 /workspace/FusOn-pLM
|
| 359 |
+
gast 0.5.4
|
| 360 |
+
gdown 5.2.0
|
| 361 |
+
gitdb 4.0.12
|
| 362 |
+
GitPython 3.1.44
|
| 363 |
+
google-auth 2.29.0
|
| 364 |
+
google-auth-oauthlib 0.4.6
|
| 365 |
+
gprofiler-official 1.0.0
|
| 366 |
+
graphsurgeon 0.4.6
|
| 367 |
+
grpcio 1.62.1
|
| 368 |
+
h11 0.14.0
|
| 369 |
+
httpcore 1.0.7
|
| 370 |
+
httpx 0.28.1
|
| 371 |
+
huggingface-hub 0.27.1
|
| 372 |
+
hypothesis 5.35.1
|
| 373 |
+
idna 3.6
|
| 374 |
+
igraph 0.11.4
|
| 375 |
+
importlib_metadata 7.0.2
|
| 376 |
+
iniconfig 2.0.0
|
| 377 |
+
intel-openmp 2021.4.0
|
| 378 |
+
ipykernel 6.29.4
|
| 379 |
+
ipython 8.21.0
|
| 380 |
+
ipython-genutils 0.2.0
|
| 381 |
+
jedi 0.19.1
|
| 382 |
+
Jinja2 3.1.3
|
| 383 |
+
joblib 1.3.2
|
| 384 |
+
json5 0.9.24
|
| 385 |
+
jsonschema 4.21.1
|
| 386 |
+
jsonschema-specifications 2023.12.1
|
| 387 |
+
jupyter_client 8.6.1
|
| 388 |
+
jupyter_core 5.7.2
|
| 389 |
+
jupyter-tensorboard 0.2.0
|
| 390 |
+
jupyterlab 2.3.2
|
| 391 |
+
jupyterlab_pygments 0.3.0
|
| 392 |
+
jupyterlab-server 1.2.0
|
| 393 |
+
jupytext 1.16.1
|
| 394 |
+
kiwisolver 1.4.5
|
| 395 |
+
langcodes 3.3.0
|
| 396 |
+
lark 1.1.9
|
| 397 |
+
lazy_loader 0.4
|
| 398 |
+
librosa 0.10.1
|
| 399 |
+
lightning-thunder 0.1.0
|
| 400 |
+
lightning-utilities 0.11.2
|
| 401 |
+
llvmlite 0.42.0
|
| 402 |
+
locket 1.0.0
|
| 403 |
+
looseversion 1.3.0
|
| 404 |
+
Markdown 3.6
|
| 405 |
+
markdown-it-py 3.0.0
|
| 406 |
+
MarkupSafe 2.1.5
|
| 407 |
+
matplotlib 3.8.4
|
| 408 |
+
matplotlib-inline 0.1.6
|
| 409 |
+
mdit-py-plugins 0.4.0
|
| 410 |
+
mdurl 0.1.2
|
| 411 |
+
mistune 3.0.2
|
| 412 |
+
mkl 2021.1.1
|
| 413 |
+
mkl-devel 2021.1.1
|
| 414 |
+
mkl-include 2021.1.1
|
| 415 |
+
mock 5.1.0
|
| 416 |
+
mpmath 1.3.0
|
| 417 |
+
msgpack 1.0.8
|
| 418 |
+
multidict 6.0.5
|
| 419 |
+
murmurhash 1.0.10
|
| 420 |
+
mygene 3.2.2
|
| 421 |
+
nbclient 0.10.0
|
| 422 |
+
nbconvert 7.16.3
|
| 423 |
+
nbformat 5.10.4
|
| 424 |
+
nest-asyncio 1.6.0
|
| 425 |
+
networkx 2.6.3
|
| 426 |
+
ninja 1.11.1.1
|
| 427 |
+
notebook 6.4.10
|
| 428 |
+
numba 0.59.0+1.g20ae2b56c
|
| 429 |
+
numpy 1.24.4
|
| 430 |
+
nvfuser 0.1.6a0+a684e2a
|
| 431 |
+
nvidia-dali-cuda120 1.36.0
|
| 432 |
+
nvidia-nvimgcodec-cu12 0.2.0.7
|
| 433 |
+
nvidia-pyindex 1.0.9
|
| 434 |
+
nvtx 0.2.5
|
| 435 |
+
oauthlib 3.2.2
|
| 436 |
+
onnx 1.16.0
|
| 437 |
+
opencv 4.7.0
|
| 438 |
+
opt-einsum 3.3.0
|
| 439 |
+
optree 0.11.0
|
| 440 |
+
packaging 23.2
|
| 441 |
+
pandas 1.5.3
|
| 442 |
+
pandocfilters 1.5.1
|
| 443 |
+
parso 0.8.4
|
| 444 |
+
partd 1.4.1
|
| 445 |
+
pexpect 4.9.0
|
| 446 |
+
pillow 10.2.0
|
| 447 |
+
pip 24.0
|
| 448 |
+
platformdirs 4.2.0
|
| 449 |
+
pluggy 1.4.0
|
| 450 |
+
ply 3.11
|
| 451 |
+
polygraphy 0.49.8
|
| 452 |
+
pooch 1.8.1
|
| 453 |
+
preshed 3.0.9
|
| 454 |
+
prettytable 3.10.0
|
| 455 |
+
prometheus_client 0.20.0
|
| 456 |
+
prompt-toolkit 3.0.43
|
| 457 |
+
protobuf 4.24.4
|
| 458 |
+
psutil 5.9.4
|
| 459 |
+
ptyprocess 0.7.0
|
| 460 |
+
pure-eval 0.2.2
|
| 461 |
+
py3Dmol 2.4.2
|
| 462 |
+
pyarrow 14.0.1
|
| 463 |
+
pyasn1 0.6.0
|
| 464 |
+
pyasn1_modules 0.4.0
|
| 465 |
+
pybind11 2.12.0
|
| 466 |
+
pybind11_global 2.12.0
|
| 467 |
+
pycocotools 2.0+nv0.8.0
|
| 468 |
+
pycparser 2.22
|
| 469 |
+
pydantic 2.6.4
|
| 470 |
+
pydantic_core 2.16.3
|
| 471 |
+
Pygments 2.17.2
|
| 472 |
+
pylibcugraph 24.2.0
|
| 473 |
+
pylibcugraphops 24.2.0
|
| 474 |
+
pylibraft 24.2.0
|
| 475 |
+
pynndescent 0.5.13
|
| 476 |
+
pynvjitlink 0.1.13
|
| 477 |
+
pynvml 11.4.1
|
| 478 |
+
pyparsing 3.1.2
|
| 479 |
+
PySocks 1.7.1
|
| 480 |
+
pytest 8.1.1
|
| 481 |
+
pytest-flakefinder 1.1.0
|
| 482 |
+
pytest-rerunfailures 14.0
|
| 483 |
+
pytest-shard 0.1.2
|
| 484 |
+
pytest-xdist 3.5.0
|
| 485 |
+
python-dateutil 2.9.0.post0
|
| 486 |
+
python-hostlist 1.23.0
|
| 487 |
+
pytorch-lightning 2.5.0.post0
|
| 488 |
+
pytorch-quantization 2.1.2
|
| 489 |
+
pytorch-triton 3.0.0+a9bc1a364
|
| 490 |
+
pytz 2024.1
|
| 491 |
+
PyYAML 6.0.1
|
| 492 |
+
pyzmq 25.1.2
|
| 493 |
+
raft-dask 24.2.0
|
| 494 |
+
rapids-dask-dependency 24.2.0a0
|
| 495 |
+
referencing 0.34.0
|
| 496 |
+
regex 2023.12.25
|
| 497 |
+
requests 2.31.0
|
| 498 |
+
requests-oauthlib 2.0.0
|
| 499 |
+
rich 13.7.1
|
| 500 |
+
rmm 24.2.0
|
| 501 |
+
rpds-py 0.18.0
|
| 502 |
+
rsa 4.9
|
| 503 |
+
safetensors 0.5.2
|
| 504 |
+
scikit-learn 1.2.0
|
| 505 |
+
scipy 1.12.0
|
| 506 |
+
seaborn 0.13.2
|
| 507 |
+
Send2Trash 1.8.2
|
| 508 |
+
sentencepiece 0.2.0
|
| 509 |
+
sentry-sdk 2.20.0
|
| 510 |
+
setproctitle 1.3.4
|
| 511 |
+
setuptools 68.2.2
|
| 512 |
+
six 1.16.0
|
| 513 |
+
smart-open 6.4.0
|
| 514 |
+
smmap 5.0.2
|
| 515 |
+
sniffio 1.3.1
|
| 516 |
+
sortedcontainers 2.4.0
|
| 517 |
+
soundfile 0.12.1
|
| 518 |
+
soupsieve 2.5
|
| 519 |
+
soxr 0.3.7
|
| 520 |
+
spacy 3.7.4
|
| 521 |
+
spacy-legacy 3.0.12
|
| 522 |
+
spacy-loggers 1.0.5
|
| 523 |
+
sphinx_glpi_theme 0.6
|
| 524 |
+
srsly 2.4.8
|
| 525 |
+
stack-data 0.6.3
|
| 526 |
+
sympy 1.12
|
| 527 |
+
tabulate 0.9.0
|
| 528 |
+
tbb 2021.12.0
|
| 529 |
+
tblib 3.0.0
|
| 530 |
+
tensorboard 2.9.0
|
| 531 |
+
tensorboard-data-server 0.6.1
|
| 532 |
+
tensorboard-plugin-wit 1.8.1
|
| 533 |
+
tensorrt 8.6.3
|
| 534 |
+
terminado 0.18.1
|
| 535 |
+
texttable 1.7.0
|
| 536 |
+
thinc 8.2.3
|
| 537 |
+
threadpoolctl 3.3.0
|
| 538 |
+
thriftpy2 0.4.17
|
| 539 |
+
tinycss2 1.2.1
|
| 540 |
+
tokenizers 0.21.0
|
| 541 |
+
toml 0.10.2
|
| 542 |
+
tomli 2.0.1
|
| 543 |
+
toolz 0.12.1
|
| 544 |
+
torch 2.3.0a0+6ddf5cf85e.nv24.4
|
| 545 |
+
torch-tensorrt 2.3.0a0
|
| 546 |
+
torchdata 0.7.1a0
|
| 547 |
+
torchmetrics 1.6.1
|
| 548 |
+
torchtext 0.17.0a0
|
| 549 |
+
torchvision 0.18.0a0
|
| 550 |
+
tornado 6.4
|
| 551 |
+
tqdm 4.66.2
|
| 552 |
+
traitlets 5.9.0
|
| 553 |
+
transformer-engine 1.5.0+6a9edc3
|
| 554 |
+
transformers 4.48.0
|
| 555 |
+
treelite 4.0.0
|
| 556 |
+
typer 0.9.4
|
| 557 |
+
types-dataclasses 0.6.6
|
| 558 |
+
typing_extensions 4.10.0
|
| 559 |
+
ucx-py 0.36.0
|
| 560 |
+
uff 0.6.9
|
| 561 |
+
umap-learn 0.5.7
|
| 562 |
+
urllib3 1.26.18
|
| 563 |
+
wandb 0.19.4
|
| 564 |
+
wasabi 1.1.2
|
| 565 |
+
wcwidth 0.2.13
|
| 566 |
+
weasel 0.3.4
|
| 567 |
+
webencodings 0.5.1
|
| 568 |
+
Werkzeug 3.0.2
|
| 569 |
+
wheel 0.43.0
|
| 570 |
+
xdoctest 1.0.2
|
| 571 |
+
xgboost 1.7.5
|
| 572 |
+
yarl 1.9.4
|
| 573 |
+
zict 3.0.0
|
| 574 |
+
zipp 3.17.0
|
| 575 |
+
|
| 576 |
+
## Docker
|
| 577 |
+
|
| 578 |
+
The following image was used for Container 1 (all code except puncta benchmark):
|
| 579 |
+
|
| 580 |
+
```
|
| 581 |
+
nvcr.io/nvidia/pytorch:23.08-py3
|
| 582 |
+
```
|
| 583 |
+
|
| 584 |
+
The following image was used for Container 2 (puncta benchmark):
|
| 585 |
+
|
| 586 |
+
```
|
| 587 |
+
nvcr.io/nvidia/pytorch:24.04-py3
|
| 588 |
+
```
|
fuson_plm/benchmarking/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmarking
|
| 2 |
+
|
| 3 |
+
This outer directory for the benchmarks in FusOn-pLM has some utility functions stored in `.py` files.
|
| 4 |
+
|
| 5 |
+
### embed.py
|
| 6 |
+
|
| 7 |
+
This file contains functions used to make and organize FusOn-pLM and ESM embeddings of benchmarking datasets. Its functions are used in all benchmarks.
|
| 8 |
+
|
| 9 |
+
### xgboost_predictor.py
|
| 10 |
+
|
| 11 |
+
This file contains functions used to train XGBoost predictors, which are utilized in the `puncta` benchmark.
|
fuson_plm/benchmarking/__init__.py
ADDED
|
File without changes
|
fuson_plm/benchmarking/embed.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python file for making embeddings from a FusOn-pLM model for any dataset
|
| 2 |
+
from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings
|
| 3 |
+
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
|
| 4 |
+
from fuson_plm.utils.data_cleaning import find_invalid_chars
|
| 5 |
+
from fuson_plm.utils.constants import VALID_AAS
|
| 6 |
+
from fuson_plm.training.model import FusOnpLM
|
| 7 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
|
| 8 |
+
import logging
|
| 9 |
+
import torch
|
| 10 |
+
import pickle
|
| 11 |
+
import os
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import numpy as np
|
| 14 |
+
|
| 15 |
+
def validate_sequence_col(df, seq_col):
|
| 16 |
+
# if column isn't there, error
|
| 17 |
+
if seq_col not in list(df.columns):
|
| 18 |
+
raise Exception("Error: provided sequence column does not exist in the input dataframe")
|
| 19 |
+
|
| 20 |
+
# if column contains invalid characters, error
|
| 21 |
+
df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS))
|
| 22 |
+
all_invalid_chars = set().union(*df['invalid_chars'])
|
| 23 |
+
df = df.drop(columns=['invalid_chars'])
|
| 24 |
+
if len(all_invalid_chars)>0:
|
| 25 |
+
raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column")
|
| 26 |
+
|
| 27 |
+
# make sure there are no duplicates
|
| 28 |
+
sequences = df[seq_col]
|
| 29 |
+
if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences")
|
| 30 |
+
|
| 31 |
+
def load_fuson_model(ckpt_path):
|
| 32 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
| 33 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 34 |
+
|
| 35 |
+
# Set device
|
| 36 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
+
print(f"Using device: {device}")
|
| 38 |
+
|
| 39 |
+
# Load model
|
| 40 |
+
model = AutoModel.from_pretrained(ckpt_path) # initialize model
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) # initialize tokenizer
|
| 42 |
+
|
| 43 |
+
# Model to device and in eval mode
|
| 44 |
+
model.to(device)
|
| 45 |
+
model.eval() # disables dropout for deterministic results
|
| 46 |
+
|
| 47 |
+
return model, tokenizer, device
|
| 48 |
+
|
| 49 |
+
def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000):
|
| 50 |
+
# Correct save path to pickle if necessary
|
| 51 |
+
if savepath is not None:
|
| 52 |
+
if savepath[-4::] != '.pkl': savepath += '.pkl'
|
| 53 |
+
|
| 54 |
+
if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.")
|
| 55 |
+
|
| 56 |
+
# If no max length was passed, just set it to the maximum in the dataset
|
| 57 |
+
max_seq_len = max([len(s) for s in sequences])
|
| 58 |
+
if max_length is None: max_length=max_seq_len+2 # add 2 for BOS, EOS
|
| 59 |
+
|
| 60 |
+
# Initialize an empty dict to store the ESM embeddings
|
| 61 |
+
embedding_dict = {}
|
| 62 |
+
# Iterate through the seqs
|
| 63 |
+
for i in range(len(sequences)):
|
| 64 |
+
sequence = sequences[i]
|
| 65 |
+
# Get the embeddings
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
# Tokenize the input sequence
|
| 68 |
+
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length)
|
| 69 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 70 |
+
|
| 71 |
+
outputs = model(**inputs)
|
| 72 |
+
# The embeddings are in the last_hidden_state tensor
|
| 73 |
+
embedding = outputs.last_hidden_state
|
| 74 |
+
# remove extra dimension
|
| 75 |
+
embedding = embedding.squeeze(0)
|
| 76 |
+
# remove BOS and EOS tokens
|
| 77 |
+
embedding = embedding[1:-1, :]
|
| 78 |
+
|
| 79 |
+
# Convert embeddings to numpy array (if needed)
|
| 80 |
+
embedding = embedding.cpu().numpy()
|
| 81 |
+
|
| 82 |
+
# Average (if necessary)
|
| 83 |
+
if average:
|
| 84 |
+
embedding = embedding.mean(0)
|
| 85 |
+
|
| 86 |
+
# Add to dictionary
|
| 87 |
+
embedding_dict[sequence] = embedding
|
| 88 |
+
|
| 89 |
+
# Save individual embedding (if necessary)
|
| 90 |
+
if not(savepath is None) and not(save_at_end):
|
| 91 |
+
with open(savepath, 'ab+') as f:
|
| 92 |
+
d = {sequence: embedding}
|
| 93 |
+
pickle.dump(d, f)
|
| 94 |
+
|
| 95 |
+
# Print update (if necessary)
|
| 96 |
+
if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...")
|
| 97 |
+
|
| 98 |
+
# Dump all at once at the end (if necessary)
|
| 99 |
+
if not(savepath is None):
|
| 100 |
+
# If saving for the first time, just dump it
|
| 101 |
+
if save_at_end:
|
| 102 |
+
with open(savepath, 'wb') as f:
|
| 103 |
+
pickle.dump(embedding_dict, f)
|
| 104 |
+
# If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely
|
| 105 |
+
else:
|
| 106 |
+
redump_pickle_dictionary(savepath)
|
| 107 |
+
|
| 108 |
+
def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000):
|
| 109 |
+
# Make sure we aren't overwriting pre-existing embeddings
|
| 110 |
+
if os.path.exists(path_to_output):
|
| 111 |
+
if overwrite:
|
| 112 |
+
log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten")
|
| 113 |
+
else:
|
| 114 |
+
log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.")
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
dataset = pd.read_csv(path_to_file)
|
| 118 |
+
# Make sure the sequence column is valid
|
| 119 |
+
validate_sequence_col(dataset, seq_col)
|
| 120 |
+
|
| 121 |
+
sequences = dataset[seq_col].unique().tolist() # ensure all entries are unique
|
| 122 |
+
|
| 123 |
+
### If FusOn-pLM: make fusion embeddings
|
| 124 |
+
if model_type=='fuson_plm':
|
| 125 |
+
if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist")
|
| 126 |
+
|
| 127 |
+
# Load model
|
| 128 |
+
try:
|
| 129 |
+
model, tokenizer, device = load_fuson_model(fuson_ckpt_path)
|
| 130 |
+
except:
|
| 131 |
+
raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}")
|
| 132 |
+
|
| 133 |
+
# Generate embeddigns
|
| 134 |
+
try:
|
| 135 |
+
get_fuson_embeddings(model, tokenizer, sequences, device, average=average,
|
| 136 |
+
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
|
| 137 |
+
max_length=max_length)
|
| 138 |
+
except:
|
| 139 |
+
raise Exception("Could not generate FusOn-pLM embeddings")
|
| 140 |
+
|
| 141 |
+
if model_type=='esm2_t33_650M_UR50D':
|
| 142 |
+
# Load model
|
| 143 |
+
try:
|
| 144 |
+
model, tokenizer, device = load_esm2_type(model_type)
|
| 145 |
+
except:
|
| 146 |
+
raise Exception(f"Could not load {model_type}")
|
| 147 |
+
# Generate embeddings
|
| 148 |
+
try:
|
| 149 |
+
get_esm_embeddings(model, tokenizer, sequences, device, average=average,
|
| 150 |
+
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
|
| 151 |
+
max_length=max_length)
|
| 152 |
+
except:
|
| 153 |
+
raise Exception(f"Could not generate {model_type} embeddings")
|
| 154 |
+
|
| 155 |
+
if model_type=="prot_t5_xl_half_uniref50_enc":
|
| 156 |
+
# Load model
|
| 157 |
+
try:
|
| 158 |
+
model, tokenizer, device = load_prott5()
|
| 159 |
+
except:
|
| 160 |
+
raise Exception(f"Could not load {model_type}")
|
| 161 |
+
# Generate embeddings
|
| 162 |
+
try:
|
| 163 |
+
get_prott5_embeddings(model, tokenizer, sequences, device, average=average,
|
| 164 |
+
print_updates=print_updates, savepath=path_to_output, save_at_end=False,
|
| 165 |
+
max_length=max_length)
|
| 166 |
+
except:
|
| 167 |
+
raise Exception(f"Could not generate {model_type} embeddings")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None):
|
| 171 |
+
# make directory for embeddings inside benchmarking dataset if one doesn't already eist
|
| 172 |
+
os.makedirs('embeddings',exist_ok=True)
|
| 173 |
+
|
| 174 |
+
# Extract input file name from configs
|
| 175 |
+
emb_type_tag ='average' if average else '2D'
|
| 176 |
+
|
| 177 |
+
all_embedding_paths = dict() # dictionary organized where embedding path points to model, epoch
|
| 178 |
+
|
| 179 |
+
# make the embedding files. Put them in an embedding directory
|
| 180 |
+
if benchmark_fusonplm:
|
| 181 |
+
os.makedirs('embeddings/fuson_plm',exist_ok=True)
|
| 182 |
+
|
| 183 |
+
log_update(f"\nMaking Fuson-PLM embeddings")
|
| 184 |
+
# make subdirs for all the
|
| 185 |
+
if type(fuson_ckpts)==dict:
|
| 186 |
+
for model_name, epoch_list in fuson_ckpts.items():
|
| 187 |
+
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
|
| 188 |
+
for epoch in epoch_list:
|
| 189 |
+
# Assemble ckpt path and throw error if it doesn't exist
|
| 190 |
+
fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}'
|
| 191 |
+
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
|
| 192 |
+
|
| 193 |
+
# Make output directory and output embedding path
|
| 194 |
+
embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}'
|
| 195 |
+
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
| 196 |
+
os.makedirs(embedding_output_dir,exist_ok=True)
|
| 197 |
+
|
| 198 |
+
# Make dictionary item
|
| 199 |
+
model_type = 'fuson_plm'
|
| 200 |
+
all_embedding_paths[embedding_output_path] = {
|
| 201 |
+
'model_type': model_type,
|
| 202 |
+
'model': model_name,
|
| 203 |
+
'epoch': epoch
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# Create embeddings (or skip if they're already made)
|
| 207 |
+
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
|
| 208 |
+
embed_dataset(input_data_path, embedding_output_path,
|
| 209 |
+
seq_col=seq_col, model_type=model_type,
|
| 210 |
+
fuson_ckpt_path=fuson_ckpt_path, average=average,
|
| 211 |
+
overwrite=overwrite,print_updates=True,
|
| 212 |
+
max_length=max_length)
|
| 213 |
+
elif fuson_ckpts=="FusOn-pLM":
|
| 214 |
+
model_name = "best"
|
| 215 |
+
os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True)
|
| 216 |
+
|
| 217 |
+
# Assemble ckpt path and throw error if it doesn't exist
|
| 218 |
+
fuson_ckpt_path = "../../.." # go back to the FusOn-pLM directory to find the best ckpt
|
| 219 |
+
if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}")
|
| 220 |
+
|
| 221 |
+
# Make output directory and output embedding path
|
| 222 |
+
embedding_output_dir = f'embeddings/fuson_plm/{model_name}'
|
| 223 |
+
embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
| 224 |
+
os.makedirs(embedding_output_dir,exist_ok=True)
|
| 225 |
+
|
| 226 |
+
# Make dictionary item
|
| 227 |
+
model_type = 'fuson_plm'
|
| 228 |
+
all_embedding_paths[embedding_output_path] = {
|
| 229 |
+
'model_type': model_type,
|
| 230 |
+
'model': model_name,
|
| 231 |
+
'epoch': None
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
# Create embeddings (or skip if they're already made)
|
| 235 |
+
log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...")
|
| 236 |
+
embed_dataset(input_data_path, embedding_output_path,
|
| 237 |
+
seq_col=seq_col, model_type=model_type,
|
| 238 |
+
fuson_ckpt_path=fuson_ckpt_path, average=average,
|
| 239 |
+
overwrite=overwrite,print_updates=True,
|
| 240 |
+
max_length=max_length)
|
| 241 |
+
else:
|
| 242 |
+
raise Exception(f"Error. fuson_ckpts should be a dict or str")
|
| 243 |
+
|
| 244 |
+
# make the embedding files. Put them in an embedding directory
|
| 245 |
+
if benchmark_esm:
|
| 246 |
+
os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True)
|
| 247 |
+
|
| 248 |
+
# make output path
|
| 249 |
+
embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
| 250 |
+
|
| 251 |
+
# Make dictioary item
|
| 252 |
+
model_type = 'esm2_t33_650M_UR50D'
|
| 253 |
+
all_embedding_paths[embedding_output_path] = {
|
| 254 |
+
'model_type': model_type,
|
| 255 |
+
'model': model_type,
|
| 256 |
+
'epoch': np.nan
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...")
|
| 260 |
+
embed_dataset(input_data_path, embedding_output_path,
|
| 261 |
+
seq_col=seq_col, model_type=model_type,
|
| 262 |
+
fuson_ckpt_path = None, average=average,
|
| 263 |
+
overwrite=overwrite,print_updates=True,
|
| 264 |
+
max_length=max_length)
|
| 265 |
+
|
| 266 |
+
if benchmark_prott5:
|
| 267 |
+
os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True)
|
| 268 |
+
|
| 269 |
+
# make output path
|
| 270 |
+
embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl'
|
| 271 |
+
|
| 272 |
+
# Make dictioary item
|
| 273 |
+
model_type = 'prot_t5_xl_half_uniref50_enc'
|
| 274 |
+
all_embedding_paths[embedding_output_path] = {
|
| 275 |
+
'model_type': model_type,
|
| 276 |
+
'model': model_type,
|
| 277 |
+
'epoch': np.nan
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...")
|
| 281 |
+
embed_dataset(input_data_path, embedding_output_path,
|
| 282 |
+
seq_col=seq_col, model_type=model_type,
|
| 283 |
+
fuson_ckpt_path = None, average=average,
|
| 284 |
+
overwrite=overwrite,print_updates=True,
|
| 285 |
+
max_length=max_length)
|
| 286 |
+
|
| 287 |
+
if benchmark_fo_puncta_ml:
|
| 288 |
+
embedding_output_path =f'FOdb_physicochemical_embeddings.pkl'
|
| 289 |
+
# Make dictionary item
|
| 290 |
+
all_embedding_paths[embedding_output_path] = {
|
| 291 |
+
'model_type': 'fo_puncta_ml',
|
| 292 |
+
'model': 'fo_puncta_ml',
|
| 293 |
+
'epoch': np.nan
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
return all_embedding_paths
|
fuson_plm/benchmarking/embedding_exploration/README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Embedding exploration
|
| 2 |
+
|
| 3 |
+
This folder contains all the data and code needed to run embedding exploration (Fig. S3).
|
| 4 |
+
|
| 5 |
+
### Data download
|
| 6 |
+
To help select TF (transcription factor) and Kinase-containing fusions for investigation (Fig. S3a), Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8) was downloaded as a reference of transcription factors and kinases.
|
| 7 |
+
|
| 8 |
+
```
|
| 9 |
+
benchmarking/
|
| 10 |
+
└── embedding_exploration/
|
| 11 |
+
└── data/
|
| 12 |
+
├── salokas_2020_tableS3.csv
|
| 13 |
+
├── tf_and_kinase_fusions.csv
|
| 14 |
+
├── top_genes.csv
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
- **`data/salokas_2020_tableS3.csv`**: Supplementary Table 3 from [Salokas et al. 2020](https://doi.org/10.1038/s41598-020-71040-8)
|
| 18 |
+
- **`data/tf_and_kinase_fusions.csv`**: set of TF::TF and Kinase::Kinase fusion oncoproteins from FusOn-DB database. Curated in `plot.py`
|
| 19 |
+
- **`data/top_genes.csv`**: fusion oncoproteins (and their head and tail components) visualized in Fig. S3b. Sequences for head and tail components were pulled from the best-aligned sequences in `fuson_plm/data/blast/blast_outputs/best_htg_alignments_swissprot_seqs.pkl`
|
| 20 |
+
|
| 21 |
+
### Plotting
|
| 22 |
+
|
| 23 |
+
Run `plot.py` to regenerate plots in Figure S3:
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
# Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
|
| 27 |
+
# # Or "FusOn-pLM" to use official model
|
| 28 |
+
FUSON_PLM_CKPT= "FusOn-pLM"
|
| 29 |
+
|
| 30 |
+
# Type of dim reduction
|
| 31 |
+
PLOT_UMAP = True
|
| 32 |
+
PLOT_TSNE = False
|
| 33 |
+
|
| 34 |
+
# Overwriting configs
|
| 35 |
+
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
To run, use:
|
| 39 |
+
```
|
| 40 |
+
nohup python plot.py > plot.out 2> plot.err &
|
| 41 |
+
```
|
| 42 |
+
- All **results** are stored in `embedding_exploration/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
|
| 43 |
+
|
| 44 |
+
Below are the FusOn-pLM paper results in `results/final/umap_plots/fuson_plm/best/`:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
benchmarking/
|
| 48 |
+
└── embedding_exploration/
|
| 49 |
+
└── results/final/umap_plots/fuson_plm/best/
|
| 50 |
+
└── favorites/
|
| 51 |
+
├── umap_favorites_source_data.csv
|
| 52 |
+
├── umap_favorites_visualization.png
|
| 53 |
+
└── tf_and_kinase/
|
| 54 |
+
├── umap_tf_and_kinase_fusions_source_data.csv ├── umap_tf_and_kinase_fusions_visualization.png
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
- **`favorites/umap_favorites_visualization.png`**: Fig. S3b, with the data directly plotted stored in `favorites/umap_favorites_source_data.csv`
|
| 58 |
+
- **`tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png`**: Fig. S3a, with the data directly plotted stored in `tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv`.
|
fuson_plm/benchmarking/embedding_exploration/__init__.py
ADDED
|
File without changes
|
fuson_plm/benchmarking/embedding_exploration/config.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dictionary: key = run name, values = epochs. (use this option if you've trained your own model)
|
| 2 |
+
# # Or, List: item goes to path (use this option if you're using the "best" ckpt from FusOn-pLM paper)
|
| 3 |
+
FUSON_PLM_CKPT= "FusOn-pLM"
|
| 4 |
+
|
| 5 |
+
# Type of dim reduction
|
| 6 |
+
PLOT_UMAP = True
|
| 7 |
+
PLOT_TSNE = False
|
| 8 |
+
|
| 9 |
+
# Overwriting configs
|
| 10 |
+
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
|
fuson_plm/benchmarking/embedding_exploration/data/salokas_2020_tableS3.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d8bebc0871a4329015a3c6c7843f5bbc86c48811b2a836c42f1ef46b37f4282a
|
| 3 |
+
size 19626
|
fuson_plm/benchmarking/embedding_exploration/data/tf_and_kinase_fusions.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:372321137ed12b2f8aa7c4891dafd0e88d64d5c5d0ea9c6f3a0aa9d897e8ead6
|
| 3 |
+
size 557262
|
fuson_plm/benchmarking/embedding_exploration/data/top_genes.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33d568fe413107318caebd5ee260ee66fe8571461ed8f8d1b47888441f7b5034
|
| 3 |
+
size 16695
|
fuson_plm/benchmarking/embedding_exploration/plot.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pickle
|
| 4 |
+
from sklearn.manifold import TSNE
|
| 5 |
+
import matplotlib.font_manager as fm
|
| 6 |
+
from matplotlib.font_manager import FontProperties
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib.gridspec as gridspec
|
| 9 |
+
import matplotlib.patches as patches
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
import umap
|
| 12 |
+
import os
|
| 13 |
+
|
| 14 |
+
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
| 15 |
+
import fuson_plm.benchmarking.embedding_exploration.config as config
|
| 16 |
+
from fuson_plm.utils.visualizing import set_font
|
| 17 |
+
from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS
|
| 18 |
+
from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_dimred_embeddings(embeddings, dimred_type="umap"):
|
| 22 |
+
if dimred_type=="umap":
|
| 23 |
+
dimred_embeddings = get_umap_embeddings(embeddings)
|
| 24 |
+
return dimred_embeddings
|
| 25 |
+
if dimred_type=="tsne":
|
| 26 |
+
dimred_embeddings = get_tsne_embeddings(embeddings)
|
| 27 |
+
return dimred_embeddings
|
| 28 |
+
|
| 29 |
+
def get_tsne_embeddings(embeddings):
|
| 30 |
+
embeddings = np.array(embeddings)
|
| 31 |
+
tsne = TSNE(n_components=2, random_state=42,perplexity=5)
|
| 32 |
+
tsne_embeddings = tsne.fit_transform(embeddings)
|
| 33 |
+
return tsne_embeddings
|
| 34 |
+
|
| 35 |
+
def get_umap_embeddings(embeddings):
|
| 36 |
+
embeddings = np.array(embeddings)
|
| 37 |
+
umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, metric='euclidean') # default parameters for UMAP
|
| 38 |
+
umap_embeddings = umap_model.fit_transform(embeddings)
|
| 39 |
+
return umap_embeddings
|
| 40 |
+
|
| 41 |
+
def plot_half_filled_circle(ax, x, y, left_color, right_color, size=100):
|
| 42 |
+
"""
|
| 43 |
+
Plots a circle filled in halves with specified colors.
|
| 44 |
+
|
| 45 |
+
Parameters:
|
| 46 |
+
- ax: Matplotlib axis to draw on.
|
| 47 |
+
- x, y: Coordinates of the marker.
|
| 48 |
+
- left_color: Color of the left half.
|
| 49 |
+
- right_color: Color of the right half.
|
| 50 |
+
- size: Size of the marker.
|
| 51 |
+
"""
|
| 52 |
+
radius = (size ** 0.5) / 100 # Scale the radius
|
| 53 |
+
# Create left half-circle (0° to 180°)
|
| 54 |
+
left_half = patches.Wedge((x, y), radius, 90, 270, color=left_color, ec="black")
|
| 55 |
+
# Create right half-circle (180° to 360°)
|
| 56 |
+
right_half = patches.Wedge((x, y), radius, 270, 90, color=right_color, ec="black")
|
| 57 |
+
|
| 58 |
+
# Add both halves to the plot
|
| 59 |
+
ax.add_patch(left_half)
|
| 60 |
+
ax.add_patch(right_half)
|
| 61 |
+
|
| 62 |
+
def plot_umap_scatter_tftf_kk(df, filename="umap.png"):
|
| 63 |
+
"""
|
| 64 |
+
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
|
| 65 |
+
Only for TF::TF and Kinase::Kinase fusions
|
| 66 |
+
|
| 67 |
+
Parameters:
|
| 68 |
+
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
|
| 69 |
+
"""
|
| 70 |
+
set_font()
|
| 71 |
+
|
| 72 |
+
# Define colors for each type
|
| 73 |
+
colors = {
|
| 74 |
+
"TF": "pink",
|
| 75 |
+
"Kinase": "orange"
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
# Define marker types and colors for each combination
|
| 79 |
+
marker_colors = {
|
| 80 |
+
"TF::TF": colors["TF"],
|
| 81 |
+
"Kinase::Kinase": colors["Kinase"],
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
# Create the plot
|
| 85 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 86 |
+
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
|
| 87 |
+
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
|
| 88 |
+
ax.set_xlim(x_min, x_max)
|
| 89 |
+
ax.set_ylim(y_min, y_max)
|
| 90 |
+
|
| 91 |
+
# Plot each point with the specified half-filled marker
|
| 92 |
+
for i in range(len(df)):
|
| 93 |
+
row = df.iloc[i]
|
| 94 |
+
marker_type = row["fusion_type"]
|
| 95 |
+
x, y = row["umap1"], row["umap2"]
|
| 96 |
+
color = marker_colors[marker_type]
|
| 97 |
+
|
| 98 |
+
ax.scatter(x, y, color=color, s=15, edgecolors="black", linewidth=0.5)
|
| 99 |
+
|
| 100 |
+
# Add custom legend
|
| 101 |
+
legend_elements = [
|
| 102 |
+
patches.Patch(facecolor="pink", edgecolor="black", label="TF::TF"),
|
| 103 |
+
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase::Kinase")
|
| 104 |
+
]
|
| 105 |
+
ax.legend(handles=legend_elements, title="Fusion Type", fontsize=16, title_fontsize=16)
|
| 106 |
+
|
| 107 |
+
# Add labels and title
|
| 108 |
+
plt.xlabel("UMAP 1", fontsize=20)
|
| 109 |
+
plt.ylabel("UMAP 2", fontsize=20)
|
| 110 |
+
plt.title("FusOn-pLM-embedded Transcription Factor and Kinase Fusions", fontsize=20)
|
| 111 |
+
plt.tight_layout()
|
| 112 |
+
|
| 113 |
+
# Save and show the plot
|
| 114 |
+
plt.savefig(filename, dpi=300)
|
| 115 |
+
plt.show()
|
| 116 |
+
|
| 117 |
+
def plot_umap_scatter_half_filled(df, filename="umap.png"):
|
| 118 |
+
"""
|
| 119 |
+
Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'.
|
| 120 |
+
|
| 121 |
+
Parameters:
|
| 122 |
+
- df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns.
|
| 123 |
+
"""
|
| 124 |
+
# Define colors for each type
|
| 125 |
+
colors = {
|
| 126 |
+
"TF": "pink",
|
| 127 |
+
"Kinase": "orange",
|
| 128 |
+
"Other": "grey"
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
# Define marker types and colors for each combination
|
| 132 |
+
marker_colors = {
|
| 133 |
+
"TF::TF": {"left": colors["TF"], "right": colors["TF"]},
|
| 134 |
+
"TF::Other": {"left": colors["TF"], "right": colors["Other"]},
|
| 135 |
+
"Other::TF": {"left": colors["Other"], "right": colors["TF"]},
|
| 136 |
+
"Kinase::Kinase": {"left": colors["Kinase"], "right": colors["Kinase"]},
|
| 137 |
+
"Kinase::Other": {"left": colors["Kinase"], "right": colors["Other"]},
|
| 138 |
+
"Other::Kinase": {"left": colors["Other"], "right": colors["Kinase"]},
|
| 139 |
+
"Kinase::TF": {"left": colors["Kinase"], "right": colors["TF"]},
|
| 140 |
+
"TF::Kinase": {"left": colors["TF"], "right": colors["Kinase"]},
|
| 141 |
+
"Other::Other": {"left": colors["Other"], "right": colors["Other"]}
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# Create the plot
|
| 145 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
| 146 |
+
x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1
|
| 147 |
+
y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1
|
| 148 |
+
ax.set_xlim(x_min, x_max)
|
| 149 |
+
ax.set_ylim(y_min, y_max)
|
| 150 |
+
|
| 151 |
+
# Plot each point with the specified half-filled marker
|
| 152 |
+
for i in range(len(df)):
|
| 153 |
+
row = df.iloc[i]
|
| 154 |
+
marker_type = row["fusion_type"]
|
| 155 |
+
x, y = row["umap1"], row["umap2"]
|
| 156 |
+
left_color = marker_colors[marker_type]["left"]
|
| 157 |
+
right_color = marker_colors[marker_type]["right"]
|
| 158 |
+
plot_half_filled_circle(ax, x, y, left_color, right_color, size=100)
|
| 159 |
+
|
| 160 |
+
# Add custom legend
|
| 161 |
+
legend_elements = [
|
| 162 |
+
patches.Patch(facecolor="pink", edgecolor="black", label="TF"),
|
| 163 |
+
patches.Patch(facecolor="orange", edgecolor="black", label="Kinase"),
|
| 164 |
+
patches.Patch(facecolor="grey", edgecolor="black", label="Other")
|
| 165 |
+
]
|
| 166 |
+
ax.legend(handles=legend_elements, title="Type")
|
| 167 |
+
|
| 168 |
+
# Add labels and title
|
| 169 |
+
plt.xlabel("UMAP 1")
|
| 170 |
+
plt.ylabel("UMAP 2")
|
| 171 |
+
plt.title("UMAP Scatter Plot")
|
| 172 |
+
plt.tight_layout()
|
| 173 |
+
|
| 174 |
+
# Save and show the plot
|
| 175 |
+
plt.savefig(filename, dpi=300)
|
| 176 |
+
plt.show()
|
| 177 |
+
|
| 178 |
+
def get_gene_type(gene, d):
|
| 179 |
+
if gene in d:
|
| 180 |
+
if d[gene] == 'kinase':
|
| 181 |
+
return 'Kinase'
|
| 182 |
+
if d[gene] == 'tf':
|
| 183 |
+
return 'TF'
|
| 184 |
+
else:
|
| 185 |
+
return 'Other'
|
| 186 |
+
|
| 187 |
+
def get_tf_and_kinase_fusions_dataset():
|
| 188 |
+
# Load TF and Kinase Fusions
|
| 189 |
+
tf_kinase_parts = pd.read_csv("data/salokas_2020_tableS3.csv")
|
| 190 |
+
print(tf_kinase_parts)
|
| 191 |
+
ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF']))
|
| 192 |
+
|
| 193 |
+
# This one has each row with one fusiongene name
|
| 194 |
+
fuson_ht_db = pd.read_csv("../../data/blast/fuson_ht_db.csv")
|
| 195 |
+
fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True)
|
| 196 |
+
|
| 197 |
+
fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
|
| 198 |
+
fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict))
|
| 199 |
+
fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type']
|
| 200 |
+
fuson_ht_db['type']=['fusion']*len(fuson_ht_db)
|
| 201 |
+
# Keep 100 things in each category
|
| 202 |
+
categories = pd.DataFrame(fuson_ht_db['fusion_type'].value_counts()).reset_index()['index'].tolist()
|
| 203 |
+
categories = ["TF::TF","Kinase::Kinase"] # manually set some easier categories
|
| 204 |
+
print(categories)
|
| 205 |
+
plot_df = None
|
| 206 |
+
|
| 207 |
+
for i, cat in enumerate(categories):
|
| 208 |
+
random_sample = fuson_ht_db.loc[fuson_ht_db['fusion_type']==cat].reset_index(drop=True)
|
| 209 |
+
#random_sample = random_sample.sample(n=100, random_state=1).reset_index(drop=True)
|
| 210 |
+
if i==0:
|
| 211 |
+
plot_df = random_sample
|
| 212 |
+
else:
|
| 213 |
+
plot_df = pd.concat([plot_df,random_sample],axis=0).reset_index(drop=True)
|
| 214 |
+
|
| 215 |
+
print(plot_df['fusion_type'].value_counts())
|
| 216 |
+
|
| 217 |
+
# Now, need to add in the embeddings
|
| 218 |
+
plot_df = plot_df[['aa_seq','fusiongenes','fusion_type','type']].rename(
|
| 219 |
+
columns={'aa_seq':'sequence','fusiongenes':'ID'}
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
return plot_df
|
| 223 |
+
|
| 224 |
+
def make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = '', dimred_type='umap'):
|
| 225 |
+
fuson_db = pd.read_csv("../../data/fuson_db.csv")
|
| 226 |
+
seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id']))
|
| 227 |
+
|
| 228 |
+
# add sequences so we can save results/sequence
|
| 229 |
+
data = seqs_with_embeddings[[f'{dimred_type}1',f'{dimred_type}2','sequence','fusion_type','ID']]
|
| 230 |
+
data['seq_id'] = data['sequence'].map(seq_id_dict)
|
| 231 |
+
|
| 232 |
+
tfkinase_save_dir = f"{savedir}"
|
| 233 |
+
os.makedirs(tfkinase_save_dir,exist_ok=True)
|
| 234 |
+
data.to_csv(f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_source_data.csv",index=False)
|
| 235 |
+
plot_umap_scatter_tftf_kk(data,filename=f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_visualization.png")
|
| 236 |
+
|
| 237 |
+
def tf_and_kinase_fusions_plot(dimred_types, output_dir):
|
| 238 |
+
"""
|
| 239 |
+
Makes the embeddings, THEN calls the plot. only on the four favorites
|
| 240 |
+
"""
|
| 241 |
+
plot_df = get_tf_and_kinase_fusions_dataset()
|
| 242 |
+
plot_df.to_csv("data/tf_and_kinase_fusions.csv",index=False)
|
| 243 |
+
|
| 244 |
+
# path to the pkl file with FOdb embeddings
|
| 245 |
+
input_fname='tf_and_kinase'
|
| 246 |
+
all_embedding_paths = embed_dataset_for_benchmark(
|
| 247 |
+
fuson_ckpts=config.FUSON_PLM_CKPT,
|
| 248 |
+
input_data_path='data/tf_and_kinase_fusions.csv', input_fname=input_fname,
|
| 249 |
+
average=True, seq_col='sequence',
|
| 250 |
+
benchmark_fusonplm=True,
|
| 251 |
+
benchmark_esm=False,
|
| 252 |
+
benchmark_fo_puncta_ml=False,
|
| 253 |
+
overwrite=config.PERMISSION_TO_OVERWRITE)
|
| 254 |
+
|
| 255 |
+
# For each of the models we are benchmarking, load embeddings and make plots
|
| 256 |
+
log_update("\nEmbedding sequences")
|
| 257 |
+
# loop through the embedding paths and train each one
|
| 258 |
+
for embedding_path, details in all_embedding_paths.items():
|
| 259 |
+
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
|
| 260 |
+
try:
|
| 261 |
+
with open(embedding_path, "rb") as f:
|
| 262 |
+
embeddings = pickle.load(f)
|
| 263 |
+
except:
|
| 264 |
+
raise Exception(f"Cannot read embeddings from {embedding_path}")
|
| 265 |
+
|
| 266 |
+
# combine the embeddings and splits into one dataframe
|
| 267 |
+
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
|
| 268 |
+
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
|
| 269 |
+
seqs_with_embeddings = pd.merge(seqs_with_embeddings, plot_df, on='sequence', how='inner')
|
| 270 |
+
# get UMAP transform of the embeddings
|
| 271 |
+
for dimred_type in dimred_types:
|
| 272 |
+
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
|
| 273 |
+
|
| 274 |
+
# turn the result into a dataframe, and add it to seqs_with_embeddings
|
| 275 |
+
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
|
| 276 |
+
# save the umap data!
|
| 277 |
+
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
|
| 278 |
+
|
| 279 |
+
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
|
| 280 |
+
|
| 281 |
+
# make subdirectory
|
| 282 |
+
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
|
| 283 |
+
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
|
| 284 |
+
|
| 285 |
+
os.makedirs(cur_output_dir,exist_ok=True)
|
| 286 |
+
make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
|
| 287 |
+
|
| 288 |
+
def make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = None, dimred_type='umap'):
|
| 289 |
+
"""
|
| 290 |
+
Make plots showing that PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK are embedded distinctly from their heads and tails
|
| 291 |
+
"""
|
| 292 |
+
set_font()
|
| 293 |
+
|
| 294 |
+
# Load one sequence each for four proteins in the test set: PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK
|
| 295 |
+
data = pd.read_csv("data/top_genes.csv")
|
| 296 |
+
seqs_with_embeddings = pd.merge(seqs_with_embeddings, data, on="sequence")
|
| 297 |
+
seqs_with_embeddings["Type"] = [""]*len(seqs_with_embeddings)
|
| 298 |
+
seqs_with_embeddings.loc[
|
| 299 |
+
seqs_with_embeddings["gene"].str.contains("::"),"Type"
|
| 300 |
+
] = "fusion_embeddings"
|
| 301 |
+
heads = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[0].tolist()
|
| 302 |
+
tails = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[1].tolist()
|
| 303 |
+
seqs_with_embeddings.loc[
|
| 304 |
+
seqs_with_embeddings["gene"].isin(heads),"Type"
|
| 305 |
+
] = "h_embeddings"
|
| 306 |
+
seqs_with_embeddings.loc[
|
| 307 |
+
seqs_with_embeddings["gene"].isin(tails),"Type"
|
| 308 |
+
] = "t_embeddings"
|
| 309 |
+
|
| 310 |
+
# make merge
|
| 311 |
+
merge = seqs_with_embeddings.loc[seqs_with_embeddings['gene'].str.contains('::')].reset_index(drop=True)[['gene','sequence']]
|
| 312 |
+
merge["head"] = merge["gene"].str.split("::",expand=True)[0]
|
| 313 |
+
merge["tail"] = merge["gene"].str.split("::",expand=True)[1]
|
| 314 |
+
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
|
| 315 |
+
columns={'gene': 'head', 'sequence': 'h_sequence'}),
|
| 316 |
+
on='head',how='left'
|
| 317 |
+
)
|
| 318 |
+
merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename(
|
| 319 |
+
columns={'gene': 'tail', 'sequence': 't_sequence'}),
|
| 320 |
+
on='tail',how='left'
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
plt.figure()
|
| 324 |
+
|
| 325 |
+
# Define colors and markers
|
| 326 |
+
colors = {
|
| 327 |
+
'fusion_embeddings': '#cf9dfa', # old color #0C4A4D
|
| 328 |
+
'h_embeddings': '#eb8888', # Updated to original names; old color #619283
|
| 329 |
+
't_embeddings': '#5fa3e3', # Updated to original names; old color #619283
|
| 330 |
+
}
|
| 331 |
+
markers = {
|
| 332 |
+
'fusion_embeddings': 'o',
|
| 333 |
+
'h_embeddings': '^', # Updated to original names
|
| 334 |
+
't_embeddings': 'v' # Updated to original names
|
| 335 |
+
}
|
| 336 |
+
label_map = {
|
| 337 |
+
'fusion_embeddings': 'Fusion',
|
| 338 |
+
'h_embeddings': 'Head', # Updated label
|
| 339 |
+
't_embeddings': 'Tail', # Updated label
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
# Create a 2x3 grid of plots
|
| 343 |
+
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
|
| 344 |
+
#fig, axes = plt.subplots(1, 4, figsize= (18, 7))
|
| 345 |
+
|
| 346 |
+
# Get the global min and max for the x and y axis ranges
|
| 347 |
+
all_tsne1 = seqs_with_embeddings[f'{dimred_type}1']
|
| 348 |
+
all_tsne2 = seqs_with_embeddings[f'{dimred_type}2']
|
| 349 |
+
x_min, x_max = all_tsne1.min(), all_tsne1.max()
|
| 350 |
+
y_min, y_max = all_tsne2.min(), all_tsne2.max()
|
| 351 |
+
x_min, x_max = [11, 16] # manually set range for cleaner plotting
|
| 352 |
+
y_min, y_max = [10, 22]
|
| 353 |
+
|
| 354 |
+
# Determine tick positions
|
| 355 |
+
x_ticks = np.arange(x_min, x_max + 1, 1)
|
| 356 |
+
y_ticks = np.arange(y_min, y_max + 1, 1)
|
| 357 |
+
|
| 358 |
+
# Flatten the axes array for easier iteration
|
| 359 |
+
axes = axes.flatten()
|
| 360 |
+
|
| 361 |
+
for i, ax in enumerate(axes):
|
| 362 |
+
# Extract the gene names from the current row
|
| 363 |
+
fgene_name = merge.loc[i, 'gene']
|
| 364 |
+
hgene = merge.loc[i, 'head']
|
| 365 |
+
tgene = merge.loc[i, 'tail']
|
| 366 |
+
|
| 367 |
+
# Filter tsne_embeddings for the relevant entries
|
| 368 |
+
tsne_data = seqs_with_embeddings[seqs_with_embeddings['gene'].isin([fgene_name, hgene, tgene])]
|
| 369 |
+
|
| 370 |
+
# Plot each type
|
| 371 |
+
for emb_type in tsne_data['Type'].unique():
|
| 372 |
+
subset = tsne_data[tsne_data['Type'] == emb_type]
|
| 373 |
+
ax.scatter(subset[f'{dimred_type}1'], subset[f'{dimred_type}2'], label=label_map[emb_type], color=colors[emb_type], marker=markers[emb_type], s=120, zorder=3)
|
| 374 |
+
|
| 375 |
+
ax.set_title(f'{fgene_name}',fontsize=44)
|
| 376 |
+
label_transform = {
|
| 377 |
+
'tsne': 't-SNE',
|
| 378 |
+
'umap': 'UMAP'
|
| 379 |
+
}
|
| 380 |
+
ax.set_xlabel(f'{label_transform[dimred_type]} 1',fontsize=44)
|
| 381 |
+
ax.set_ylabel(f'{label_transform[dimred_type]} 2',fontsize=44)
|
| 382 |
+
ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', zorder=1)
|
| 383 |
+
|
| 384 |
+
# Set the same limits and ticks for all axes
|
| 385 |
+
ax.set_xlim(x_min, x_max)
|
| 386 |
+
ax.set_ylim(y_min, y_max)
|
| 387 |
+
ax.set_xticks(x_ticks)#\\, labelsize=24)
|
| 388 |
+
ax.set_yticks(y_ticks)#, labelsize=24)
|
| 389 |
+
|
| 390 |
+
# Rotate x-axis labels
|
| 391 |
+
ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right')
|
| 392 |
+
|
| 393 |
+
ax.tick_params(axis='x', labelsize=16)
|
| 394 |
+
ax.tick_params(axis='y', labelsize=16)
|
| 395 |
+
|
| 396 |
+
for label in ax.get_xticklabels():
|
| 397 |
+
label.set_fontsize(24)
|
| 398 |
+
for label in ax.get_yticklabels():
|
| 399 |
+
label.set_fontsize(24)
|
| 400 |
+
|
| 401 |
+
# Set font size for the legend if needed
|
| 402 |
+
if i == 0:
|
| 403 |
+
legend = ax.legend(fontsize=20, markerscale=2, loc='best')
|
| 404 |
+
for text in legend.get_texts():
|
| 405 |
+
text.set_fontsize(24)
|
| 406 |
+
|
| 407 |
+
# Adjust layout to prevent overlap
|
| 408 |
+
plt.tight_layout()
|
| 409 |
+
|
| 410 |
+
# Show the plot
|
| 411 |
+
plt.show()
|
| 412 |
+
|
| 413 |
+
# Save the figure
|
| 414 |
+
plt.savefig(f'{savedir}/{dimred_type}_favorites_visualization.png', dpi=300)
|
| 415 |
+
|
| 416 |
+
# Save the data
|
| 417 |
+
seq_to_id_dict = pd.read_csv("../../data/fuson_db.csv")
|
| 418 |
+
seq_to_id_dict = dict(zip(seq_to_id_dict['aa_seq'],seq_to_id_dict['seq_id']))
|
| 419 |
+
seqs_with_embeddings['seq_id'] = seqs_with_embeddings['sequence'].map(seq_to_id_dict)
|
| 420 |
+
seqs_with_embeddings[['umap1','umap2','sequence','Type','gene','id','seq_id']].to_csv(f"{savedir}/{dimred_type}_favorites_source_data.csv",index=False)
|
| 421 |
+
|
| 422 |
+
def fusion_v_parts_favorites(dimred_types, output_dir):
|
| 423 |
+
"""
|
| 424 |
+
Makes the embeddings, THEN calls the plot. only on the four favorites
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
# path to the pkl file with FOdb embeddings
|
| 428 |
+
input_fname='favorites'
|
| 429 |
+
all_embedding_paths = embed_dataset_for_benchmark(
|
| 430 |
+
fuson_ckpts=config.FUSON_PLM_CKPT,
|
| 431 |
+
input_data_path='data/top_genes.csv', input_fname=input_fname,
|
| 432 |
+
average=True, seq_col='sequence',
|
| 433 |
+
benchmark_fusonplm=True,
|
| 434 |
+
benchmark_esm=False,
|
| 435 |
+
benchmark_fo_puncta_ml=False,
|
| 436 |
+
overwrite=config.PERMISSION_TO_OVERWRITE)
|
| 437 |
+
|
| 438 |
+
# For each of the models we are benchmarking, load embeddings and make plots
|
| 439 |
+
log_update("\nEmbedding sequences")
|
| 440 |
+
# loop through the embedding paths and train each one
|
| 441 |
+
for embedding_path, details in all_embedding_paths.items():
|
| 442 |
+
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
|
| 443 |
+
try:
|
| 444 |
+
with open(embedding_path, "rb") as f:
|
| 445 |
+
embeddings = pickle.load(f)
|
| 446 |
+
except:
|
| 447 |
+
raise Exception(f"Cannot read embeddings from {embedding_path}")
|
| 448 |
+
|
| 449 |
+
# combine the embeddings and splits into one dataframe
|
| 450 |
+
seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
|
| 451 |
+
seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) # the column that was called FusOn-pLM is now called embedding
|
| 452 |
+
|
| 453 |
+
# get UMAP transform of the embeddings
|
| 454 |
+
for dimred_type in dimred_types:
|
| 455 |
+
dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type)
|
| 456 |
+
|
| 457 |
+
# turn the result into a dataframe, and add it to seqs_with_embeddings
|
| 458 |
+
data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2'])
|
| 459 |
+
# save the umap data!
|
| 460 |
+
model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1])
|
| 461 |
+
|
| 462 |
+
seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data
|
| 463 |
+
|
| 464 |
+
# make subdirectory
|
| 465 |
+
intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1])
|
| 466 |
+
cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}"
|
| 467 |
+
|
| 468 |
+
os.makedirs(cur_output_dir,exist_ok=True)
|
| 469 |
+
make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type)
|
| 470 |
+
|
| 471 |
+
def main():
|
| 472 |
+
# make directory to save results
|
| 473 |
+
os.makedirs('results',exist_ok=True)
|
| 474 |
+
output_dir = f'results/{get_local_time()}'
|
| 475 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 476 |
+
|
| 477 |
+
dimred_types = []
|
| 478 |
+
if config.PLOT_UMAP:
|
| 479 |
+
dimred_types.append("umap")
|
| 480 |
+
#os.makedirs(f"{output_dir}/umap_data",exist_ok=True)
|
| 481 |
+
os.makedirs(f"{output_dir}/umap_plots",exist_ok=True)
|
| 482 |
+
if config.PLOT_TSNE:
|
| 483 |
+
dimred_types.append("tsne")
|
| 484 |
+
#os.makedirs(f"{output_dir}/tsne_data",exist_ok=True)
|
| 485 |
+
os.makedirs(f"{output_dir}/tsne_plots",exist_ok=True)
|
| 486 |
+
|
| 487 |
+
with open_logfile(f'{output_dir}/embedding_exploration_log.txt'):
|
| 488 |
+
print_configpy(config)
|
| 489 |
+
# make the disinct embeddings plot
|
| 490 |
+
fusion_v_parts_favorites(dimred_types, output_dir)
|
| 491 |
+
|
| 492 |
+
tf_and_kinase_fusions_plot(dimred_types, output_dir)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
if __name__ == "__main__":
|
| 496 |
+
main()
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_source_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28c0b51f513da01df3dee3c4e71aa0c583bd57d9878137bdac9e7ebc704694e4
|
| 3 |
+
size 17383
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/favorites/umap_favorites_visualization.png
ADDED
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_source_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b26a5a6c2f8f54225fd46f01dab52813532438732624561af8e2e4ad005e5dc7
|
| 3 |
+
size 570073
|
fuson_plm/benchmarking/embedding_exploration/results/final/umap_plots/fuson_plm/best/tf_and_kinase/umap_tf_and_kinase_fusions_visualization.png
ADDED
|
fuson_plm/benchmarking/mutation_prediction/README.md
CHANGED
|
@@ -81,7 +81,7 @@ To run, use:
|
|
| 81 |
```
|
| 82 |
nohup python discover.py > discover.out 2> discover.err &
|
| 83 |
```
|
| 84 |
-
- All **results** are stored in `
|
| 85 |
|
| 86 |
Below are the FusOn-pLM paper results in `results/final`:
|
| 87 |
|
|
|
|
| 81 |
```
|
| 82 |
nohup python discover.py > discover.out 2> discover.err &
|
| 83 |
```
|
| 84 |
+
- All **results** are stored in `mutation_prediction/results/<timestamp>`, where `timestamp` is a unique string encoding the date and time when you started training.
|
| 85 |
|
| 86 |
Below are the FusOn-pLM paper results in `results/final`:
|
| 87 |
|
fuson_plm/benchmarking/puncta/train.py
CHANGED
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
| 5 |
import pickle
|
| 6 |
import os
|
| 7 |
|
| 8 |
-
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
|
| 9 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
| 10 |
import fuson_plm.benchmarking.puncta.config as config
|
| 11 |
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
|
|
|
|
| 5 |
import pickle
|
| 6 |
import os
|
| 7 |
|
| 8 |
+
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor
|
| 9 |
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
| 10 |
import fuson_plm.benchmarking.puncta.config as config
|
| 11 |
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
|
fuson_plm/benchmarking/xgboost_predictor.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from sklearn.model_selection import train_test_split, StratifiedKFold
|
| 2 |
+
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score, roc_auc_score, average_precision_score
|
| 3 |
+
from fuson_plm.utils.logging import log_update
|
| 4 |
+
import time
|
| 5 |
+
import xgboost as xgb
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
def train_final_predictor(X_train, y_train, n_estimators=50,tree_method="hist"):
|
| 10 |
+
clf = xgb.XGBClassifier(n_estimators=n_estimators, tree_method=tree_method)
|
| 11 |
+
clf.fit(X_train, y_train)
|
| 12 |
+
return clf
|
| 13 |
+
|
| 14 |
+
def evaluate_predictor(clf,X_test,y_test,class1_thresh=None):
|
| 15 |
+
# Predicting the labels on test set
|
| 16 |
+
y_pred_test = clf.predict(X_test) # labels with automatic thresholds
|
| 17 |
+
y_pred_prob_test = clf.predict_proba(X_test)[:, 1]
|
| 18 |
+
if class1_thresh is not None: y_pred_customthresh_test = np.where(np.array(y_pred_prob_test) >= class1_thresh, 1, 0)
|
| 19 |
+
|
| 20 |
+
# Calculating metrics - automatic
|
| 21 |
+
accuracy = accuracy_score(y_test, y_pred_test)
|
| 22 |
+
precision = precision_score(y_test, y_pred_test)
|
| 23 |
+
recall = recall_score(y_test, y_pred_test)
|
| 24 |
+
f1 = f1_score(y_test, y_pred_test)
|
| 25 |
+
auroc_prob = roc_auc_score(y_test, y_pred_prob_test)
|
| 26 |
+
auprc_prob = average_precision_score(y_test, y_pred_prob_test)
|
| 27 |
+
auroc_label = roc_auc_score(y_test, y_pred_test)
|
| 28 |
+
auprc_label = average_precision_score(y_test, y_pred_test)
|
| 29 |
+
|
| 30 |
+
automatic_stats_df = pd.DataFrame(data={
|
| 31 |
+
'Accuracy': [accuracy],
|
| 32 |
+
'Precision': [precision],
|
| 33 |
+
'Recall': [recall],
|
| 34 |
+
'F1 Score': [f1],
|
| 35 |
+
'AUROC': [auroc_prob],
|
| 36 |
+
'AUROC Label': [auroc_label],
|
| 37 |
+
'AUPRC': [auprc_prob],
|
| 38 |
+
'AUPRC Label': [auprc_label]
|
| 39 |
+
})
|
| 40 |
+
|
| 41 |
+
# Calculating metrics - custom threshold (note that probability ones won't change)
|
| 42 |
+
if class1_thresh is not None:
|
| 43 |
+
accuracy_custom = accuracy_score(y_test, y_pred_customthresh_test)
|
| 44 |
+
precision_custom = precision_score(y_test, y_pred_customthresh_test)
|
| 45 |
+
recall_custom = recall_score(y_test, y_pred_customthresh_test)
|
| 46 |
+
f1_custom = f1_score(y_test, y_pred_customthresh_test)
|
| 47 |
+
auroc_prob_custom = roc_auc_score(y_test, y_pred_prob_test)
|
| 48 |
+
auprc_prob_custom = average_precision_score(y_test, y_pred_prob_test)
|
| 49 |
+
auroc_label_custom = roc_auc_score(y_test, y_pred_customthresh_test)
|
| 50 |
+
auprc_label_custom = average_precision_score(y_test, y_pred_customthresh_test)
|
| 51 |
+
|
| 52 |
+
custom_stats_df = pd.DataFrame(data={
|
| 53 |
+
'Accuracy': [accuracy_custom],
|
| 54 |
+
'Precision': [precision_custom],
|
| 55 |
+
'Recall': [recall_custom],
|
| 56 |
+
'F1 Score': [f1_custom],
|
| 57 |
+
'AUROC': [auroc_prob_custom],
|
| 58 |
+
'AUROC Label': [auroc_label_custom],
|
| 59 |
+
'AUPRC': [auprc_prob_custom],
|
| 60 |
+
'AUPRC Label': [auprc_label_custom]
|
| 61 |
+
})
|
| 62 |
+
else:
|
| 63 |
+
custom_stats_df = None
|
| 64 |
+
|
| 65 |
+
return automatic_stats_df, custom_stats_df
|