Spaces:
Running
Running
File size: 4,188 Bytes
6766437 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import numpy as np
import pandas as pd
from joblib import dump, load
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import utils
from pathlib import Path
SEED = 420
np.random.seed(SEED)
def read_dataset(splicing_analysis_csv_path, filter_cryptic_restriction_site=True):
barcode_statistics = pd.read_csv(splicing_analysis_csv_path).set_index("barcode")
barcode_statistics = barcode_statistics[
barcode_statistics.badly_coupled == False
] # remove badly coupled barcodes
# Filter barcodes containing restriction site, as those contain artifacts
if filter_cryptic_restriction_site:
contains_restriction_site = barcode_statistics.apply(
lambda x: utils.contains_Esp3I_site(utils.add_flanking(x.exon, 5))
or utils.contains_Esp3I_site(utils.add_barcode_flanking(x.name, 5)),
axis=1,
)
barcode_statistics = barcode_statistics[~contains_restriction_site]
return barcode_statistics
def to_input_data(df, flanking_length=10):
assert flanking_length <= 30 and flanking_length >= 0
return utils.create_input_data(
[utils.add_flanking(exon, flanking_length) for exon in df.exon]
)
def to_target_data(df):
return np.array(
df.num_exon_inclusion / (df.num_exon_inclusion + df.num_exon_skipping)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_folder", required=True, type=str, help="Input folder")
args = parser.parse_args()
data_files = [
"BS11504A_S1_splicing_analysis.csv",
"BS11505A_S2_splicing_analysis.csv",
"BS11506A_S3_splicing_analysis.csv",
]
data_folder = args.input_folder
splicing_analysis_csvs = [
a
for b in [
list(Path(data_folder).rglob(f"*{data_file}")) for data_file in data_files
]
for a in b
]
print('Reading datasets... ', end='')
datasets = [read_dataset(d) for d in splicing_analysis_csvs]
print('Done.')
numeric_columns = np.unique(
[e for d in datasets for e in d.columns.values if "num" in e]
)
non_numeric_columns = np.unique(
[e for d in datasets for e in d.columns.values if e not in numeric_columns]
)
d_numeric = sum([d[numeric_columns] for d in datasets])
dataset = (datasets[0][non_numeric_columns]).join(d_numeric).dropna()
# add statistics to dataset
dataset["others"] = (
dataset.num_unknown_splicing
+ dataset.num_intron_retention
+ dataset.num_bad_reads
+ dataset.num_bad_exon1
)
dataset["total"] = (
dataset.others
+ dataset.num_exon_skipping
+ dataset.num_exon_inclusion
+ dataset.num_splicing_in_exon
)
# filter exons with too few reads
MIN_READS = 60
dataset = dataset[
dataset.num_exon_skipping + dataset.num_exon_inclusion >= MIN_READS
]
# Also, we want inclusion and skipping to be at least 80% of the total reads;
# this gets rid of splice sites inside exon
dataset = dataset[
(dataset.num_exon_inclusion + dataset.num_exon_skipping) / dataset.total > 0.8
]
# split dataset
TEST_SPLIT_FRACTION = 0.2
dataset_tr, dataset_te = train_test_split(
dataset,
test_size=TEST_SPLIT_FRACTION,
train_size=1 - TEST_SPLIT_FRACTION,
random_state=SEED,
)
# create datasets
print('Computing structure, one-hot-encoding... ', end='')
xTr = to_input_data(dataset_tr)
yTr = to_target_data(dataset_tr)
xTe = to_input_data(dataset_te)
yTe = to_target_data(dataset_te)
print('Done.')
data_dump_list = [xTr, yTr, xTe, yTe, dataset_tr, dataset_te]
dataset_names = [
"xTr",
"yTr",
"xTe",
"yTe",
"barcode_statistics_train",
"barcode_statistics_test",
]
print('Dumping preprocessed data to disk... ', end='')
for D, Dn in tqdm(
zip(data_dump_list, dataset_names), leave=False, total=len(data_dump_list)
):
dump(D, Path(data_folder) / f"{Dn}_ES7_HeLa_ABC.pkl.gz")
print('Done.')
|