Add fix for new datasets version to make_perturbation_batch_special

#562
Files changed (1) hide show
  1. geneformer/perturber_utils.py +5 -0
geneformer/perturber_utils.py CHANGED
@@ -508,6 +508,11 @@ def make_perturbation_batch(
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
511
  if combo_lvl == 0 and tokens_to_perturb == "all":
512
  if perturb_type in ["overexpress", "activate"]:
513
  range_start = 1
 
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
511
+
512
+ # For datasets>=4.0.0, convert to dict to avoid format issues
513
+ if int(datasets.__version__.split(".")[0]) >= 4:
514
+ example_cell = example_cell[:]
515
+
516
  if combo_lvl == 0 and tokens_to_perturb == "all":
517
  if perturb_type in ["overexpress", "activate"]:
518
  range_start = 1