Fixed error with perturbing individual genes and updated ways to specify cell_states_to_model
Browse files- geneformer/in_silico_perturber.py +138 -34
geneformer/in_silico_perturber.py
CHANGED
|
@@ -105,6 +105,12 @@ def downsample_and_sort(data_shuffled, max_ncells):
|
|
| 105 |
data_sorted = data_subset.sort("length",reverse=True)
|
| 106 |
return data_sorted
|
| 107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
| 109 |
example_cell.set_format(type="torch")
|
| 110 |
input_data = example_cell["input_ids"]
|
|
@@ -235,13 +241,15 @@ def get_cell_state_avg_embs(model,
|
|
| 235 |
num_proc):
|
| 236 |
|
| 237 |
model_input_size = get_model_input_size(model)
|
| 238 |
-
possible_states =
|
| 239 |
state_embs_dict = dict()
|
| 240 |
for possible_state in possible_states:
|
| 241 |
state_embs_list = []
|
|
|
|
| 242 |
|
| 243 |
def filter_states(example):
|
| 244 |
-
|
|
|
|
| 245 |
filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
|
| 246 |
total_batch_length = len(filtered_input_data_state)
|
| 247 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
|
@@ -254,6 +262,7 @@ def get_cell_state_avg_embs(model,
|
|
| 254 |
state_minibatch.set_format(type="torch")
|
| 255 |
|
| 256 |
input_data_minibatch = state_minibatch["input_ids"]
|
|
|
|
| 257 |
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
| 258 |
max_len,
|
| 259 |
pad_token_id,
|
|
@@ -271,8 +280,12 @@ def get_cell_state_avg_embs(model,
|
|
| 271 |
del input_data_minibatch
|
| 272 |
del state_embs_i
|
| 273 |
torch.cuda.empty_cache()
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
state_embs_dict[possible_state] = avg_state_emb
|
| 277 |
return state_embs_dict
|
| 278 |
|
|
@@ -291,7 +304,6 @@ def quant_cos_sims(model,
|
|
| 291 |
pad_token_id,
|
| 292 |
model_input_size,
|
| 293 |
nproc):
|
| 294 |
-
|
| 295 |
cos = torch.nn.CosineSimilarity(dim=2)
|
| 296 |
total_batch_length = len(perturbation_batch)
|
| 297 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
|
@@ -301,7 +313,7 @@ def quant_cos_sims(model,
|
|
| 301 |
comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
|
| 302 |
cos_sims = []
|
| 303 |
else:
|
| 304 |
-
possible_states =
|
| 305 |
cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
|
| 306 |
|
| 307 |
# measure length of each element in perturbation_batch
|
|
@@ -316,6 +328,7 @@ def quant_cos_sims(model,
|
|
| 316 |
|
| 317 |
# determine if need to pad or truncate batch
|
| 318 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
|
|
|
| 319 |
if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
|
| 320 |
needs_pad_or_trunc = True
|
| 321 |
else:
|
|
@@ -360,6 +373,7 @@ def quant_cos_sims(model,
|
|
| 360 |
# truncate to the (model input size - # tokens to overexpress) to ensure comparability
|
| 361 |
# since max input size of perturb batch will be reduced by # tokens to overexpress
|
| 362 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
|
|
|
| 363 |
original_minibatch_length_set = set(original_minibatch["length"])
|
| 364 |
if perturb_type == "overexpress":
|
| 365 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
|
@@ -385,7 +399,32 @@ def quant_cos_sims(model,
|
|
| 385 |
original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
|
| 386 |
else:
|
| 387 |
original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
# cosine similarity between original emb and batch items
|
| 390 |
if cell_states_to_model is None:
|
| 391 |
if perturb_group == False:
|
|
@@ -406,7 +445,9 @@ def quant_cos_sims(model,
|
|
| 406 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
|
| 407 |
minibatch_emb,
|
| 408 |
state_embs_dict[state],
|
| 409 |
-
perturb_group
|
|
|
|
|
|
|
| 410 |
del outputs
|
| 411 |
del minibatch_emb
|
| 412 |
if cell_states_to_model is None:
|
|
@@ -421,14 +462,40 @@ def quant_cos_sims(model,
|
|
| 421 |
return cos_sims_vs_alt_dict
|
| 422 |
|
| 423 |
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
| 424 |
-
def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group):
|
| 425 |
cos = torch.nn.CosineSimilarity(dim=2)
|
| 426 |
-
|
| 427 |
-
|
| 428 |
original_emb = original_emb[None, :]
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
| 433 |
|
| 434 |
def pad_list(input_ids, pad_token_id, max_len):
|
|
@@ -706,6 +773,12 @@ class InSilicoPerturber:
|
|
| 706 |
|
| 707 |
if self.cell_states_to_model is not None:
|
| 708 |
if len(self.cell_states_to_model.items()) == 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
for key,value in self.cell_states_to_model.items():
|
| 710 |
if (len(value) == 3) and isinstance(value, tuple):
|
| 711 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
@@ -713,14 +786,48 @@ class InSilicoPerturber:
|
|
| 713 |
all_values = value[0]+value[1]+value[2]
|
| 714 |
if len(all_values) == len(set(all_values)):
|
| 715 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
else:
|
| 717 |
logger.error(
|
| 718 |
-
"
|
| 719 |
-
"
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
|
|
|
|
|
|
|
|
|
| 723 |
raise
|
|
|
|
| 724 |
if self.anchor_gene is not None:
|
| 725 |
self.anchor_gene = None
|
| 726 |
logger.warning(
|
|
@@ -770,6 +877,13 @@ class InSilicoPerturber:
|
|
| 770 |
if self.cell_states_to_model is None:
|
| 771 |
state_embs_dict = None
|
| 772 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
# get dictionary of average cell state embeddings for comparison
|
| 774 |
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
| 775 |
state_embs_dict = get_cell_state_avg_embs(model,
|
|
@@ -780,9 +894,9 @@ class InSilicoPerturber:
|
|
| 780 |
self.forward_batch_size,
|
| 781 |
self.nproc)
|
| 782 |
# filter for start state cells
|
| 783 |
-
start_state =
|
| 784 |
def filter_for_origin(example):
|
| 785 |
-
return example[
|
| 786 |
|
| 787 |
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
|
| 788 |
|
|
@@ -878,7 +992,6 @@ class InSilicoPerturber:
|
|
| 878 |
# or (perturbed_genes, "cell_emb") for avg cell emb change
|
| 879 |
cos_sims_data = cos_sims_data.to("cuda")
|
| 880 |
max_padded_len = cos_sims_data.shape[1]
|
| 881 |
-
|
| 882 |
for j in range(cos_sims_data.shape[0]):
|
| 883 |
# remove padding before mean pooling cell embedding
|
| 884 |
original_length = original_lengths[j]
|
|
@@ -900,21 +1013,13 @@ class InSilicoPerturber:
|
|
| 900 |
# update cos sims dict
|
| 901 |
# key is tuple of (perturbed_genes, "cell_emb")
|
| 902 |
# value is list of tuples of cos sims for cell_states_to_model
|
| 903 |
-
origin_state_key =
|
| 904 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
| 905 |
for j in range(cos_sims_origin.shape[0]):
|
| 906 |
-
original_length = original_lengths[j]
|
| 907 |
-
max_padded_len = cos_sims_origin.shape[1]
|
| 908 |
-
indices_removed = indices_to_perturb[j]
|
| 909 |
-
padding_to_remove = max_padded_len - (original_length \
|
| 910 |
-
- len(self.tokens_to_perturb) \
|
| 911 |
-
- len(indices_removed))
|
| 912 |
data_list = []
|
| 913 |
for data in list(cos_sims_data.values()):
|
| 914 |
data_item = data.to("cuda")
|
| 915 |
-
|
| 916 |
-
cell_data = torch.mean(nonpadding_data_item).item()
|
| 917 |
-
data_list += [cell_data]
|
| 918 |
cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
|
| 919 |
|
| 920 |
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
|
@@ -987,7 +1092,7 @@ class InSilicoPerturber:
|
|
| 987 |
# update cos sims dict
|
| 988 |
# key is tuple of (perturbed_gene, "cell_emb")
|
| 989 |
# value is list of tuples of cos sims for cell_states_to_model
|
| 990 |
-
origin_state_key =
|
| 991 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
| 992 |
|
| 993 |
for j in range(cos_sims_origin.shape[0]):
|
|
@@ -1109,4 +1214,3 @@ class InSilicoPerturber:
|
|
| 1109 |
# save remainder cells
|
| 1110 |
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
| 1111 |
pickle.dump(cos_sims_dict, fp)
|
| 1112 |
-
|
|
|
|
| 105 |
data_sorted = data_subset.sort("length",reverse=True)
|
| 106 |
return data_sorted
|
| 107 |
|
| 108 |
+
def get_possible_states(cell_states_to_model):
|
| 109 |
+
if list(cell_states_to_model.values())[3] is not None:
|
| 110 |
+
return list(cell_states_to_model.values())[1:3] + list(cell_states_to_model.values())[3]
|
| 111 |
+
else:
|
| 112 |
+
return list(cell_states_to_model.values())[1:3]
|
| 113 |
+
|
| 114 |
def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
| 115 |
example_cell.set_format(type="torch")
|
| 116 |
input_data = example_cell["input_ids"]
|
|
|
|
| 241 |
num_proc):
|
| 242 |
|
| 243 |
model_input_size = get_model_input_size(model)
|
| 244 |
+
possible_states = get_possible_states(cell_states_to_model)
|
| 245 |
state_embs_dict = dict()
|
| 246 |
for possible_state in possible_states:
|
| 247 |
state_embs_list = []
|
| 248 |
+
original_lens = []
|
| 249 |
|
| 250 |
def filter_states(example):
|
| 251 |
+
state_key = cell_states_to_model["state_key"]
|
| 252 |
+
return example[state_key] in possible_state
|
| 253 |
filtered_input_data_state = filtered_input_data.filter(filter_states, num_proc=num_proc)
|
| 254 |
total_batch_length = len(filtered_input_data_state)
|
| 255 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
|
|
|
| 262 |
state_minibatch.set_format(type="torch")
|
| 263 |
|
| 264 |
input_data_minibatch = state_minibatch["input_ids"]
|
| 265 |
+
original_lens += [tensor.numel() for tensor in input_data_minibatch]
|
| 266 |
input_data_minibatch = pad_tensor_list(input_data_minibatch,
|
| 267 |
max_len,
|
| 268 |
pad_token_id,
|
|
|
|
| 280 |
del input_data_minibatch
|
| 281 |
del state_embs_i
|
| 282 |
torch.cuda.empty_cache()
|
| 283 |
+
|
| 284 |
+
# import here to avoid circular imports
|
| 285 |
+
from .emb_extractor import mean_nonpadding_embs
|
| 286 |
+
state_embs = torch.cat(state_embs_list)
|
| 287 |
+
avg_state_emb = mean_nonpadding_embs(state_embs, torch.Tensor(original_lens).to("cuda"))
|
| 288 |
+
avg_state_emb = torch.mean(avg_state_emb, dim=0, keepdim=True)
|
| 289 |
state_embs_dict[possible_state] = avg_state_emb
|
| 290 |
return state_embs_dict
|
| 291 |
|
|
|
|
| 304 |
pad_token_id,
|
| 305 |
model_input_size,
|
| 306 |
nproc):
|
|
|
|
| 307 |
cos = torch.nn.CosineSimilarity(dim=2)
|
| 308 |
total_batch_length = len(perturbation_batch)
|
| 309 |
if ((total_batch_length-1)/forward_batch_size).is_integer():
|
|
|
|
| 313 |
comparison_batch = make_comparison_batch(original_emb, indices_to_perturb, perturb_group)
|
| 314 |
cos_sims = []
|
| 315 |
else:
|
| 316 |
+
possible_states = get_possible_states(cell_states_to_model)
|
| 317 |
cos_sims_vs_alt_dict = dict(zip(possible_states,[[] for i in range(len(possible_states))]))
|
| 318 |
|
| 319 |
# measure length of each element in perturbation_batch
|
|
|
|
| 328 |
|
| 329 |
# determine if need to pad or truncate batch
|
| 330 |
minibatch_length_set = set(perturbation_minibatch["length"])
|
| 331 |
+
minibatch_lengths = perturbation_minibatch["length"]
|
| 332 |
if (len(minibatch_length_set) > 1) or (max(minibatch_length_set) > model_input_size):
|
| 333 |
needs_pad_or_trunc = True
|
| 334 |
else:
|
|
|
|
| 373 |
# truncate to the (model input size - # tokens to overexpress) to ensure comparability
|
| 374 |
# since max input size of perturb batch will be reduced by # tokens to overexpress
|
| 375 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
| 376 |
+
original_minibatch_lengths = original_minibatch["length"]
|
| 377 |
original_minibatch_length_set = set(original_minibatch["length"])
|
| 378 |
if perturb_type == "overexpress":
|
| 379 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
|
|
|
| 399 |
original_minibatch_emb = torch.squeeze(original_outputs.hidden_states[layer_to_quant])
|
| 400 |
else:
|
| 401 |
original_minibatch_emb = original_outputs.hidden_states[layer_to_quant]
|
| 402 |
+
|
| 403 |
+
# remove perturbed index before calculating the cos sims
|
| 404 |
+
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
| 405 |
+
# indices_to_remove is list of indices to remove
|
| 406 |
+
gene_dim -= 1 # removing a dim in calling the function
|
| 407 |
+
indices_to_keep = [i for i in range(emb.size()[gene_dim]) if i not in indices_to_remove]
|
| 408 |
+
num_dims = emb.dim()
|
| 409 |
+
emb_slice = [slice(None) if dim != gene_dim else indices_to_keep for dim in range(num_dims)]
|
| 410 |
+
sliced_emb = emb[emb_slice]
|
| 411 |
+
return sliced_emb
|
| 412 |
+
|
| 413 |
+
# this could probably be optimized
|
| 414 |
+
gene_dim = 1
|
| 415 |
+
|
| 416 |
+
# current there's the case if a gene is not expressed and is being overexpressed,
|
| 417 |
+
# the dimensions will be thrown off --> not removing indices to get around that issue
|
| 418 |
+
# not sure what's the best way to handle it
|
| 419 |
+
if perturb_type != "overexpress":
|
| 420 |
+
original_minibatch_emb = torch.stack([
|
| 421 |
+
remove_indices_from_emb(original_minibatch_emb[i, :, :], idx, gene_dim) for
|
| 422 |
+
i, idx in enumerate(indices_to_perturb)
|
| 423 |
+
])
|
| 424 |
+
|
| 425 |
+
# do the averaging here
|
| 426 |
+
|
| 427 |
+
|
| 428 |
# cosine similarity between original emb and batch items
|
| 429 |
if cell_states_to_model is None:
|
| 430 |
if perturb_group == False:
|
|
|
|
| 445 |
cos_sims_vs_alt_dict[state] += cos_sim_shift(original_minibatch_emb,
|
| 446 |
minibatch_emb,
|
| 447 |
state_embs_dict[state],
|
| 448 |
+
perturb_group,
|
| 449 |
+
torch.tensor(original_minibatch_lengths, device="cuda"),
|
| 450 |
+
torch.tensor(minibatch_lengths, device="cuda"))
|
| 451 |
del outputs
|
| 452 |
del minibatch_emb
|
| 453 |
if cell_states_to_model is None:
|
|
|
|
| 462 |
return cos_sims_vs_alt_dict
|
| 463 |
|
| 464 |
# calculate cos sim shift of perturbation with respect to origin and alternative cell
|
| 465 |
+
def cos_sim_shift(original_emb, minibatch_emb, alt_emb, perturb_group, original_minibatch_lengths = None, minibatch_lengths = None,):
|
| 466 |
cos = torch.nn.CosineSimilarity(dim=2)
|
| 467 |
+
if not perturb_group:
|
| 468 |
+
original_emb = torch.mean(original_emb,dim=0,keepdim=True)
|
| 469 |
original_emb = original_emb[None, :]
|
| 470 |
+
origin_v_end = torch.squeeze(cos(original_emb, alt_emb))
|
| 471 |
+
else:
|
| 472 |
+
if original_emb.size() != minibatch_emb.size():
|
| 473 |
+
logger.error(
|
| 474 |
+
f"Embeddings are not the same dimensions. " \
|
| 475 |
+
f"original_emb is {original_emb.size()}. " \
|
| 476 |
+
f"minibatch_emb is {minibatch_emb.size()}. "
|
| 477 |
+
)
|
| 478 |
+
raise
|
| 479 |
+
from .emb_extractor import mean_nonpadding_embs
|
| 480 |
+
|
| 481 |
+
if original_minibatch_lengths is not None:
|
| 482 |
+
original_emb = mean_nonpadding_embs(original_emb, original_minibatch_lengths)
|
| 483 |
+
# not sure if the else is necessary, but keeping it here in case
|
| 484 |
+
else:
|
| 485 |
+
original_emb = torch.mean(original_emb,dim=1,keepdim=True)
|
| 486 |
+
|
| 487 |
+
alt_emb = torch.unsqueeze(alt_emb, 1)
|
| 488 |
+
origin_v_end = cos(original_emb, alt_emb)
|
| 489 |
+
origin_v_end = torch.squeeze(origin_v_end)
|
| 490 |
+
|
| 491 |
+
if minibatch_lengths is not None:
|
| 492 |
+
perturb_emb = mean_nonpadding_embs(minibatch_emb, minibatch_lengths)
|
| 493 |
+
else:
|
| 494 |
+
perturb_emb = torch.mean(minibatch_emb,dim=1,keepdim=True)
|
| 495 |
+
|
| 496 |
+
perturb_v_end = cos(perturb_emb, alt_emb)
|
| 497 |
+
perturb_v_end = torch.squeeze(perturb_v_end)
|
| 498 |
+
|
| 499 |
return [(perturb_v_end-origin_v_end).to("cpu")]
|
| 500 |
|
| 501 |
def pad_list(input_ids, pad_token_id, max_len):
|
|
|
|
| 773 |
|
| 774 |
if self.cell_states_to_model is not None:
|
| 775 |
if len(self.cell_states_to_model.items()) == 1:
|
| 776 |
+
logger.warning(
|
| 777 |
+
"The single value dictionary for cell_states_to_model will be " \
|
| 778 |
+
"replaced with explicitly modeling start and end states. " \
|
| 779 |
+
"Please specify state_key, start_state, end_state, and alt_states " \
|
| 780 |
+
"in the cell_states_to_model dictionary for future use."
|
| 781 |
+
)
|
| 782 |
for key,value in self.cell_states_to_model.items():
|
| 783 |
if (len(value) == 3) and isinstance(value, tuple):
|
| 784 |
if isinstance(value[0],list) and isinstance(value[1],list) and isinstance(value[2],list):
|
|
|
|
| 786 |
all_values = value[0]+value[1]+value[2]
|
| 787 |
if len(all_values) == len(set(all_values)):
|
| 788 |
continue
|
| 789 |
+
# reformat to the new format
|
| 790 |
+
state_values = flatten_list(list(self.cell_states_to_model.values()))
|
| 791 |
+
self.cell_states_to_model = {
|
| 792 |
+
"state_key": list(self.cell_states_to_model.keys())[0],
|
| 793 |
+
"start_state": state_values[0][0],
|
| 794 |
+
"goal_state": state_values[1][0],
|
| 795 |
+
"alt_states": state_values[2:][0]
|
| 796 |
+
}
|
| 797 |
+
elif set(self.cell_states_to_model.keys()) == {"state_key", "start_state", "goal_state", "alt_states"}:
|
| 798 |
+
if self.cell_states_to_model["start_state"] is None or self.cell_states_to_model["goal_state"] is None:
|
| 799 |
+
logger.error(
|
| 800 |
+
"Please specify 'start_state' and 'goal_state' in cell_states_to_model.")
|
| 801 |
+
raise
|
| 802 |
+
|
| 803 |
+
if self.cell_states_to_model["start_state"] == self.cell_states_to_model["goal_state"]:
|
| 804 |
+
logger.error(
|
| 805 |
+
"All states must be unique.")
|
| 806 |
+
raise
|
| 807 |
+
|
| 808 |
+
if self.cell_states_to_model["alt_states"] is not None:
|
| 809 |
+
if type(self.cell_states_to_model["alt_states"]) is not list:
|
| 810 |
+
logger.error(
|
| 811 |
+
"self.cell_states_to_model['alt_states'] must be a list (even if it is one element)."
|
| 812 |
+
)
|
| 813 |
+
raise
|
| 814 |
+
if len(self.cell_states_to_model["alt_states"])!= len(set(self.cell_states_to_model["alt_states"])):
|
| 815 |
+
logger.error(
|
| 816 |
+
"All states must be unique.")
|
| 817 |
+
raise
|
| 818 |
+
|
| 819 |
else:
|
| 820 |
logger.error(
|
| 821 |
+
"states_to_model must only have the following four keys: 'state_key', 'start_state', 'goal_state', 'alt_states'." \
|
| 822 |
+
"For example, cell_states_to_model={ \
|
| 823 |
+
'state_key': 'disease', \
|
| 824 |
+
'start_state': 'dcm', \
|
| 825 |
+
'goal_state': 'nf'', \
|
| 826 |
+
'alt_states': ['hcm', 'other1', 'other2'] \
|
| 827 |
+
}"
|
| 828 |
+
)
|
| 829 |
raise
|
| 830 |
+
|
| 831 |
if self.anchor_gene is not None:
|
| 832 |
self.anchor_gene = None
|
| 833 |
logger.warning(
|
|
|
|
| 877 |
if self.cell_states_to_model is None:
|
| 878 |
state_embs_dict = None
|
| 879 |
else:
|
| 880 |
+
# make sure that all states are valid; save time on filtering
|
| 881 |
+
state_name = self.cell_states_to_model["state_key"]
|
| 882 |
+
for value in get_possible_states(self.cell_states_to_model):
|
| 883 |
+
if value not in filtered_input_data[state_name]:
|
| 884 |
+
logger.error(
|
| 885 |
+
f"{value} is not a valid value in {state_name}.")
|
| 886 |
+
raise
|
| 887 |
# get dictionary of average cell state embeddings for comparison
|
| 888 |
downsampled_data = downsample_and_sort(filtered_input_data, self.max_ncells)
|
| 889 |
state_embs_dict = get_cell_state_avg_embs(model,
|
|
|
|
| 894 |
self.forward_batch_size,
|
| 895 |
self.nproc)
|
| 896 |
# filter for start state cells
|
| 897 |
+
start_state = self.cell_states_to_model["start_state"]
|
| 898 |
def filter_for_origin(example):
|
| 899 |
+
return example[state_name] in [start_state]
|
| 900 |
|
| 901 |
filtered_input_data = filtered_input_data.filter(filter_for_origin, num_proc=self.nproc)
|
| 902 |
|
|
|
|
| 992 |
# or (perturbed_genes, "cell_emb") for avg cell emb change
|
| 993 |
cos_sims_data = cos_sims_data.to("cuda")
|
| 994 |
max_padded_len = cos_sims_data.shape[1]
|
|
|
|
| 995 |
for j in range(cos_sims_data.shape[0]):
|
| 996 |
# remove padding before mean pooling cell embedding
|
| 997 |
original_length = original_lengths[j]
|
|
|
|
| 1013 |
# update cos sims dict
|
| 1014 |
# key is tuple of (perturbed_genes, "cell_emb")
|
| 1015 |
# value is list of tuples of cos sims for cell_states_to_model
|
| 1016 |
+
origin_state_key = self.cell_states_to_model["start_state"]
|
| 1017 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
| 1018 |
for j in range(cos_sims_origin.shape[0]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1019 |
data_list = []
|
| 1020 |
for data in list(cos_sims_data.values()):
|
| 1021 |
data_item = data.to("cuda")
|
| 1022 |
+
data_list += [data_item]
|
|
|
|
|
|
|
| 1023 |
cos_sims_dict[(perturbed_genes, "cell_emb")] += [tuple(data_list)]
|
| 1024 |
|
| 1025 |
with open(f"{output_path_prefix}_raw.pickle", "wb") as fp:
|
|
|
|
| 1092 |
# update cos sims dict
|
| 1093 |
# key is tuple of (perturbed_gene, "cell_emb")
|
| 1094 |
# value is list of tuples of cos sims for cell_states_to_model
|
| 1095 |
+
origin_state_key = self.cell_states_to_model["start_state"]
|
| 1096 |
cos_sims_origin = cos_sims_data[origin_state_key]
|
| 1097 |
|
| 1098 |
for j in range(cos_sims_origin.shape[0]):
|
|
|
|
| 1214 |
# save remainder cells
|
| 1215 |
with open(f"{output_path_prefix}{pickle_batch}_raw.pickle", "wb") as fp:
|
| 1216 |
pickle.dump(cos_sims_dict, fp)
|
|
|