problem with get_train_sampler()

#572
by scsrsao - opened

Hi...
I was trying to follow the code in your pretrainer_geneformer_w_deepspeed.py to test the training of geneformer from scratch, but when I get to the point when calling trainer.train() -in line 164- I get an error derived from the Trainer class:

1137 if self.train_dataset is None:
1138 raise ValueError("Trainer: training requires a train_dataset.")
...
-> 1109 dataloader_params["sampler"] = sampler_fn(dataset)
1110 dataloader_params["drop_last"] = self.args.dataloader_drop_last
1111 dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

TypeError: GeneformerPretrainer._get_train_sampler() takes 1 positional argument but 2 were given

As far as I understand, this is because the definition of function _get_train_sampler(self) inside GeneformerPretrainer -line 607 of pretrainer.py- takes no arguments, whilst the code inside Trainer is trying to pass the dataset as parameter (as shown above).

Do you have any guidance on how to fix this?

Thank you for your question. It's possible that the Trainer has changed compared to the version we used for pretraining as the current model was pretrained about a year ago. You could try an older version (e.g. 4.48.2) or modify the initialization of the GeneformerPretrainer to account for this discrepancy.

ctheodoris changed discussion status to closed

Sign up or log in to comment