Sophia Vincoff
commited on
Commit
·
8d9d9da
1
Parent(s):
3527fb2
puncta benchmark
Browse files- fuson_plm/benchmarking/puncta/FOdb_physicochemical_embeddings.pkl +3 -0
- fuson_plm/benchmarking/puncta/README.md +95 -0
- fuson_plm/benchmarking/puncta/__init__.py +0 -0
- fuson_plm/benchmarking/puncta/clean.py +184 -0
- fuson_plm/benchmarking/puncta/cleaned_dataset_s4.csv +3 -0
- fuson_plm/benchmarking/puncta/cleaning_log.txt +3 -0
- fuson_plm/benchmarking/puncta/config.py +17 -0
- fuson_plm/benchmarking/puncta/plot.py +244 -0
- fuson_plm/benchmarking/puncta/results/final/cytoplasm_verificationFOs_results.csv +3 -0
- fuson_plm/benchmarking/puncta/results/final/figures/cytoplasm_verificationFOs_barchart.png +0 -0
- fuson_plm/benchmarking/puncta/results/final/figures/cytoplasm_verificationFOs_barchart_source_data.csv +3 -0
- fuson_plm/benchmarking/puncta/results/final/figures/formation_verificationFOs_0.83thresh_barchart.png +0 -0
- fuson_plm/benchmarking/puncta/results/final/figures/formation_verificationFOs_0.83thresh_barchart_source_data.csv +3 -0
- fuson_plm/benchmarking/puncta/results/final/figures/nucleus_verificationFOs_barchart.png +0 -0
- fuson_plm/benchmarking/puncta/results/final/figures/nucleus_verificationFOs_barchart_source_data.csv +3 -0
- fuson_plm/benchmarking/puncta/results/final/formation_verificationFOs_0.83thresh_results.csv +3 -0
- fuson_plm/benchmarking/puncta/results/final/nucleus_verificationFOs_results.csv +3 -0
- fuson_plm/benchmarking/puncta/splits.csv +3 -0
- fuson_plm/benchmarking/puncta/train.py +155 -0
fuson_plm/benchmarking/puncta/FOdb_physicochemical_embeddings.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d78986f0724138ed83c72fd4274154ca3f96a09f5fd8ad94030493375788006
|
| 3 |
+
size 168405
|
fuson_plm/benchmarking/puncta/README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Puncta Prediction Benchmark
|
| 2 |
+
|
| 3 |
+
This folder contains all the data and code needed to perform the **puncta prediction benchmark** (Figure 3).
|
| 4 |
+
|
| 5 |
+
### From raw data to train/test splits
|
| 6 |
+
To train the puncta predictors, we processed raw data from FOdb [(Tripathi et al. 2023)](https://doi.org/10.1038/s41467-023-41655-2) Supplementary dataset 4 (`fuson_plm/data/raw_data/FOdb_puncta.csv`) and Supplementary dataset 5 (`fuson_plm/data/raw_data/FODb_SD5.csv`) using the file `clean.py` in the `puncta` directory.
|
| 7 |
+
|
| 8 |
+
```
|
| 9 |
+
data/
|
| 10 |
+
└── raw_data/
|
| 11 |
+
├── FOdb_puncta.csv
|
| 12 |
+
├── FOdb_SD5.csv
|
| 13 |
+
|
| 14 |
+
benchmarking/
|
| 15 |
+
└── puncta/
|
| 16 |
+
├── clean.py
|
| 17 |
+
├── cleaned_dataset_s4.csv
|
| 18 |
+
├── splits.csv
|
| 19 |
+
├── FOdb_physicochemical_embeddings.pkl
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
The `clean.py` script generates the following files:
|
| 23 |
+
- **`cleaned_dataset_s4.csv`**: clean version of `FOdb_puncta.csv`, where fusion oncoproteins with puncta status "Other" or "Nucleolar" have been removed, and only the 25 low-MI features from `FOdb_SD5.csv' are retained.
|
| 24 |
+
- **`splits.csv`**: fusion oncoproteins from `cleaned_dataset_s4.csv`, labeled in the `split` column as either being part of the *train* set ("Expressed_Set" in FOdb) or *test* set ("Verification_Set" in FOdb). This dataset also features `nucleus`, `cytoplasm`, and `formation` columns of 1s and 0s. In `nucleus`, 1=forms a condensate in the nucleus, 0=does not; in `cytoplasm`, 1=forms a condensate in the cytoplasm, 0=does not; in `formation`, 1=forms a condensate at all, 0=does not.
|
| 25 |
+
- **`FOdb_physicochemical_embeddings.pkl`**: a dictionary where fusion proteins from `splits.csv` are they keys, and their feature vectors of 25 low-MI features from `cleaned_dataset_s4.csv` are the values.
|
| 26 |
+
|
| 27 |
+
### Training
|
| 28 |
+
|
| 29 |
+
`config.py` holds training configuations.
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
# Benchmarking configs
|
| 33 |
+
BENCHMARK_FUSONPLM = True # True if you want to benchmark a FusOn-pLM Model
|
| 34 |
+
|
| 35 |
+
# FUSONPLM_CKPTS. If you've traiend your own model, this is a dictionary: key = run name, values = epochs
|
| 36 |
+
# If you want to use the trained FusOn-pLM, instead FUSONPLM_CKPTS="FusOn-pLM"
|
| 37 |
+
FUSONPLM_CKPTS= {}
|
| 38 |
+
|
| 39 |
+
# Model comparison configs
|
| 40 |
+
BENCHMARK_ESM = True # True if you want to benchmark ESM-2-650M
|
| 41 |
+
BENCHMARK_PROTT5 = True # True if you want to benchmark ProtT5
|
| 42 |
+
BENCHMARK_FO_PUNCTA_ML = True # True if you want to benchmark FO-Puncta-ML from the FOdb paper
|
| 43 |
+
|
| 44 |
+
# Overwriting configs
|
| 45 |
+
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
|
| 46 |
+
|
| 47 |
+
# GPU configs
|
| 48 |
+
CUDA_VISIBLE_DEVICES="0" # GPUs to make visible for this process
|
| 49 |
+
```
|
| 50 |
+
<br>
|
| 51 |
+
|
| 52 |
+
`train.py` will train the XGBoost classifiers.
|
| 53 |
+
- All **results** are stored in `puncta/results/timestamp`, where `timestamp` is a unique string encoding the date and time when you started training.
|
| 54 |
+
- All **embeddings** made for training will be stored in a new folder called `puncta/embeddings/` with subfolders for each model. This allows you to use the same model multiple times without regenerating embeddings.
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
benchmarking/
|
| 58 |
+
└── puncta/
|
| 59 |
+
└── embeddings/
|
| 60 |
+
└── esm2_t33_650M_UR50D/...
|
| 61 |
+
└── fuson_plm/...
|
| 62 |
+
└── prot_t5_xl_half_uniref50_enc/...
|
| 63 |
+
└── results/
|
| 64 |
+
└── final/
|
| 65 |
+
└── figures/
|
| 66 |
+
├── cytoplasm_verificationFOs_barchart_source_data.csv
|
| 67 |
+
├── cytoplasm_verificationFOs_barchart.png
|
| 68 |
+
├── formation_verificationFOs_0.83thresh_barchart_source_data.csv
|
| 69 |
+
├── formation_verificationFOs_0.83thresh_barchart.png
|
| 70 |
+
├── nucleus_verificationFOs_barchart_source_data.csv
|
| 71 |
+
├── nucleus_verificationFOs_barchart.png
|
| 72 |
+
├── cytoplasm_verificationFOs_results.csv
|
| 73 |
+
├── formation_verificationFOs_0.83thresh_results.csv
|
| 74 |
+
├── nucleus_verificationFOs_results.csv
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
The following files are in `results/final/figures`:
|
| 78 |
+
- **`cytoplasm_verificationFOs_barchart.png`**: bar chart of performance on the cytoplasm puncta prediction task (Fig. 3E), and the formatted data that went directly into the plot (`cytoplasm_verificationFOs_barchart_source_data.csv`)
|
| 79 |
+
- **`formation_verificationFOs_0.83thresh_barchart.png`**: bar chart of performance on the puncta formation prediction task (Fig. 3C), and the formatted data that went directly into the plot (`formation_verificationFOs_0.83thresh_barchart_source_data.csv`)
|
| 80 |
+
- **`nucleus_verificationFOs_barchart.png`**: bar chart of performance on the nucleus puncta prediction task (Fig. 3D), and the formatted data that went directly into the plot (`nucleus_verificationFOs_barchart_source_data.csv`)
|
| 81 |
+
|
| 82 |
+
The raw data are included in `results/final` as `cytoplasm_verificationFOs_results.csv`, `formation_verificationFOs_0.83thresh_results.csv`, and `nucleus_verificationFOs_results.csv`.
|
| 83 |
+
|
| 84 |
+
If you train a new model, the equivalents of these files will be created in `results/timestamp` for your specific configurations set in `config.py`.
|
| 85 |
+
|
| 86 |
+
To run training, enter in terminal:
|
| 87 |
+
```
|
| 88 |
+
python train.py
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
To regnerate plots, run
|
| 92 |
+
```
|
| 93 |
+
python plot.py
|
| 94 |
+
```
|
| 95 |
+
|
fuson_plm/benchmarking/puncta/__init__.py
ADDED
|
File without changes
|
fuson_plm/benchmarking/puncta/clean.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cleans raw data to prepare FO labels and embeddings
|
| 2 |
+
from fuson_plm.utils.logging import open_logfile, log_update
|
| 3 |
+
from fuson_plm.utils.data_cleaning import find_invalid_chars
|
| 4 |
+
from fuson_plm.utils.constants import VALID_AAS
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pickle
|
| 8 |
+
|
| 9 |
+
def find_localization(row):
|
| 10 |
+
puncta_status = row['Puncta_Status']
|
| 11 |
+
cytoplasm = (row['Cytoplasm']=='Punctate')
|
| 12 |
+
nucleus = (row['Nucleus']=='Punctate')
|
| 13 |
+
both = cytoplasm and nucleus
|
| 14 |
+
|
| 15 |
+
if puncta_status=='YES':
|
| 16 |
+
if both:
|
| 17 |
+
return 'Both'
|
| 18 |
+
else:
|
| 19 |
+
if cytoplasm:
|
| 20 |
+
return 'Cytoplasm'
|
| 21 |
+
if nucleus:
|
| 22 |
+
return 'Nucleus'
|
| 23 |
+
return np.nan
|
| 24 |
+
|
| 25 |
+
def clean_s5(df):
|
| 26 |
+
log_update("Cleaning FOdb Supplementary Table 5")
|
| 27 |
+
|
| 28 |
+
# extract only the physicochemical features used by the FO-Puncta ML model
|
| 29 |
+
retained_features = df.loc[
|
| 30 |
+
df['Low MI Set: Used In ML Model'].isin(['Yes','Yet']) # allow flexibility for typo in this DF
|
| 31 |
+
]['Parameter Label (Sup Table 2 & Matlab Scripts)'].tolist()
|
| 32 |
+
retained_features = sorted(retained_features)
|
| 33 |
+
|
| 34 |
+
# log the result
|
| 35 |
+
log_update(f'\tIsolated the {len(retained_features)} low-MI features used to train ML model')
|
| 36 |
+
for i, feat in enumerate(retained_features): log_update(f'\t\t{i+1}. {feat}')
|
| 37 |
+
|
| 38 |
+
# return the result
|
| 39 |
+
return retained_features
|
| 40 |
+
|
| 41 |
+
def make_label_df(df):
|
| 42 |
+
"""
|
| 43 |
+
Input df should be cleaned s4
|
| 44 |
+
"""
|
| 45 |
+
label_df = df[['FO_Name','AAseq','Localization','Puncta_Status','Dataset']].rename(columns={'FO_Name':'fusiongene','AAseq':'aa_seq','Dataset':'dataset'})
|
| 46 |
+
dataset_to_split_dict = {'Expressed_Set': 'train', 'Verification_Set': 'test'}
|
| 47 |
+
label_df['split'] = label_df['dataset'].apply(lambda x: dataset_to_split_dict[x])
|
| 48 |
+
label_df['nucleus'] = label_df['Localization'].apply(lambda x: 1 if x in ['Nucleus','Both'] else 0)
|
| 49 |
+
label_df['cytoplasm'] = label_df['Localization'].apply(lambda x: 1 if x in ['Cytoplasm','Both'] else 0)
|
| 50 |
+
label_df['formation'] = label_df['Puncta_Status'].apply(lambda x: 1 if x=='YES' else 0)
|
| 51 |
+
label_df = label_df[['fusiongene','aa_seq','dataset','split','nucleus','cytoplasm','formation']]
|
| 52 |
+
|
| 53 |
+
return label_df
|
| 54 |
+
|
| 55 |
+
def make_embeddings(df, physicochemical_features):
|
| 56 |
+
feat_string = '\n\t' + '\n\t'.join([str(i)+'. '+feat for i,feat in enumerate(physicochemical_features)])
|
| 57 |
+
log_update(f"\nMaking phyisochemical feature vectors.\nFeature Order: {feat_string}")
|
| 58 |
+
embeddings = {}
|
| 59 |
+
aa_seqs = df['AAseq'].unique()
|
| 60 |
+
for seq in aa_seqs:
|
| 61 |
+
feats = df.loc[df['AAseq']==seq].reset_index(drop=True)[physicochemical_features].T[0].tolist()
|
| 62 |
+
embeddings[seq] = feats
|
| 63 |
+
|
| 64 |
+
return embeddings
|
| 65 |
+
|
| 66 |
+
def clean_s4(df, retained_features):
|
| 67 |
+
log_update("Cleaning FOdb Supplementary Table 4")
|
| 68 |
+
df = df.loc[
|
| 69 |
+
df['Puncta_Status'].isin(['YES','NO'])
|
| 70 |
+
].reset_index(drop=True)
|
| 71 |
+
log_update(f'\tRemoved invalid FOs (puncta status = "Other" or "Nucleolar"). Remaining FOs: {len(df)}')
|
| 72 |
+
|
| 73 |
+
# check for duplicate sequences
|
| 74 |
+
dup_seqs = df.loc[df['AAseq'].duplicated()]['AAseq'].unique()
|
| 75 |
+
log_update(f"\tTotal duplicated sequences: {len(dup_seqs)}")
|
| 76 |
+
|
| 77 |
+
# check for invalid characters
|
| 78 |
+
df['invalid_chars'] = df['AAseq'].apply(lambda x: find_invalid_chars(x, VALID_AAS))
|
| 79 |
+
all_invalid_chars = set().union(*df['invalid_chars'])
|
| 80 |
+
log_update(f"\tChecking for invalid characters...\n\t\tFound {len(all_invalid_chars)} invalid characters")
|
| 81 |
+
for c in all_invalid_chars:
|
| 82 |
+
subset = df.loc[df['AAseq'].str.contains(c)]['AAseq'].tolist()
|
| 83 |
+
for seq in subset:
|
| 84 |
+
log_update(f"\t\tInvalid char {c} at index {seq.index(c)}/{len(seq)-1} of sequence {seq}")
|
| 85 |
+
# going to just remove the "-" from the special sequence
|
| 86 |
+
df = df.drop(columns=['invalid_chars'])
|
| 87 |
+
df.loc[
|
| 88 |
+
df['AAseq'].str.contains('-'),'AAseq'
|
| 89 |
+
] = df.loc[df['AAseq'].str.contains('-'),'AAseq'].item().replace('-','')
|
| 90 |
+
|
| 91 |
+
# change FO format to ::
|
| 92 |
+
df['FO_Name'] = df['FO_Name'].apply(lambda x: x.replace('_','::'))
|
| 93 |
+
log_update(f'\tChanged FO names to Head::Tail format')
|
| 94 |
+
|
| 95 |
+
# Isolate positive and negative sets
|
| 96 |
+
df['Localization'] = ['']*len(df)
|
| 97 |
+
df['Localization'] = df.apply(lambda row: find_localization(row), axis=1)
|
| 98 |
+
puncta_positive = df.loc[
|
| 99 |
+
df['Puncta_Status']=='YES'
|
| 100 |
+
].reset_index(drop=True)
|
| 101 |
+
puncta_negative = df.loc[
|
| 102 |
+
df['Puncta_Status']=='NO'
|
| 103 |
+
].reset_index(drop=True)
|
| 104 |
+
|
| 105 |
+
# Only keeping retained features
|
| 106 |
+
cols = list(df.columns)
|
| 107 |
+
mi_feats_included = set(retained_features).intersection(set(cols))
|
| 108 |
+
log_update(f"\tChecking for the {len(retained_features)} low-MI features... {len(mi_feats_included)} found")
|
| 109 |
+
# make sure all of these are no-na
|
| 110 |
+
for rf in retained_features:
|
| 111 |
+
# if there's NaN, log it. Make sure the only instances of np.nan are for Verification Set FOs.
|
| 112 |
+
if df[rf].isna().sum()>0:
|
| 113 |
+
nas = df.loc[df[rf].isna()]
|
| 114 |
+
log_update(f"\t\tFeature {rf} has {len(nas)} np.nan values in the following datasets:")
|
| 115 |
+
for k,v in nas['Dataset'].value_counts().items():
|
| 116 |
+
print(f'\t\t\t{k}: {v}')
|
| 117 |
+
|
| 118 |
+
df = df[['FO_Name', 'Nucleus', 'Nucleolus', 'Cytoplasm','Puncta_Status', 'Dataset', 'Localization', 'AAseq',
|
| 119 |
+
'Puncta.pred', 'Puncta.prob']+retained_features]
|
| 120 |
+
|
| 121 |
+
# Quantify localization
|
| 122 |
+
log_update(f'\n\tPuncta localization for {len(puncta_positive)} FOs where Puncta_Status==YES')
|
| 123 |
+
for k, v in puncta_positive['Localization'].value_counts().items():
|
| 124 |
+
pcnt = 100*v/sum(puncta_positive['Localization'].value_counts())
|
| 125 |
+
log_update(f'\t\t{k}: \t{v} ({pcnt:.2f}%)')
|
| 126 |
+
|
| 127 |
+
log_update("\tDataset breakdown...")
|
| 128 |
+
dataset_vc = df['Dataset'].value_counts()
|
| 129 |
+
expressed_puncta_statuses = df.loc[df['Dataset']=='Expressed_Set']['Puncta_Status'].value_counts()
|
| 130 |
+
expressed_positive_locs = puncta_positive.loc[puncta_positive['Dataset']=='Expressed_Set']['Localization'].value_counts()
|
| 131 |
+
verification_positive_locs = puncta_positive.loc[puncta_positive['Dataset']=='Verification_Set']['Localization'].value_counts()
|
| 132 |
+
verification_puncta_statuses = df.loc[df['Dataset']=='Verification_Set']['Puncta_Status'].value_counts()
|
| 133 |
+
for k, v in dataset_vc.items():
|
| 134 |
+
pcnt = 100*v/sum(dataset_vc)
|
| 135 |
+
log_update(f'\t\t{k}: \t{v} ({pcnt:.2f}%)')
|
| 136 |
+
if k=='Expressed_Set':
|
| 137 |
+
for key, val in expressed_puncta_statuses.items():
|
| 138 |
+
pcnt = 100*val/v
|
| 139 |
+
log_update(f'\t\t\t{key}: \t{val} ({pcnt:.2f}%)')
|
| 140 |
+
if key=='YES':
|
| 141 |
+
log_update('\t\t\t\tLocalizations...')
|
| 142 |
+
for key2, val2 in expressed_positive_locs.items():
|
| 143 |
+
pcnt = 100*val2/val
|
| 144 |
+
log_update(f'\t\t\t\t\t{key2}: \t{val2} ({pcnt:.2f}%)')
|
| 145 |
+
if k=='Verification_Set':
|
| 146 |
+
for key, val in verification_puncta_statuses.items():
|
| 147 |
+
pcnt = 100*val/v
|
| 148 |
+
log_update(f'\t\t\t{key}: \t{val} ({pcnt:.2f}%)')
|
| 149 |
+
if key=='YES':
|
| 150 |
+
log_update('\t\t\t\tLocalizations...')
|
| 151 |
+
for key2, val2 in verification_positive_locs.items():
|
| 152 |
+
pcnt = 100*val2/val
|
| 153 |
+
log_update(f'\t\t\t\t\t{key2}: \t{val2} ({pcnt:.2f}%)')
|
| 154 |
+
|
| 155 |
+
return df
|
| 156 |
+
|
| 157 |
+
def main():
|
| 158 |
+
LOG_PATH = 'cleaning_log.txt'
|
| 159 |
+
FODB_S4_PATH = '../../data/raw_data/FOdb_puncta.csv'
|
| 160 |
+
FODB_S5_PATH = '../../data/raw_data/FOdb_SD5.csv'
|
| 161 |
+
|
| 162 |
+
with open_logfile(LOG_PATH):
|
| 163 |
+
s4 = pd.read_csv(FODB_S4_PATH)
|
| 164 |
+
s5 = pd.read_csv(FODB_S5_PATH)
|
| 165 |
+
|
| 166 |
+
retained_features = clean_s5(s5)
|
| 167 |
+
cleaned_s4 = clean_s4(s4, retained_features)
|
| 168 |
+
|
| 169 |
+
label_df = make_label_df(cleaned_s4)
|
| 170 |
+
embeddings = make_embeddings(cleaned_s4, retained_features)
|
| 171 |
+
|
| 172 |
+
# save the results
|
| 173 |
+
cleaned_s4.to_csv('cleaned_dataset_s4.csv', index=False)
|
| 174 |
+
log_update("\nSaved cleaned table S5 to cleaned_dataset_s4.csv")
|
| 175 |
+
|
| 176 |
+
label_df.to_csv('splits.csv', index=False)
|
| 177 |
+
log_update("\nSaved train-test splits with nucleus, cytoplasm, and formation labels to splits.csv")
|
| 178 |
+
|
| 179 |
+
with open('FOdb_physicochemical_embeddings.pkl','wb') as f:
|
| 180 |
+
pickle.dump(embeddings, f)
|
| 181 |
+
log_update("\nSaved physicochemical embeddings as a dictionary to FOdb_physicochemical_embeddings.pkl")
|
| 182 |
+
|
| 183 |
+
if __name__ == '__main__':
|
| 184 |
+
main()
|
fuson_plm/benchmarking/puncta/cleaned_dataset_s4.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f9075866f3296746c83eac61caf5c871e3a6dd54a2986896c9fd71a5a11511c
|
| 3 |
+
size 183523
|
fuson_plm/benchmarking/puncta/cleaning_log.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22677b05ff483b17390edc30e53555097e85ce7ac6aaa3cd04aece67d3963bc1
|
| 3 |
+
size 3356
|
fuson_plm/benchmarking/puncta/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Benchmarking configs
|
| 2 |
+
BENCHMARK_FUSONPLM = True # True if you want to benchmark a FusOn-pLM Model
|
| 3 |
+
|
| 4 |
+
# FUSONPLM_CKPTS. If you've traiend your own model, this is a dictionary: key = run name, values = epochs
|
| 5 |
+
# If you want to use the trained FusOn-pLM, instead FUSONPLM_CKPTS="FusOn-pLM"
|
| 6 |
+
FUSONPLM_CKPTS= "FusOn-pLM"
|
| 7 |
+
|
| 8 |
+
# Model comparison configs
|
| 9 |
+
BENCHMARK_ESM = True # True if you want to benchmark ESM-2-650M
|
| 10 |
+
BENCHMARK_PROTT5 = True # True if you want to benchmark ProtT5
|
| 11 |
+
BENCHMARK_FO_PUNCTA_ML = True # True if you want to benchmark FO-Puncta-ML from the FOdb paper
|
| 12 |
+
|
| 13 |
+
# Overwriting configs
|
| 14 |
+
PERMISSION_TO_OVERWRITE = False # if False, script will halt if it believes these embeddings have already been made.
|
| 15 |
+
|
| 16 |
+
# GPU configs
|
| 17 |
+
CUDA_VISIBLE_DEVICES="0" # GPUs to make visible for this process
|
fuson_plm/benchmarking/puncta/plot.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import matplotlib.patches as mpatches
|
| 3 |
+
import seaborn as sns
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
import matplotlib.colors as mcolors
|
| 8 |
+
from fuson_plm.utils.visualizing import set_font
|
| 9 |
+
|
| 10 |
+
fo_puncta_db_training_thresh31 = pd.DataFrame(data={
|
| 11 |
+
'Model Type': ['fo_puncta_ml'],
|
| 12 |
+
'Model Name': ['fo_puncta_ml_literature'],
|
| 13 |
+
'Model Epoch': np.nan,
|
| 14 |
+
'Accuracy': 0.81,
|
| 15 |
+
'Precision': 0.78,
|
| 16 |
+
'Recall': 0.98,
|
| 17 |
+
'F1 Score': 0.87,
|
| 18 |
+
'AUROC': 0.88,
|
| 19 |
+
'AUPRC': 0.94
|
| 20 |
+
})
|
| 21 |
+
|
| 22 |
+
fo_puncta_db_verification_thresh83 = pd.DataFrame(data={
|
| 23 |
+
'Model Type': ['fo_puncta_ml'],
|
| 24 |
+
'Model Name': ['fo_puncta_ml_literature'],
|
| 25 |
+
'Model Epoch': np.nan,
|
| 26 |
+
'Accuracy': 0.79,
|
| 27 |
+
'Precision': 0.81,
|
| 28 |
+
'Recall': 0.89,
|
| 29 |
+
'F1 Score': 0.85,
|
| 30 |
+
'AUROC': 0.73,
|
| 31 |
+
'AUPRC': 0.82
|
| 32 |
+
})
|
| 33 |
+
|
| 34 |
+
# Method for lengthening the model name
|
| 35 |
+
def lengthen_model_name(row):
|
| 36 |
+
name = row['Model Name']
|
| 37 |
+
epoch = row['Model Epoch']
|
| 38 |
+
|
| 39 |
+
if 'esm' in name:
|
| 40 |
+
return name
|
| 41 |
+
if 'puncta' in name:
|
| 42 |
+
return name
|
| 43 |
+
|
| 44 |
+
return f'{name}_e{epoch}'
|
| 45 |
+
|
| 46 |
+
# Method for shortening the model name for display
|
| 47 |
+
def shorten_model_name(row):
|
| 48 |
+
name = row['Model Name']
|
| 49 |
+
epoch = row['Model Epoch']
|
| 50 |
+
|
| 51 |
+
if 'esm' in name:
|
| 52 |
+
return 'ESM-2-650M'
|
| 53 |
+
if name=='fo_puncta_ml':
|
| 54 |
+
return 'FO-Puncta-ML'
|
| 55 |
+
if name=='fo_puncta_ml_literature':
|
| 56 |
+
return 'FO-Puncta-ML Lit'
|
| 57 |
+
if name=="prot_t5_xl_half_uniref50_enc":
|
| 58 |
+
return 'ProtT5-XL-U50' # this is waht they call it in the paper
|
| 59 |
+
|
| 60 |
+
if 'snp_' in name:
|
| 61 |
+
prob_type = 'snp'
|
| 62 |
+
elif 'uniform_' in name:
|
| 63 |
+
prob_type = 'uni'
|
| 64 |
+
|
| 65 |
+
layers = name.split('layers')[0].split('_')[-1]
|
| 66 |
+
dt = name.split('mask')[1].split('-', 1)[1]
|
| 67 |
+
|
| 68 |
+
return f'{prob_type}_{layers}L_{dt}_e{epoch}'
|
| 69 |
+
|
| 70 |
+
def make_final_bar(dataframe, title, save_path):
|
| 71 |
+
set_font()
|
| 72 |
+
df = dataframe.copy(deep=True)
|
| 73 |
+
|
| 74 |
+
# Pivot the DataFrame to have metrics as rows and names as columns, and reorder columns
|
| 75 |
+
pivot_df = df.pivot(index='Metric', columns='Name', values='Value')
|
| 76 |
+
ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns]
|
| 77 |
+
pivot_df = pivot_df[ordered_columns]
|
| 78 |
+
|
| 79 |
+
# Define the groups
|
| 80 |
+
engineered_embeddings = ['FOdb']
|
| 81 |
+
deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM']
|
| 82 |
+
|
| 83 |
+
# Reorder the metrics
|
| 84 |
+
metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1]
|
| 85 |
+
pivot_df = pivot_df.reindex(metric_order)
|
| 86 |
+
|
| 87 |
+
# Plotting
|
| 88 |
+
fig, ax = plt.subplots(figsize=(8, 6), dpi=300) # Increased figure size for better legend placement
|
| 89 |
+
|
| 90 |
+
# Define bar width and positions
|
| 91 |
+
bar_width = 0.2
|
| 92 |
+
indices = np.arange(len(pivot_df))
|
| 93 |
+
|
| 94 |
+
# Use a colorblind-friendly color scheme from tableau
|
| 95 |
+
color_map = {
|
| 96 |
+
#'One-Hot': "#999999",
|
| 97 |
+
'FOdb': "#E69F00",
|
| 98 |
+
'ESM-2-650M': "#F0E442",
|
| 99 |
+
'FusOn-pLM': "#FF69B4",
|
| 100 |
+
'ProtT5-XL-U50': "#00ccff" # light blue
|
| 101 |
+
}
|
| 102 |
+
colors = [color_map[col] for col in ordered_columns]
|
| 103 |
+
|
| 104 |
+
# Plot bars for each category and add them to appropriate legend groups
|
| 105 |
+
engineered_handles = []
|
| 106 |
+
deep_learning_handles = []
|
| 107 |
+
for i, (name, color) in enumerate(zip(pivot_df.columns, colors)):
|
| 108 |
+
bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color)
|
| 109 |
+
if name in engineered_embeddings:
|
| 110 |
+
engineered_handles.append(bars[0])
|
| 111 |
+
else:
|
| 112 |
+
deep_learning_handles.append(bars[0])
|
| 113 |
+
|
| 114 |
+
# Add bold black asterisks next to the winning bars for each category (could be multiple)
|
| 115 |
+
#for j, metric in enumerate(pivot_df.index):
|
| 116 |
+
# max_value = pivot_df.loc[metric].max()
|
| 117 |
+
# max_indices = pivot_df.loc[metric][pivot_df.loc[metric] == max_value].index
|
| 118 |
+
# for max_name in max_indices:
|
| 119 |
+
# max_index = list(pivot_df.columns).index(max_name)
|
| 120 |
+
# ax.text(max_value + 0.01, j + max_index * bar_width - bar_width / 4, '*',
|
| 121 |
+
# color='black', fontsize=12, fontweight='bold', ha='center', va='center')
|
| 122 |
+
|
| 123 |
+
# Set labels, ticks, and title
|
| 124 |
+
plt.xlabel('Value', fontsize=44) # Adjusted font size
|
| 125 |
+
ax.set_yticks(indices + bar_width * 1.5)
|
| 126 |
+
ax.set_xlim([0, 1])
|
| 127 |
+
ax.set_yticklabels(pivot_df.index)
|
| 128 |
+
# make the xticklabels size 24
|
| 129 |
+
ax.tick_params(axis='x')
|
| 130 |
+
ax.set_title(title, fontsize=44) # Adjusted font size
|
| 131 |
+
|
| 132 |
+
# Setting font size for tick labels
|
| 133 |
+
for label in plt.gca().get_xticklabels():
|
| 134 |
+
label.set_fontsize(32) # Adjusted font size
|
| 135 |
+
for label in plt.gca().get_yticklabels():
|
| 136 |
+
label.set_fontsize(32) # Adjusted font size
|
| 137 |
+
|
| 138 |
+
# Create two separate legends
|
| 139 |
+
if engineered_handles:
|
| 140 |
+
legend1 = fig.legend(
|
| 141 |
+
engineered_handles[::-1],
|
| 142 |
+
[emb for emb in engineered_embeddings if emb in ordered_columns][::-1],
|
| 143 |
+
loc='center left',
|
| 144 |
+
bbox_to_anchor=(1, 0.4),
|
| 145 |
+
title="Engineered Embeddings",
|
| 146 |
+
title_fontsize=24) # Adjusted font size
|
| 147 |
+
if deep_learning_handles:
|
| 148 |
+
legend2 = fig.legend(
|
| 149 |
+
deep_learning_handles[::-1],
|
| 150 |
+
[emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1],
|
| 151 |
+
loc='center left',
|
| 152 |
+
bbox_to_anchor=(1, 0.6),
|
| 153 |
+
title="Learned Embeddings",
|
| 154 |
+
title_fontsize=24) # Adjusted font size
|
| 155 |
+
|
| 156 |
+
# Adjust legend text size
|
| 157 |
+
if engineered_handles:
|
| 158 |
+
ax.add_artist(legend1)
|
| 159 |
+
for text in legend1.get_texts():
|
| 160 |
+
text.set_fontsize(22) # Adjusted font size
|
| 161 |
+
for handle in legend1.legendHandles:
|
| 162 |
+
if isinstance(handle, mpatches.Patch):
|
| 163 |
+
handle.set_height(15) # Adjust height
|
| 164 |
+
handle.set_width(20) # Adjust width
|
| 165 |
+
elif hasattr(handle, '_sizes'):
|
| 166 |
+
handle._sizes = [200] # Increase marker size in the legend
|
| 167 |
+
|
| 168 |
+
if deep_learning_handles:
|
| 169 |
+
ax.add_artist(legend2)
|
| 170 |
+
for text in legend2.get_texts():
|
| 171 |
+
text.set_fontsize(22) # Adjusted font size
|
| 172 |
+
for handle in legend2.legendHandles:
|
| 173 |
+
if isinstance(handle, mpatches.Patch):
|
| 174 |
+
handle.set_height(15) # Adjust height
|
| 175 |
+
handle.set_width(20) # Adjust width
|
| 176 |
+
elif hasattr(handle, '_sizes'):
|
| 177 |
+
handle._sizes = [200] # Increase marker size in the legend
|
| 178 |
+
|
| 179 |
+
plt.tight_layout() # Adjust layout to make room for the legends
|
| 180 |
+
|
| 181 |
+
# Save the plot to a file
|
| 182 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 183 |
+
|
| 184 |
+
plt.show()
|
| 185 |
+
|
| 186 |
+
def prepare_data_for_bar(results_dir, task, split, thresh=None):
|
| 187 |
+
fname = f"{task}_{split}FOs_results.csv"
|
| 188 |
+
if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv"
|
| 189 |
+
image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png'
|
| 190 |
+
|
| 191 |
+
data = pd.read_csv(f"{results_dir}/{fname}")
|
| 192 |
+
data = data.loc[
|
| 193 |
+
data['Model Name'].isin(['best',
|
| 194 |
+
'fo_puncta_ml',
|
| 195 |
+
'esm2_t33_650M_UR50D',
|
| 196 |
+
'prot_t5_xl_half_uniref50_enc'])
|
| 197 |
+
]
|
| 198 |
+
data = pd.DataFrame(data = {
|
| 199 |
+
'Name': data['Model Name'].tolist() * 5,
|
| 200 |
+
'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy',
|
| 201 |
+
'Precision', 'Precision', 'Precision', 'Precision',
|
| 202 |
+
'Recall', 'Recall', 'Recall', 'Recall',
|
| 203 |
+
'F1', 'F1', 'F1','F1',
|
| 204 |
+
'AUROC', 'AUROC', 'AUROC','AUROC'],
|
| 205 |
+
'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist()
|
| 206 |
+
}
|
| 207 |
+
)
|
| 208 |
+
rename_dict = {'fo_puncta_ml': 'FOdb',
|
| 209 |
+
'esm2_t33_650M_UR50D':'ESM-2-650M',
|
| 210 |
+
'best':'FusOn-pLM',
|
| 211 |
+
'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'}
|
| 212 |
+
data['Name'] = data['Name'].map(rename_dict)
|
| 213 |
+
return data, image_save_path
|
| 214 |
+
|
| 215 |
+
def make_all_final_bar_charts(results_dir):
|
| 216 |
+
# Puncta verification
|
| 217 |
+
data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83)
|
| 218 |
+
data_cp = data.copy(deep=True)
|
| 219 |
+
data_cp["Value"] = data_cp["Value"].round(3)
|
| 220 |
+
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
|
| 221 |
+
make_final_bar(data, "Puncta Propensity", image_save_path)
|
| 222 |
+
|
| 223 |
+
# Nucleus verification
|
| 224 |
+
data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None)
|
| 225 |
+
data_cp = data.copy(deep=True)
|
| 226 |
+
data_cp["Value"] = data_cp["Value"].round(3)
|
| 227 |
+
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
|
| 228 |
+
make_final_bar(data, "Nucleus Localization", image_save_path)
|
| 229 |
+
|
| 230 |
+
# Cytoplasm verification
|
| 231 |
+
data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None)
|
| 232 |
+
data_cp = data.copy(deep=True)
|
| 233 |
+
data_cp["Value"] = data_cp["Value"].round(3)
|
| 234 |
+
data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False)
|
| 235 |
+
make_final_bar(data, "Cytoplasm Localization", image_save_path)
|
| 236 |
+
|
| 237 |
+
def main():
|
| 238 |
+
# Read in the input data
|
| 239 |
+
results_dir="results/final"
|
| 240 |
+
os.makedirs(f"{results_dir}/figures",exist_ok=True)
|
| 241 |
+
make_all_final_bar_charts(results_dir)
|
| 242 |
+
|
| 243 |
+
if __name__ == '__main__':
|
| 244 |
+
main()
|
fuson_plm/benchmarking/puncta/results/final/cytoplasm_verificationFOs_results.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:800f935f72b089b357fb4b0ac22a4c75a09a4578e44fac2c20a297c60c76df76
|
| 3 |
+
size 871
|
fuson_plm/benchmarking/puncta/results/final/figures/cytoplasm_verificationFOs_barchart.png
ADDED
|
fuson_plm/benchmarking/puncta/results/final/figures/cytoplasm_verificationFOs_barchart_source_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06aa241a68bff40ae38cd6d484c4ff3ebf4d8613fb0e671576a3f07b6977dbda
|
| 3 |
+
size 470
|
fuson_plm/benchmarking/puncta/results/final/figures/formation_verificationFOs_0.83thresh_barchart.png
ADDED
|
fuson_plm/benchmarking/puncta/results/final/figures/formation_verificationFOs_0.83thresh_barchart_source_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:85ad0497edcb0438fafe20d2807afb694114bdc3a73401ca0ed6b739baca1603
|
| 3 |
+
size 472
|
fuson_plm/benchmarking/puncta/results/final/figures/nucleus_verificationFOs_barchart.png
ADDED
|
fuson_plm/benchmarking/puncta/results/final/figures/nucleus_verificationFOs_barchart_source_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73697291e6f1d8036fd089babbde87e39c30d040e98a2c20d71dfb202925e316
|
| 3 |
+
size 472
|
fuson_plm/benchmarking/puncta/results/final/formation_verificationFOs_0.83thresh_results.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:72c68d45ca772a2bded7473803767c12dbafa4bac09bc10aed70a075c386682c
|
| 3 |
+
size 888
|
fuson_plm/benchmarking/puncta/results/final/nucleus_verificationFOs_results.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37d6fc7ec393c48756c286e01ddb942b8b98b03564f22a099d01e2bd537f33ca
|
| 3 |
+
size 887
|
fuson_plm/benchmarking/puncta/splits.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:44f627efa4f76a35b2a4a83be77ae8815c7b728e3ca2ca5127d8177789127f7e
|
| 3 |
+
size 133807
|
fuson_plm/benchmarking/puncta/train.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pickle
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from fuson_plm.benchmarking.xgboost_predictor import train_final_predictor, evaluate_predictor, train_predictor_xval
|
| 9 |
+
from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark
|
| 10 |
+
import fuson_plm.benchmarking.puncta.config as config
|
| 11 |
+
from fuson_plm.benchmarking.puncta.plot import make_all_final_bar_charts
|
| 12 |
+
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy, get_local_time, CustomParams
|
| 13 |
+
|
| 14 |
+
def check_splits(df):
|
| 15 |
+
# make sure everything has a split
|
| 16 |
+
if len(df.loc[df['split'].isna()])>0:
|
| 17 |
+
raise Exception("Error: not every benchmarking sequence has been allocated to a split (train or test)")
|
| 18 |
+
# make sure the only things are train and test
|
| 19 |
+
if len({'train','test'} - set(df['split'].unique()))!=0:
|
| 20 |
+
raise Exception("Error: splits column should only have \'train\' and \'test\'.")
|
| 21 |
+
# make sure there are no duplicate sequences
|
| 22 |
+
if len(df.loc[df['aa_seq'].duplicated()])>0:
|
| 23 |
+
raise Exception("Error: duplicate sequences provided")
|
| 24 |
+
|
| 25 |
+
def train_and_evaluate_puncta_predictor(details, splits_with_embeddings,outdir,task='nucleus',class1_thresh=0.5,n_estimators=50,tree_method="hist"):
|
| 26 |
+
"""
|
| 27 |
+
task = 'nucleus', 'cytoplasm', or 'formation'
|
| 28 |
+
"""
|
| 29 |
+
# unpack the details dictioanry
|
| 30 |
+
benchmark_model_type = details['model_type']
|
| 31 |
+
benchmark_model_name = details['model']
|
| 32 |
+
benchmark_model_epoch = details['epoch']
|
| 33 |
+
|
| 34 |
+
# prepare train and test sets for model
|
| 35 |
+
train_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='train'].reset_index(drop=True)
|
| 36 |
+
test_split = splits_with_embeddings.loc[splits_with_embeddings['split']=='test'].reset_index(drop=True)
|
| 37 |
+
|
| 38 |
+
X_train = np.array(train_split['embedding'].tolist())
|
| 39 |
+
y_train = np.array(train_split[task].tolist())
|
| 40 |
+
X_test = np.array(test_split['embedding'].tolist())
|
| 41 |
+
y_test = np.array(test_split[task].tolist())
|
| 42 |
+
|
| 43 |
+
# Train the final model on all the data
|
| 44 |
+
clf = train_final_predictor(X_train, y_train, n_estimators=n_estimators, tree_method=tree_method)
|
| 45 |
+
|
| 46 |
+
# Evaluate it
|
| 47 |
+
automatic_stats_df, custom_stats_df = evaluate_predictor(clf, X_test, y_test, class1_thresh=class1_thresh)
|
| 48 |
+
|
| 49 |
+
# Add the model details back in
|
| 50 |
+
cols = list(automatic_stats_df.columns)
|
| 51 |
+
automatic_stats_df['Model Type'] = [benchmark_model_type]
|
| 52 |
+
automatic_stats_df['Model Name'] = [benchmark_model_name]
|
| 53 |
+
automatic_stats_df['Model Epoch'] = [benchmark_model_epoch]
|
| 54 |
+
newcols = ['Model Type','Model Name','Model Epoch'] + cols
|
| 55 |
+
automatic_stats_df = automatic_stats_df[newcols]
|
| 56 |
+
|
| 57 |
+
cols = list(custom_stats_df.columns)
|
| 58 |
+
custom_stats_df['Model Type'] = [benchmark_model_type]
|
| 59 |
+
custom_stats_df['Model Name'] = [benchmark_model_name]
|
| 60 |
+
custom_stats_df['Model Epoch'] = [benchmark_model_epoch]
|
| 61 |
+
newcols = ['Model Type','Model Name','Model Epoch'] + cols
|
| 62 |
+
custom_stats_df = custom_stats_df[newcols]
|
| 63 |
+
|
| 64 |
+
# Save automatic results (for nucleus and cytoplasm)
|
| 65 |
+
if task!="formation":
|
| 66 |
+
automatic_stats_path = f'{outdir}/{task}_verificationFOs_results.csv'
|
| 67 |
+
if not(os.path.exists(automatic_stats_path)):
|
| 68 |
+
automatic_stats_df.to_csv(automatic_stats_path,index=False)
|
| 69 |
+
else:
|
| 70 |
+
automatic_stats_df.to_csv(automatic_stats_path,mode='a',index=False,header=False)
|
| 71 |
+
|
| 72 |
+
# Save custom threshold results (only if it's formation)
|
| 73 |
+
if task=="formation":
|
| 74 |
+
custom_stats_path = f'{outdir}/{task}_verificationFOs_{class1_thresh}thresh_results.csv'
|
| 75 |
+
if not(os.path.exists(custom_stats_path)):
|
| 76 |
+
custom_stats_df.to_csv(custom_stats_path,index=False)
|
| 77 |
+
else:
|
| 78 |
+
custom_stats_df.to_csv(custom_stats_path,mode='a',index=False,header=False)
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
# make output directory for this run
|
| 82 |
+
os.makedirs('results',exist_ok=True)
|
| 83 |
+
output_dir = f'results/{get_local_time()}'
|
| 84 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 85 |
+
|
| 86 |
+
with open_logfile(f'{output_dir}/puncta_benchmark_log.txt'):
|
| 87 |
+
# print configurations
|
| 88 |
+
print_configpy(config)
|
| 89 |
+
|
| 90 |
+
# Verify that the environment variables are set correctly
|
| 91 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
| 92 |
+
log_update("\nChecking on environment variables...")
|
| 93 |
+
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
| 94 |
+
|
| 95 |
+
# make embeddings if needed
|
| 96 |
+
all_embedding_paths = embed_dataset_for_benchmark(
|
| 97 |
+
fuson_ckpts=config.FUSONPLM_CKPTS,
|
| 98 |
+
input_data_path='splits.csv', input_fname='FOdb_puncta_sequences',
|
| 99 |
+
average=True, seq_col='aa_seq',
|
| 100 |
+
benchmark_fusonplm=config.BENCHMARK_FUSONPLM,
|
| 101 |
+
benchmark_esm=config.BENCHMARK_ESM,
|
| 102 |
+
benchmark_fo_puncta_ml=config.BENCHMARK_FO_PUNCTA_ML,
|
| 103 |
+
benchmark_prott5 = config.BENCHMARK_PROTT5,
|
| 104 |
+
overwrite=config.PERMISSION_TO_OVERWRITE)
|
| 105 |
+
|
| 106 |
+
# load the splits with labels
|
| 107 |
+
splits = pd.read_csv('splits.csv')
|
| 108 |
+
# perform some sanity checks on the splits
|
| 109 |
+
check_splits(splits)
|
| 110 |
+
n_train = len(splits.loc[splits['split']=='train'])
|
| 111 |
+
n_test = len(splits.loc[splits['split']=='test'])
|
| 112 |
+
log_update(f"\nSplit breakdown...\n\t{n_train} Training FOs\n\t{n_test} Verification FOs")
|
| 113 |
+
|
| 114 |
+
# set training constants
|
| 115 |
+
train_params = CustomParams(
|
| 116 |
+
N_ESTIMATORS = 50,
|
| 117 |
+
TREE_METHOD = "hist",
|
| 118 |
+
CLASS1_THRESHOLDS = {
|
| 119 |
+
'nucleus': 0.83,
|
| 120 |
+
'cytoplasm': 0.83,
|
| 121 |
+
'formation': 0.83
|
| 122 |
+
},
|
| 123 |
+
)
|
| 124 |
+
log_update("\nTraining configs:")
|
| 125 |
+
train_params.print_config(indent='\t')
|
| 126 |
+
|
| 127 |
+
log_update("\nTraining models")
|
| 128 |
+
# loop through the embedding paths and train each one
|
| 129 |
+
for embedding_path, details in all_embedding_paths.items():
|
| 130 |
+
log_update(f"\tBenchmarking embeddings at: {embedding_path}")
|
| 131 |
+
try:
|
| 132 |
+
with open(embedding_path, "rb") as f:
|
| 133 |
+
embeddings = pickle.load(f)
|
| 134 |
+
except:
|
| 135 |
+
raise Exception(f"Cannot read embeddings from {embedding_path}")
|
| 136 |
+
|
| 137 |
+
# combine the embeddings and splits into one dataframe
|
| 138 |
+
splits_with_embeddings = pd.DataFrame.from_dict(embeddings.items())
|
| 139 |
+
splits_with_embeddings = splits_with_embeddings.rename(columns={0: 'aa_seq', 1: 'embedding'})
|
| 140 |
+
splits_with_embeddings = pd.merge(splits_with_embeddings, splits, on='aa_seq',how='left')
|
| 141 |
+
|
| 142 |
+
for task in ['nucleus','cytoplasm','formation']:
|
| 143 |
+
log_update(f"\t\tTask: {task}")
|
| 144 |
+
train_and_evaluate_puncta_predictor(details, splits_with_embeddings, output_dir, task=task,
|
| 145 |
+
class1_thresh=train_params.CLASS1_THRESHOLDS[task],
|
| 146 |
+
n_estimators=train_params.N_ESTIMATORS,tree_method=train_params.TREE_METHOD)
|
| 147 |
+
|
| 148 |
+
log_update(f"\nMaking summary figures:\n")
|
| 149 |
+
log_update(f"\tbar charts...")
|
| 150 |
+
os.makedirs(f"{output_dir}/figures",exist_ok=True)
|
| 151 |
+
make_all_final_bar_charts(output_dir)
|
| 152 |
+
log_update(f"\tDone.")
|
| 153 |
+
|
| 154 |
+
if __name__ == '__main__':
|
| 155 |
+
main()
|