import argparse import numpy as np import pandas as pd from joblib import dump, load from sklearn.model_selection import train_test_split from tqdm.auto import tqdm import utils from pathlib import Path SEED = 420 np.random.seed(SEED) def read_dataset(splicing_analysis_csv_path, filter_cryptic_restriction_site=True): barcode_statistics = pd.read_csv(splicing_analysis_csv_path).set_index("barcode") barcode_statistics = barcode_statistics[ barcode_statistics.badly_coupled == False ] # remove badly coupled barcodes # Filter barcodes containing restriction site, as those contain artifacts if filter_cryptic_restriction_site: contains_restriction_site = barcode_statistics.apply( lambda x: utils.contains_Esp3I_site(utils.add_flanking(x.exon, 5)) or utils.contains_Esp3I_site(utils.add_barcode_flanking(x.name, 5)), axis=1, ) barcode_statistics = barcode_statistics[~contains_restriction_site] return barcode_statistics def to_input_data(df, flanking_length=10): assert flanking_length <= 30 and flanking_length >= 0 return utils.create_input_data( [utils.add_flanking(exon, flanking_length) for exon in df.exon] ) def to_target_data(df): return np.array( df.num_exon_inclusion / (df.num_exon_inclusion + df.num_exon_skipping) ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--input_folder", required=True, type=str, help="Input folder") args = parser.parse_args() data_files = [ "BS11504A_S1_splicing_analysis.csv", "BS11505A_S2_splicing_analysis.csv", "BS11506A_S3_splicing_analysis.csv", ] data_folder = args.input_folder splicing_analysis_csvs = [ a for b in [ list(Path(data_folder).rglob(f"*{data_file}")) for data_file in data_files ] for a in b ] print('Reading datasets... ', end='') datasets = [read_dataset(d) for d in splicing_analysis_csvs] print('Done.') numeric_columns = np.unique( [e for d in datasets for e in d.columns.values if "num" in e] ) non_numeric_columns = np.unique( [e for d in datasets for e in d.columns.values if e not in numeric_columns] ) d_numeric = sum([d[numeric_columns] for d in datasets]) dataset = (datasets[0][non_numeric_columns]).join(d_numeric).dropna() # add statistics to dataset dataset["others"] = ( dataset.num_unknown_splicing + dataset.num_intron_retention + dataset.num_bad_reads + dataset.num_bad_exon1 ) dataset["total"] = ( dataset.others + dataset.num_exon_skipping + dataset.num_exon_inclusion + dataset.num_splicing_in_exon ) # filter exons with too few reads MIN_READS = 60 dataset = dataset[ dataset.num_exon_skipping + dataset.num_exon_inclusion >= MIN_READS ] # Also, we want inclusion and skipping to be at least 80% of the total reads; # this gets rid of splice sites inside exon dataset = dataset[ (dataset.num_exon_inclusion + dataset.num_exon_skipping) / dataset.total > 0.8 ] # split dataset TEST_SPLIT_FRACTION = 0.2 dataset_tr, dataset_te = train_test_split( dataset, test_size=TEST_SPLIT_FRACTION, train_size=1 - TEST_SPLIT_FRACTION, random_state=SEED, ) # create datasets print('Computing structure, one-hot-encoding... ', end='') xTr = to_input_data(dataset_tr) yTr = to_target_data(dataset_tr) xTe = to_input_data(dataset_te) yTe = to_target_data(dataset_te) print('Done.') data_dump_list = [xTr, yTr, xTe, yTe, dataset_tr, dataset_te] dataset_names = [ "xTr", "yTr", "xTe", "yTe", "barcode_statistics_train", "barcode_statistics_test", ] print('Dumping preprocessed data to disk... ', end='') for D, Dn in tqdm( zip(data_dump_list, dataset_names), leave=False, total=len(data_dump_list) ): dump(D, Path(data_folder) / f"{Dn}_ES7_HeLa_ABC.pkl.gz") print('Done.')