Update geneformer/emb_extractor.py
Browse filesCheck to make sure that all the emb_labels exist in the tokenized data before extracting embedding
geneformer/emb_extractor.py
CHANGED
|
@@ -411,7 +411,7 @@ class EmbExtractor:
|
|
| 411 |
self,
|
| 412 |
model_type="Pretrained",
|
| 413 |
num_classes=0,
|
| 414 |
-
emb_mode="
|
| 415 |
cell_emb_style="mean_pool",
|
| 416 |
gene_emb_style="mean_pool",
|
| 417 |
filter_data=None,
|
|
@@ -596,6 +596,12 @@ class EmbExtractor:
|
|
| 596 |
filtered_input_data = pu.load_and_filter(
|
| 597 |
self.filter_data, self.nproc, input_data_file
|
| 598 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 599 |
if cell_state is not None:
|
| 600 |
filtered_input_data = pu.filter_by_dict(
|
| 601 |
filtered_input_data, cell_state, self.nproc
|
|
@@ -719,12 +725,6 @@ class EmbExtractor:
|
|
| 719 |
)
|
| 720 |
raise
|
| 721 |
|
| 722 |
-
if self.emb_label is not None:
|
| 723 |
-
logger.error(
|
| 724 |
-
"For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
|
| 725 |
-
)
|
| 726 |
-
raise
|
| 727 |
-
|
| 728 |
state_embs_dict = dict()
|
| 729 |
state_key = cell_states_to_model["state_key"]
|
| 730 |
for k, v in cell_states_to_model.items():
|
|
|
|
| 411 |
self,
|
| 412 |
model_type="Pretrained",
|
| 413 |
num_classes=0,
|
| 414 |
+
emb_mode="cell",
|
| 415 |
cell_emb_style="mean_pool",
|
| 416 |
gene_emb_style="mean_pool",
|
| 417 |
filter_data=None,
|
|
|
|
| 596 |
filtered_input_data = pu.load_and_filter(
|
| 597 |
self.filter_data, self.nproc, input_data_file
|
| 598 |
)
|
| 599 |
+
|
| 600 |
+
# Check to make sure that all the labels exist in the tokenized data:
|
| 601 |
+
if self.emb_label is not None:
|
| 602 |
+
for label in self.emb_label:
|
| 603 |
+
assert label in list(filtered_input_data.features), f"Attribute `{label}` not present in dataset features"
|
| 604 |
+
|
| 605 |
if cell_state is not None:
|
| 606 |
filtered_input_data = pu.filter_by_dict(
|
| 607 |
filtered_input_data, cell_state, self.nproc
|
|
|
|
| 725 |
)
|
| 726 |
raise
|
| 727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 728 |
state_embs_dict = dict()
|
| 729 |
state_key = cell_states_to_model["state_key"]
|
| 730 |
for k, v in cell_states_to_model.items():
|