Fix TypeError in make_perturbation_batch_special
Browse filesSummary:
`make_perturbation_batch_special` failed with
```
TypeError: unsupported operand type(s) for *: 'Column' and 'int'
```
because example_cell["input_ids"] was a Hugging Face Column.
This is a bug for `datasets>=4.0.0`
Fix:
Convert example_cell to a Python dict (same as in make_perturbation_batch) and explicitly build the repeated input list instead of multiplying the column.
geneformer/perturber_utils.py
CHANGED
|
@@ -508,6 +508,11 @@ def make_perturbation_batch(
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 512 |
if perturb_type in ["overexpress", "activate"]:
|
| 513 |
range_start = 1
|
|
|
|
| 508 |
def make_perturbation_batch_special(
|
| 509 |
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
| 510 |
) -> tuple[Dataset, List[int]]:
|
| 511 |
+
|
| 512 |
+
# For datasets>=4.0.0, convert to dict to avoid format issues
|
| 513 |
+
if int(datasets.__version__.split(".")[0]) >= 4:
|
| 514 |
+
example_cell = example_cell[:]
|
| 515 |
+
|
| 516 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
| 517 |
if perturb_type in ["overexpress", "activate"]:
|
| 518 |
range_start = 1
|