tusharagashe commited on
Commit
926b9a1
·
verified ·
1 Parent(s): 1930588

Add fix for new datasets version to make_perturbation_batch_special

Browse files

Added same fix as make_perturbation_batch() to make_perturbation_batch_special() to account for new datasets module

Files changed (1) hide show
  1. geneformer/perturber_utils.py +5 -0
geneformer/perturber_utils.py CHANGED
@@ -508,6 +508,11 @@ def make_perturbation_batch(
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
 
 
 
 
 
511
  if combo_lvl == 0 and tokens_to_perturb == "all":
512
  if perturb_type in ["overexpress", "activate"]:
513
  range_start = 1
 
508
  def make_perturbation_batch_special(
509
  example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
510
  ) -> tuple[Dataset, List[int]]:
511
+
512
+ # For datasets>=4.0.0, convert to dict to avoid format issues
513
+ if int(datasets.__version__.split(".")[0]) >= 4:
514
+ example_cell = example_cell[:]
515
+
516
  if combo_lvl == 0 and tokens_to_perturb == "all":
517
  if perturb_type in ["overexpress", "activate"]:
518
  range_start = 1