| """Utility functions for training and inference.""" |
|
|
| import functools |
| from pathlib import Path |
| import pickle |
| import warnings |
| from io import BytesIO |
|
|
| import torch |
| import torch.utils._device |
| from lightning.fabric.strategies import DeepSpeedStrategy, FSDPStrategy |
| from torch.distributed.fsdp import FullStateDictConfig |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.distributed.fsdp import StateDictType |
|
|
|
|
| def save_model_checkpoint(fabric, model, file_path): |
| """Handles boilerplate logic for retrieving and saving the state_dict. |
| |
| This will be upstreamed to Fabric soon. |
| """ |
| file_path = Path(file_path) |
|
|
| if isinstance(fabric.strategy, DeepSpeedStrategy): |
| from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict |
|
|
| fabric.save(file_path, {"model": model}) |
| fabric.barrier() |
| if fabric.global_rank == 0: |
| |
| convert_zero_checkpoint_to_fp32_state_dict(file_path, file_path.with_suffix(".pth")) |
| return |
|
|
| if isinstance(fabric.strategy, FSDPStrategy): |
| save_policy = FullStateDictConfig(offload_to_cpu=(fabric.world_size > 1), rank0_only=True) |
| with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): |
| state_dict = model._forward_module.state_dict() |
| else: |
| state_dict = model.state_dict() |
|
|
| if fabric.global_rank == 0: |
| torch.save(state_dict, file_path) |
| fabric.barrier() |
|
|
|
|
| class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): |
| def __init__(self, device=None, dtype=None, quantization_mode=None): |
| """ |
| Create tensors with given device and dtype and don't run initialization |
| (but instead use "empty tensors", i.e. uninitialized memory). |
| |
| device: `torch.device` to work with |
| dtype: `torch.dtype` to work with |
| quantization_mode: optional string, quantization mode to work with, default `None`. |
| Available modes: `llm.int8` bitsnbytes LLM.int8 quantization (only on GPU) |
| `qptq.int4`, `gptq.int8`: GPTQ pre-quantized models |
| |
| Example:: |
| with EmptyInitOnDevice("cuda", dtype=torch.bfloat16): |
| model = LLaMA.from_name('7B') |
| model.load_state_dict(torch.load('llama-lit/7B/lit-llama.pth'))""" |
|
|
| self.quantization_mode = quantization_mode |
| self.quantized_linear_cls = None |
| if self.quantization_mode == 'llm.int8': |
| if device.type != "cuda": |
| raise ValueError("Quantization is only supported on the GPU.") |
| from .quantization import Linear8bitLt |
| self.quantized_linear_cls = Linear8bitLt |
| elif self.quantization_mode == 'gptq.int4': |
| from .quantization import ColBlockQuantizedLinear |
| self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=4, tile_cols=-1) |
| elif self.quantization_mode == 'gptq.int8': |
| from .quantization import ColBlockQuantizedLinear |
| self.quantized_linear_cls = functools.partial(ColBlockQuantizedLinear, bits=8, tile_cols=-1) |
| elif self.quantization_mode is not None: |
| raise RuntimeError(f"unknown quantization mode {self.quantization_mode}") |
| self.device = device |
| self.dtype = dtype |
|
|
| def __enter__(self): |
| if self.quantized_linear_cls != None: |
| self.torch_linear_cls = torch.nn.Linear |
| torch.nn.Linear = self.quantized_linear_cls |
| return super().__enter__() |
|
|
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.quantized_linear_cls != None: |
| torch.nn.Linear = self.torch_linear_cls |
| return super().__exit__(exc_type, exc_val, exc_tb) |
|
|
| def __torch_function__(self, func, types, args=(), kwargs=None): |
| kwargs = kwargs or {} |
| if getattr(func, "__module__", None) == "torch.nn.init": |
| if "tensor" in kwargs: |
| return kwargs["tensor"] |
| else: |
| return args[0] |
| if ( |
| self.device is not None |
| and func in torch.utils._device._device_constructors() |
| and kwargs.get("device") is None |
| ): |
| kwargs["device"] = self.device |
| if ( |
| self.dtype is not None |
| and func in torch.utils._device._device_constructors() |
| and kwargs.get("dtype") is None |
| ): |
| kwargs["dtype"] = self.dtype |
| return func(*args, **kwargs) |
|
|
|
|
| |
|
|
|
|
| class NotYetLoadedTensor: |
| def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): |
| self.metatensor = metatensor |
| self.archiveinfo = archiveinfo |
| self.storageinfo = storageinfo |
| self.rebuild_args = rebuild_args |
|
|
| @classmethod |
| def rebuild( |
| cls, |
| storage, |
| storage_offset, |
| size, |
| stride, |
| requires_grad, |
| backward_hooks, |
| metadata=None, |
| archiveinfo=None, |
| ): |
| rebuild_args = ( |
| storage_offset, |
| size, |
| stride, |
| requires_grad, |
| backward_hooks, |
| metadata, |
| ) |
| metatensor = torch._utils._rebuild_tensor_v2( |
| storage, |
| storage_offset, |
| size, |
| stride, |
| requires_grad, |
| backward_hooks, |
| metadata, |
| ) |
| storageinfo = storage.archiveinfo |
| return NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) |
|
|
| def _load_tensor(self): |
| name, storage_cls, fn, device, size = self.storageinfo |
| dtype = self.metatensor.dtype |
|
|
| uts = ( |
| self.archiveinfo.zipfile.get_storage_from_record( |
| f"data/{fn}", |
| size * torch._utils._element_size(dtype), |
| torch.UntypedStorage, |
| ) |
| ._typed_storage() |
| ._untyped_storage |
| ) |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| storage = torch.storage.TypedStorage( |
| wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True |
| ) |
| tensor = torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) |
| return tensor |
|
|
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| loaded_args = [ |
| (a._load_tensor() if isinstance(a, NotYetLoadedTensor) else a) for a in args |
| ] |
| res = func(*loaded_args, **kwargs) |
| |
| return res |
|
|
| def __getattr__(self, name): |
| |
| |
| |
| |
| if name in { |
| "dtype", |
| "grad", |
| "grad_fn", |
| "layout", |
| "names", |
| "ndim", |
| "output_nr", |
| "requires_grad", |
| "retains_grad", |
| "shape", |
| "volatile", |
| }: |
| return getattr(self.metatensor, name) |
| if name in {"size"}: |
| return getattr(self.metatensor, name) |
| |
| if name in {"contiguous"}: |
| return getattr(self._load_tensor(), name) |
|
|
| raise AttributeError(f"{type(self)} does not have {name}") |
|
|
| def __repr__(self): |
| return f"NotYetLoadedTensor({repr(self.metatensor)})" |
|
|
|
|
| class LazyLoadingUnpickler(pickle.Unpickler): |
| def __init__(self, file, zipfile): |
| super().__init__(file) |
| self.zipfile = zipfile |
|
|
| def find_class(self, module, name): |
| if module == "torch._utils" and name == "_rebuild_tensor_v2": |
| res = super().find_class(module, name) |
| return functools.partial(NotYetLoadedTensor.rebuild, archiveinfo=self) |
| return super().find_class(module, name) |
|
|
| def persistent_load(self, pid): |
| name, cls, fn, device, size = pid |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| s = torch.storage.TypedStorage(dtype=cls().dtype, device="meta") |
| s.archiveinfo = pid |
| return s |
|
|
|
|
| def lazy_load(fn): |
| zf = torch._C.PyTorchFileReader(str(fn)) |
| with BytesIO(zf.get_record("data.pkl")) as pkl: |
| mup = LazyLoadingUnpickler(pkl, zf) |
| sd = mup.load() |
| return sd |
|
|