Update geneformer/tokenizer.py
Browse filesUpdate to use Ensembl ID mapped throughout
- geneformer/tokenizer.py +14 -30
geneformer/tokenizer.py
CHANGED
|
@@ -63,17 +63,6 @@ logger = logging.getLogger(__name__)
|
|
| 63 |
|
| 64 |
from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
| 65 |
|
| 66 |
-
def rename_attr(data_ra_or_ca, old_name, new_name):
|
| 67 |
-
""" Rename attributes
|
| 68 |
-
Args:
|
| 69 |
-
data_ra_or_ca: data as a record array or column attribute
|
| 70 |
-
old_name (str): old name of attribute
|
| 71 |
-
new_name (str): new name of attribute
|
| 72 |
-
"""
|
| 73 |
-
data_ra_or_ca[new_name] = data_ra_or_ca[old_name]
|
| 74 |
-
if new_name != old_name:
|
| 75 |
-
del data_ra_or_ca[old_name]
|
| 76 |
-
|
| 77 |
def rank_genes(gene_vector, gene_tokens):
|
| 78 |
"""
|
| 79 |
Rank gene expression vector.
|
|
@@ -131,18 +120,16 @@ def sum_ensembl_ids(
|
|
| 131 |
]
|
| 132 |
|
| 133 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
| 134 |
-
|
| 135 |
-
rename_attr(data.ra, "ensembl_id", "ensembl_id_original")
|
| 136 |
-
data.ra["ensembl_id"] = gene_ids_collapsed
|
| 137 |
return data_directory
|
| 138 |
else:
|
| 139 |
dedup_filename = data_directory.with_name(
|
| 140 |
data_directory.stem + "__dedup.loom"
|
| 141 |
)
|
| 142 |
-
data.ra["
|
| 143 |
dup_genes = [
|
| 144 |
idx
|
| 145 |
-
for idx, count in Counter(data.ra["
|
| 146 |
if count > 1
|
| 147 |
]
|
| 148 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
|
@@ -153,7 +140,7 @@ def sum_ensembl_ids(
|
|
| 153 |
|
| 154 |
def process_chunk(view, duplic_genes):
|
| 155 |
data_count_view = pd.DataFrame(
|
| 156 |
-
view, index=data.ra["
|
| 157 |
)
|
| 158 |
unique_data_df = data_count_view.loc[
|
| 159 |
~data_count_view.index.isin(duplic_genes)
|
|
@@ -179,7 +166,7 @@ def sum_ensembl_ids(
|
|
| 179 |
|
| 180 |
processed_chunk = process_chunk(view[:, :], dup_genes)
|
| 181 |
processed_array = processed_chunk.to_numpy()
|
| 182 |
-
new_row_attrs = {"
|
| 183 |
|
| 184 |
if "n_counts" not in view.ca.keys():
|
| 185 |
total_count_view = np.sum(view[:, :], axis=0).astype(int)
|
|
@@ -230,11 +217,11 @@ def sum_ensembl_ids(
|
|
| 230 |
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
| 231 |
]
|
| 232 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
| 233 |
-
data.var
|
| 234 |
return data
|
| 235 |
|
| 236 |
else:
|
| 237 |
-
data.var["
|
| 238 |
data.var_names = gene_ids_collapsed
|
| 239 |
data = data[:, ~data.var.index.isna()]
|
| 240 |
dup_genes = [
|
|
@@ -265,16 +252,13 @@ def sum_ensembl_ids(
|
|
| 265 |
processed_chunks = pd.concat(processed_chunks, axis=1)
|
| 266 |
processed_genes.append(processed_chunks)
|
| 267 |
processed_genes = pd.concat(processed_genes, axis=0)
|
| 268 |
-
var_df = pd.DataFrame({"
|
| 269 |
var_df.index = processed_genes.columns
|
| 270 |
processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
|
| 271 |
|
| 272 |
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
| 273 |
data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
|
| 274 |
data_dedup.obs = data.obs
|
| 275 |
-
data_dedup.var = data_dedup.var.rename(
|
| 276 |
-
columns={"gene_ids_collapsed": "ensembl_id"}
|
| 277 |
-
)
|
| 278 |
return data_dedup
|
| 279 |
|
| 280 |
|
|
@@ -474,15 +458,15 @@ class TranscriptomeTokenizer:
|
|
| 474 |
}
|
| 475 |
|
| 476 |
coding_miRNA_loc = np.where(
|
| 477 |
-
[self.genelist_dict.get(i, False) for i in adata.var["
|
| 478 |
)[0]
|
| 479 |
norm_factor_vector = np.array(
|
| 480 |
[
|
| 481 |
self.gene_median_dict[i]
|
| 482 |
-
for i in adata.var["
|
| 483 |
]
|
| 484 |
)
|
| 485 |
-
coding_miRNA_ids = adata.var["
|
| 486 |
coding_miRNA_tokens = np.array(
|
| 487 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
| 488 |
)
|
|
@@ -546,15 +530,15 @@ class TranscriptomeTokenizer:
|
|
| 546 |
with lp.connect(str(loom_file_path)) as data:
|
| 547 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
| 548 |
coding_miRNA_loc = np.where(
|
| 549 |
-
[self.genelist_dict.get(i, False) for i in data.ra["
|
| 550 |
)[0]
|
| 551 |
norm_factor_vector = np.array(
|
| 552 |
[
|
| 553 |
self.gene_median_dict[i]
|
| 554 |
-
for i in data.ra["
|
| 555 |
]
|
| 556 |
)
|
| 557 |
-
coding_miRNA_ids = data.ra["
|
| 558 |
coding_miRNA_tokens = np.array(
|
| 559 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
| 560 |
)
|
|
|
|
| 63 |
|
| 64 |
from . import ENSEMBL_MAPPING_FILE, GENE_MEDIAN_FILE, TOKEN_DICTIONARY_FILE
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
def rank_genes(gene_vector, gene_tokens):
|
| 67 |
"""
|
| 68 |
Rank gene expression vector.
|
|
|
|
| 120 |
]
|
| 121 |
|
| 122 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
| 123 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
|
|
|
|
|
|
| 124 |
return data_directory
|
| 125 |
else:
|
| 126 |
dedup_filename = data_directory.with_name(
|
| 127 |
data_directory.stem + "__dedup.loom"
|
| 128 |
)
|
| 129 |
+
data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
|
| 130 |
dup_genes = [
|
| 131 |
idx
|
| 132 |
+
for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
|
| 133 |
if count > 1
|
| 134 |
]
|
| 135 |
num_chunks = int(np.ceil(data.shape[1] / chunk_size))
|
|
|
|
| 140 |
|
| 141 |
def process_chunk(view, duplic_genes):
|
| 142 |
data_count_view = pd.DataFrame(
|
| 143 |
+
view, index=data.ra["ensembl_id_collapsed"]
|
| 144 |
)
|
| 145 |
unique_data_df = data_count_view.loc[
|
| 146 |
~data_count_view.index.isin(duplic_genes)
|
|
|
|
| 166 |
|
| 167 |
processed_chunk = process_chunk(view[:, :], dup_genes)
|
| 168 |
processed_array = processed_chunk.to_numpy()
|
| 169 |
+
new_row_attrs = {"ensembl_id_collapsed": processed_chunk.index.to_numpy()}
|
| 170 |
|
| 171 |
if "n_counts" not in view.ca.keys():
|
| 172 |
total_count_view = np.sum(view[:, :], axis=0).astype(int)
|
|
|
|
| 217 |
gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
|
| 218 |
]
|
| 219 |
if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
|
| 220 |
+
data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
|
| 221 |
return data
|
| 222 |
|
| 223 |
else:
|
| 224 |
+
data.var["ensembl_id_collapsed"] = gene_ids_collapsed
|
| 225 |
data.var_names = gene_ids_collapsed
|
| 226 |
data = data[:, ~data.var.index.isna()]
|
| 227 |
dup_genes = [
|
|
|
|
| 252 |
processed_chunks = pd.concat(processed_chunks, axis=1)
|
| 253 |
processed_genes.append(processed_chunks)
|
| 254 |
processed_genes = pd.concat(processed_genes, axis=0)
|
| 255 |
+
var_df = pd.DataFrame({"ensembl_id_collapsed": processed_genes.columns})
|
| 256 |
var_df.index = processed_genes.columns
|
| 257 |
processed_genes = sc.AnnData(X=processed_genes, obs=data.obs, var=var_df)
|
| 258 |
|
| 259 |
data_dedup = data[:, ~data.var.index.isin(dup_genes)] # Deduplicated data
|
| 260 |
data_dedup = sc.concat([data_dedup, processed_genes], axis=1)
|
| 261 |
data_dedup.obs = data.obs
|
|
|
|
|
|
|
|
|
|
| 262 |
return data_dedup
|
| 263 |
|
| 264 |
|
|
|
|
| 458 |
}
|
| 459 |
|
| 460 |
coding_miRNA_loc = np.where(
|
| 461 |
+
[self.genelist_dict.get(i, False) for i in adata.var["ensembl_id_collapsed"]]
|
| 462 |
)[0]
|
| 463 |
norm_factor_vector = np.array(
|
| 464 |
[
|
| 465 |
self.gene_median_dict[i]
|
| 466 |
+
for i in adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
|
| 467 |
]
|
| 468 |
)
|
| 469 |
+
coding_miRNA_ids = adata.var["ensembl_id_collapsed"][coding_miRNA_loc]
|
| 470 |
coding_miRNA_tokens = np.array(
|
| 471 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
| 472 |
)
|
|
|
|
| 530 |
with lp.connect(str(loom_file_path)) as data:
|
| 531 |
# define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors
|
| 532 |
coding_miRNA_loc = np.where(
|
| 533 |
+
[self.genelist_dict.get(i, False) for i in data.ra["ensembl_id_collapsed"]]
|
| 534 |
)[0]
|
| 535 |
norm_factor_vector = np.array(
|
| 536 |
[
|
| 537 |
self.gene_median_dict[i]
|
| 538 |
+
for i in data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
|
| 539 |
]
|
| 540 |
)
|
| 541 |
+
coding_miRNA_ids = data.ra["ensembl_id_collapsed"][coding_miRNA_loc]
|
| 542 |
coding_miRNA_tokens = np.array(
|
| 543 |
[self.gene_token_dict[i] for i in coding_miRNA_ids]
|
| 544 |
)
|