File size: 6,241 Bytes
dfefe0b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 2898189ebb94..d4c9815bd342 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -4824,11 +4824,10 @@ def _load_pretrained_model(
# Warmup cuda to load the weights much faster on devices
if device_map is not None and hf_quantizer is None:
expanded_device_map = expand_device_map(device_map, expected_keys)
- caching_allocator_warmup(model_to_load, expanded_device_map, dtype)
+ caching_allocator_warmup(model_to_load, expanded_device_map)
error_msgs = []
mismatched_keys = []
- has_multiple_shards = len(checkpoint_files) > 1
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
@@ -4865,7 +4864,7 @@ def _load_pretrained_model(
prefix if loading_base_model_from_task_state_dict else "",
)
- if low_cpu_mem_usage and shard_file is not None:
+ if low_cpu_mem_usage:
# Skip it with fsdp on ranks other than 0
if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
@@ -4893,10 +4892,8 @@ def _load_pretrained_model(
else:
model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params)
+ # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
- # force memory release if loading multiple shards
- if has_multiple_shards:
- gc.collect()
# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:
@@ -5789,11 +5786,24 @@ def expand_device_map(device_map, param_names):
return new_device_map
-def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
+def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict):
"""This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
the model, which is actually the loading speed botteneck.
Calling this function allows to cut the model loading time by a very large margin.
+
+ A few facts related to loading speed (taking into account the use of this function):
+ - When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
+ to cache the different state dicts (if enough ressources/RAM are available)
+ - Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
+ and not a good idea in general as this is low level OS optimizations that depend on ressource usage anyway
+ - As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
+ The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
+ These numbers are reported for TP on 4 H100 GPUs.
+ - It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
+ cudaMalloc is not a bottleneck at all anymore
+ - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
+ However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
"""
# Remove disk and cpu devices, and cast to proper torch.device
accelerator_device_map = {
@@ -5808,31 +5818,26 @@ def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict,
else None
)
- parameter_count = defaultdict(lambda: 0)
- allocation_factor = 1
- if torch.distributed.is_initialized() or len(set(accelerator_device_map.values())) >= 2:
- allocation_factor = 2
-
+ total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
- param_size = int(math.prod(param.shape) * allocation_factor)
+ # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
+ param_byte_count = math.prod(param.shape) * dtype_byte_size(param.dtype)
if tp_plan_regex is not None:
generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
- param_size //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
-
- parameter_count[device] += param_size
+ param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
- dtype = dtype if dtype is not None else torch.float32
+ total_byte_count[device] += param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards
- for device, param_count in parameter_count.items():
- max_memory_device = None
+ for device, byte_count in total_byte_count.items():
if device.type == "cuda":
- max_memory_device = torch.cuda.mem_get_info(device.index)[0]
- # allocate only if we have enough memory
- if max_memory_device is None or max_memory_device > param_count * dtype_byte_size(dtype):
- _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
+ device_memory = torch.cuda.mem_get_info(device)[0]
+ # Allow up to 95% of max device memory
+ byte_count = min(byte_count, int(0.95 * device_memory))
+ # Allocate memory
+ _ = torch.empty(byte_count // 2, dtype=torch.float16, device=device, requires_grad=False)
def get_disk_only_shard_files(device_map, weight_map):
|