root
data cleaning, blast, and splitting code with source data, also deleting unnecessary files
6efd653
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from scipy.stats import entropy | |
| from sklearn.manifold import TSNE | |
| import pickle | |
| import pandas as pd | |
| import os | |
| from fuson_plm.utils.logging import log_update | |
| from fuson_plm.utils.visualizing import set_font, visualize_splits | |
| def main(): | |
| set_font() | |
| train_clusters = pd.read_csv('splits/train_cluster_split.csv') | |
| val_clusters = pd.read_csv('splits/val_cluster_split.csv') | |
| test_clusters = pd.read_csv('splits/test_cluster_split.csv') | |
| clusters = pd.concat([train_clusters,val_clusters,test_clusters]) | |
| fuson_db = pd.read_csv('fuson_db.csv') | |
| # Get the sequence IDs of all clustered benchmark sequences. | |
| benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] | |
| # Use benchmark_seq_ids to find which clusters contain benchmark sequences. | |
| benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() | |
| visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps) | |
| ## Add seq_id to every source data file that is saved from visualize_splits | |
| seq_to_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) | |
| files_to_edit = os.listdir("splits/split_vis") | |
| files_to_edit = [x for x in files_to_edit if x[-4::]==".csv"] | |
| log_update(f"Adding seq_ids to the following files: {files_to_edit}") | |
| for fname in files_to_edit: | |
| source_data_file = pd.read_csv(f"splits/split_vis/{fname}") | |
| if "sequence" in list(source_data_file.columns): | |
| source_data_file["seq_id"] = source_data_file["sequence"].map(seq_to_id_dict) | |
| source_data_file.drop(columns=['sequence']).to_csv(f"splits/split_vis/{fname}",index=False) | |
| if __name__ == "__main__": | |
| main() |