Advanced Optimizations Phase 4: FlashAttention & Beyond
This document outlines the next level of optimizations beyond what we've already implemented, targeting additional 2-5x speedups and better training stability.
π― New Optimizations Overview
High-Impact Optimizations
- FlashAttention - 2-4x faster attention, 50% memory reduction
- FSDP (Fully Sharded Data Parallel) - Train models that don't fit on single GPU
- BF16 (bfloat16) - Better than FP16 for training stability
- Gradient Clipping - Prevent gradient explosion
- Learning Rate Finder - Automatically find optimal LR
- Automatic Batch Size Finder - Maximize GPU utilization
- TensorRT Optimization - 5-10x faster production inference
- QAT (Quantization Aware Training) - Better INT8 quantization
- Sequence Parallelism - Handle very long sequences
- Selective Activation Recompute - Advanced memory optimization
1. FlashAttention β‘
Impact: 2-4x faster attention, 50% memory reduction
Why: DA3 uses Vision Transformers with attention mechanisms. FlashAttention uses tiled attention to avoid materializing the full attention matrix.
Implementation:
# Install: pip install flash-attn
from ylff.utils.flash_attention import FlashAttentionWrapper, check_flash_attention_available
# Check availability
if check_flash_attention_available():
# Use FlashAttention in model
# Note: This requires model-specific integration
# DA3's attention is in DinoV2, so we'd need to modify the model code
pass
Challenges:
- DA3 uses custom attention in DinoV2 (alternating local/global)
- Requires modifying model source code or creating wrappers
- FlashAttention may not support all attention patterns
Status: Utility created, requires model integration
2. FSDP (Fully Sharded Data Parallel) π
Impact: Train models that exceed single GPU memory
Why: FSDP shards parameters, gradients, and optimizer states across GPUs, allowing training of very large models.
Implementation:
from ylff.utils.fsdp_utils import wrap_model_fsdp
# Wrap model with FSDP
model = wrap_model_fsdp(
model,
sharding_strategy="FULL_SHARD", # Most memory efficient
mixed_precision="bf16", # Use BF16
auto_wrap_policy="transformer", # Auto-wrap transformer blocks
)
Benefits:
- Train models 2-4x larger than single GPU memory
- Better memory efficiency than DDP
- Works with mixed precision
Status: β Implemented
3. BF16 (bfloat16) Support π―
Impact: Better training stability than FP16, same speed
Why: BF16 has same exponent range as FP32, preventing underflow issues that FP16 can have.
Implementation:
from ylff.utils.training_utils import get_bf16_autocast_context, enable_bf16_training
# Option 1: Use BF16 autocast (recommended)
with get_bf16_autocast_context(enable=True):
output = model(inputs)
loss = loss_fn(output, targets)
# Option 2: Convert model to BF16
model = enable_bf16_training(model)
Benefits:
- More stable than FP16
- Same speed as FP16
- Better for training large models
Status: β Implemented
4. Gradient Clipping π
Impact: Prevents gradient explosion, more stable training
Implementation:
from ylff.utils.training_utils import clip_gradients
# In training loop, after backward, before optimizer.step()
loss.backward()
grad_norm = clip_gradients(model, max_norm=1.0, norm_type=2.0)
optimizer.step()
Status: β Implemented
5. Learning Rate Finder π
Impact: Automatically find optimal learning rate
Implementation:
from ylff.utils.training_utils import find_learning_rate
# Find optimal LR
result = find_learning_rate(
model=model,
train_loader=train_loader,
loss_fn=loss_fn,
min_lr=1e-8,
max_lr=1.0,
num_steps=100,
)
optimal_lr = result["best_lr"] # Use this for training
Status: β Implemented
6. Automatic Batch Size Finder π¦
Impact: Maximize GPU utilization automatically
Implementation:
from ylff.utils.training_utils import find_optimal_batch_size
# Find optimal batch size
result = find_optimal_batch_size(
model=model,
dataset=dataset,
loss_fn=loss_fn,
initial_batch_size=1,
max_batch_size=64,
)
optimal_batch = result["optimal_batch_size"] # Use this for training
Status: β Implemented
7. TensorRT Optimization ποΈ
Impact: 5-10x faster inference in production
Status: β³ Not yet implemented (requires TensorRT SDK)
Planned Implementation:
# Export to ONNX first
export_to_onnx(model, sample_input, "model.onnx")
# Then convert to TensorRT
# Requires: pip install nvidia-tensorrt
import tensorrt as trt
# TensorRT conversion (simplified)
builder = trt.Builder(logger)
network = builder.create_network()
parser = trt.OnnxParser(network, logger)
parser.parse_from_file("model.onnx")
# Build engine
engine = builder.build_engine(network, config)
8. QAT (Quantization Aware Training) π
Impact: Better INT8 quantization with minimal accuracy loss
Status: β³ Not yet implemented
Planned Implementation:
# During training, simulate quantization
from torch.quantization import prepare_qat, convert
# Prepare model for QAT
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = prepare_qat(model)
# Train normally (quantization is simulated)
# ...
# Convert to quantized after training
quantized_model = convert(model)
9. Sequence Parallelism π
Impact: Handle very long sequences by splitting across GPUs
Status: β³ Not yet implemented (requires model architecture support)
10. Selective Activation Recompute π§
Impact: Advanced memory optimization beyond gradient checkpointing
Status: β³ Not yet implemented
π Expected Combined Performance
With all Phase 4 optimizations:
- Training speed: +2-5x additional speedup (on top of existing 5-15x)
- Memory usage: Additional 30-50% reduction
- Training stability: Significantly improved (BF16, gradient clipping)
- Model size: Can train 2-4x larger models (FSDP)
π Implementation Priority
Phase 4.1: Quick Wins (1-2 days)
- β Gradient clipping
- β BF16 support
- β Learning rate finder
- β Automatic batch size finder
Phase 4.2: High Impact (3-5 days)
- β FSDP support
- β³ FlashAttention (requires model integration)
- β³ TensorRT export
Phase 4.3: Advanced (1-2 weeks)
- β³ QAT implementation
- β³ Sequence parallelism
- β³ Selective activation recompute
π Integration into Training
Updated Training Function Signature
def fine_tune_da3(
# ... existing parameters ...
# New Phase 4 parameters
use_flash_attention: bool = False,
use_fsdp: bool = False,
fsdp_sharding_strategy: str = "FULL_SHARD",
use_bf16: bool = False, # Better than FP16
gradient_clip_norm: Optional[float] = 1.0,
find_lr: bool = False, # Auto-find LR
find_batch_size: bool = False, # Auto-find batch size
# ...
):
Example Usage
# Fast training with all optimizations
fine_tune_da3(
model=model,
training_samples_info=samples,
# Existing optimizations
use_amp=True, # Or use_bf16=True for better stability
use_ema=True,
use_onecycle=True,
gradient_accumulation_steps=4,
compile_model=True,
# New Phase 4 optimizations
use_bf16=True, # Better than FP16
gradient_clip_norm=1.0,
find_lr=True, # Auto-discover optimal LR
find_batch_size=True, # Auto-discover optimal batch size
use_fsdp=True, # If model is too large
use_flash_attention=True, # If available
)
π§ Installation Requirements
FlashAttention
# Requires specific CUDA and PyTorch versions
pip install flash-attn --no-build-isolation
FSDP
# Requires PyTorch 2.0+ with distributed support
# Already included in PyTorch
TensorRT
# Requires NVIDIA TensorRT SDK
# Download from: https://developer.nvidia.com/tensorrt
pip install nvidia-tensorrt
π References
- FlashAttention: https://arxiv.org/abs/2205.14135
- FSDP: https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
- BF16: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
- LR Finder: https://arxiv.org/abs/1506.01186
- TensorRT: https://developer.nvidia.com/tensorrt
β Status Summary
| Optimization | Status | Impact | Difficulty |
|---|---|---|---|
| FlashAttention | β³ Utility created | 2-4x speedup | High (requires model mod) |
| FSDP | β Implemented | Train larger models | Medium |
| BF16 | β Implemented | Better stability | Low |
| Gradient Clipping | β Implemented | Stability | Low |
| LR Finder | β Implemented | Auto-tune LR | Low |
| Batch Size Finder | β Implemented | Auto-tune batch | Low |
| TensorRT | β³ Planned | 5-10x inference | Medium |
| QAT | β³ Planned | Better INT8 | Medium |
| Sequence Parallelism | β³ Planned | Long sequences | High |
| Activation Recompute | β³ Planned | Memory savings | Medium |
π― Next Steps
- Integrate FlashAttention into DA3's attention layers (requires model code access)
- Add TensorRT export for production inference
- Implement QAT for better quantization
- Wire up new optimizations to API endpoints
- Add comprehensive tests for all new features