Batch Processing
Hi there,
I found that this encoder processes the sample one-by-one.
How could we process a batch of samples to accelerate training speed when we need to fine-tune it for a downstream task?
Hi! The batch processing issue stems from the model's architecture requiring a feature_lens parameter for variable-length audio inputs.
The core problem:
- The model uses split_with_sizes internally to handle different audio lengths in a batch
- This causes dimension mismatches when the calculated feature lengths don't perfectly align with the model's output tensors
- Unlike Whisper (which uses fixed 30-second chunks), this model was designed for flexible input lengths, making batch processing more complex
Current workaround:
The community has implemented a fallback strategy - try batch processing first, and if it fails, automatically process samples one by one. While not ideal for training speed, it ensures stability.
For fine-tuning:
You might want to:
- Pad all audio to the same length (like Whisper does) to avoid the feature_lens issue
- Use the sequential processing with gradient accumulation to simulate larger batch sizes
- Wait for an official fix from the transformers team, but ... hope they would.
The fundamental issue is in the transformers library implementation rather than the model itself. The model architecture supports batching in theory, but the current implementation has this limitation.
Thank you so much for your information!
Initially, I processed audio samples individually through the encoder and manually stacked several encoded representations to form a batch, which was then fed into the classifier head. While this approach guaranteed functional correctness, it was suboptimal in terms of training efficiency.
After further investigation, I implemented dynamic padding so that all audio sequences in a batch share the same temporal length. I also introduced feature_lens to compute cu_seqlens, which FlashAttention-2 uses to delineate sequence boundaries and prevent cross-sample information leakage. This modification enabled proper batched processing.
However, a new issue emerged:
The encoder’s outputs were inconsistent between single-sample and batched inference, with discrepancies observed even across different batch sizes.
This phenomenon has been previously reported and is attributed to numerical instability introduced by lower-precision floating-point formats such as bfloat16. These formats can cause minor rounding errors during matrix multiplications, which accumulate through the model’s layers, resulting in noticeable output divergence between single and batch inference. Switching to float32 precision resolves this issue.
Unfortunately, FlashAttention-2 currently seems only supports bfloat16, leaving this inconsistency unresolved for now.
Thank you so much for your information!
Initially, I processed audio samples individually through the encoder and manually stacked several encoded representations to form a batch, which was then fed into the classifier head. While this approach guaranteed functional correctness, it was suboptimal in terms of training efficiency.
After further investigation, I implemented dynamic padding so that all audio sequences in a batch share the same temporal length. I also introduced feature_lens to compute cu_seqlens, which FlashAttention-2 uses to delineate sequence boundaries and prevent cross-sample information leakage. This modification enabled proper batched processing.
However, a new issue emerged:
The encoder’s outputs were inconsistent between single-sample and batched inference, with discrepancies observed even across different batch sizes.
This phenomenon has been previously reported and is attributed to numerical instability introduced by lower-precision floating-point formats such as bfloat16. These formats can cause minor rounding errors during matrix multiplications, which accumulate through the model’s layers, resulting in noticeable output divergence between single and batch inference. Switching to float32 precision resolves this issue.
Unfortunately, FlashAttention-2 currently seems only supports bfloat16, leaving this inconsistency unresolved for now.
This is a classic issue. As early as the BERT/T5 era, it was observed that using fp16 leads to noticeable differences in embeddings between single inputs and batched inputs, whereas the differences are much smaller with fp32 [https://discuss.huggingface.co/t/large-max-differences-between-single-input-processing-and-batching-with-bert-and-t5/5767]. Given that AuT itself is not large in size and the runtime speed I tested is acceptable, perhaps we should stop using Flash Attention 2.