Add fix for new datasets version to make_perturbation_batch_special
#562
by
tusharagashe
- opened
geneformer/perturber_utils.py
CHANGED
|
@@ -508,6 +508,11 @@ def make_perturbation_batch(
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 512 |
if perturb_type in ["overexpress", "activate"]:
|
| 513 |
range_start = 1
|
|
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
| 511 |
+
|
| 512 |
+
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 513 |
+
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 514 |
+
example_cell = example_cell[:]
|
| 515 |
+
|
| 516 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 517 |
if perturb_type in ["overexpress", "activate"]:
|
| 518 |
range_start = 1
|