| |
| |
| |
| |
| @@ -21,11 +21,13 @@ |
| import inspect |
| import itertools |
| import json |
| +import math |
| import os |
| import re |
| import shutil |
| import tempfile |
| import warnings |
| +from collections import defaultdict |
| from contextlib import contextmanager |
| from dataclasses import dataclass |
| from enum import Enum |
| @@ -4816,8 +4818,13 @@ def _find_mismatched_keys( |
| folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) |
| else: |
| folder = None |
| + |
| + if device_map is not None: |
| + expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) |
| + caching_allocator_warmup(model, expanded_device_map, dtype) |
| + |
| if device_map is not None and is_safetensors: |
| - param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) |
| + param_device_map = expanded_device_map |
| str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" |
| if sharded_metadata is None: |
| archive_file = ( |
| @@ -5795,6 +5802,30 @@ def expand_device_map(device_map, param_names, start_prefix): |
| return new_device_map |
| |
| |
| +def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict: |
| + """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each |
| + device. It allows to have one large call to Malloc, instead of recursively calling it later when loading |
| + the model, which is actually the loading speed botteneck. |
| + Calling this function allows to cut the model loading time by a very large margin. |
| + """ |
| + # Remove disk and cpu devices, and cast to proper torch.device |
| + accelerator_device_map = { |
| + param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"] |
| + } |
| + parameter_count = defaultdict(lambda: 0) |
| + for param_name, device in accelerator_device_map.items(): |
| + try: |
| + param = model.get_parameter(param_name) |
| + except AttributeError: |
| + param = model.get_buffer(param_name) |
| + parameter_count[device] += math.prod(param.shape) |
| + |
| + dtype = dtype if dtype is not None else torch.float32 |
| + # This will kick off the caching allocator to avoid having to Malloc afterwards |
| + for device, param_count in parameter_count.items(): |
| + _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False) |
| + |
| + |
| def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): |
| """ |
| Returns the list of shard files containing only weights offloaded to disk. |
|
|