so298 commited on
Commit
bf55247
·
verified ·
1 Parent(s): 7ec90d6

Fix TypeError in make_perturbation_batch_special

Browse files

Summary:

`make_perturbation_batch_special` failed with
```
TypeError: unsupported operand type(s) for *: 'Column' and 'int'
```
because example_cell["input_ids"] was a Hugging Face Column.
This is a bug for `datasets>=4.0.0`

Fix:
Convert example_cell to a Python dict (same as in make_perturbation_batch) and explicitly build the repeated input list instead of multiplying the column.

Files changed (1) hide show
  1. geneformer/perturber_utils.py +5 -0
geneformer/perturber_utils.py CHANGED
@@ -508,6 +508,11 @@ def make_perturbation_batch(
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
511
  if combo_lvl == 0 and tokens_to_perturb == "all":
512
  if perturb_type in ["overexpress", "activate"]:
513
  range_start = 1
 
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
511
+
512
+ # For datasets>=4.0.0, convert to dict to avoid format issues
513
+ if int(datasets.__version__.split(".")[0]) >= 4:
514
+ example_cell = example_cell[:]
515
+
516
  if combo_lvl == 0 and tokens_to_perturb == "all":
517
  if perturb_type in ["overexpress", "activate"]:
518
  range_start = 1