| | |
| |
|
| | """Utility functions for training and inference.""" |
| | import inspect |
| | import math |
| | import os |
| | import pickle |
| | import shutil |
| | import sys |
| | from dataclasses import asdict, is_dataclass |
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import ( |
| | TYPE_CHECKING, |
| | Any, |
| | Dict, |
| | Iterable, |
| | List, |
| | Literal, |
| | Mapping, |
| | Optional, |
| | TypeVar, |
| | Union, |
| | ) |
| |
|
| | import lightning as L |
| | import torch |
| | import torch.nn as nn |
| | import torch.utils._device |
| | import yaml |
| | from lightning.fabric.loggers import CSVLogger, TensorBoardLogger |
| | from lightning.fabric.strategies import FSDPStrategy |
| | from lightning.fabric.utilities.load import _lazy_load as lazy_load |
| | from lightning.pytorch.loggers import WandbLogger |
| | from lightning.pytorch.cli import instantiate_class |
| | from torch.serialization import normalize_storage_type |
| | from typing_extensions import Self |
| |
|
| | if TYPE_CHECKING: |
| | from litgpt import GPT, Config |
| |
|
| |
|
| | def init_out_dir(out_dir: Path) -> Path: |
| | if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ: |
| | return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir |
| | return out_dir |
| |
|
| |
|
| | def find_resume_path( |
| | resume: Union[bool, Literal["auto"], Path], out_dir: Path |
| | ) -> Optional[Path]: |
| | if not resume or isinstance(resume, Path): |
| | return resume |
| |
|
| | resume_path = max( |
| | out_dir.rglob("step-*/*.pth"), |
| | key=(lambda p: int(p.parent.name.split("-")[1])), |
| | default=None, |
| | ) |
| | if resume == "auto": |
| | return resume_path |
| | if resume is True and resume_path is None: |
| | raise FileNotFoundError( |
| | f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`." |
| | ) |
| | return resume_path |
| |
|
| |
|
| | def find_multiple(n: int, k: int) -> int: |
| | assert k > 0 |
| | if n % k == 0: |
| | return n |
| | return n + k - (n % k) |
| |
|
| |
|
| | def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int: |
| | total = 0 |
| | for p in module.parameters(): |
| | if requires_grad is None or p.requires_grad == requires_grad: |
| | if hasattr(p, "quant_state"): |
| | |
| | total += math.prod(p.quant_state.shape) |
| | else: |
| | total += p.numel() |
| | return total |
| |
|
| |
|
| | def reset_parameters(module: nn.Module) -> None: |
| | """Calls `reset_parameters` on the module and all its submodules.""" |
| | for mod in module.modules(): |
| | if callable(getattr(mod, "reset_parameters", None)): |
| | mod.reset_parameters() |
| |
|
| |
|
| | def check_valid_checkpoint_dir( |
| | checkpoint_dir: Path, |
| | model_filename: str = "lit_model.pth", |
| | verbose: bool = True, |
| | raise_error: bool = False, |
| | ) -> None: |
| | files = { |
| | model_filename: (checkpoint_dir / model_filename).is_file(), |
| | "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(), |
| | "tokenizer.json OR tokenizer.model": ( |
| | checkpoint_dir / "tokenizer.json" |
| | ).is_file() |
| | or (checkpoint_dir / "tokenizer.model").is_file(), |
| | "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(), |
| | } |
| | if checkpoint_dir.is_dir(): |
| | if all(files.values()): |
| | |
| | return |
| | problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" |
| | else: |
| | problem = " is not a checkpoint directory" |
| |
|
| | |
| | available = list(Path("checkpoints").glob("*/*")) |
| | if available: |
| | options = "\n".join([""] + [repr(str(p.resolve())) for p in available]) |
| | extra = f"\nYou have downloaded locally:{options}\n" |
| | else: |
| | extra = "" |
| |
|
| | if verbose: |
| | error_message = ( |
| | f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." |
| | "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n" |
| | f"{extra}\nSee all download options by running:\n litgpt download" |
| | ) |
| | print(error_message, file=sys.stderr) |
| |
|
| | if raise_error: |
| | raise FileNotFoundError( |
| | f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." |
| | ) |
| | else: |
| | raise SystemExit(1) |
| |
|
| |
|
| | class SavingProxyForStorage: |
| | def __init__(self, obj, saver, protocol_version=5): |
| | self.protocol_version = protocol_version |
| | self.saver = saver |
| | if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)): |
| | raise TypeError(f"expected storage, not {type(obj)}") |
| |
|
| | |
| | if isinstance(obj, torch.storage.TypedStorage): |
| | |
| | storage = obj._untyped_storage |
| | storage_type_str = obj._pickle_storage_type() |
| | storage_type = getattr(torch, storage_type_str) |
| | storage_numel = obj._size() |
| | else: |
| | storage = obj |
| | storage_type = normalize_storage_type(type(obj)) |
| | storage_numel = storage.nbytes() |
| |
|
| | storage_key = saver._write_storage_and_return_key(storage) |
| | location = torch.serialization.location_tag(storage) |
| |
|
| | self.storage_info = ( |
| | "storage", |
| | storage_type, |
| | storage_key, |
| | location, |
| | storage_numel, |
| | ) |
| |
|
| | def __reduce_ex__(self, protocol_version): |
| | assert False, "this should be handled with out of band" |
| |
|
| |
|
| | class SavingProxyForTensor: |
| | def __init__(self, tensor, saver, protocol_version=5): |
| | self.protocol_version = protocol_version |
| | self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version) |
| | if reduce_args[0] == torch._utils._rebuild_tensor_v2: |
| | |
| | (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args |
| | assert isinstance( |
| | storage, torch.storage.TypedStorage |
| | ), "Please check for updates" |
| | storage_proxy = SavingProxyForStorage( |
| | storage, saver, protocol_version=protocol_version |
| | ) |
| | self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args) |
| | else: |
| | (storage, *other_reduce_args) = reduce_args |
| | assert isinstance( |
| | storage, torch.storage.TypedStorage |
| | ), "Please check for updates" |
| | storage_proxy = SavingProxyForStorage( |
| | storage, saver, protocol_version=protocol_version |
| | ) |
| | self.reduce_args = (storage_proxy, *other_reduce_args) |
| |
|
| | def __reduce_ex__(self, protocol_version): |
| | if protocol_version != self.protocol_version: |
| | raise RuntimeError( |
| | f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}" |
| | ) |
| | return self.reduce_ret_fn, self.reduce_args |
| |
|
| |
|
| | class IncrementalPyTorchPickler(pickle.Pickler): |
| | def __init__(self, saver, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.storage_dtypes = {} |
| | self.saver = saver |
| | self.id_map = {} |
| |
|
| | |
| | def persistent_id(self, obj): |
| | |
| | |
| | |
| | |
| | |
| | if isinstance(obj, SavingProxyForStorage): |
| | return obj.storage_info |
| |
|
| | if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj): |
| | if isinstance(obj, torch.storage.TypedStorage): |
| | |
| | |
| | storage = obj._untyped_storage |
| | storage_dtype = obj.dtype |
| | storage_type_str = obj._pickle_storage_type() |
| | storage_type = getattr(torch, storage_type_str) |
| | storage_numel = obj._size() |
| |
|
| | else: |
| | storage = obj |
| | storage_dtype = torch.uint8 |
| | storage_type = normalize_storage_type(type(obj)) |
| | storage_numel = storage.nbytes() |
| |
|
| | |
| | |
| | |
| | if storage.data_ptr() != 0: |
| | if storage.data_ptr() in self.storage_dtypes: |
| | if storage_dtype != self.storage_dtypes[storage.data_ptr()]: |
| | raise RuntimeError( |
| | "Cannot save multiple tensors or storages that view the same data as different types" |
| | ) |
| | else: |
| | self.storage_dtypes[storage.data_ptr()] = storage_dtype |
| |
|
| | storage_key = self.id_map.get(storage._cdata) |
| | if storage_key is None: |
| | storage_key = self.saver._write_storage_and_return_key(storage) |
| | self.id_map[storage._cdata] = storage_key |
| | location = torch.serialization.location_tag(storage) |
| |
|
| | return ("storage", storage_type, storage_key, location, storage_numel) |
| |
|
| | return None |
| |
|
| |
|
| | class incremental_save: |
| | def __init__(self, name): |
| | self.name = name |
| | self.zipfile = torch._C.PyTorchFileWriter(str(name)) |
| | self.has_saved = False |
| | self.next_key = 0 |
| |
|
| | def __enter__(self): |
| | return self |
| |
|
| | def store_early(self, tensor): |
| | if isinstance(tensor, torch.Tensor): |
| | return SavingProxyForTensor(tensor, self) |
| | raise TypeError(f"can only store tensors early, not {type(tensor)}") |
| |
|
| | def save(self, obj): |
| | if self.has_saved: |
| | raise RuntimeError("have already saved") |
| | |
| | data_buf = BytesIO() |
| | pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) |
| | pickler.dump(obj) |
| | data_value = data_buf.getvalue() |
| | self.zipfile.write_record("data.pkl", data_value, len(data_value)) |
| | self.has_saved = True |
| |
|
| | def _write_storage_and_return_key(self, storage): |
| | if self.has_saved: |
| | raise RuntimeError("have already saved") |
| | key = self.next_key |
| | self.next_key += 1 |
| | name = f"data/{key}" |
| | if storage.device.type != "cpu": |
| | storage = storage.cpu() |
| | num_bytes = storage.nbytes() |
| | self.zipfile.write_record(name, storage.data_ptr(), num_bytes) |
| | return key |
| |
|
| | def __exit__(self, type, value, traceback): |
| | self.zipfile.write_end_of_file() |
| |
|
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | def chunked_cross_entropy( |
| | logits: Union[torch.Tensor, List[torch.Tensor]], |
| | targets: torch.Tensor, |
| | chunk_size: int = 128, |
| | ignore_index: int = -100, |
| | ) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if isinstance(logits, list): |
| | |
| | if chunk_size == 0: |
| | logits = torch.cat(logits, dim=1) |
| | logits = logits.reshape(-1, logits.size(-1)) |
| | targets = targets.reshape(-1) |
| | return torch.nn.functional.cross_entropy( |
| | logits, targets, ignore_index=ignore_index |
| | ) |
| |
|
| | |
| | logit_chunks = [ |
| | logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits |
| | ] |
| | target_chunks = [ |
| | target_chunk.reshape(-1) |
| | for target_chunk in targets.split(logits[0].size(1), dim=1) |
| | ] |
| | loss_chunks = [ |
| | torch.nn.functional.cross_entropy( |
| | logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none" |
| | ) |
| | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) |
| | ] |
| | non_masked_elems = (targets != ignore_index).sum() |
| | |
| | return torch.cat(loss_chunks).sum() / non_masked_elems.maximum( |
| | torch.ones_like(non_masked_elems) |
| | ) |
| |
|
| | |
| | logits = logits.reshape(-1, logits.size(-1)) |
| | targets = targets.reshape(-1) |
| | if chunk_size == 0: |
| | return torch.nn.functional.cross_entropy( |
| | logits, targets, ignore_index=ignore_index |
| | ) |
| |
|
| | |
| | logit_chunks = logits.split(chunk_size) |
| | target_chunks = targets.split(chunk_size) |
| | loss_chunks = [ |
| | torch.nn.functional.cross_entropy( |
| | logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none" |
| | ) |
| | for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) |
| | ] |
| | non_masked_elems = (targets != ignore_index).sum() |
| | |
| | |
| | |
| | |
| | return torch.cat(loss_chunks).sum() / non_masked_elems.maximum( |
| | torch.ones_like(non_masked_elems) |
| | ) |
| |
|
| |
|
| | def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict: |
| | for checkpoint_name, attribute_name in mapping.items(): |
| | full_checkpoint_name = prefix + checkpoint_name |
| | if full_checkpoint_name in state_dict: |
| | full_attribute_name = prefix + attribute_name |
| | state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name) |
| | return state_dict |
| |
|
| |
|
| | def get_default_supported_precision(training: bool) -> str: |
| | """Return default precision that is supported by the hardware: either `bf16` or `16`. |
| | |
| | Args: |
| | training: `-mixed` or `-true` version of the precision to use |
| | |
| | Returns: |
| | default precision that is suitable for the task and is supported by the hardware |
| | """ |
| | from lightning.fabric.accelerators import MPSAccelerator |
| |
|
| | if MPSAccelerator.is_available() or ( |
| | torch.cuda.is_available() and not torch.cuda.is_bf16_supported() |
| | ): |
| | return "16-mixed" if training else "16-true" |
| | return "bf16-mixed" if training else "bf16-true" |
| |
|
| |
|
| | def load_checkpoint( |
| | fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True |
| | ) -> None: |
| | if isinstance(fabric.strategy, FSDPStrategy): |
| | fabric.load_raw(checkpoint_path, model, strict=strict) |
| | else: |
| | state_dict = lazy_load(checkpoint_path) |
| | state_dict = state_dict.get("model", state_dict) |
| | model.load_state_dict(state_dict, strict=strict) |
| |
|
| |
|
| | def flops_per_param( |
| | max_seq_length: int, n_layer: int, n_embd: int, n_params: int |
| | ) -> int: |
| | flops_per_token = ( |
| | 2 * n_params |
| | ) |
| | |
| | |
| | flops_per_seq = flops_per_token * max_seq_length |
| | attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) |
| | return flops_per_seq + attn_flops_per_seq |
| |
|
| |
|
| | def estimate_flops(model: "GPT", training: bool) -> int: |
| | """Measures estimated FLOPs for MFU. |
| | |
| | Refs: |
| | * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 |
| | * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 |
| | """ |
| | |
| | |
| | |
| | |
| | n_trainable_params = num_parameters(model, requires_grad=True) |
| | trainable_flops = flops_per_param( |
| | model.max_seq_length, |
| | model.config.n_layer, |
| | model.config.n_embd, |
| | n_trainable_params, |
| | ) |
| | |
| | ops_per_step = 3 if training else 1 |
| | n_frozen_params = num_parameters(model, requires_grad=False) |
| | frozen_flops = flops_per_param( |
| | model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params |
| | ) |
| | |
| | frozen_ops_per_step = 2 if training else 1 |
| | return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops |
| |
|
| |
|
| | class CycleIterator: |
| | """An iterator that cycles through an iterable indefinitely. |
| | |
| | Example: |
| | >>> iterator = CycleIterator([1, 2, 3]) |
| | >>> [next(iterator) for _ in range(5)] |
| | [1, 2, 3, 1, 2] |
| | |
| | Note: |
| | Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable. |
| | """ |
| |
|
| | def __init__(self, iterable: Iterable) -> None: |
| | self.iterable = iterable |
| | self.epoch = 0 |
| | self._iterator = None |
| |
|
| | def __next__(self) -> Any: |
| | if self._iterator is None: |
| | self._iterator = iter(self.iterable) |
| | try: |
| | return next(self._iterator) |
| | except StopIteration: |
| | self._iterator = iter(self.iterable) |
| | self.epoch += 1 |
| | return next(self._iterator) |
| |
|
| | def __iter__(self) -> Self: |
| | return self |
| |
|
| |
|
| | def copy_config_files(source_dir: Path, out_dir: Path) -> None: |
| | """Copies the specified configuration and tokenizer files into the output directory.""" |
| |
|
| | config_files = ["config.json", "generation_config.json", "model_config.yaml"] |
| | tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"] |
| |
|
| | for file_name in config_files + tokenizer_files: |
| | src_path = source_dir / file_name |
| | if src_path.exists(): |
| | shutil.copy(src_path, out_dir) |
| |
|
| |
|
| | def CLI(*args: Any, **kwargs: Any) -> Any: |
| | from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options |
| |
|
| | set_docstring_parse_options(attribute_docstrings=True) |
| | set_config_read_mode(urls_enabled=True) |
| |
|
| | return CLI(*args, **kwargs) |
| |
|
| |
|
| | def capture_hparams() -> Dict[str, Any]: |
| | """Captures the local variables ('hyperparameters') from where this function gets called.""" |
| | caller_frame = inspect.currentframe().f_back |
| | locals_of_caller = caller_frame.f_locals |
| | hparams = {} |
| | for name, value in locals_of_caller.items(): |
| | if value is None or isinstance(value, (int, float, str, bool, Path)): |
| | hparams[name] = value |
| | elif is_dataclass(value): |
| | hparams[name] = asdict(value) |
| | else: |
| | hparams[name] = str(value) |
| | return hparams |
| |
|
| |
|
| | def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None: |
| | """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" |
| | from jsonargparse import capture_parser |
| |
|
| | |
| | |
| | |
| | known_commands = [ |
| | ("finetune_full",), |
| | ("finetune_lora",), |
| | ("finetune_adapter",), |
| | ("finetune_adapter_v2",), |
| | ("finetune",), |
| | ("pretrain",), |
| | ] |
| | for known_command in known_commands: |
| | unwanted = slice(1, 1 + len(known_command)) |
| | if tuple(sys.argv[unwanted]) == known_command: |
| | sys.argv[unwanted] = [] |
| |
|
| | parser = capture_parser(lambda: CLI(function)) |
| | config = parser.parse_args() |
| | parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True) |
| |
|
| |
|
| | def save_config(config: "Config", checkpoint_dir: Path) -> None: |
| | config_dict = asdict(config) |
| | with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp: |
| | yaml.dump(config_dict, fp) |
| |
|
| |
|
| | def parse_devices(devices: Union[str, int]) -> int: |
| | if devices in (-1, "auto"): |
| | return torch.cuda.device_count() or 1 |
| | if isinstance(devices, int) and devices > 0: |
| | return devices |
| | raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}") |
| |
|
| |
|
| | def choose_logger( |
| | logger_name: Literal["csv", "tensorboard", "wandb"], |
| | out_dir: Path, |
| | name: str, |
| | log_interval: int = 1, |
| | resume: Optional[bool] = None, |
| | **kwargs: Any, |
| | ): |
| | if logger_name == "csv": |
| | return CSVLogger( |
| | root_dir=(out_dir / "logs"), |
| | name="csv", |
| | flush_logs_every_n_steps=log_interval, |
| | **kwargs, |
| | ) |
| | if logger_name == "tensorboard": |
| | return TensorBoardLogger( |
| | root_dir=(out_dir / "logs"), name="tensorboard", **kwargs |
| | ) |
| | if logger_name == "wandb": |
| | return WandbLogger(project=name, resume=resume, **kwargs) |
| | raise ValueError( |
| | f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'." |
| | ) |
| |
|
| |
|
| | def get_argument_names(cls): |
| | sig = inspect.signature(cls.__init__) |
| | return { |
| | name |
| | for name, param in sig.parameters.items() |
| | if param.kind |
| | in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY] |
| | } |
| |
|
| |
|
| | def instantiate_bnb_optimizer(optimizer, model_parameters): |
| | if (isinstance(optimizer, str) and "AdamW" not in optimizer) or ( |
| | isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "") |
| | ): |
| | raise ValueError( |
| | "The chosen quantization format only supports the AdamW optimizer." |
| | ) |
| |
|
| | import bitsandbytes as bnb |
| |
|
| | if isinstance(optimizer, str): |
| | optimizer = bnb.optim.PagedAdamW(model_parameters) |
| | else: |
| | optim_args = get_argument_names(bnb.optim.PagedAdamW) |
| | allowed_kwargs = { |
| | key: optimizer["init_args"][key] |
| | for key in optim_args & optimizer["init_args"].keys() |
| | } |
| | optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs) |
| | return optimizer |
| |
|
| |
|
| | def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs): |
| | if isinstance(optimizer, str): |
| | optimizer_cls = getattr(torch.optim, optimizer) |
| | optimizer = optimizer_cls(model_parameters, **kwargs) |
| | else: |
| | optimizer = dict(optimizer) |
| | optimizer["init_args"].update(kwargs) |
| | optimizer = instantiate_class(model_parameters, optimizer) |
| | return optimizer |
| |
|
| |
|
| | def extend_checkpoint_dir(checkpoint_dir: Path) -> Path: |
| | new_checkpoint_dir = "checkpoints" / checkpoint_dir |
| | should_return_new_dir = ( |
| | not checkpoint_dir.is_dir() |
| | and checkpoint_dir.parts[0] != "checkpoints" |
| | and not checkpoint_dir.is_absolute() |
| | and new_checkpoint_dir.exists() |
| | ) |
| | return new_checkpoint_dir if should_return_new_dir else checkpoint_dir |
| |
|