allenxiao's picture
Upload 17 files
8081b08 verified
import multiprocessing as mp
import os
import traceback
import pandas as pd
import scanpy as sc
from loguru import logger
from scipy.sparse import csr_matrix
def get_filename() -> str:
current_file_name = os.path.basename(__file__)
log_name = "{}.log"
return log_name.format(current_file_name.split('.')[0])
def set_log() -> None:
filename = get_filename()
logger.add(f'../log/{filename}')
set_log()
gene_info_table_path = '/mnt/nfs/data/geneformer/input/material/gene_info_table.csv'
gene_info = pd.read_csv(gene_info_table_path)
# 转换为字典
gene_name_id_combine_dict = gene_info.set_index('gene_name')['ensembl_id'].to_dict()
input_dir = '/mnt/nfs/hs/disco_1115_all'
output_dir = '/mnt/nfs/hs/disco_1117_update'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
def process_fn(file_path: str):
try:
adata = sc.read_h5ad(file_path)
adata.raw = None
h5ad_filename = file_path.split(os.sep)[-1]
dst_path = os.path.join(output_dir, h5ad_filename)
new_adata = sc.AnnData(X = adata.X)
new_adata.obs_names = list(adata.obs_names)
new_adata.var_names = list(adata.var_names)
if "cell.type" in adata.obs:
new_adata.obs["cell_type"] = list(adata.obs["cell.type"])
if "cell_type" in adata.obs:
new_adata.obs["cell_type"] = list(adata.obs["cell_type"])
# new_adata.var["gene_ids"] = new_adata.var_names.map(gene_name_id_combine_dict)
new_adata.X = csr_matrix(new_adata.X)
new_adata.write_h5ad(dst_path)
del adata
del new_adata
except Exception as e:
logger.error(f"====={file_path}")
logger.error(f'{traceback.print_exc()}')
if __name__ == "__main__":
files = [os.path.join(input_dir, file) for file in os.listdir(input_dir) if file.endswith(".h5ad")]
cpu_num = mp.cpu_count()
with mp.Pool(processes=len(files) if len(files) < cpu_num else 8) as pool:
pool.map(process_fn, files)