import torch # Configuration for Batch Chunk Sizes based on Model Size and Dtype # Reference: 4B model with bf16 uses batch_chunk_size = 64 # Mapping: (Min_Params_Billions, Max_Params_Billions) -> Recommended Batch Size (BF16/FP16) BATCH_CHUNK_MAPPING = { (0.0, 1.0): 256, # For 0.6B and similar (1.0, 3.0): 128, # For 1.7B, 2B, 3B (3.0, 6.0): 8, # For 4B, 6B (6.0, 12.0): 32, # For 7B, 8B, 10B (12.0, 25.0): 16, # For 14B, 20B (25.0, 1000.0): 8 # For 32B+, 70B } def get_batch_chunk_size(model_params_count, model_dtype): """ Determine appropriate batch chunk size based on parameter count and dtype. Args: model_params_count (int): Total number of parameters in the model. model_dtype (torch.dtype): The data type used for computation (activations). Returns: int: Recommended batch chunk size. """ # Convert to Billions params_billions = model_params_count / 1e9 # Default fallback chunk_size = 32 # Lookup in mapping for (min_b, max_b), size in BATCH_CHUNK_MAPPING.items(): if min_b <= params_billions < max_b: chunk_size = size break # Scale by Dtype # If using float32, activations take 2x memory compared to bf16/fp16 if model_dtype == torch.float32: chunk_size = max(1, chunk_size // 2) return chunk_size