Add function for summing of Ensembl IDs
Browse files- geneformer/tokenizer.py +135 -4
geneformer/tokenizer.py
CHANGED
|
@@ -36,14 +36,21 @@ Geneformer tokenizer.
|
|
| 36 |
|
| 37 |
from __future__ import annotations
|
| 38 |
|
|
|
|
| 39 |
import logging
|
| 40 |
import pickle
|
|
|
|
| 41 |
import warnings
|
| 42 |
from pathlib import Path
|
| 43 |
from typing import Literal
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
import anndata as ad
|
| 46 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
import scipy.sparse as sp
|
| 48 |
from datasets import Dataset
|
| 49 |
|
|
@@ -52,7 +59,7 @@ import loompy as lp # noqa
|
|
| 52 |
|
| 53 |
logger = logging.getLogger(__name__)
|
| 54 |
|
| 55 |
-
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
| 56 |
|
| 57 |
|
| 58 |
def rank_genes(gene_vector, gene_tokens):
|
|
@@ -74,6 +81,115 @@ def tokenize_cell(gene_vector, gene_tokens):
|
|
| 74 |
# rank by median-scaled gene values
|
| 75 |
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
class TranscriptomeTokenizer:
|
| 79 |
def __init__(
|
|
@@ -85,6 +201,7 @@ class TranscriptomeTokenizer:
|
|
| 85 |
special_token=False,
|
| 86 |
gene_median_file=GENE_MEDIAN_FILE,
|
| 87 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
|
|
|
| 88 |
):
|
| 89 |
"""
|
| 90 |
Initialize tokenizer.
|
|
@@ -103,11 +220,15 @@ class TranscriptomeTokenizer:
|
|
| 103 |
| Max input size of model to truncate input to.
|
| 104 |
special_token : bool = False
|
| 105 |
| Adds CLS token before and EOS token after rank value encoding.
|
|
|
|
|
|
|
| 106 |
gene_median_file : Path
|
| 107 |
| Path to pickle file containing dictionary of non-zero median
|
| 108 |
| gene expression values across Genecorpus-30M.
|
| 109 |
token_dictionary_file : Path
|
| 110 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
|
|
|
|
|
|
| 111 |
|
| 112 |
"""
|
| 113 |
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
|
@@ -134,6 +255,10 @@ class TranscriptomeTokenizer:
|
|
| 134 |
with open(token_dictionary_file, "rb") as f:
|
| 135 |
self.gene_token_dict = pickle.load(f)
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
# gene keys for full vocabulary
|
| 138 |
self.gene_keys = list(self.gene_token_dict.keys())
|
| 139 |
|
|
@@ -214,7 +339,7 @@ class TranscriptomeTokenizer:
|
|
| 214 |
return tokenized_cells, cell_metadata
|
| 215 |
|
| 216 |
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
| 217 |
-
adata =
|
| 218 |
|
| 219 |
if self.custom_attr_name_dict is not None:
|
| 220 |
file_cell_metadata = {
|
|
@@ -256,7 +381,8 @@ class TranscriptomeTokenizer:
|
|
| 256 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
| 257 |
|
| 258 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
| 259 |
-
|
|
|
|
| 260 |
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
| 261 |
X_norm = sp.csr_matrix(X_norm)
|
| 262 |
|
|
@@ -280,6 +406,8 @@ class TranscriptomeTokenizer:
|
|
| 280 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
| 281 |
}
|
| 282 |
|
|
|
|
|
|
|
| 283 |
with lp.connect(str(loom_file_path)) as data:
|
| 284 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
| 285 |
coding_miRNA_loc = np.where(
|
|
@@ -341,6 +469,9 @@ class TranscriptomeTokenizer:
|
|
| 341 |
else:
|
| 342 |
file_cell_metadata = None
|
| 343 |
|
|
|
|
|
|
|
|
|
|
| 344 |
return tokenized_cells, file_cell_metadata
|
| 345 |
|
| 346 |
def create_dataset(
|
|
|
|
| 36 |
|
| 37 |
from __future__ import annotations
|
| 38 |
|
| 39 |
+
import os
|
| 40 |
import logging
|
| 41 |
import pickle
|
| 42 |
+
import sys
|
| 43 |
import warnings
|
| 44 |
from pathlib import Path
|
| 45 |
from typing import Literal
|
| 46 |
+
from tqdm import tqdm
|
| 47 |
+
from collections import Counter
|
| 48 |
|
|
|
|
| 49 |
import numpy as np
|
| 50 |
+
import scanpy as sc
|
| 51 |
+
import loompy as lp
|
| 52 |
+
import pandas as pd
|
| 53 |
+
import anndata as ad
|
| 54 |
import scipy.sparse as sp
|
| 55 |
from datasets import Dataset
|
| 56 |
|
|
|
|
| 59 |
|
| 60 |
logger = logging.getLogger(__name__)
|
| 61 |
|
| 62 |
+
from . import GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE, ENSEMBL_MAPPING_FILE
|
| 63 |
|
| 64 |
|
| 65 |
def rank_genes(gene_vector, gene_tokens):
|
|
|
|
| 81 |
# rank by median-scaled gene values
|
| 82 |
return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])
|
| 83 |
|
| 84 |
+
def sum_ensembl_ids(data_directory,
|
| 85 |
+
gene_mapping_dict,
|
| 86 |
+
file_format = "loom",
|
| 87 |
+
chunk_size = 512):
|
| 88 |
+
if file_format == "loom":
|
| 89 |
+
"""
|
| 90 |
+
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
| 91 |
+
"""
|
| 92 |
+
with lp.connect(data_directory) as data:
|
| 93 |
+
assert "ensembl_id" in data.ra.keys(), "'ensembl_id' column missing from data.ra.keys()"
|
| 94 |
+
gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id]
|
| 95 |
+
|
| 96 |
+
if len(set(gene_ids_collapsed)) == len(set(data.ra.ensembl_id)):
|
| 97 |
+
return data_directory
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
|
| 101 |
+
dup_genes = [idx for idx, count in Counter(data.ra["ensembl_id"]).items() if count > 1]
|
| 102 |
+
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
| 103 |
+
first_chunk = True
|
| 104 |
+
for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
|
| 105 |
+
def process_chunk(view, duplic_genes):
|
| 106 |
+
data_count_view = pd.DataFrame(view, index=data.ra["ensembl_id"])
|
| 107 |
+
unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
|
| 108 |
+
dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
|
| 109 |
+
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
| 110 |
+
if not summed_data.index.is_unique:
|
| 111 |
+
raise ValueError("Error: summed data frame non-unique.")
|
| 112 |
+
data_count_view = pd.concat([unique_data_df, summed_data], axis=0)
|
| 113 |
+
if not data_count_view.index.is_unique:
|
| 114 |
+
raise ValueError("Error: final data frame non-unique.")
|
| 115 |
+
return data_count_view
|
| 116 |
+
processed_chunk = process_chunk(view[:, :], dup_genes)
|
| 117 |
+
processed_array = processed_chunk.to_numpy()
|
| 118 |
+
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
| 119 |
+
|
| 120 |
+
ra_keys = [k for k in data.ra.keys() if k != "ensembl_id"]
|
| 121 |
+
for ra_value in ra_keys:
|
| 122 |
+
mapping_dict = dict(zip(data.ra["ensembl_id"], data.ra[ra_value]))
|
| 123 |
+
values_new = [mapping_dict[i] for i in processed_chunk.index]
|
| 124 |
+
new_row_attrs[ra_value] = np.array(values_new)
|
| 125 |
+
|
| 126 |
+
if "n_counts" not in view.ca.keys():
|
| 127 |
+
total_count_view = np.sum(view[:,:], axis=0).astype(int)
|
| 128 |
+
view.ca["n_counts"] = total_count_view
|
| 129 |
+
|
| 130 |
+
if first_chunk: # Create the Loom file with the first chunk
|
| 131 |
+
lp.create(f"{dedup_filename}", processed_array, row_attrs=new_row_attrs, col_attrs=view.ca)
|
| 132 |
+
first_chunk = False
|
| 133 |
+
else: # Append subsequent chunks
|
| 134 |
+
with lp.connect(dedup_filename, mode='r+') as dsout:
|
| 135 |
+
dsout.add_columns(processed_array, col_attrs=view.ca)
|
| 136 |
+
return dedup_filename
|
| 137 |
+
|
| 138 |
+
elif file_format == "h5ad":
|
| 139 |
+
"""
|
| 140 |
+
Map Ensembl IDs from gene mapping dictionary. If duplicate Ensembl IDs are found, sum counts together.
|
| 141 |
+
Returns adata object with deduplicated Ensembl IDs.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
data = sc.read_h5ad(str(data_directory))
|
| 145 |
+
|
| 146 |
+
assert "ensembl_id" in data.var.columns, "'ensembl_id' column missing from data.var"
|
| 147 |
+
|
| 148 |
+
gene_ids_collapsed = [gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id]
|
| 149 |
+
|
| 150 |
+
if len(set(gene_ids_collapsed)) == len(set(data.var.ensembl_id)):
|
| 151 |
+
return data
|
| 152 |
+
|
| 153 |
+
else:
|
| 154 |
+
data.var["gene_ids_collapsed"] = gene_ids_collapsed
|
| 155 |
+
data.var_names = gene_ids_collapsed
|
| 156 |
+
data = data[:, ~data.var.index.isna()]
|
| 157 |
+
dup_genes = [idx for idx, count in Counter(data.var_names).items() if count > 1]
|
| 158 |
+
|
| 159 |
+
num_chunks = int(np.ceil(data.shape[0] / chunk_size))
|
| 160 |
+
|
| 161 |
+
processed_genes = []
|
| 162 |
+
for i in tqdm(range(num_chunks)):
|
| 163 |
+
|
| 164 |
+
start_idx = i * chunk_size
|
| 165 |
+
end_idx = min((i + 1) * chunk_size, data.shape[0])
|
| 166 |
+
data_chunk = data[start_idx:end_idx, :]
|
| 167 |
+
|
| 168 |
+
processed_chunks = []
|
| 169 |
+
for dup_gene in dup_genes:
|
| 170 |
+
data_dup_gene = data_chunk[:, data_chunk.var_names == dup_gene]
|
| 171 |
+
df = pd.DataFrame.sparse.from_spmatrix(data_dup_gene.X,
|
| 172 |
+
index=data_dup_gene.obs_names,
|
| 173 |
+
columns=data_dup_gene.var_names)
|
| 174 |
+
df_sum = pd.DataFrame(df.sum(axis=1))
|
| 175 |
+
df_sum.columns = [dup_gene]
|
| 176 |
+
df_sum.index = data_dup_gene.obs.index
|
| 177 |
+
processed_chunks.append(df_sum)
|
| 178 |
+
|
| 179 |
+
processed_chunks = pd.concat(processed_chunks, axis=1)
|
| 180 |
+
processed_genes.append(processed_chunks)
|
| 181 |
+
processed_genes = pd.concat(processed_genes, axis = 0)
|
| 182 |
+
var_df = pd.DataFrame({"gene_ids_collapsed" : processed_genes.columns})
|
| 183 |
+
var_df.index = processed_genes.columns
|
| 184 |
+
processed_genes = sc.AnnData(X = processed_genes,
|
| 185 |
+
obs = data.obs,
|
| 186 |
+
var = var_df)
|
| 187 |
+
|
| 188 |
+
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
| 189 |
+
data_dedup = sc.concat([data_dedup, processed_genes], axis = 1)
|
| 190 |
+
data_dedup.obs = data.obs
|
| 191 |
+
data_dedup.var = data_dedup.var.rename(columns = {"gene_ids_collapsed" : "ensembl_id"})
|
| 192 |
+
return data_dedup
|
| 193 |
|
| 194 |
class TranscriptomeTokenizer:
|
| 195 |
def __init__(
|
|
|
|
| 201 |
special_token=False,
|
| 202 |
gene_median_file=GENE_MEDIAN_FILE,
|
| 203 |
token_dictionary_file=TOKEN_DICTIONARY_FILE,
|
| 204 |
+
gene_mapping_file=ENSEMBL_MAPPING_FILE,
|
| 205 |
):
|
| 206 |
"""
|
| 207 |
Initialize tokenizer.
|
|
|
|
| 220 |
| Max input size of model to truncate input to.
|
| 221 |
special_token : bool = False
|
| 222 |
| Adds CLS token before and EOS token after rank value encoding.
|
| 223 |
+
collapse_gene_ids : bool = False
|
| 224 |
+
| Whether to collapse gene IDs based on gene mapping dictionary.
|
| 225 |
gene_median_file : Path
|
| 226 |
| Path to pickle file containing dictionary of non-zero median
|
| 227 |
| gene expression values across Genecorpus-30M.
|
| 228 |
token_dictionary_file : Path
|
| 229 |
| Path to pickle file containing token dictionary (Ensembl IDs:token).
|
| 230 |
+
gene_mapping_file : Path
|
| 231 |
+
| Path to pickle file containing dictionary for collapsing gene IDs.
|
| 232 |
|
| 233 |
"""
|
| 234 |
# dictionary of custom attributes {output dataset column name: input .loom column name}
|
|
|
|
| 255 |
with open(token_dictionary_file, "rb") as f:
|
| 256 |
self.gene_token_dict = pickle.load(f)
|
| 257 |
|
| 258 |
+
# load gene mappings dictionary (Ensembl IDs:Ensembl ID)
|
| 259 |
+
with open(gene_mapping_file, "rb") as f:
|
| 260 |
+
self.gene_mapping_dict = pickle.load(f)
|
| 261 |
+
|
| 262 |
# gene keys for full vocabulary
|
| 263 |
self.gene_keys = list(self.gene_token_dict.keys())
|
| 264 |
|
|
|
|
| 339 |
return tokenized_cells, cell_metadata
|
| 340 |
|
| 341 |
def tokenize_anndata(self, adata_file_path, target_sum=10_000):
|
| 342 |
+
adata = sum_ensembl_ids(adata_file_path, self.gene_mapping_dict, file_format = "h5ad", chunk_size = self.chunk_size)
|
| 343 |
|
| 344 |
if self.custom_attr_name_dict is not None:
|
| 345 |
file_cell_metadata = {
|
|
|
|
| 381 |
idx = filter_pass_loc[i : i + self.chunk_size]
|
| 382 |
|
| 383 |
n_counts = adata[idx].obs["n_counts"].values[:, None]
|
| 384 |
+
X_view0 = adata[idx,:].X
|
| 385 |
+
X_view = X_view0[:, coding_miRNA_loc]
|
| 386 |
X_norm = X_view / n_counts * target_sum / norm_factor_vector
|
| 387 |
X_norm = sp.csr_matrix(X_norm)
|
| 388 |
|
|
|
|
| 406 |
attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
|
| 407 |
}
|
| 408 |
|
| 409 |
+
loom_file_path = sum_ensembl_ids(loom_file_path, self.gene_mapping_dict, file_format = "loom", chunk_size = self.chunk_size)
|
| 410 |
+
|
| 411 |
with lp.connect(str(loom_file_path)) as data:
|
| 412 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
| 413 |
coding_miRNA_loc = np.where(
|
|
|
|
| 469 |
else:
|
| 470 |
file_cell_metadata = None
|
| 471 |
|
| 472 |
+
if "__dedup" in str(loom_file_path):
|
| 473 |
+
os.remove(str(loom_file_path))
|
| 474 |
+
|
| 475 |
return tokenized_cells, file_cell_metadata
|
| 476 |
|
| 477 |
def create_dataset(
|