| """ |
| Geneformer tokenizer. |
| |
| Input data: |
| Required format: raw counts scRNAseq data without feature selection as .loom file |
| Required row (gene) attribute: "ensembl_id"; Ensembl ID for each gene |
| Required col (cell) attribute: "n_counts"; total read counts in that cell |
| Optional col (cell) attribute: "filter_pass"; binary indicator of whether cell should be tokenized based on user-defined filtering criteria |
| Optional col (cell) attributes: any other cell metadata can be passed on to the tokenized dataset as a custom attribute dictionary as shown below |
| |
| Usage: |
| from geneformer import TranscriptomeTokenizer |
| tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4) |
| tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix") |
| """ |
|
|
| import pickle |
| from pathlib import Path |
|
|
| import logging |
|
|
| import warnings |
| warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*") |
|
|
| import loompy as lp |
| import numpy as np |
| from datasets import Dataset |
|
|
| logger = logging.getLogger(__name__) |
|
|
| GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" |
| TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" |
|
|
|
|
| def tokenize_cell(gene_vector, gene_tokens): |
| """ |
| Convert normalized gene expression vector to tokenized rank value encoding. |
| """ |
| |
| |
| nonzero_mask = np.nonzero(gene_vector)[0] |
| |
| sorted_indices = np.argsort(-gene_vector[nonzero_mask]) |
| |
| sentence_tokens = gene_tokens[nonzero_mask][sorted_indices] |
| return sentence_tokens |
|
|
|
|
| class TranscriptomeTokenizer: |
| def __init__( |
| self, |
| custom_attr_name_dict=None, |
| nproc=1, |
| gene_median_file=GENE_MEDIAN_FILE, |
| token_dictionary_file=TOKEN_DICTIONARY_FILE, |
| ): |
| """ |
| Initialize tokenizer. |
| |
| Parameters |
| ---------- |
| custom_attr_name_dict : None, dict |
| Dictionary of custom attributes to be added to the dataset. |
| Keys are the names of the attributes in the loom file. |
| Values are the names of the attributes in the dataset. |
| nproc : int |
| Number of processes to use for dataset mapping. |
| gene_median_file : Path |
| Path to pickle file containing dictionary of non-zero median |
| gene expression values across Genecorpus-30M. |
| token_dictionary_file : Path |
| Path to pickle file containing token dictionary (Ensembl IDs:token). |
| """ |
| |
| self.custom_attr_name_dict = custom_attr_name_dict |
|
|
| |
| self.nproc = nproc |
|
|
| |
| |
| with open(gene_median_file, "rb") as f: |
| self.gene_median_dict = pickle.load(f) |
|
|
| |
| with open(token_dictionary_file, "rb") as f: |
| self.gene_token_dict = pickle.load(f) |
|
|
| |
| self.gene_keys = list(self.gene_median_dict.keys()) |
|
|
| |
| self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) |
|
|
| def tokenize_data(self, loom_data_directory, output_directory, output_prefix): |
| """ |
| Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory. |
| |
| Parameters |
| ---------- |
| loom_data_directory : Path |
| Path to directory containing loom files |
| output_directory : Path |
| Path to directory where tokenized data will be saved as .dataset |
| output_prefix : str |
| Prefix for output .dataset |
| """ |
| tokenized_cells, cell_metadata = self.tokenize_files(Path(loom_data_directory)) |
| tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata) |
|
|
| output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset") |
| tokenized_dataset.save_to_disk(output_path) |
|
|
| def tokenize_files(self, loom_data_directory): |
| tokenized_cells = [] |
| if self.custom_attr_name_dict is not None: |
| loom_cell_attr = [attr_key for attr_key in self.custom_attr_name_dict.keys()] |
| cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.values()} |
|
|
| |
| file_found = 0 |
| for loom_file_path in loom_data_directory.glob("*.loom"): |
| file_found = 1 |
| print(f"Tokenizing {loom_file_path}") |
| file_tokenized_cells, file_cell_metadata = self.tokenize_file( |
| loom_file_path |
| ) |
| tokenized_cells += file_tokenized_cells |
| if self.custom_attr_name_dict is not None: |
| for k in loom_cell_attr: |
| cell_metadata[self.custom_attr_name_dict[k]] += file_cell_metadata[k] |
| else: |
| cell_metadata = None |
|
|
| if file_found == 0: |
| logger.error( |
| f"No .loom files found in directory {loom_data_directory}.") |
| raise |
| return tokenized_cells, cell_metadata |
|
|
| def tokenize_file(self, loom_file_path): |
| if self.custom_attr_name_dict is not None: |
| file_cell_metadata = { |
| attr_key: [] for attr_key in self.custom_attr_name_dict.keys() |
| } |
|
|
| with lp.connect(str(loom_file_path)) as data: |
| |
| coding_miRNA_loc = np.where( |
| [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]] |
| )[0] |
| norm_factor_vector = np.array( |
| [ |
| self.gene_median_dict[i] |
| for i in data.ra["ensembl_id"][coding_miRNA_loc] |
| ] |
| ) |
| coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc] |
| coding_miRNA_tokens = np.array( |
| [self.gene_token_dict[i] for i in coding_miRNA_ids] |
| ) |
|
|
| |
| try: |
| data.ca["filter_pass"] |
| except AttributeError: |
| var_exists = False |
| else: |
| var_exists = True |
|
|
| if var_exists is True: |
| filter_pass_loc = np.where( |
| [True if i == 1 else False for i in data.ca["filter_pass"]] |
| )[0] |
| elif var_exists is False: |
| print( |
| f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells." |
| ) |
| filter_pass_loc = np.array([i for i in range(data.shape[1])]) |
|
|
| |
| tokenized_cells = [] |
| for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1): |
| |
| subview = view.view[coding_miRNA_loc, :] |
|
|
| |
| |
| subview_norm_array = ( |
| subview[:, :] |
| / subview.ca.n_counts |
| * 10_000 |
| / norm_factor_vector[:, None] |
| ) |
| |
| tokenized_cells += [ |
| tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens) |
| for i in range(subview_norm_array.shape[1]) |
| ] |
|
|
| |
| if self.custom_attr_name_dict is not None: |
| for k in file_cell_metadata.keys(): |
| file_cell_metadata[k] += subview.ca[k].tolist() |
| else: |
| file_cell_metadata = None |
|
|
| return tokenized_cells, file_cell_metadata |
|
|
| def create_dataset(self, tokenized_cells, cell_metadata): |
| |
| dataset_dict = {"input_ids": tokenized_cells} |
| if self.custom_attr_name_dict is not None: |
| dataset_dict.update(cell_metadata) |
|
|
| |
| output_dataset = Dataset.from_dict(dataset_dict) |
|
|
| |
| def truncate(example): |
| example["input_ids"] = example["input_ids"][0:2048] |
| return example |
|
|
| output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc) |
|
|
| |
| def measure_length(example): |
| example["length"] = len(example["input_ids"]) |
| return example |
|
|
| output_dataset_truncated_w_length = output_dataset_truncated.map( |
| measure_length, num_proc=self.nproc |
| ) |
|
|
| return output_dataset_truncated_w_length |
|
|