| import argparse | |
| import logging | |
| import os | |
| import numpy as np | |
| import scanpy as sc | |
| import torch | |
| import yaml | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from src.data.dataset_ds import DatasetMultiPad | |
| from src.models.chromfd_mixer import PretrainModelMambaLM | |
| from src.utils.model_utils import ModelUtils | |
| logging.basicConfig(level=logging.INFO) | |
| def load_data(file_path): | |
| if file_path.endswith('.h5ad'): | |
| logging.info(f"Reading h5ad file from {file_path}") | |
| adata = sc.read_h5ad(file_path) | |
| return adata | |
| else: | |
| raise ValueError("Unsupported file format. Please provide a .h5ad file.") | |
| class EmbeddingModel(PretrainModelMambaLM): | |
| def __init__(self, **kwargs): | |
| super(EmbeddingModel, self).__init__(**kwargs) | |
| def forward(self, value, chromosome, hg38_start, hg38_end, **kwargs): | |
| x = self.embedding(value, chromosome.long(), hg38_start.long(), hg38_end.long()) | |
| x = self.backbone(x) | |
| return x | |
| def generate_embeddings(model, adata_dataloader, device): | |
| model.eval() | |
| feature_embeddings = [] | |
| with torch.no_grad(): | |
| for batch_data in tqdm(adata_dataloader): | |
| value, chromosome, pos_start, pos_end, cell_type = batch_data | |
| value = value.to(device) | |
| chromosome = chromosome.to(device) | |
| pos_start = pos_start.to(device) | |
| pos_end = pos_end.to(device) | |
| embedding = model(value, chromosome, pos_start, pos_end) | |
| embedding = embedding.mean(axis=-1) | |
| embedding = embedding.detach().cpu().numpy() | |
| feature_embeddings.append(embedding) | |
| del value, chromosome, pos_start, pos_end, embedding | |
| torch.cuda.empty_cache() | |
| feature_embeddings = np.concatenate(feature_embeddings, axis=0) | |
| return feature_embeddings | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--local_rank', type=int, help='GPU rank', default=0) | |
| parser.add_argument('--data_path', type=str, required=True, help='path of h5ad file') | |
| parser.add_argument('--pretrain_checkpoint_path', type=str, required=True, help='path to pre-trained model') | |
| parser.add_argument('--pretrain_model_file', type=str, required=True, help='file name of pre-trained model') | |
| parser.add_argument('--pretrain_config_file', type=str, required=True, help='file name of pre-trained config') | |
| parser.add_argument('--batch_size', type=int, default=64, help='batch size') | |
| parser.add_argument('--cell_type_col', required=True, help='column name of cell type') | |
| parser.add_argument('--output_path', type=str, required=True, help='Path to save the updated h5ad file') | |
| args = parser.parse_args() | |
| device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu") | |
| torch.cuda.set_device(device) | |
| pretrain_path = args.pretrain_checkpoint_path | |
| pretrain_config_file = args.pretrain_config_file | |
| with open(os.path.join(pretrain_path, pretrain_config_file), 'r') as file: | |
| pretrain_config = yaml.safe_load(file) | |
| pretrain_model_args = pretrain_config['model_args'] | |
| pretrain_data_args = pretrain_config["data_args"] | |
| chromosome_vocab = ModelUtils.get_chromosome_vocab(os.path.join(pretrain_path, "chromosome_vocab.yaml")) | |
| pretrain_data_args["chromosome_vocab"] = chromosome_vocab | |
| adata = load_data(args.data_path) | |
| max_length = adata.shape[1] | |
| cell_type = list(set(adata.obs['celltype'].unique().tolist())) | |
| cell_type_map = {cell_type: idx for idx, cell_type in enumerate(sorted(cell_type))} | |
| pretrain_data_args['cell_type_map'] = cell_type_map | |
| pretrain_model_args["cell_type_num"] = len(cell_type_map) | |
| pretrain_data_args['cell_type_col'] = args.cell_type_col | |
| pretrain_data_args["feature_num"] = adata.shape[1] | |
| pretrain_model_args["feature_num"] = adata.shape[1] | |
| pretrain_model_args["batch_size"] = args.batch_size | |
| pretrain_data_args["max_length"] = max_length | |
| pretrain_model_args["max_length"] = max_length | |
| pretrain_model_args["device"] = device | |
| pretrain_model_args["add_cls"] = pretrain_data_args["add_cls"] | |
| pretrain_model_args["mask_ratio"] = 0.0 | |
| pretrain_data_args["return_batch_label"] = False | |
| model = EmbeddingModel(**pretrain_model_args) | |
| state_dict = torch.load(str(os.path.join(pretrain_path, args.pretrain_model_file))) | |
| model.load_state_dict(state_dict['module']) | |
| model = model.to(device) | |
| adataset = DatasetMultiPad(*[adata], **pretrain_data_args) | |
| adata_dataloader = DataLoader( | |
| adataset, batch_size=args.batch_size, shuffle=False, pin_memory=True | |
| ) | |
| embeddings = generate_embeddings(model, adata_dataloader, device) | |
| if not os.path.exists(args.output_path): | |
| os.makedirs(args.output_path) | |
| logging.info(f"make directory {args.output_path}") | |
| adata.obsm['X_embedding'] = embeddings | |
| adata.write_h5ad(os.path.join(args.output_path, "embeddings.h5ad")) | |
| logging.info(f"Embeddings shape: {adata.obsm['X_embedding'].shape} saved to {args.output_path}") | |
| if __name__ == '__main__': | |
| main() | |