| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Contains helpers to split tensors into shards.""" |
|
|
| from dataclasses import dataclass, field |
| from typing import Any, Callable, Dict, List, Optional, TypeVar, Union |
|
|
| from .. import logging |
|
|
|
|
| TensorT = TypeVar("TensorT") |
| TensorSizeFn_T = Callable[[TensorT], int] |
| StorageIDFn_T = Callable[[TensorT], Optional[Any]] |
|
|
| MAX_SHARD_SIZE = "5GB" |
| SIZE_UNITS = { |
| "TB": 10**12, |
| "GB": 10**9, |
| "MB": 10**6, |
| "KB": 10**3, |
| } |
|
|
|
|
| logger = logging.get_logger(__file__) |
|
|
|
|
| @dataclass |
| class StateDictSplit: |
| is_sharded: bool = field(init=False) |
| metadata: Dict[str, Any] |
| filename_to_tensors: Dict[str, List[str]] |
| tensor_to_filename: Dict[str, str] |
|
|
| def __post_init__(self): |
| self.is_sharded = len(self.filename_to_tensors) > 1 |
|
|
|
|
| def split_state_dict_into_shards_factory( |
| state_dict: Dict[str, TensorT], |
| *, |
| get_storage_size: TensorSizeFn_T, |
| filename_pattern: str, |
| get_storage_id: StorageIDFn_T = lambda tensor: None, |
| max_shard_size: Union[int, str] = MAX_SHARD_SIZE, |
| ) -> StateDictSplit: |
| """ |
| Split a model state dictionary in shards so that each shard is smaller than a given size. |
| |
| The shards are determined by iterating through the `state_dict` in the order of its keys. There is no optimization |
| made to make each shard as close as possible to the maximum size passed. For example, if the limit is 10GB and we |
| have tensors of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], [6+2+2GB] and not |
| [6+2+2GB], [6+2GB], [6GB]. |
| |
| <Tip warning={true}> |
| |
| If one of the model's tensor is bigger than `max_shard_size`, it will end up in its own shard which will have a |
| size greater than `max_shard_size`. |
| |
| </Tip> |
| |
| Args: |
| state_dict (`Dict[str, Tensor]`): |
| The state dictionary to save. |
| get_storage_size (`Callable[[Tensor], int]`): |
| A function that returns the size of a tensor when saved on disk in bytes. |
| get_storage_id (`Callable[[Tensor], Optional[Any]]`, *optional*): |
| A function that returns a unique identifier to a tensor storage. Multiple different tensors can share the |
| same underlying storage. This identifier is guaranteed to be unique and constant for this tensor's storage |
| during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. |
| filename_pattern (`str`, *optional*): |
| The pattern to generate the files names in which the model will be saved. Pattern must be a string that |
| can be formatted with `filename_pattern.format(suffix=...)` and must contain the keyword `suffix` |
| max_shard_size (`int` or `str`, *optional*): |
| The maximum size of each shard, in bytes. Defaults to 5GB. |
| |
| Returns: |
| [`StateDictSplit`]: A `StateDictSplit` object containing the shards and the index to retrieve them. |
| """ |
| storage_id_to_tensors: Dict[Any, List[str]] = {} |
|
|
| shard_list: List[Dict[str, TensorT]] = [] |
| current_shard: Dict[str, TensorT] = {} |
| current_shard_size = 0 |
| total_size = 0 |
|
|
| if isinstance(max_shard_size, str): |
| max_shard_size = parse_size_to_int(max_shard_size) |
|
|
| for key, tensor in state_dict.items(): |
| |
| |
| if isinstance(tensor, str): |
| logger.info("Skipping tensor %s as it is a string (bnb serialization)", key) |
| continue |
|
|
| |
| storage_id = get_storage_id(tensor) |
| if storage_id is not None: |
| if storage_id in storage_id_to_tensors: |
| |
| storage_id_to_tensors[storage_id].append(key) |
| continue |
| else: |
| |
| |
| storage_id_to_tensors[storage_id] = [key] |
|
|
| |
| tensor_size = get_storage_size(tensor) |
|
|
| |
| if tensor_size > max_shard_size: |
| total_size += tensor_size |
| shard_list.append({key: tensor}) |
| continue |
|
|
| |
| |
| if current_shard_size + tensor_size > max_shard_size: |
| shard_list.append(current_shard) |
| current_shard = {} |
| current_shard_size = 0 |
|
|
| |
| current_shard[key] = tensor |
| current_shard_size += tensor_size |
| total_size += tensor_size |
|
|
| |
| if len(current_shard) > 0: |
| shard_list.append(current_shard) |
| nb_shards = len(shard_list) |
|
|
| |
| for storage_id, keys in storage_id_to_tensors.items(): |
| |
| for shard in shard_list: |
| if keys[0] in shard: |
| for key in keys: |
| shard[key] = state_dict[key] |
| break |
|
|
| |
| if nb_shards == 1: |
| filename = filename_pattern.format(suffix="") |
| return StateDictSplit( |
| metadata={"total_size": total_size}, |
| filename_to_tensors={filename: list(state_dict.keys())}, |
| tensor_to_filename={key: filename for key in state_dict.keys()}, |
| ) |
|
|
| |
| tensor_name_to_filename = {} |
| filename_to_tensors = {} |
| for idx, shard in enumerate(shard_list): |
| filename = filename_pattern.format(suffix=f"-{idx + 1:05d}-of-{nb_shards:05d}") |
| for key in shard: |
| tensor_name_to_filename[key] = filename |
| filename_to_tensors[filename] = list(shard.keys()) |
|
|
| |
| return StateDictSplit( |
| metadata={"total_size": total_size}, |
| filename_to_tensors=filename_to_tensors, |
| tensor_to_filename=tensor_name_to_filename, |
| ) |
|
|
|
|
| def parse_size_to_int(size_as_str: str) -> int: |
| """ |
| Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). |
| |
| Supported units are "TB", "GB", "MB", "KB". |
| |
| Args: |
| size_as_str (`str`): The size to convert. Will be directly returned if an `int`. |
| |
| Example: |
| |
| ```py |
| >>> parse_size_to_int("5MB") |
| 5000000 |
| ``` |
| """ |
| size_as_str = size_as_str.strip() |
|
|
| |
| unit = size_as_str[-2:].upper() |
| if unit not in SIZE_UNITS: |
| raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") |
| multiplier = SIZE_UNITS[unit] |
|
|
| |
| try: |
| value = float(size_as_str[:-2].strip()) |
| except ValueError as e: |
| raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e |
|
|
| return int(value * multiplier) |
|
|