File size: 5,586 Bytes
8081b08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import multiprocessing as mp
import os
import sys
import traceback

import loompy
import numpy as np
import pandas as pd
import scanpy as sc
import scrublet
from loguru import logger


def get_filename() -> str:
    current_file_name = os.path.basename(__file__)
    log_name = "{}.log"
    return log_name.format(current_file_name.split('.')[0])


def set_log() -> None:
    filename = get_filename()
    logger.add(f'../log/{filename}')


set_log()


gene_info_table_path = '/scDifformer/data/240412_test/material/gene_info_table.csv'
out_path = '/scDifformer/data/240412_test/looms'
home_dir_list = ['/scDifformer/data/240412_test/raw_h5ad']

gene_info = pd.read_csv(gene_info_table_path)

# 转换为字典
gene_name_id_combine_dict = gene_info.set_index('gene_name')['ensembl_id'].to_dict()
gene_name_type_dict = gene_info.set_index('gene_name')['gene_type'].to_dict()
gene_id_type_dict = gene_info.set_index('ensembl_id')['gene_type'].to_dict()
func_gene_list = [i for i in
                  gene_info[(gene_info["gene_type"] == "protein_coding") | (gene_info["gene_type"] == "miRNA")][
                      "ensembl_id"]]


def preprocess(file_path: str) -> None:
    try:
        file_name = file_path.split(os.sep)[-1].split(".")[0]
        adata = sc.read_h5ad(file_path)
        adata.raw = None
        adata.X = np.round(adata.X, decimals=1)
        # 1. 去除线粒体和总基因数within 3 s.d. of the mean
        adata.var['mt'] = adata.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
        sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

        # 计算线粒体比例的均值和标准差
        mean_ratio_mt = np.mean(adata.obs["pct_counts_mt"])
        std_ratio_mt = np.std(adata.obs["pct_counts_mt"])

        # 计算在均值三个标准差范围内的细胞
        min_ratio_mt = mean_ratio_mt - 3 * std_ratio_mt
        max_ratio_mt = mean_ratio_mt + 3 * std_ratio_mt
        within_range_mt = (adata.obs["pct_counts_mt"] >= min_ratio_mt) & (adata.obs["pct_counts_mt"] <= max_ratio_mt)
        # 从adata中删除不在范围内的细胞
        adata = adata[within_range_mt, :]

        # 计算线基因总表达量的均值和标准差
        mean_ratio = np.mean(adata.obs["total_counts"])
        std_ratio = np.std(adata.obs["total_counts"])

        # 计算在均值三个标准差范围内的细胞
        min_ratio = mean_ratio - 3 * std_ratio
        max_ratio = mean_ratio + 3 * std_ratio
        within_range = (adata.obs["total_counts"] >= min_ratio) & (adata.obs["total_counts"] <= max_ratio)

        # 从adata中删除不在范围内的细胞
        adata = adata[within_range]

        # 2. 创建DoubletFinder对象并运行双峰细胞检测
        scrub = scrublet.Scrublet(adata.X)
        # Run doublet detection
        doublet_scores, predicted_doublets = scrub.scrub_doublets()
        # Filter the cells based on the doublet scores and predicted doublets
        if predicted_doublets is not None:
            adata = adata[~predicted_doublets]

        # 3. exclude cells with less than seven detected Ensembl-annotated protein-coding or miRNA genes
        retain_gene_set = set[str]()
        func_gene_set = set(func_gene_list)
        for gene_name in adata.var_names:
            if gene_name in func_gene_set:
                retain_gene_set.add(gene_name)
            elif gene_name in gene_name_id_combine_dict and gene_name_id_combine_dict[gene_name] in func_gene_set:
                retain_gene_set.add(gene_name)

        adata = adata[:, list(retain_gene_set)]
        sc.pp.filter_cells(adata, min_genes=7)
        gene_type_arr = []
        ensembl_id_arr = []
        for var_info in adata.var_names:
            if var_info in gene_name_type_dict:
                gene_type_arr.append(gene_name_type_dict[var_info])
                ensembl_id_arr.append(gene_name_id_combine_dict[var_info])
            elif str(var_info).startswith("ENSG"):
                gene_type_arr.append(gene_id_type_dict[var_info])
                ensembl_id_arr.append(var_info)
            else:
                raise Exception(f"can not find {var_info}")

        total_genes_per_row = adata.X.sum(axis=1)
        filename = f'{out_path}/{file_name}.loom'
        loompy.create(filename, adata.X.T, {"ensembl_id": np.array(ensembl_id_arr), "gene_type": gene_type_arr},
                      {"n_counts": total_genes_per_row})
        logger.info(f"{filename} is done")
        del adata
    except Exception as e:
        logger.error(f"====={file_path}")
        logger.error(f'{traceback.print_exc()}')


def argv_to_args(argv):
    """Parse argument from cmd"""
    parser = argparse.ArgumentParser()

    # start
    parser.add_argument("--start", '-s', type=int, help="start num for iter", required=True)
    parser.add_argument("--end", '-e', type=int, help="end num for iter", default=1)

    try:
        args = parser.parse_args(argv[1:])
    except Exception as e:
        logger.error(e)
        parser.print_help()
        sys.exit(1)

    return args


no_need_handle_set = set()

if __name__ == '__main__':
    for home_dir in home_dir_list:
        home_path = home_dir
        files = [os.path.join(home_path, file_name) for file_name in os.listdir(home_path) if
                 str(file_name).endswith(".h5ad") and file_name not in no_need_handle_set]
        logger.info(f'本次处理的文件个数: {len(files)}')
        with mp.Pool(processes=32) as pool:
            pool.map(preprocess, files)