harness / diffs /36380.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 3ebd0eacfa63..d7abc3bf7e22 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -21,11 +21,13 @@
import inspect
import itertools
import json
+import math
import os
import re
import shutil
import tempfile
import warnings
+from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
@@ -4816,8 +4818,13 @@ def _find_mismatched_keys(
folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
else:
folder = None
+
+ if device_map is not None:
+ expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
+ caching_allocator_warmup(model, expanded_device_map, dtype)
+
if device_map is not None and is_safetensors:
- param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix)
+ param_device_map = expanded_device_map
str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
@@ -5795,6 +5802,30 @@ def expand_device_map(device_map, param_names, start_prefix):
return new_device_map
+def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: Dict, dtype: torch.dtype) -> Dict:
+ """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
+ device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
+ the model, which is actually the loading speed botteneck.
+ Calling this function allows to cut the model loading time by a very large margin.
+ """
+ # Remove disk and cpu devices, and cast to proper torch.device
+ accelerator_device_map = {
+ param: torch.device(device) for param, device in expanded_device_map.items() if device not in ["cpu", "disk"]
+ }
+ parameter_count = defaultdict(lambda: 0)
+ for param_name, device in accelerator_device_map.items():
+ try:
+ param = model.get_parameter(param_name)
+ except AttributeError:
+ param = model.get_buffer(param_name)
+ parameter_count[device] += math.prod(param.shape)
+
+ dtype = dtype if dtype is not None else torch.float32
+ # This will kick off the caching allocator to avoid having to Malloc afterwards
+ for device, param_count in parameter_count.items():
+ _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
+
+
def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
"""
Returns the list of shard files containing only weights offloaded to disk.