allenxiao's picture
Upload 17 files
8081b08 verified
import argparse
import multiprocessing as mp
import os
import sys
import traceback
import loompy
import numpy as np
import pandas as pd
import scanpy as sc
import scrublet
from loguru import logger
def get_filename() -> str:
current_file_name = os.path.basename(__file__)
log_name = "{}.log"
return log_name.format(current_file_name.split('.')[0])
def set_log() -> None:
filename = get_filename()
logger.add(f'../log/{filename}')
set_log()
gene_info_table_path = '/scDifformer/data/240412_test/material/gene_info_table.csv'
out_path = '/scDifformer/data/240412_test/looms'
home_dir_list = ['/scDifformer/data/240412_test/raw_h5ad']
gene_info = pd.read_csv(gene_info_table_path)
# 转换为字典
gene_name_id_combine_dict = gene_info.set_index('gene_name')['ensembl_id'].to_dict()
gene_name_type_dict = gene_info.set_index('gene_name')['gene_type'].to_dict()
gene_id_type_dict = gene_info.set_index('ensembl_id')['gene_type'].to_dict()
func_gene_list = [i for i in
gene_info[(gene_info["gene_type"] == "protein_coding") | (gene_info["gene_type"] == "miRNA")][
"ensembl_id"]]
def preprocess(file_path: str) -> None:
try:
file_name = file_path.split(os.sep)[-1].split(".")[0]
adata = sc.read_h5ad(file_path)
adata.raw = None
adata.X = np.round(adata.X, decimals=1)
# 1. 去除线粒体和总基因数within 3 s.d. of the mean
adata.var['mt'] = adata.var_names.str.startswith('MT-') # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
# 计算线粒体比例的均值和标准差
mean_ratio_mt = np.mean(adata.obs["pct_counts_mt"])
std_ratio_mt = np.std(adata.obs["pct_counts_mt"])
# 计算在均值三个标准差范围内的细胞
min_ratio_mt = mean_ratio_mt - 3 * std_ratio_mt
max_ratio_mt = mean_ratio_mt + 3 * std_ratio_mt
within_range_mt = (adata.obs["pct_counts_mt"] >= min_ratio_mt) & (adata.obs["pct_counts_mt"] <= max_ratio_mt)
# 从adata中删除不在范围内的细胞
adata = adata[within_range_mt, :]
# 计算线基因总表达量的均值和标准差
mean_ratio = np.mean(adata.obs["total_counts"])
std_ratio = np.std(adata.obs["total_counts"])
# 计算在均值三个标准差范围内的细胞
min_ratio = mean_ratio - 3 * std_ratio
max_ratio = mean_ratio + 3 * std_ratio
within_range = (adata.obs["total_counts"] >= min_ratio) & (adata.obs["total_counts"] <= max_ratio)
# 从adata中删除不在范围内的细胞
adata = adata[within_range]
# 2. 创建DoubletFinder对象并运行双峰细胞检测
scrub = scrublet.Scrublet(adata.X)
# Run doublet detection
doublet_scores, predicted_doublets = scrub.scrub_doublets()
# Filter the cells based on the doublet scores and predicted doublets
if predicted_doublets is not None:
adata = adata[~predicted_doublets]
# 3. exclude cells with less than seven detected Ensembl-annotated protein-coding or miRNA genes
retain_gene_set = set[str]()
func_gene_set = set(func_gene_list)
for gene_name in adata.var_names:
if gene_name in func_gene_set:
retain_gene_set.add(gene_name)
elif gene_name in gene_name_id_combine_dict and gene_name_id_combine_dict[gene_name] in func_gene_set:
retain_gene_set.add(gene_name)
adata = adata[:, list(retain_gene_set)]
sc.pp.filter_cells(adata, min_genes=7)
gene_type_arr = []
ensembl_id_arr = []
for var_info in adata.var_names:
if var_info in gene_name_type_dict:
gene_type_arr.append(gene_name_type_dict[var_info])
ensembl_id_arr.append(gene_name_id_combine_dict[var_info])
elif str(var_info).startswith("ENSG"):
gene_type_arr.append(gene_id_type_dict[var_info])
ensembl_id_arr.append(var_info)
else:
raise Exception(f"can not find {var_info}")
total_genes_per_row = adata.X.sum(axis=1)
filename = f'{out_path}/{file_name}.loom'
loompy.create(filename, adata.X.T, {"ensembl_id": np.array(ensembl_id_arr), "gene_type": gene_type_arr},
{"n_counts": total_genes_per_row})
logger.info(f"{filename} is done")
del adata
except Exception as e:
logger.error(f"====={file_path}")
logger.error(f'{traceback.print_exc()}')
def argv_to_args(argv):
"""Parse argument from cmd"""
parser = argparse.ArgumentParser()
# start
parser.add_argument("--start", '-s', type=int, help="start num for iter", required=True)
parser.add_argument("--end", '-e', type=int, help="end num for iter", default=1)
try:
args = parser.parse_args(argv[1:])
except Exception as e:
logger.error(e)
parser.print_help()
sys.exit(1)
return args
no_need_handle_set = set()
if __name__ == '__main__':
for home_dir in home_dir_list:
home_path = home_dir
files = [os.path.join(home_path, file_name) for file_name in os.listdir(home_path) if
str(file_name).endswith(".h5ad") and file_name not in no_need_handle_set]
logger.info(f'本次处理的文件个数: {len(files)}')
with mp.Pool(processes=32) as pool:
pool.map(preprocess, files)