Christina Theodoris
commited on
Commit
·
fb6abe0
1
Parent(s):
5cada60
add emb extractor option for saving all gene embs
Browse files- geneformer/emb_extractor.py +60 -12
geneformer/emb_extractor.py
CHANGED
|
@@ -42,6 +42,8 @@ def get_embs(
|
|
| 42 |
special_token=False,
|
| 43 |
summary_stat=None,
|
| 44 |
silent=False,
|
|
|
|
|
|
|
| 45 |
):
|
| 46 |
model_input_size = pu.get_model_input_size(model)
|
| 47 |
total_batch_length = len(filtered_input_data)
|
|
@@ -180,12 +182,18 @@ def get_embs(
|
|
| 180 |
# calculate summary stat embs from approximated tdigests
|
| 181 |
elif summary_stat is not None:
|
| 182 |
if emb_mode == "cell":
|
|
|
|
|
|
|
|
|
|
| 183 |
if summary_stat == "mean":
|
| 184 |
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 185 |
elif summary_stat == "median":
|
| 186 |
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 187 |
embs_stack = torch.tensor(summary_emb_list)
|
| 188 |
elif emb_mode == "gene":
|
|
|
|
|
|
|
|
|
|
| 189 |
if summary_stat == "mean":
|
| 190 |
[
|
| 191 |
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
|
@@ -252,7 +260,7 @@ def label_cell_embs(embs, downsampled_data, emb_labels):
|
|
| 252 |
return embs_df
|
| 253 |
|
| 254 |
|
| 255 |
-
def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
| 256 |
gene_set = {
|
| 257 |
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 258 |
}
|
|
@@ -267,16 +275,39 @@ def label_gene_embs(embs, downsampled_data, token_gene_dict):
|
|
| 267 |
)
|
| 268 |
for k in dict_i.keys():
|
| 269 |
gene_emb_dict[k].append(dict_i[k])
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
| 277 |
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 278 |
return embs_df
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
| 282 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
|
@@ -404,7 +435,7 @@ class EmbExtractor:
|
|
| 404 |
"num_classes": {int},
|
| 405 |
"emb_mode": {"cls", "cell", "gene"},
|
| 406 |
"cell_emb_style": {"mean_pool"},
|
| 407 |
-
"gene_emb_style": {"mean_pool"},
|
| 408 |
"filter_data": {None, dict},
|
| 409 |
"max_ncells": {None, int},
|
| 410 |
"emb_layer": {-1, 0},
|
|
@@ -432,6 +463,7 @@ class EmbExtractor:
|
|
| 432 |
forward_batch_size=100,
|
| 433 |
nproc=4,
|
| 434 |
summary_stat=None,
|
|
|
|
| 435 |
model_version="V2",
|
| 436 |
token_dictionary_file=None,
|
| 437 |
):
|
|
@@ -451,9 +483,9 @@ class EmbExtractor:
|
|
| 451 |
cell_emb_style : {"mean_pool"}
|
| 452 |
| Method for summarizing cell embeddings if not using CLS token.
|
| 453 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 454 |
-
gene_emb_style : "mean_pool"
|
| 455 |
| Method for summarizing gene embeddings.
|
| 456 |
-
| Currently only option is mean pooling of contextual gene embeddings for given gene.
|
| 457 |
filter_data : None, dict
|
| 458 |
| Default is to extract embeddings from all input data.
|
| 459 |
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
@@ -483,6 +515,9 @@ class EmbExtractor:
|
|
| 483 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 484 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 485 |
| Non-exact is slower but more memory-efficient.
|
|
|
|
|
|
|
|
|
|
| 486 |
model_version : str
|
| 487 |
| To auto-select settings for model version other than current default.
|
| 488 |
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
|
@@ -526,9 +561,16 @@ class EmbExtractor:
|
|
| 526 |
else:
|
| 527 |
self.summary_stat = summary_stat
|
| 528 |
self.exact_summary_stat = None
|
|
|
|
| 529 |
|
| 530 |
self.validate_options()
|
| 531 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
if self.model_version == "V1":
|
| 533 |
from . import TOKEN_DICTIONARY_FILE_30M
|
| 534 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
|
@@ -635,6 +677,10 @@ class EmbExtractor:
|
|
| 635 |
self.model_type, self.num_classes, model_directory, mode="eval"
|
| 636 |
)
|
| 637 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
embs = get_embs(
|
| 639 |
model=model,
|
| 640 |
filtered_input_data=downsampled_data,
|
|
@@ -644,6 +690,8 @@ class EmbExtractor:
|
|
| 644 |
forward_batch_size=self.forward_batch_size,
|
| 645 |
token_gene_dict=self.token_gene_dict,
|
| 646 |
summary_stat=self.summary_stat,
|
|
|
|
|
|
|
| 647 |
)
|
| 648 |
|
| 649 |
if self.emb_mode == "cell":
|
|
@@ -653,7 +701,7 @@ class EmbExtractor:
|
|
| 653 |
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 654 |
elif self.emb_mode == "gene":
|
| 655 |
if self.summary_stat is None:
|
| 656 |
-
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict)
|
| 657 |
elif self.summary_stat is not None:
|
| 658 |
embs_df = pd.DataFrame(embs).T
|
| 659 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
| 42 |
special_token=False,
|
| 43 |
summary_stat=None,
|
| 44 |
silent=False,
|
| 45 |
+
save_tdigest=False,
|
| 46 |
+
tdigest_path=None,
|
| 47 |
):
|
| 48 |
model_input_size = pu.get_model_input_size(model)
|
| 49 |
total_batch_length = len(filtered_input_data)
|
|
|
|
| 182 |
# calculate summary stat embs from approximated tdigests
|
| 183 |
elif summary_stat is not None:
|
| 184 |
if emb_mode == "cell":
|
| 185 |
+
if save_tdigest:
|
| 186 |
+
with open(f"{tdigest_path}","wb") as fp:
|
| 187 |
+
pickle.dump(embs_tdigests, fp)
|
| 188 |
if summary_stat == "mean":
|
| 189 |
summary_emb_list = tdigest_mean(embs_tdigests, emb_dims)
|
| 190 |
elif summary_stat == "median":
|
| 191 |
summary_emb_list = tdigest_median(embs_tdigests, emb_dims)
|
| 192 |
embs_stack = torch.tensor(summary_emb_list)
|
| 193 |
elif emb_mode == "gene":
|
| 194 |
+
if save_tdigest:
|
| 195 |
+
with open(f"{tdigest_path}","wb") as fp:
|
| 196 |
+
pickle.dump(embs_tdigests_dict, fp)
|
| 197 |
if summary_stat == "mean":
|
| 198 |
[
|
| 199 |
update_tdigest_dict_mean(embs_tdigests_dict, gene, emb_dims)
|
|
|
|
| 260 |
return embs_df
|
| 261 |
|
| 262 |
|
| 263 |
+
def label_gene_embs(embs, downsampled_data, token_gene_dict, gene_emb_style="mean_pool"):
|
| 264 |
gene_set = {
|
| 265 |
element for sublist in downsampled_data["input_ids"] for element in sublist
|
| 266 |
}
|
|
|
|
| 275 |
)
|
| 276 |
for k in dict_i.keys():
|
| 277 |
gene_emb_dict[k].append(dict_i[k])
|
| 278 |
+
if gene_emb_style != "all":
|
| 279 |
+
for k in gene_emb_dict.keys():
|
| 280 |
+
gene_emb_dict[k] = (
|
| 281 |
+
torch.squeeze(torch.mean(torch.stack(gene_emb_dict[k]), dim=0), dim=0)
|
| 282 |
+
.cpu()
|
| 283 |
+
.numpy()
|
| 284 |
+
)
|
| 285 |
+
embs_df = pd.DataFrame(gene_emb_dict).T
|
| 286 |
+
else:
|
| 287 |
+
embs_df = dict_lol_to_df(gene_emb_dict)
|
| 288 |
embs_df.index = [token_gene_dict[token] for token in embs_df.index]
|
| 289 |
return embs_df
|
| 290 |
|
| 291 |
+
def dict_lol_to_df(data_dict):
|
| 292 |
+
# save dictionary with values being list of equal-length lists as dataframe
|
| 293 |
+
df_data = []
|
| 294 |
+
for key, list_of_lists in data_dict.items():
|
| 295 |
+
for i, sublist in enumerate(list_of_lists):
|
| 296 |
+
row_data = [key, i] + sublist.tolist()
|
| 297 |
+
df_data.append(row_data)
|
| 298 |
+
|
| 299 |
+
# determine column names based on the length of sublists
|
| 300 |
+
# assuming all sublists have the same length
|
| 301 |
+
num_columns_from_sublist = len(list(data_dict.values())[0][0])
|
| 302 |
+
column_names = ['Gene', 'Identifier'] + [f'{j}' for j in range(num_columns_from_sublist)]
|
| 303 |
+
|
| 304 |
+
# create the dataframe
|
| 305 |
+
df = pd.DataFrame(df_data, columns=column_names)
|
| 306 |
+
|
| 307 |
+
# set 'Gene' as the index
|
| 308 |
+
df = df.set_index('Gene')
|
| 309 |
+
|
| 310 |
+
return df
|
| 311 |
|
| 312 |
def plot_umap(embs_df, emb_dims, labels_clean, output_prefix, output_directory, kwargs_dict, seed=0):
|
| 313 |
only_embs_df = embs_df.iloc[:, :emb_dims]
|
|
|
|
| 435 |
"num_classes": {int},
|
| 436 |
"emb_mode": {"cls", "cell", "gene"},
|
| 437 |
"cell_emb_style": {"mean_pool"},
|
| 438 |
+
"gene_emb_style": {"mean_pool", "all"},
|
| 439 |
"filter_data": {None, dict},
|
| 440 |
"max_ncells": {None, int},
|
| 441 |
"emb_layer": {-1, 0},
|
|
|
|
| 463 |
forward_batch_size=100,
|
| 464 |
nproc=4,
|
| 465 |
summary_stat=None,
|
| 466 |
+
save_tdigest=False,
|
| 467 |
model_version="V2",
|
| 468 |
token_dictionary_file=None,
|
| 469 |
):
|
|
|
|
| 483 |
cell_emb_style : {"mean_pool"}
|
| 484 |
| Method for summarizing cell embeddings if not using CLS token.
|
| 485 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
| 486 |
+
gene_emb_style : {"mean_pool", "all}
|
| 487 |
| Method for summarizing gene embeddings.
|
| 488 |
+
| Currently only option is returning all or mean pooling of contextual gene embeddings for given gene.
|
| 489 |
filter_data : None, dict
|
| 490 |
| Default is to extract embeddings from all input data.
|
| 491 |
| Otherwise, dictionary specifying .dataset column name and list of values to filter by.
|
|
|
|
| 515 |
| If mean or median, outputs only approximated mean or median embedding of input data.
|
| 516 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
| 517 |
| Non-exact is slower but more memory-efficient.
|
| 518 |
+
save_tdigest : bool
|
| 519 |
+
| Whether to save a dictionary of tdigests for each gene and embedding dimension
|
| 520 |
+
| Only applies when summary_stat is not None
|
| 521 |
model_version : str
|
| 522 |
| To auto-select settings for model version other than current default.
|
| 523 |
| Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
|
|
|
|
| 561 |
else:
|
| 562 |
self.summary_stat = summary_stat
|
| 563 |
self.exact_summary_stat = None
|
| 564 |
+
self.save_tdigest = save_tdigest
|
| 565 |
|
| 566 |
self.validate_options()
|
| 567 |
|
| 568 |
+
if (summary_stat is None) and (save_tdigest is True):
|
| 569 |
+
logger.warning(
|
| 570 |
+
"tdigests will not be saved since summary_stat is None."
|
| 571 |
+
)
|
| 572 |
+
save_tdigest = False
|
| 573 |
+
|
| 574 |
if self.model_version == "V1":
|
| 575 |
from . import TOKEN_DICTIONARY_FILE_30M
|
| 576 |
self.token_dictionary_file = TOKEN_DICTIONARY_FILE_30M
|
|
|
|
| 677 |
self.model_type, self.num_classes, model_directory, mode="eval"
|
| 678 |
)
|
| 679 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
| 680 |
+
if self.save_tdigest:
|
| 681 |
+
tdigest_path = (Path(output_directory) / f"{output_prefix}_tdigest").with_suffix(".pkl")
|
| 682 |
+
else:
|
| 683 |
+
tdigest_path = None
|
| 684 |
embs = get_embs(
|
| 685 |
model=model,
|
| 686 |
filtered_input_data=downsampled_data,
|
|
|
|
| 690 |
forward_batch_size=self.forward_batch_size,
|
| 691 |
token_gene_dict=self.token_gene_dict,
|
| 692 |
summary_stat=self.summary_stat,
|
| 693 |
+
save_tdigest=self.save_tdigest,
|
| 694 |
+
tdigest_path=tdigest_path,
|
| 695 |
)
|
| 696 |
|
| 697 |
if self.emb_mode == "cell":
|
|
|
|
| 701 |
embs_df = pd.DataFrame(embs.cpu().numpy()).T
|
| 702 |
elif self.emb_mode == "gene":
|
| 703 |
if self.summary_stat is None:
|
| 704 |
+
embs_df = label_gene_embs(embs, downsampled_data, self.token_gene_dict, self.gene_emb_style)
|
| 705 |
elif self.summary_stat is not None:
|
| 706 |
embs_df = pd.DataFrame(embs).T
|
| 707 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|