Finetuning vs Pretraining Speed
Hi GeneFormer Developers
Thank you again for your work and continued maintanance of the project. I have a general question regarding the behavior of finetuning training speed of Geneformer.
The Problem I currently face:
- I have a dataset of roughly 150k samples, with median expressed genes of around 4000 genes, which I want to finetune using Geneformer multi-task classification.
- To do so, I preprocessed according to the tutorials given and train with distributed enabled using 5 H200 GPUs using the 104M model
- I found that training a single epoch (batch size 16) takes a little more than 1 hour, which seems very long (once accounting for optuna optimization and higher epochs)
- In the newest paper on Bioarxiv, your team mentioned that pretraining 3 epochs with ~90 million data points with the 12 layer model (which I presume is similar or exactly the 104M model) on 8 H100s took ~44 hours
My questions are:
- Is this speed difference between finetuning and pretraining to be expected under the current code?
- If this is to be expected, would it be fair to say that this is attributed to the use of custom dynamic padding,
lengthgroupedsamplinganddeepspeedduring pretraining? ( the paper notes a 29.4x speed increase due to the padding and length grouping for pretraining, which doesn't seem to be the exact same case during finetuning) - Are there any reasons to not use
lengthgroupedsamplingfor finetuning?
Thank you for your time and patience
Thank you for your question! The bioRxiv paper is an intermediate model size that is smaller than the 104M model provided currently. The length grouped sampling is faster, but comes at a cost of potentially reducing data randomization during training if specific cell types tend to have less genes detected and therefore shorter length. The benefit of length grouping is dataset-dependent. If a dataset has many different lengths, it is more beneficial, compared to if the dataset has similar lengths to begin with. Other ways to reduce training time include quantization and freezing some of the layers of the model. Of note, we typically only fine-tune for 1 epoch to avoid potential for memorizing the training data.
Thank you for the quick and informative response, I suppose it will always be a trade-off between speed and model quality depending on the data. Quantization also seems to be only implemented for the classifier class but not yet for the multi task classification. I've enabled mixed-precision training for a speedup of ~2x for the time being (albeit with a very slight drops in quality).