| |
| |
| |
| |
| @@ -4824,11 +4824,10 @@ def _load_pretrained_model( |
| # Warmup cuda to load the weights much faster on devices |
| if device_map is not None and hf_quantizer is None: |
| expanded_device_map = expand_device_map(device_map, expected_keys) |
| - caching_allocator_warmup(model_to_load, expanded_device_map, dtype) |
| + caching_allocator_warmup(model_to_load, expanded_device_map) |
| |
| error_msgs = [] |
| mismatched_keys = [] |
| - has_multiple_shards = len(checkpoint_files) > 1 |
| # Iterate on all the shards to load the weights |
| for shard_file in checkpoint_files: |
| # Skip the load for shards that only contain disk-offloaded weights |
| @@ -4865,7 +4864,7 @@ def _load_pretrained_model( |
| prefix if loading_base_model_from_task_state_dict else "", |
| ) |
| |
| - if low_cpu_mem_usage and shard_file is not None: |
| + if low_cpu_mem_usage: |
| # Skip it with fsdp on ranks other than 0 |
| if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): |
| disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( |
| @@ -4893,10 +4892,8 @@ def _load_pretrained_model( |
| else: |
| model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) |
| |
| + # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop |
| del state_dict |
| - # force memory release if loading multiple shards |
| - if has_multiple_shards: |
| - gc.collect() |
| |
| # Adjust offloaded weights name and save if needed |
| if disk_offload_index is not None and len(disk_offload_index) > 0: |
| @@ -5789,11 +5786,24 @@ def expand_device_map(device_map, param_names): |
| return new_device_map |
| |
| |
| -def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict: |
| +def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict): |
| """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each |
| device. It allows to have one large call to Malloc, instead of recursively calling it later when loading |
| the model, which is actually the loading speed botteneck. |
| Calling this function allows to cut the model loading time by a very large margin. |
| + |
| + A few facts related to loading speed (taking into account the use of this function): |
| + - When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely |
| + to cache the different state dicts (if enough ressources/RAM are available) |
| + - Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard, |
| + and not a good idea in general as this is low level OS optimizations that depend on ressource usage anyway |
| + - As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache. |
| + The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache. |
| + These numbers are reported for TP on 4 H100 GPUs. |
| + - It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as |
| + cudaMalloc is not a bottleneck at all anymore |
| + - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices. |
| + However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end. |
| """ |
| # Remove disk and cpu devices, and cast to proper torch.device |
| accelerator_device_map = { |
| @@ -5808,31 +5818,26 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, |
| else None |
| ) |
| |
| - parameter_count = defaultdict(lambda: 0) |
| - allocation_factor = 1 |
| - if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2: |
| - allocation_factor = 2 |
| - |
| + total_byte_count = defaultdict(lambda: 0) |
| for param_name, device in accelerator_device_map.items(): |
| param = model.get_parameter_or_buffer(param_name) |
| - param_size = int(math.prod(param.shape) * allocation_factor) |
| + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` |
| + param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype) |
| |
| if tp_plan_regex is not None: |
| generic_name = re.sub(r"\.\d+\.", ".*.", param_name) |
| - param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1 |
| - |
| - parameter_count[device] += param_size |
| + param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1 |
| |
| - dtype = dtype if dtype is not None else torch.float32 |
| + total_byte_count[device] += param_byte_count |
| |
| # This will kick off the caching allocator to avoid having to Malloc afterwards |
| - for device, param_count in parameter_count.items(): |
| - max_memory_device = None |
| + for device, byte_count in total_byte_count.items(): |
| if device.type == "cuda": |
| - max_memory_device = torch.cuda.mem_get_info(device.index)[0] |
| - # allocate only if we have enough memory |
| - if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype): |
| - _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) |
| + device_memory = torch.cuda.mem_get_info(device)[0] |
| + # Allow up to 95% of max device memory |
| + byte_count = min(byte_count, int(0.95 * device_memory)) |
| + # Allocate memory |
| + _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False) |
| |
| |
| def get_disk_only_shard_files(device_map, weight_map): |
|
|