Update geneformer/tokenizer.py
Browse files- Add checks for CLS and EOS token when special_toke = True
- More efficient filter of gene_mapping_dict for values in gene_token_dict
- Remove summing of genes that do not exist in gene_token_dict for loom files
- geneformer/tokenizer.py +14 -12
geneformer/tokenizer.py
CHANGED
|
@@ -94,18 +94,17 @@ def sum_ensembl_ids(data_directory,
|
|
| 94 |
|
| 95 |
if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
|
| 96 |
return data_directory
|
| 97 |
-
|
| 98 |
else:
|
| 99 |
dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
|
| 100 |
-
data.ra["
|
| 101 |
-
dup_genes = [idx for idx, count in Counter(data.ra["
|
| 102 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
| 103 |
first_chunk = True
|
| 104 |
for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
|
| 105 |
def process_chunk(view, duplic_genes):
|
| 106 |
-
data_count_view = pd.DataFrame(view, index=data.ra["
|
| 107 |
unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
|
| 108 |
-
dup_data_df = data_count_view.loc[data_count_view.index.isin(duplic_genes)]
|
| 109 |
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
| 110 |
if not summed_data.index.is_unique:
|
| 111 |
raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
|
|
@@ -117,12 +116,6 @@ def sum_ensembl_ids(data_directory,
|
|
| 117 |
processed_array = processed_chunk.to_numpy()
|
| 118 |
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
| 119 |
|
| 120 |
-
ra_keys = [k for k in data.ra.keys() if k != "ensembl_id"]
|
| 121 |
-
for ra_value in ra_keys:
|
| 122 |
-
mapping_dict = dict(zip(data.ra["ensembl_id"], data.ra[ra_value]))
|
| 123 |
-
values_new = [mapping_dict[i] for i in processed_chunk.index]
|
| 124 |
-
new_row_attrs[ra_value] = np.array(values_new)
|
| 125 |
-
|
| 126 |
if "n_counts" not in view.ca.keys():
|
| 127 |
total_count_view = np.sum(view[:,:], axis=0).astype(int)
|
| 128 |
view.ca["n_counts"] = total_count_view
|
|
@@ -263,6 +256,14 @@ class TranscriptomeTokenizer:
|
|
| 263 |
with open(token_dictionary_file, "rb") as f:
|
| 264 |
self.gene_token_dict = pickle.load(f)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
# if collapsing duplicate gene IDs
|
| 267 |
self.collapse_gene_ids = collapse_gene_ids
|
| 268 |
|
|
@@ -277,7 +278,8 @@ class TranscriptomeTokenizer:
|
|
| 277 |
self.gene_keys = list(self.gene_token_dict.keys())
|
| 278 |
|
| 279 |
# Filter gene mapping dict for items that exist in gene_token_dict
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
| 283 |
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|
|
|
|
| 94 |
|
| 95 |
if (len(set(gene_ids_collapsed_in_dict)) == len(set(gene_ids_in_dict))) and token_genes_unique:
|
| 96 |
return data_directory
|
|
|
|
| 97 |
else:
|
| 98 |
dedup_filename = data_directory.with_name(data_directory.stem + "__dedup.loom")
|
| 99 |
+
data.ra["gene_ids_collapsed"] = gene_ids_collapsed
|
| 100 |
+
dup_genes = [idx for idx, count in Counter(data.ra["gene_ids_collapsed"]).items() if count > 1]
|
| 101 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
| 102 |
first_chunk = True
|
| 103 |
for _, _, view in tqdm(data.scan(axis = 1, batch_size = chunk_size), total = num_chunks):
|
| 104 |
def process_chunk(view, duplic_genes):
|
| 105 |
+
data_count_view = pd.DataFrame(view, index=data.ra["gene_ids_collapsed"])
|
| 106 |
unique_data_df = data_count_view.loc[~data_count_view.index.isin(duplic_genes)]
|
| 107 |
+
dup_data_df = data_count_view.loc[data_count_view.index.isin([i for i in duplic_genes if "None" not in i])]
|
| 108 |
summed_data = dup_data_df.groupby(dup_data_df.index).sum()
|
| 109 |
if not summed_data.index.is_unique:
|
| 110 |
raise ValueError("Error: Ensembl IDs in summed data frame non-unique.")
|
|
|
|
| 116 |
processed_array = processed_chunk.to_numpy()
|
| 117 |
new_row_attrs = {"ensembl_id": processed_chunk.index.to_numpy()}
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
if "n_counts" not in view.ca.keys():
|
| 120 |
total_count_view = np.sum(view[:,:], axis=0).astype(int)
|
| 121 |
view.ca["n_counts"] = total_count_view
|
|
|
|
| 256 |
with open(token_dictionary_file, "rb") as f:
|
| 257 |
self.gene_token_dict = pickle.load(f)
|
| 258 |
|
| 259 |
+
# check for special token in gene_token_dict
|
| 260 |
+
if self.special_token:
|
| 261 |
+
if ("<cls>" not in self.gene_token_dict.keys()) and ("<eos>" not in self.gene_token_dict.keys()):
|
| 262 |
+
logger.error(
|
| 263 |
+
"<cls> and <eos> required in gene_token_dict when special_token = True."
|
| 264 |
+
)
|
| 265 |
+
raise
|
| 266 |
+
|
| 267 |
# if collapsing duplicate gene IDs
|
| 268 |
self.collapse_gene_ids = collapse_gene_ids
|
| 269 |
|
|
|
|
| 278 |
self.gene_keys = list(self.gene_token_dict.keys())
|
| 279 |
|
| 280 |
# Filter gene mapping dict for items that exist in gene_token_dict
|
| 281 |
+
gene_keys_set = set(self.gene_token_dict.keys())
|
| 282 |
+
self.gene_mapping_dict = {k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set}
|
| 283 |
|
| 284 |
# protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
|
| 285 |
self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))
|