Update geneformer/perturber_utils.py
#361
by
hchen725
- opened
- geneformer/perturber_utils.py +44 -13
geneformer/perturber_utils.py
CHANGED
|
@@ -228,21 +228,41 @@ def overexpress_indices(example):
|
|
| 228 |
example["length"] = len(example["input_ids"])
|
| 229 |
return example
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
| 233 |
-
def overexpress_tokens(example,
|
|
|
|
| 234 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
| 235 |
if example["perturb_index"] != [-100]:
|
| 236 |
example = delete_indices(example)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
# truncate to max input size, must also truncate original emb to be comparable
|
| 243 |
-
if len(example["input_ids"]) >
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
| 246 |
example["length"] = len(example["input_ids"])
|
| 247 |
return example
|
| 248 |
|
|
@@ -259,6 +279,12 @@ def truncate_by_n_overflow(example):
|
|
| 259 |
example["length"] = len(example["input_ids"])
|
| 260 |
return example
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
| 264 |
# indices_to_remove is list of indices to remove
|
|
@@ -321,7 +347,7 @@ def remove_perturbed_indices_set(
|
|
| 321 |
|
| 322 |
|
| 323 |
def make_perturbation_batch(
|
| 324 |
-
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 325 |
) -> tuple[Dataset, List[int]]:
|
| 326 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 327 |
if perturb_type in ["overexpress", "activate"]:
|
|
@@ -383,9 +409,14 @@ def make_perturbation_batch(
|
|
| 383 |
delete_indices, num_proc=num_proc_i
|
| 384 |
)
|
| 385 |
elif perturb_type == "overexpress":
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
| 391 |
|
|
@@ -758,4 +789,4 @@ class GeneIdHandler:
|
|
| 758 |
return self.ens_to_symbol(self.token_to_ens(token))
|
| 759 |
|
| 760 |
def symbol_to_token(self, symbol):
|
| 761 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
|
| 228 |
example["length"] = len(example["input_ids"])
|
| 229 |
return example
|
| 230 |
|
| 231 |
+
# if CLS token present, move to 1st rather than 0th position
|
| 232 |
+
def overexpress_indices_special(example):
|
| 233 |
+
indices = example["perturb_index"]
|
| 234 |
+
if any(isinstance(el, list) for el in indices):
|
| 235 |
+
indices = flatten_list(indices)
|
| 236 |
+
for index in sorted(indices, reverse=True):
|
| 237 |
+
example["input_ids"].insert(1, example["input_ids"].pop(index))
|
| 238 |
+
|
| 239 |
+
example["length"] = len(example["input_ids"])
|
| 240 |
+
return example
|
| 241 |
|
| 242 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
| 243 |
+
def overexpress_tokens(example, special_token):
|
| 244 |
+
original_len = example["length"]
|
| 245 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
| 246 |
if example["perturb_index"] != [-100]:
|
| 247 |
example = delete_indices(example)
|
| 248 |
+
if special_token:
|
| 249 |
+
[
|
| 250 |
+
example["input_ids"].insert(1, token)
|
| 251 |
+
for token in example["tokens_to_perturb"][::-1]
|
| 252 |
+
]
|
| 253 |
+
else:
|
| 254 |
+
example = overexpress_indices(example)
|
| 255 |
+
[
|
| 256 |
+
example["input_ids"].insert(0, token)
|
| 257 |
+
for token in example["tokens_to_perturb"][::-1]
|
| 258 |
+
]
|
| 259 |
|
| 260 |
# truncate to max input size, must also truncate original emb to be comparable
|
| 261 |
+
if len(example["input_ids"]) > original_len:
|
| 262 |
+
if special_token:
|
| 263 |
+
del example["input_ids"][original_len-1]
|
| 264 |
+
else:
|
| 265 |
+
example["input_ids"] = example["input_ids"][0:original_len]
|
| 266 |
example["length"] = len(example["input_ids"])
|
| 267 |
return example
|
| 268 |
|
|
|
|
| 279 |
example["length"] = len(example["input_ids"])
|
| 280 |
return example
|
| 281 |
|
| 282 |
+
def truncate_by_n_overflow_special(example):
|
| 283 |
+
new_max_len = example["length"] - example["n_overflow"]
|
| 284 |
+
example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
|
| 285 |
+
example["length"] = len(example["input_ids"])
|
| 286 |
+
return example
|
| 287 |
+
|
| 288 |
|
| 289 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
| 290 |
# indices_to_remove is list of indices to remove
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
def make_perturbation_batch(
|
| 350 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
|
| 351 |
) -> tuple[Dataset, List[int]]:
|
| 352 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 353 |
if perturb_type in ["overexpress", "activate"]:
|
|
|
|
| 409 |
delete_indices, num_proc=num_proc_i
|
| 410 |
)
|
| 411 |
elif perturb_type == "overexpress":
|
| 412 |
+
if special_token:
|
| 413 |
+
perturbation_dataset = perturbation_dataset.map(
|
| 414 |
+
overexpress_indices_special, num_proc=num_proc_i
|
| 415 |
+
)
|
| 416 |
+
else:
|
| 417 |
+
perturbation_dataset = perturbation_dataset.map(
|
| 418 |
+
overexpress_indices, num_proc=num_proc_i
|
| 419 |
+
)
|
| 420 |
|
| 421 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
| 422 |
|
|
|
|
| 789 |
return self.ens_to_symbol(self.token_to_ens(token))
|
| 790 |
|
| 791 |
def symbol_to_token(self, symbol):
|
| 792 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|