Spaces:
Running
Running
| import torch | |
| from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union, Dict | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| from lavis.common.dist_utils import is_dist_avail_and_initialized | |
| from model.help_funcs import pad_and_concat | |
| from pytorch_lightning import strategies | |
| from lightning_fabric.utilities.types import _PATH | |
| try: | |
| from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict | |
| from deepspeed import DeepSpeedEngine | |
| _DEEPSPEED_AVAILABLE = True | |
| except ImportError: | |
| _DEEPSPEED_AVAILABLE = False | |
| ''' | |
| overwrite the function in deepspeed | |
| ''' | |
| ### start overwrite ### | |
| def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): | |
| sd = self.module.state_dict(destination, prefix, keep_vars) | |
| # Remove frozen parameter weights from state_dict if specified | |
| if exclude_frozen_parameters: | |
| to_be_removed = [] | |
| for n in sd: | |
| try: | |
| if not self.module.get_parameter(n).requires_grad: | |
| to_be_removed.append(n) | |
| except AttributeError: | |
| to_be_removed.append(n) | |
| for key in to_be_removed: | |
| sd.pop(key) | |
| if self.random_ltd_enabled(): | |
| sd = remove_random_ltd_state_dict(sd) | |
| return sd | |
| if _DEEPSPEED_AVAILABLE: | |
| DeepSpeedEngine.module_state_dict = module_state_dict | |
| ### end overwrite ### | |
| _DeepSpeedStrategyBase = strategies.DeepSpeedStrategy if _DEEPSPEED_AVAILABLE else object | |
| class MyDeepSpeedStrategy(_DeepSpeedStrategyBase): | |
| def save_checkpoint_v1( | |
| self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None | |
| ): | |
| """Save model/training states as a checkpoint file through state-dump and file-write. | |
| Args: | |
| checkpoint: dict containing model and trainer state | |
| filepath: write-target file's path | |
| storage_options: parameter for how to save to st | |
| orage, passed to ``CheckpointIO`` plugin | |
| """ | |
| if self.is_global_zero: | |
| self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) | |
| def save_checkpoint(self, checkpoint: Dict, filepath: _PATH, storage_options: Optional[Any] = None) -> None: | |
| """Save model/training states as a checkpoint file through state-dump and file-write. | |
| Args: | |
| checkpoint: The checkpoint state dictionary | |
| filepath: write-target file's path | |
| storage_options: not used for ``DeepSpeedStrategy`` as ``CheckpointIO`` is not used | |
| Raises: | |
| TypeError: | |
| If ``storage_options`` arg is passed in | |
| """ | |
| # broadcast the filepath from rank 0 to ensure all the states are saved in a common filepath | |
| filepath = self.broadcast(filepath) | |
| if storage_options is not None: | |
| raise TypeError( | |
| "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" | |
| f" is not supported for `{self.__class__.__name__}` as `CheckpointIO` is not used." | |
| ) | |
| if self.zero_stage_3 and self._multi_device and self.is_global_zero: | |
| print( | |
| "Warning: When saving the DeepSpeed Stage 3 checkpoint, " | |
| "each worker will save a shard of the checkpoint within a directory. " | |
| "If a single file is required after training, " | |
| "see https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html#" | |
| "deepspeed-zero-stage-3-single-file for instructions." | |
| ) | |
| # Use deepspeed's internal checkpointing function to handle partitioned weights across processes | |
| # dump states as a checkpoint dictionary object | |
| _exclude_keys = ["state_dict", "optimizer_states"] | |
| checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} | |
| # MODIFIED: Save ALL parameters including frozen ones (PLM base, LLM base) | |
| # This makes checkpoints larger (~14GB) but self-contained | |
| # Original: exclude_frozen_parameters=True (only LoRA, ~500MB) | |
| self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint", exclude_frozen_parameters=False) | |
| def pl_concat_all_gather(tensor, padding=False, fill_value=0): | |
| """ | |
| Performs all_gather operation on the provided tensors. | |
| *** Warning ***: torch.distributed.all_gather has no gradient. | |
| """ | |
| # if use distributed training | |
| if not is_dist_avail_and_initialized(): | |
| return tensor | |
| tensors_gather = gather_all_tensors(tensor) | |
| if padding: | |
| output = pad_and_concat(tensors_gather, fill_value=fill_value).detach() | |
| else: | |
| output = torch.cat(tensors_gather, dim=0) | |
| return output | |
| def gather_all_tensors(*args: Any, **kwargs: Any) -> Any: | |
| return _gather_all_tensors(*args, **kwargs) | |
| def _gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]: | |
| """Function to gather all tensors from several DDP processes onto a list that is broadcasted to all processes. | |
| Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |
| tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |
| Args: | |
| result: The value to sync | |
| group: The process group to gather results from. Defaults to all processes (world) | |
| Return: | |
| gathered_result: List with size equal to the process group where | |
| gathered_result[i] corresponds to result tensor from process i | |
| """ | |
| if group is None: | |
| group = torch.distributed.group.WORLD | |
| # Convert tensors to contiguous format | |
| result = result.contiguous() | |
| world_size = torch.distributed.get_world_size(group) | |
| torch.distributed.barrier(group=group) | |
| # If the tensor is scalar, things are easy | |
| if result.ndim == 0: | |
| return _simple_gather_all_tensors(result, group, world_size) | |
| # 1. Gather sizes of all tensors | |
| local_size = torch.tensor(result.shape, device=result.device) | |
| local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |
| torch.distributed.all_gather(local_sizes, local_size, group=group) | |
| max_size = torch.stack(local_sizes).max(dim=0).values | |
| all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |
| # 2. If shapes are all the same, then do a simple gather: | |
| if all_sizes_equal: | |
| return _simple_gather_all_tensors(result, group, world_size) | |
| # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |
| pad_dims = [] | |
| pad_by = (max_size - local_size).detach().cpu() | |
| for val in reversed(pad_by): | |
| pad_dims.append(0) | |
| pad_dims.append(val.item()) | |
| result_padded = F.pad(result, pad_dims) | |
| gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |
| torch.distributed.all_gather(gathered_result, result_padded, group) | |
| for idx, item_size in enumerate(local_sizes): | |
| slice_param = [slice(dim_size) for dim_size in item_size] | |
| gathered_result[idx] = gathered_result[idx][slice_param] | |
| return gathered_result | |
| def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]: | |
| gathered_result = [torch.zeros_like(result) for _ in range(world_size)] | |
| torch.distributed.all_gather(gathered_result, result, group) | |
| return gathered_result |