File size: 2,005 Bytes
8081b08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import multiprocessing as mp
import os
import traceback

import pandas as pd
import scanpy as sc
from loguru import logger
from scipy.sparse import csr_matrix


def get_filename() -> str:
    current_file_name = os.path.basename(__file__)
    log_name = "{}.log"
    return log_name.format(current_file_name.split('.')[0])


def set_log() -> None:
    filename = get_filename()
    logger.add(f'../log/{filename}')


set_log()

gene_info_table_path = '/mnt/nfs/data/geneformer/input/material/gene_info_table.csv'

gene_info = pd.read_csv(gene_info_table_path)

# 转换为字典
gene_name_id_combine_dict = gene_info.set_index('gene_name')['ensembl_id'].to_dict()

input_dir = '/mnt/nfs/hs/disco_1115_all'
output_dir = '/mnt/nfs/hs/disco_1117_update'

if not os.path.exists(output_dir):
    os.makedirs(output_dir)


def process_fn(file_path: str):
    try:
        adata = sc.read_h5ad(file_path)
        adata.raw = None
        h5ad_filename = file_path.split(os.sep)[-1]
        dst_path = os.path.join(output_dir, h5ad_filename)
        new_adata = sc.AnnData(X = adata.X)
        new_adata.obs_names = list(adata.obs_names)
        new_adata.var_names = list(adata.var_names)
        if "cell.type" in adata.obs:
            new_adata.obs["cell_type"] = list(adata.obs["cell.type"])
        if "cell_type" in adata.obs:
            new_adata.obs["cell_type"] = list(adata.obs["cell_type"])
        # new_adata.var["gene_ids"] = new_adata.var_names.map(gene_name_id_combine_dict)
        new_adata.X = csr_matrix(new_adata.X)
        new_adata.write_h5ad(dst_path)
        del adata
        del new_adata
    except Exception as e:
        logger.error(f"====={file_path}")
        logger.error(f'{traceback.print_exc()}')


if __name__ == "__main__":
    files = [os.path.join(input_dir, file) for file in os.listdir(input_dir) if file.endswith(".h5ad")]
    cpu_num = mp.cpu_count()
    with mp.Pool(processes=len(files) if len(files) < cpu_num else 8) as pool:
        pool.map(process_fn, files)