File size: 5,093 Bytes
534e5a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import argparse
import logging
import os
import numpy as np
import scanpy as sc
import torch
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.data.dataset_ds import DatasetMultiPad
from src.models.chromfd_mixer import PretrainModelMambaLM
from src.utils.model_utils import ModelUtils
logging.basicConfig(level=logging.INFO)
def load_data(file_path):
if file_path.endswith('.h5ad'):
logging.info(f"Reading h5ad file from {file_path}")
adata = sc.read_h5ad(file_path)
return adata
else:
raise ValueError("Unsupported file format. Please provide a .h5ad file.")
class EmbeddingModel(PretrainModelMambaLM):
def __init__(self, **kwargs):
super(EmbeddingModel, self).__init__(**kwargs)
def forward(self, value, chromosome, hg38_start, hg38_end, **kwargs):
x = self.embedding(value, chromosome.long(), hg38_start.long(), hg38_end.long())
x = self.backbone(x)
return x
def generate_embeddings(model, adata_dataloader, device):
model.eval()
feature_embeddings = []
with torch.no_grad():
for batch_data in tqdm(adata_dataloader):
value, chromosome, pos_start, pos_end, cell_type = batch_data
value = value.to(device)
chromosome = chromosome.to(device)
pos_start = pos_start.to(device)
pos_end = pos_end.to(device)
embedding = model(value, chromosome, pos_start, pos_end)
embedding = embedding.mean(axis=-1)
embedding = embedding.detach().cpu().numpy()
feature_embeddings.append(embedding)
del value, chromosome, pos_start, pos_end, embedding
torch.cuda.empty_cache()
feature_embeddings = np.concatenate(feature_embeddings, axis=0)
return feature_embeddings
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, help='GPU rank', default=0)
parser.add_argument('--data_path', type=str, required=True, help='path of h5ad file')
parser.add_argument('--pretrain_checkpoint_path', type=str, required=True, help='path to pre-trained model')
parser.add_argument('--pretrain_model_file', type=str, required=True, help='file name of pre-trained model')
parser.add_argument('--pretrain_config_file', type=str, required=True, help='file name of pre-trained config')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--cell_type_col', required=True, help='column name of cell type')
parser.add_argument('--output_path', type=str, required=True, help='Path to save the updated h5ad file')
args = parser.parse_args()
device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
pretrain_path = args.pretrain_checkpoint_path
pretrain_config_file = args.pretrain_config_file
with open(os.path.join(pretrain_path, pretrain_config_file), 'r') as file:
pretrain_config = yaml.safe_load(file)
pretrain_model_args = pretrain_config['model_args']
pretrain_data_args = pretrain_config["data_args"]
chromosome_vocab = ModelUtils.get_chromosome_vocab(os.path.join(pretrain_path, "chromosome_vocab.yaml"))
pretrain_data_args["chromosome_vocab"] = chromosome_vocab
adata = load_data(args.data_path)
max_length = adata.shape[1]
cell_type = list(set(adata.obs['celltype'].unique().tolist()))
cell_type_map = {cell_type: idx for idx, cell_type in enumerate(sorted(cell_type))}
pretrain_data_args['cell_type_map'] = cell_type_map
pretrain_model_args["cell_type_num"] = len(cell_type_map)
pretrain_data_args['cell_type_col'] = args.cell_type_col
pretrain_data_args["feature_num"] = adata.shape[1]
pretrain_model_args["feature_num"] = adata.shape[1]
pretrain_model_args["batch_size"] = args.batch_size
pretrain_data_args["max_length"] = max_length
pretrain_model_args["max_length"] = max_length
pretrain_model_args["device"] = device
pretrain_model_args["add_cls"] = pretrain_data_args["add_cls"]
pretrain_model_args["mask_ratio"] = 0.0
pretrain_data_args["return_batch_label"] = False
model = EmbeddingModel(**pretrain_model_args)
state_dict = torch.load(str(os.path.join(pretrain_path, args.pretrain_model_file)))
model.load_state_dict(state_dict['module'])
model = model.to(device)
adataset = DatasetMultiPad(*[adata], **pretrain_data_args)
adata_dataloader = DataLoader(
adataset, batch_size=args.batch_size, shuffle=False, pin_memory=True
)
embeddings = generate_embeddings(model, adata_dataloader, device)
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
logging.info(f"make directory {args.output_path}")
adata.obsm['X_embedding'] = embeddings
adata.write_h5ad(os.path.join(args.output_path, "embeddings.h5ad"))
logging.info(f"Embeddings shape: {adata.obsm['X_embedding'].shape} saved to {args.output_path}")
if __name__ == '__main__':
main()
|