|
|
import argparse |
|
|
import multiprocessing as mp |
|
|
import os |
|
|
import sys |
|
|
import traceback |
|
|
|
|
|
import loompy |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import scanpy as sc |
|
|
import scrublet |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
def get_filename() -> str: |
|
|
current_file_name = os.path.basename(__file__) |
|
|
log_name = "{}.log" |
|
|
return log_name.format(current_file_name.split('.')[0]) |
|
|
|
|
|
|
|
|
def set_log() -> None: |
|
|
filename = get_filename() |
|
|
logger.add(f'../log/{filename}') |
|
|
|
|
|
|
|
|
set_log() |
|
|
|
|
|
|
|
|
gene_info_table_path = '/scDifformer/data/240412_test/material/gene_info_table.csv' |
|
|
out_path = '/scDifformer/data/240412_test/looms' |
|
|
home_dir_list = ['/scDifformer/data/240412_test/raw_h5ad'] |
|
|
|
|
|
gene_info = pd.read_csv(gene_info_table_path) |
|
|
|
|
|
|
|
|
gene_name_id_combine_dict = gene_info.set_index('gene_name')['ensembl_id'].to_dict() |
|
|
gene_name_type_dict = gene_info.set_index('gene_name')['gene_type'].to_dict() |
|
|
gene_id_type_dict = gene_info.set_index('ensembl_id')['gene_type'].to_dict() |
|
|
func_gene_list = [i for i in |
|
|
gene_info[(gene_info["gene_type"] == "protein_coding") | (gene_info["gene_type"] == "miRNA")][ |
|
|
"ensembl_id"]] |
|
|
|
|
|
|
|
|
def preprocess(file_path: str) -> None: |
|
|
try: |
|
|
file_name = file_path.split(os.sep)[-1].split(".")[0] |
|
|
adata = sc.read_h5ad(file_path) |
|
|
adata.raw = None |
|
|
adata.X = np.round(adata.X, decimals=1) |
|
|
|
|
|
adata.var['mt'] = adata.var_names.str.startswith('MT-') |
|
|
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True) |
|
|
|
|
|
|
|
|
mean_ratio_mt = np.mean(adata.obs["pct_counts_mt"]) |
|
|
std_ratio_mt = np.std(adata.obs["pct_counts_mt"]) |
|
|
|
|
|
|
|
|
min_ratio_mt = mean_ratio_mt - 3 * std_ratio_mt |
|
|
max_ratio_mt = mean_ratio_mt + 3 * std_ratio_mt |
|
|
within_range_mt = (adata.obs["pct_counts_mt"] >= min_ratio_mt) & (adata.obs["pct_counts_mt"] <= max_ratio_mt) |
|
|
|
|
|
adata = adata[within_range_mt, :] |
|
|
|
|
|
|
|
|
mean_ratio = np.mean(adata.obs["total_counts"]) |
|
|
std_ratio = np.std(adata.obs["total_counts"]) |
|
|
|
|
|
|
|
|
min_ratio = mean_ratio - 3 * std_ratio |
|
|
max_ratio = mean_ratio + 3 * std_ratio |
|
|
within_range = (adata.obs["total_counts"] >= min_ratio) & (adata.obs["total_counts"] <= max_ratio) |
|
|
|
|
|
|
|
|
adata = adata[within_range] |
|
|
|
|
|
|
|
|
scrub = scrublet.Scrublet(adata.X) |
|
|
|
|
|
doublet_scores, predicted_doublets = scrub.scrub_doublets() |
|
|
|
|
|
if predicted_doublets is not None: |
|
|
adata = adata[~predicted_doublets] |
|
|
|
|
|
|
|
|
retain_gene_set = set[str]() |
|
|
func_gene_set = set(func_gene_list) |
|
|
for gene_name in adata.var_names: |
|
|
if gene_name in func_gene_set: |
|
|
retain_gene_set.add(gene_name) |
|
|
elif gene_name in gene_name_id_combine_dict and gene_name_id_combine_dict[gene_name] in func_gene_set: |
|
|
retain_gene_set.add(gene_name) |
|
|
|
|
|
adata = adata[:, list(retain_gene_set)] |
|
|
sc.pp.filter_cells(adata, min_genes=7) |
|
|
gene_type_arr = [] |
|
|
ensembl_id_arr = [] |
|
|
for var_info in adata.var_names: |
|
|
if var_info in gene_name_type_dict: |
|
|
gene_type_arr.append(gene_name_type_dict[var_info]) |
|
|
ensembl_id_arr.append(gene_name_id_combine_dict[var_info]) |
|
|
elif str(var_info).startswith("ENSG"): |
|
|
gene_type_arr.append(gene_id_type_dict[var_info]) |
|
|
ensembl_id_arr.append(var_info) |
|
|
else: |
|
|
raise Exception(f"can not find {var_info}") |
|
|
|
|
|
total_genes_per_row = adata.X.sum(axis=1) |
|
|
filename = f'{out_path}/{file_name}.loom' |
|
|
loompy.create(filename, adata.X.T, {"ensembl_id": np.array(ensembl_id_arr), "gene_type": gene_type_arr}, |
|
|
{"n_counts": total_genes_per_row}) |
|
|
logger.info(f"{filename} is done") |
|
|
del adata |
|
|
except Exception as e: |
|
|
logger.error(f"====={file_path}") |
|
|
logger.error(f'{traceback.print_exc()}') |
|
|
|
|
|
|
|
|
def argv_to_args(argv): |
|
|
"""Parse argument from cmd""" |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
|
parser.add_argument("--start", '-s', type=int, help="start num for iter", required=True) |
|
|
parser.add_argument("--end", '-e', type=int, help="end num for iter", default=1) |
|
|
|
|
|
try: |
|
|
args = parser.parse_args(argv[1:]) |
|
|
except Exception as e: |
|
|
logger.error(e) |
|
|
parser.print_help() |
|
|
sys.exit(1) |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
no_need_handle_set = set() |
|
|
|
|
|
if __name__ == '__main__': |
|
|
for home_dir in home_dir_list: |
|
|
home_path = home_dir |
|
|
files = [os.path.join(home_path, file_name) for file_name in os.listdir(home_path) if |
|
|
str(file_name).endswith(".h5ad") and file_name not in no_need_handle_set] |
|
|
logger.info(f'本次处理的文件个数: {len(files)}') |
|
|
with mp.Pool(processes=32) as pool: |
|
|
pool.map(preprocess, files) |