Upload in_silico_perturber.py
#187
by
davidjwen
- opened
geneformer/in_silico_perturber.py
CHANGED
|
@@ -396,19 +396,22 @@ def quant_cos_sims(model,
|
|
| 396 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
| 397 |
original_minibatch_lengths = original_minibatch["length"]
|
| 398 |
original_minibatch_length_set = set(original_minibatch["length"])
|
|
|
|
|
|
|
|
|
|
| 399 |
if perturb_type == "overexpress":
|
| 400 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
| 401 |
else:
|
| 402 |
new_max_len = model_input_size
|
| 403 |
if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
|
| 404 |
-
|
| 405 |
def pad_or_trunc_example(example):
|
| 406 |
-
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id,
|
| 407 |
return example
|
| 408 |
original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
| 409 |
original_minibatch.set_format(type="torch")
|
| 410 |
original_input_data_minibatch = original_minibatch["input_ids"]
|
| 411 |
-
attention_mask = gen_attention_mask(original_minibatch,
|
| 412 |
# extract embeddings for original minibatch
|
| 413 |
with torch.no_grad():
|
| 414 |
original_outputs = model(
|
|
@@ -429,7 +432,7 @@ def quant_cos_sims(model,
|
|
| 429 |
# exclude overexpression due to case when genes are not expressed but being overexpressed
|
| 430 |
if perturb_type != "overexpress":
|
| 431 |
original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
|
| 432 |
-
|
| 433 |
gene_dim)
|
| 434 |
|
| 435 |
# cosine similarity between original emb and batch items
|
|
@@ -438,7 +441,7 @@ def quant_cos_sims(model,
|
|
| 438 |
minibatch_comparison = comparison_batch[i:max_range]
|
| 439 |
elif perturb_group == True:
|
| 440 |
minibatch_comparison = make_comparison_batch(original_minibatch_emb,
|
| 441 |
-
|
| 442 |
perturb_group)
|
| 443 |
|
| 444 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
|
|
|
| 396 |
original_minibatch = original_emb.select([i for i in range(i, max_range)])
|
| 397 |
original_minibatch_lengths = original_minibatch["length"]
|
| 398 |
original_minibatch_length_set = set(original_minibatch["length"])
|
| 399 |
+
|
| 400 |
+
indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
|
| 401 |
+
|
| 402 |
if perturb_type == "overexpress":
|
| 403 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
| 404 |
else:
|
| 405 |
new_max_len = model_input_size
|
| 406 |
if (len(original_minibatch_length_set) > 1) or (max(original_minibatch_length_set) > new_max_len):
|
| 407 |
+
new_max_len = min(max(original_minibatch_length_set),new_max_len)
|
| 408 |
def pad_or_trunc_example(example):
|
| 409 |
+
example["input_ids"] = pad_or_truncate_encoding(example["input_ids"], pad_token_id, new_max_len)
|
| 410 |
return example
|
| 411 |
original_minibatch = original_minibatch.map(pad_or_trunc_example, num_proc=nproc)
|
| 412 |
original_minibatch.set_format(type="torch")
|
| 413 |
original_input_data_minibatch = original_minibatch["input_ids"]
|
| 414 |
+
attention_mask = gen_attention_mask(original_minibatch, new_max_len)
|
| 415 |
# extract embeddings for original minibatch
|
| 416 |
with torch.no_grad():
|
| 417 |
original_outputs = model(
|
|
|
|
| 432 |
# exclude overexpression due to case when genes are not expressed but being overexpressed
|
| 433 |
if perturb_type != "overexpress":
|
| 434 |
original_minibatch_emb = remove_indices_from_emb_batch(original_minibatch_emb,
|
| 435 |
+
indices_to_perturb_minibatch,
|
| 436 |
gene_dim)
|
| 437 |
|
| 438 |
# cosine similarity between original emb and batch items
|
|
|
|
| 441 |
minibatch_comparison = comparison_batch[i:max_range]
|
| 442 |
elif perturb_group == True:
|
| 443 |
minibatch_comparison = make_comparison_batch(original_minibatch_emb,
|
| 444 |
+
indices_to_perturb_minibatch,
|
| 445 |
perturb_group)
|
| 446 |
|
| 447 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|