Christina Theodoris
commited on
Commit
·
feeecd0
1
Parent(s):
dc1481d
Add function to create remainder emb for in silico overexpression batch
Browse files
geneformer/in_silico_perturber.py
CHANGED
|
@@ -140,6 +140,18 @@ def make_comparison_batch(original_emb, indices_to_perturb):
|
|
| 140 |
all_embs_list += [torch.cat(emb_list)]
|
| 141 |
return torch.stack(all_embs_list)
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
# average embedding position of goal cell states
|
| 144 |
def get_cell_state_avg_embs(model,
|
| 145 |
filtered_input_data,
|
|
@@ -188,6 +200,7 @@ def get_cell_state_avg_embs(model,
|
|
| 188 |
|
| 189 |
# quantify cosine similarity of perturbed vs original or alternate states
|
| 190 |
def quant_cos_sims(model,
|
|
|
|
| 191 |
perturbation_batch,
|
| 192 |
forward_batch_size,
|
| 193 |
layer_to_quant,
|
|
@@ -226,8 +239,14 @@ def quant_cos_sims(model,
|
|
| 226 |
minibatch_emb = outputs.hidden_states[layer_to_quant]
|
| 227 |
if cell_states_to_model is None:
|
| 228 |
minibatch_comparison = comparison_batch[i:max_range]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
| 230 |
-
|
| 231 |
for state in possible_states:
|
| 232 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
|
| 233 |
del outputs
|
|
@@ -279,9 +298,9 @@ def pad_tensor_list(tensor_list, dynamic_or_constant, token_dictionary):
|
|
| 279 |
class InSilicoPerturber:
|
| 280 |
valid_option_dict = {
|
| 281 |
"perturb_type": {"delete","overexpress","inhibit","activate"},
|
| 282 |
-
"perturb_rank_shift": {None,
|
| 283 |
"genes_to_perturb": {"all", list},
|
| 284 |
-
"combos": {0,1,2},
|
| 285 |
"anchor_gene": {None, str},
|
| 286 |
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
| 287 |
"num_classes": {int},
|
|
@@ -326,7 +345,7 @@ class InSilicoPerturber:
|
|
| 326 |
"overexpress": move gene to front of rank value encoding
|
| 327 |
"inhibit": move gene to lower quartile of rank value encoding
|
| 328 |
"activate": move gene to higher quartile of rank value encoding
|
| 329 |
-
perturb_rank_shift : None,
|
| 330 |
Number of quartiles by which to shift rank of gene.
|
| 331 |
For example, if perturb_type="activate" and perturb_rank_shift=1:
|
| 332 |
genes in 4th quartile will move to middle of 3rd quartile.
|
|
@@ -414,6 +433,15 @@ class InSilicoPerturber:
|
|
| 414 |
self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
|
| 415 |
|
| 416 |
def validate_options(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
for attr_name,valid_options in self.valid_option_dict.items():
|
| 418 |
attr_value = self.__dict__[attr_name]
|
| 419 |
if type(attr_value) not in {list, dict}:
|
|
@@ -442,7 +470,7 @@ class InSilicoPerturber:
|
|
| 442 |
elif self.perturb_type == "overexpress":
|
| 443 |
logger.warning(
|
| 444 |
"perturb_rank_shift set to None. " \
|
| 445 |
-
"If perturb type is
|
| 446 |
"of rank value encoding rather than shifted by quartile")
|
| 447 |
self.perturb_rank_shift = None
|
| 448 |
|
|
@@ -626,13 +654,14 @@ class InSilicoPerturber:
|
|
| 626 |
combo_lvl,
|
| 627 |
self.nproc)
|
| 628 |
cos_sims_data = quant_cos_sims(model,
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
|
|
|
| 636 |
|
| 637 |
if self.cell_states_to_model is None:
|
| 638 |
# update cos sims dict
|
|
@@ -699,6 +728,7 @@ class InSilicoPerturber:
|
|
| 699 |
0,
|
| 700 |
self.nproc)
|
| 701 |
cos_sims_data = quant_cos_sims(model,
|
|
|
|
| 702 |
perturbation_batch,
|
| 703 |
self.forward_batch_size,
|
| 704 |
layer_to_quant,
|
|
@@ -715,6 +745,7 @@ class InSilicoPerturber:
|
|
| 715 |
1,
|
| 716 |
self.nproc)
|
| 717 |
combo_cos_sims_data = quant_cos_sims(model,
|
|
|
|
| 718 |
combo_perturbation_batch,
|
| 719 |
self.forward_batch_size,
|
| 720 |
layer_to_quant,
|
|
|
|
| 140 |
all_embs_list += [torch.cat(emb_list)]
|
| 141 |
return torch.stack(all_embs_list)
|
| 142 |
|
| 143 |
+
# perturbed cell emb removing the activated/overexpressed/inhibited gene emb
|
| 144 |
+
# so that only non-perturbed gene embeddings are compared to each other
|
| 145 |
+
# in original or perturbed context
|
| 146 |
+
def make_perturbed_remainder_batch(emb_batch, indices_to_remove):
|
| 147 |
+
if type(indices_to_remove) == int:
|
| 148 |
+
indices_to_keep = [i for i in range(emb_batch.size()[1])]
|
| 149 |
+
indices_to_keep.pop(indices_to_remove)
|
| 150 |
+
perturbed_remainder_batch = torch.stack([emb[indices_to_keep,:] for emb in emb_batch])
|
| 151 |
+
elif type(indices_to_remove) == list:
|
| 152 |
+
perturbed_remainder_batch = torch.stack([make_comparison_batch(emb_batch[i],indices_to_remove[i]) for i in range(len(emb_batch))])
|
| 153 |
+
return perturbed_remainder_batch
|
| 154 |
+
|
| 155 |
# average embedding position of goal cell states
|
| 156 |
def get_cell_state_avg_embs(model,
|
| 157 |
filtered_input_data,
|
|
|
|
| 200 |
|
| 201 |
# quantify cosine similarity of perturbed vs original or alternate states
|
| 202 |
def quant_cos_sims(model,
|
| 203 |
+
perturb_type,
|
| 204 |
perturbation_batch,
|
| 205 |
forward_batch_size,
|
| 206 |
layer_to_quant,
|
|
|
|
| 239 |
minibatch_emb = outputs.hidden_states[layer_to_quant]
|
| 240 |
if cell_states_to_model is None:
|
| 241 |
minibatch_comparison = comparison_batch[i:max_range]
|
| 242 |
+
if perturb_type == "overexpress":
|
| 243 |
+
index_to_remove = 0
|
| 244 |
+
minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
|
| 245 |
+
# elif (perturb_type == "inhibit") or (perturb_type == "activate"):
|
| 246 |
+
# index_to_remove = placeholder
|
| 247 |
+
# minibatch_emb = make_perturbed_remainder_batch(minibatch_emb, index_to_remove)
|
| 248 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
| 249 |
+
elif cell_states_to_model is not None:
|
| 250 |
for state in possible_states:
|
| 251 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_emb, minibatch_emb, state_embs_dict[state])
|
| 252 |
del outputs
|
|
|
|
| 298 |
class InSilicoPerturber:
|
| 299 |
valid_option_dict = {
|
| 300 |
"perturb_type": {"delete","overexpress","inhibit","activate"},
|
| 301 |
+
"perturb_rank_shift": {None, 1, 2, 3},
|
| 302 |
"genes_to_perturb": {"all", list},
|
| 303 |
+
"combos": {0, 1, 2},
|
| 304 |
"anchor_gene": {None, str},
|
| 305 |
"model_type": {"Pretrained","GeneClassifier","CellClassifier"},
|
| 306 |
"num_classes": {int},
|
|
|
|
| 345 |
"overexpress": move gene to front of rank value encoding
|
| 346 |
"inhibit": move gene to lower quartile of rank value encoding
|
| 347 |
"activate": move gene to higher quartile of rank value encoding
|
| 348 |
+
perturb_rank_shift : None, {1,2,3}
|
| 349 |
Number of quartiles by which to shift rank of gene.
|
| 350 |
For example, if perturb_type="activate" and perturb_rank_shift=1:
|
| 351 |
genes in 4th quartile will move to middle of 3rd quartile.
|
|
|
|
| 433 |
self.tokens_to_perturb = [self.gene_token_dict[gene] for gene in self.genes_to_perturb]
|
| 434 |
|
| 435 |
def validate_options(self):
|
| 436 |
+
# first disallow options under development
|
| 437 |
+
if self.perturb_type in ["inhibit", "activate"]:
|
| 438 |
+
logger.error(
|
| 439 |
+
f"In silico inhibition and activation currently under developemnt. " \
|
| 440 |
+
f"Current valid options for 'perturb_type': 'delete' or 'overexpress'"
|
| 441 |
+
)
|
| 442 |
+
raise
|
| 443 |
+
|
| 444 |
+
# confirm arguments are within valid options and compatible with each other
|
| 445 |
for attr_name,valid_options in self.valid_option_dict.items():
|
| 446 |
attr_value = self.__dict__[attr_name]
|
| 447 |
if type(attr_value) not in {list, dict}:
|
|
|
|
| 470 |
elif self.perturb_type == "overexpress":
|
| 471 |
logger.warning(
|
| 472 |
"perturb_rank_shift set to None. " \
|
| 473 |
+
"If perturb type is overexpress then gene is moved to front " \
|
| 474 |
"of rank value encoding rather than shifted by quartile")
|
| 475 |
self.perturb_rank_shift = None
|
| 476 |
|
|
|
|
| 654 |
combo_lvl,
|
| 655 |
self.nproc)
|
| 656 |
cos_sims_data = quant_cos_sims(model,
|
| 657 |
+
self.perturb_type,
|
| 658 |
+
perturbation_batch,
|
| 659 |
+
self.forward_batch_size,
|
| 660 |
+
layer_to_quant,
|
| 661 |
+
original_emb,
|
| 662 |
+
indices_to_perturb,
|
| 663 |
+
self.cell_states_to_model,
|
| 664 |
+
state_embs_dict)
|
| 665 |
|
| 666 |
if self.cell_states_to_model is None:
|
| 667 |
# update cos sims dict
|
|
|
|
| 728 |
0,
|
| 729 |
self.nproc)
|
| 730 |
cos_sims_data = quant_cos_sims(model,
|
| 731 |
+
self.perturb_type,
|
| 732 |
perturbation_batch,
|
| 733 |
self.forward_batch_size,
|
| 734 |
layer_to_quant,
|
|
|
|
| 745 |
1,
|
| 746 |
self.nproc)
|
| 747 |
combo_cos_sims_data = quant_cos_sims(model,
|
| 748 |
+
self.perturb_type,
|
| 749 |
combo_perturbation_batch,
|
| 750 |
self.forward_batch_size,
|
| 751 |
layer_to_quant,
|