Christina Theodoris
Add data collator for cell classification and example for cell classification
088ea6e | """ | |
| Geneformer tokenizer. | |
| Usage: | |
| from geneformer import TranscriptomeTokenizer | |
| tk = TranscriptomeTokenizer({"cell_type": "cell_type", "organ_major": "organ_major"}, nproc=4) | |
| tk.tokenize_data("loom_data_directory", "output_directory", "output_prefix") | |
| """ | |
| import pickle | |
| from pathlib import Path | |
| import loompy as lp | |
| import numpy as np | |
| from datasets import Dataset | |
| GENE_MEDIAN_FILE = Path(__file__).parent / "gene_median_dictionary.pkl" | |
| TOKEN_DICTIONARY_FILE = Path(__file__).parent / "token_dictionary.pkl" | |
| def tokenize_cell(gene_vector, gene_tokens): | |
| """ | |
| Convert normalized gene expression vector to tokenized rank value encoding. | |
| """ | |
| # create array of gene vector with token indices | |
| # mask undetected genes | |
| nonzero_mask = np.nonzero(gene_vector)[0] | |
| # sort by median-scaled gene values | |
| sorted_indices = np.argsort(-gene_vector[nonzero_mask]) | |
| # tokenize | |
| sentence_tokens = gene_tokens[nonzero_mask][sorted_indices] | |
| return sentence_tokens | |
| class TranscriptomeTokenizer: | |
| def __init__( | |
| self, | |
| custom_attr_name_dict, | |
| nproc=1, | |
| gene_median_file=GENE_MEDIAN_FILE, | |
| token_dictionary_file=TOKEN_DICTIONARY_FILE, | |
| ): | |
| """ | |
| Initialize tokenizer. | |
| Parameters | |
| ---------- | |
| custom_attr_name_dict : dict | |
| Dictionary of custom attributes to be added to the dataset. | |
| Keys are the names of the attributes in the loom file. | |
| Values are the names of the attributes in the dataset. | |
| nproc : int | |
| Number of processes to use for dataset mapping. | |
| gene_median_file : Path | |
| Path to pickle file containing dictionary of non-zero median | |
| gene expression values across Genecorpus-30M. | |
| token_dictionary_file : Path | |
| Path to pickle file containing token dictionary (Ensembl IDs:token). | |
| """ | |
| # dictionary of custom attributes {output dataset column name: input .loom column name} | |
| self.custom_attr_name_dict = custom_attr_name_dict | |
| # number of processes for dataset mapping | |
| self.nproc = nproc | |
| # load dictionary of gene normalization factors | |
| # (non-zero median value of expression across Genecorpus-30M) | |
| with open(gene_median_file, "rb") as f: | |
| self.gene_median_dict = pickle.load(f) | |
| # load token dictionary (Ensembl IDs:token) | |
| with open(token_dictionary_file, "rb") as f: | |
| self.gene_token_dict = pickle.load(f) | |
| # gene keys for full vocabulary | |
| self.gene_keys = list(self.gene_median_dict.keys()) | |
| # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization | |
| self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) | |
| def tokenize_data(self, loom_data_directory, output_directory, output_prefix): | |
| """ | |
| Tokenize .loom files in loom_data_directory and save as tokenized .dataset in output_directory. | |
| Parameters | |
| ---------- | |
| loom_data_directory : Path | |
| Path to directory containing loom files | |
| output_directory : Path | |
| Path to directory where tokenized data will be saved as .dataset | |
| output_prefix : str | |
| Prefix for output .dataset | |
| """ | |
| tokenized_cells, cell_metadata = self.tokenize_files(loom_data_directory) | |
| tokenized_dataset = self.create_dataset(tokenized_cells, cell_metadata) | |
| output_path = (Path(output_directory) / output_prefix).with_suffix(".dataset") | |
| tokenized_dataset.save_to_disk(output_path) | |
| def tokenize_files(self, loom_data_directory): | |
| tokenized_cells = [] | |
| cell_metadata = {attr_key: [] for attr_key in self.custom_attr_name_dict.keys()} | |
| # loops through directories to tokenize .loom files | |
| for loom_file_path in loom_data_directory.glob("*.loom"): | |
| print(f"Tokenizing {loom_file_path}") | |
| file_tokenized_cells, file_cell_metadata = self.tokenize_file( | |
| loom_file_path | |
| ) | |
| tokenized_cells += file_tokenized_cells | |
| cell_metadata.update(file_cell_metadata) | |
| return tokenized_cells, cell_metadata | |
| def tokenize_file(self, loom_file_path): | |
| file_cell_metadata = { | |
| attr_key: [] for attr_key in self.custom_attr_name_dict.keys() | |
| } | |
| with lp.connect(str(loom_file_path)) as data: | |
| # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors | |
| coding_miRNA_loc = np.where( | |
| [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]] | |
| )[0] | |
| norm_factor_vector = np.array( | |
| [ | |
| self.gene_median_dict[i] | |
| for i in data.ra["ensembl_id"][coding_miRNA_loc] | |
| ] | |
| ) | |
| coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc] | |
| coding_miRNA_tokens = np.array( | |
| [self.gene_token_dict[i] for i in coding_miRNA_ids] | |
| ) | |
| # define coordinates of cells passing filters for inclusion (e.g. QC) | |
| try: | |
| data.ca["filter_pass"] | |
| except NameError: | |
| var_exists = False | |
| else: | |
| var_exists = True | |
| if var_exists is True: | |
| filter_pass_loc = np.where( | |
| [True if i == 1 else False for i in data.ca["filter_pass"]] | |
| )[0] | |
| elif var_exists is False: | |
| print( | |
| f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all cells." | |
| ) | |
| filter_pass_loc = np.array([i for i in range(data.shape[1])]) | |
| # scan through .loom files and tokenize cells | |
| tokenized_cells = [] | |
| for (_ix, _selection, view) in data.scan(items=filter_pass_loc, axis=1): | |
| # select subview with protein-coding and miRNA genes | |
| subview = view.view[coding_miRNA_loc, :] | |
| # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision | |
| # and normalize by gene normalization factors | |
| subview_norm_array = ( | |
| subview[:, :] | |
| / subview.ca.n_counts | |
| * 10_000 | |
| / norm_factor_vector[:, None] | |
| ) | |
| # tokenize subview gene vectors | |
| tokenized_cells += [ | |
| tokenize_cell(subview_norm_array[:, i], coding_miRNA_tokens) | |
| for i in range(subview_norm_array.shape[1]) | |
| ] | |
| # add custom attributes for subview to dict | |
| for k in file_cell_metadata.keys(): | |
| file_cell_metadata[k] += subview.ca[k].tolist() | |
| return tokenized_cells, file_cell_metadata | |
| def create_dataset(self, tokenized_cells, cell_metadata): | |
| # create dict for dataset creation | |
| dataset_dict = {"input_ids": tokenized_cells} | |
| dataset_dict.update(cell_metadata) | |
| # create dataset | |
| output_dataset = Dataset.from_dict(dataset_dict) | |
| # truncate dataset | |
| def truncate(example): | |
| example["input_ids"] = example["input_ids"][0:2048] | |
| return example | |
| output_dataset_truncated = output_dataset.map(truncate, num_proc=self.nproc) | |
| # measure lengths of dataset | |
| def measure_length(example): | |
| example["length"] = len(example["input_ids"]) | |
| return example | |
| output_dataset_truncated_w_length = output_dataset_truncated.map( | |
| measure_length, num_proc=self.nproc | |
| ) | |
| return output_dataset_truncated_w_length | |