Added feature to perturb a set of indices to help with debugging and with very large runtimes
#175
by
davidjwen
- opened
geneformer/in_silico_perturber.py
CHANGED
|
@@ -604,6 +604,7 @@ class InSilicoPerturber:
|
|
| 604 |
"filter_data": {None, dict},
|
| 605 |
"cell_states_to_model": {None, dict},
|
| 606 |
"max_ncells": {None, int},
|
|
|
|
| 607 |
"emb_layer": {-1, 0},
|
| 608 |
"forward_batch_size": {int},
|
| 609 |
"nproc": {int},
|
|
@@ -622,6 +623,7 @@ class InSilicoPerturber:
|
|
| 622 |
filter_data=None,
|
| 623 |
cell_states_to_model=None,
|
| 624 |
max_ncells=None,
|
|
|
|
| 625 |
emb_layer=-1,
|
| 626 |
forward_batch_size=100,
|
| 627 |
nproc=4,
|
|
@@ -687,6 +689,13 @@ class InSilicoPerturber:
|
|
| 687 |
max_ncells : None, int
|
| 688 |
Maximum number of cells to test.
|
| 689 |
If None, will test all cells.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 690 |
emb_layer : {-1, 0}
|
| 691 |
Embedding layer to use for quantification.
|
| 692 |
-1: 2nd to last layer (recommended for pretrained Geneformer)
|
|
@@ -723,6 +732,7 @@ class InSilicoPerturber:
|
|
| 723 |
self.filter_data = filter_data
|
| 724 |
self.cell_states_to_model = cell_states_to_model
|
| 725 |
self.max_ncells = max_ncells
|
|
|
|
| 726 |
self.emb_layer = emb_layer
|
| 727 |
self.forward_batch_size = forward_batch_size
|
| 728 |
self.nproc = nproc
|
|
@@ -886,7 +896,7 @@ class InSilicoPerturber:
|
|
| 886 |
if self.perturb_type in ["inhibit","activate"]:
|
| 887 |
if self.perturb_rank_shift is None:
|
| 888 |
logger.error(
|
| 889 |
-
"If
|
| 890 |
"quartile to shift by must be specified.")
|
| 891 |
raise
|
| 892 |
|
|
@@ -897,6 +907,18 @@ class InSilicoPerturber:
|
|
| 897 |
logger.warning(
|
| 898 |
"Values in filter_data dict must be lists. " \
|
| 899 |
f"Changing {key} value to list ([{value}]).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
|
| 901 |
def perturb_data(self,
|
| 902 |
model_directory,
|
|
@@ -995,6 +1017,15 @@ class InSilicoPerturber:
|
|
| 995 |
cos_sims_dict = defaultdict(list)
|
| 996 |
pickle_batch = -1
|
| 997 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 998 |
|
| 999 |
# make perturbation batch w/ single perturbation in multiple cells
|
| 1000 |
if self.perturb_group == True:
|
|
|
|
| 604 |
"filter_data": {None, dict},
|
| 605 |
"cell_states_to_model": {None, dict},
|
| 606 |
"max_ncells": {None, int},
|
| 607 |
+
"cell_inds_to_perturb": {"all", dict},
|
| 608 |
"emb_layer": {-1, 0},
|
| 609 |
"forward_batch_size": {int},
|
| 610 |
"nproc": {int},
|
|
|
|
| 623 |
filter_data=None,
|
| 624 |
cell_states_to_model=None,
|
| 625 |
max_ncells=None,
|
| 626 |
+
cell_inds_to_perturb="all",
|
| 627 |
emb_layer=-1,
|
| 628 |
forward_batch_size=100,
|
| 629 |
nproc=4,
|
|
|
|
| 689 |
max_ncells : None, int
|
| 690 |
Maximum number of cells to test.
|
| 691 |
If None, will test all cells.
|
| 692 |
+
cell_inds_to_perturb : "all", list
|
| 693 |
+
Default is perturbing each cell in the dataset.
|
| 694 |
+
Otherwise, may provide a dict of indices of cells to perturb with keys start_ind and end_ind.
|
| 695 |
+
start_ind: the first index to perturb.
|
| 696 |
+
end_ind: the last index to perturb (exclusive).
|
| 697 |
+
Indices will be selected *after* the filter_data criteria and sorting.
|
| 698 |
+
Useful for splitting extremely large datasets across separate GPUs.
|
| 699 |
emb_layer : {-1, 0}
|
| 700 |
Embedding layer to use for quantification.
|
| 701 |
-1: 2nd to last layer (recommended for pretrained Geneformer)
|
|
|
|
| 732 |
self.filter_data = filter_data
|
| 733 |
self.cell_states_to_model = cell_states_to_model
|
| 734 |
self.max_ncells = max_ncells
|
| 735 |
+
self.cell_inds_to_perturb = cell_inds_to_perturb
|
| 736 |
self.emb_layer = emb_layer
|
| 737 |
self.forward_batch_size = forward_batch_size
|
| 738 |
self.nproc = nproc
|
|
|
|
| 896 |
if self.perturb_type in ["inhibit","activate"]:
|
| 897 |
if self.perturb_rank_shift is None:
|
| 898 |
logger.error(
|
| 899 |
+
"If perturb_type is inhibit or activate then " \
|
| 900 |
"quartile to shift by must be specified.")
|
| 901 |
raise
|
| 902 |
|
|
|
|
| 907 |
logger.warning(
|
| 908 |
"Values in filter_data dict must be lists. " \
|
| 909 |
f"Changing {key} value to list ([{value}]).")
|
| 910 |
+
|
| 911 |
+
if self.cell_inds_to_perturb != "all":
|
| 912 |
+
if set(self.cell_inds_to_perturb.keys()) != {"start", "end"}:
|
| 913 |
+
logger.error(
|
| 914 |
+
"If cell_inds_to_perturb is a dictionary, keys must be 'start' and 'end'."
|
| 915 |
+
)
|
| 916 |
+
raise
|
| 917 |
+
if self.cell_inds_to_perturb["start"] < 0 or self.cell_inds_to_perturb["end"] < 0:
|
| 918 |
+
logger.error(
|
| 919 |
+
'cell_inds_to_perturb must be positive.'
|
| 920 |
+
)
|
| 921 |
+
raise
|
| 922 |
|
| 923 |
def perturb_data(self,
|
| 924 |
model_directory,
|
|
|
|
| 1017 |
cos_sims_dict = defaultdict(list)
|
| 1018 |
pickle_batch = -1
|
| 1019 |
filtered_input_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
| 1020 |
+
if self.cell_inds_to_perturb != "all":
|
| 1021 |
+
if self.cell_inds_to_perturb["start"] >= len(filtered_input_data):
|
| 1022 |
+
logger.error("cell_inds_to_perturb['start'] is larger than the filtered dataset.")
|
| 1023 |
+
raise
|
| 1024 |
+
if self.cell_inds_to_perturb["end"] > len(filtered_input_data):
|
| 1025 |
+
logger.warning("cell_inds_to_perturb['end'] is larger than the filtered dataset. \
|
| 1026 |
+
Setting to the end of the filtered dataset.")
|
| 1027 |
+
self.cell_inds_to_perturb["end"] = len(filtered_input_data)
|
| 1028 |
+
filtered_input_data = filtered_input_data.select([i for i in range(self.cell_inds_to_perturb["start"], self.cell_inds_to_perturb["end"])])
|
| 1029 |
|
| 1030 |
# make perturbation batch w/ single perturbation in multiple cells
|
| 1031 |
if self.perturb_group == True:
|