import argparse import multiprocessing as mp import os import sys import traceback import loompy import numpy as np import pandas as pd import scanpy as sc import scrublet from loguru import logger def get_filename() -> str: current_file_name = os.path.basename(__file__) log_name = "{}.log" return log_name.format(current_file_name.split('.')[0]) def set_log() -> None: filename = get_filename() logger.add(f'../log/{filename}') set_log() gene_info_table_path = '/scDifformer/data/240412_test/material/gene_info_table.csv' out_path = '/scDifformer/data/240412_test/looms' home_dir_list = ['/scDifformer/data/240412_test/raw_h5ad'] gene_info = pd.read_csv(gene_info_table_path) # 转换为字典 gene_name_id_combine_dict = gene_info.set_index('gene_name')['ensembl_id'].to_dict() gene_name_type_dict = gene_info.set_index('gene_name')['gene_type'].to_dict() gene_id_type_dict = gene_info.set_index('ensembl_id')['gene_type'].to_dict() func_gene_list = [i for i in gene_info[(gene_info["gene_type"] == "protein_coding") | (gene_info["gene_type"] == "miRNA")][ "ensembl_id"]] def preprocess(file_path: str) -> None: try: file_name = file_path.split(os.sep)[-1].split(".")[0] adata = sc.read_h5ad(file_path) adata.raw = None adata.X = np.round(adata.X, decimals=1) # 1. 去除线粒体和总基因数within 3 s.d. of the mean adata.var['mt'] = adata.var_names.str.startswith('MT-') # annotate the group of mitochondrial genes as 'mt' sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True) # 计算线粒体比例的均值和标准差 mean_ratio_mt = np.mean(adata.obs["pct_counts_mt"]) std_ratio_mt = np.std(adata.obs["pct_counts_mt"]) # 计算在均值三个标准差范围内的细胞 min_ratio_mt = mean_ratio_mt - 3 * std_ratio_mt max_ratio_mt = mean_ratio_mt + 3 * std_ratio_mt within_range_mt = (adata.obs["pct_counts_mt"] >= min_ratio_mt) & (adata.obs["pct_counts_mt"] <= max_ratio_mt) # 从adata中删除不在范围内的细胞 adata = adata[within_range_mt, :] # 计算线基因总表达量的均值和标准差 mean_ratio = np.mean(adata.obs["total_counts"]) std_ratio = np.std(adata.obs["total_counts"]) # 计算在均值三个标准差范围内的细胞 min_ratio = mean_ratio - 3 * std_ratio max_ratio = mean_ratio + 3 * std_ratio within_range = (adata.obs["total_counts"] >= min_ratio) & (adata.obs["total_counts"] <= max_ratio) # 从adata中删除不在范围内的细胞 adata = adata[within_range] # 2. 创建DoubletFinder对象并运行双峰细胞检测 scrub = scrublet.Scrublet(adata.X) # Run doublet detection doublet_scores, predicted_doublets = scrub.scrub_doublets() # Filter the cells based on the doublet scores and predicted doublets if predicted_doublets is not None: adata = adata[~predicted_doublets] # 3. exclude cells with less than seven detected Ensembl-annotated protein-coding or miRNA genes retain_gene_set = set[str]() func_gene_set = set(func_gene_list) for gene_name in adata.var_names: if gene_name in func_gene_set: retain_gene_set.add(gene_name) elif gene_name in gene_name_id_combine_dict and gene_name_id_combine_dict[gene_name] in func_gene_set: retain_gene_set.add(gene_name) adata = adata[:, list(retain_gene_set)] sc.pp.filter_cells(adata, min_genes=7) gene_type_arr = [] ensembl_id_arr = [] for var_info in adata.var_names: if var_info in gene_name_type_dict: gene_type_arr.append(gene_name_type_dict[var_info]) ensembl_id_arr.append(gene_name_id_combine_dict[var_info]) elif str(var_info).startswith("ENSG"): gene_type_arr.append(gene_id_type_dict[var_info]) ensembl_id_arr.append(var_info) else: raise Exception(f"can not find {var_info}") total_genes_per_row = adata.X.sum(axis=1) filename = f'{out_path}/{file_name}.loom' loompy.create(filename, adata.X.T, {"ensembl_id": np.array(ensembl_id_arr), "gene_type": gene_type_arr}, {"n_counts": total_genes_per_row}) logger.info(f"{filename} is done") del adata except Exception as e: logger.error(f"====={file_path}") logger.error(f'{traceback.print_exc()}') def argv_to_args(argv): """Parse argument from cmd""" parser = argparse.ArgumentParser() # start parser.add_argument("--start", '-s', type=int, help="start num for iter", required=True) parser.add_argument("--end", '-e', type=int, help="end num for iter", default=1) try: args = parser.parse_args(argv[1:]) except Exception as e: logger.error(e) parser.print_help() sys.exit(1) return args no_need_handle_set = set() if __name__ == '__main__': for home_dir in home_dir_list: home_path = home_dir files = [os.path.join(home_path, file_name) for file_name in os.listdir(home_path) if str(file_name).endswith(".h5ad") and file_name not in no_need_handle_set] logger.info(f'本次处理的文件个数: {len(files)}') with mp.Pool(processes=32) as pool: pool.map(preprocess, files)