Update geneformer/in_silico_perturber.py
Browse files
geneformer/in_silico_perturber.py
CHANGED
|
@@ -604,7 +604,7 @@ class InSilicoPerturber:
|
|
| 604 |
"filter_data": {None, dict},
|
| 605 |
"cell_states_to_model": {None, dict},
|
| 606 |
"max_ncells": {None, int},
|
| 607 |
-
"
|
| 608 |
"emb_layer": {-1, 0},
|
| 609 |
"forward_batch_size": {int},
|
| 610 |
"nproc": {int},
|
|
@@ -623,7 +623,7 @@ class InSilicoPerturber:
|
|
| 623 |
filter_data=None,
|
| 624 |
cell_states_to_model=None,
|
| 625 |
max_ncells=None,
|
| 626 |
-
|
| 627 |
emb_layer=-1,
|
| 628 |
forward_batch_size=100,
|
| 629 |
nproc=4,
|
|
@@ -689,9 +689,9 @@ class InSilicoPerturber:
|
|
| 689 |
max_ncells : None, int
|
| 690 |
Maximum number of cells to test.
|
| 691 |
If None, will test all cells.
|
| 692 |
-
|
| 693 |
Default is perturbing each cell in the dataset.
|
| 694 |
-
Otherwise, may provide a dict of indices of
|
| 695 |
start_ind: the first index to perturb.
|
| 696 |
end_ind: the last index to perturb (exclusive).
|
| 697 |
Indices will be selected *after* the filter_data criteria and sorting.
|
|
@@ -732,7 +732,7 @@ class InSilicoPerturber:
|
|
| 732 |
self.filter_data = filter_data
|
| 733 |
self.cell_states_to_model = cell_states_to_model
|
| 734 |
self.max_ncells = max_ncells
|
| 735 |
-
self.
|
| 736 |
self.emb_layer = emb_layer
|
| 737 |
self.forward_batch_size = forward_batch_size
|
| 738 |
self.nproc = nproc
|
|
@@ -908,15 +908,15 @@ class InSilicoPerturber:
|
|
| 908 |
"Values in filter_data dict must be lists. " \
|
| 909 |
f"Changing {key} value to list ([{value}]).")
|
| 910 |
|
| 911 |
-
if self.
|
| 912 |
-
if set(self.
|
| 913 |
logger.error(
|
| 914 |
-
"If
|
| 915 |
)
|
| 916 |
raise
|
| 917 |
-
if self.
|
| 918 |
logger.error(
|
| 919 |
-
'
|
| 920 |
)
|
| 921 |
raise
|
| 922 |
|
|
@@ -1017,15 +1017,15 @@ class InSilicoPerturber:
|
|
| 1017 |
cos_sims_dict = defaultdict(list)
|
| 1018 |
pickle_batch = -1
|
| 1019 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
| 1020 |
-
if self.
|
| 1021 |
-
if self.
|
| 1022 |
-
logger.error("
|
| 1023 |
raise
|
| 1024 |
-
if self.
|
| 1025 |
-
logger.warning("
|
| 1026 |
Setting to the end of the filtered dataset.")
|
| 1027 |
-
self.
|
| 1028 |
-
filtered_input_data = filtered_input_data.select([i for i in range(self.
|
| 1029 |
|
| 1030 |
# make perturbation batch w/ single perturbation in multiple cells
|
| 1031 |
if self.perturb_group == True:
|
|
|
|
| 604 |
"filter_data": {None, dict},
|
| 605 |
"cell_states_to_model": {None, dict},
|
| 606 |
"max_ncells": {None, int},
|
| 607 |
+
"cell_inds_to_perturb": {"all", dict},
|
| 608 |
"emb_layer": {-1, 0},
|
| 609 |
"forward_batch_size": {int},
|
| 610 |
"nproc": {int},
|
|
|
|
| 623 |
filter_data=None,
|
| 624 |
cell_states_to_model=None,
|
| 625 |
max_ncells=None,
|
| 626 |
+
cell_inds_to_perturb="all",
|
| 627 |
emb_layer=-1,
|
| 628 |
forward_batch_size=100,
|
| 629 |
nproc=4,
|
|
|
|
| 689 |
max_ncells : None, int
|
| 690 |
Maximum number of cells to test.
|
| 691 |
If None, will test all cells.
|
| 692 |
+
cell_inds_to_perturb : "all", list
|
| 693 |
Default is perturbing each cell in the dataset.
|
| 694 |
+
Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
|
| 695 |
start_ind: the first index to perturb.
|
| 696 |
end_ind: the last index to perturb (exclusive).
|
| 697 |
Indices will be selected *after* the filter_data criteria and sorting.
|
|
|
|
| 732 |
self.filter_data = filter_data
|
| 733 |
self.cell_states_to_model = cell_states_to_model
|
| 734 |
self.max_ncells = max_ncells
|
| 735 |
+
self.cell_inds_to_perturb = cell_inds_to_perturb
|
| 736 |
self.emb_layer = emb_layer
|
| 737 |
self.forward_batch_size = forward_batch_size
|
| 738 |
self.nproc = nproc
|
|
|
|
| 908 |
"Values in filter_data dict must be lists. " \
|
| 909 |
f"Changing {key} value to list ([{value}]).")
|
| 910 |
|
| 911 |
+
if self.cell_inds_to_perturb != "all":
|
| 912 |
+
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
| 913 |
logger.error(
|
| 914 |
+
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
| 915 |
)
|
| 916 |
raise
|
| 917 |
+
if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
|
| 918 |
logger.error(
|
| 919 |
+
'cell_inds_to_perturb must be positive.'
|
| 920 |
)
|
| 921 |
raise
|
| 922 |
|
|
|
|
| 1017 |
cos_sims_dict = defaultdict(list)
|
| 1018 |
pickle_batch = -1
|
| 1019 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
| 1020 |
+
if self.cell_inds_to_perturb != "all":
|
| 1021 |
+
if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
|
| 1022 |
+
logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
|
| 1023 |
raise
|
| 1024 |
+
if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
|
| 1025 |
+
logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
|
| 1026 |
Setting to the end of the filtered dataset.")
|
| 1027 |
+
self.cell_inds_to_perturb["end"] = len(filtered_input_data)
|
| 1028 |
+
filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
|
| 1029 |
|
| 1030 |
# make perturbation batch w/ single perturbation in multiple cells
|
| 1031 |
if self.perturb_group == True:
|