| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Tuple |
|
|
| import torch |
| import torch.distributed as dist |
| from torch import Tensor |
|
|
| from fastvideo.utils.parallel_states import nccl_info |
|
|
|
|
| def broadcast(input_: torch.Tensor): |
| src = nccl_info.group_id * nccl_info.sp_size |
| dist.broadcast(input_, src=src, group=nccl_info.group) |
|
|
|
|
| def _all_to_all_4D(input: torch.tensor, |
| scatter_idx: int = 2, |
| gather_idx: int = 1, |
| group=None) -> torch.tensor: |
| """ |
| all-to-all for QKV |
| |
| Args: |
| input (torch.tensor): a tensor sharded along dim scatter dim |
| scatter_idx (int): default 1 |
| gather_idx (int): default 2 |
| group : torch process group |
| |
| Returns: |
| torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) |
| """ |
| assert ( |
| input.dim() == 4 |
| ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" |
|
|
| seq_world_size = dist.get_world_size(group) |
|
|
| if scatter_idx == 2 and gather_idx == 1: |
| |
| bs, shard_seqlen, hc, hs = input.shape |
| seqlen = shard_seqlen * seq_world_size |
| shard_hc = hc // seq_world_size |
|
|
| |
| |
| input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, |
| hs).transpose(0, 2).contiguous()) |
|
|
| output = torch.empty_like(input_t) |
| |
| |
| if seq_world_size > 1: |
| dist.all_to_all_single(output, input_t, group=group) |
| torch.cuda.synchronize() |
| else: |
| output = input_t |
| |
| output = output.reshape(seqlen, bs, shard_hc, hs) |
|
|
| |
| output = output.transpose(0, 1).contiguous().reshape( |
| bs, seqlen, shard_hc, hs) |
|
|
| return output |
|
|
| elif scatter_idx == 1 and gather_idx == 2: |
| |
| bs, seqlen, shard_hc, hs = input.shape |
| hc = shard_hc * seq_world_size |
| shard_seqlen = seqlen // seq_world_size |
| seq_world_size = dist.get_world_size(group) |
|
|
| |
| |
| input_t = (input.reshape( |
| bs, seq_world_size, shard_seqlen, shard_hc, |
| hs).transpose(0, 3).transpose(0, 1).contiguous().reshape( |
| seq_world_size, shard_hc, shard_seqlen, bs, hs)) |
|
|
| output = torch.empty_like(input_t) |
| |
| |
| if seq_world_size > 1: |
| dist.all_to_all_single(output, input_t, group=group) |
| torch.cuda.synchronize() |
| else: |
| output = input_t |
|
|
| |
| output = output.reshape(hc, shard_seqlen, bs, hs) |
|
|
| |
| output = output.transpose(0, 2).contiguous().reshape( |
| bs, shard_seqlen, hc, hs) |
|
|
| return output |
| else: |
| raise RuntimeError( |
| "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") |
|
|
|
|
| class SeqAllToAll4D(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward( |
| ctx: Any, |
| group: dist.ProcessGroup, |
| input: Tensor, |
| scatter_idx: int, |
| gather_idx: int, |
| ) -> Tensor: |
| ctx.group = group |
| ctx.scatter_idx = scatter_idx |
| ctx.gather_idx = gather_idx |
|
|
| return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) |
|
|
| @staticmethod |
| def backward(ctx: Any, |
| *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: |
| return ( |
| None, |
| SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, |
| ctx.scatter_idx), |
| None, |
| None, |
| ) |
|
|
|
|
| def all_to_all_4D( |
| input_: torch.Tensor, |
| scatter_dim: int = 2, |
| gather_dim: int = 1, |
| ): |
| return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, |
| gather_dim) |
|
|
|
|
| def _all_to_all( |
| input_: torch.Tensor, |
| world_size: int, |
| group: dist.ProcessGroup, |
| scatter_dim: int, |
| gather_dim: int, |
| ): |
| input_list = [ |
| t.contiguous() |
| for t in torch.tensor_split(input_, world_size, scatter_dim) |
| ] |
| output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
| dist.all_to_all(output_list, input_list, group=group) |
| return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
| class _AllToAll(torch.autograd.Function): |
| """All-to-all communication. |
| |
| Args: |
| input_: input matrix |
| process_group: communication group |
| scatter_dim: scatter dimension |
| gather_dim: gather dimension |
| """ |
|
|
| @staticmethod |
| def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
| ctx.process_group = process_group |
| ctx.scatter_dim = scatter_dim |
| ctx.gather_dim = gather_dim |
| ctx.world_size = dist.get_world_size(process_group) |
| output = _all_to_all(input_, ctx.world_size, process_group, |
| scatter_dim, gather_dim) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| grad_output = _all_to_all( |
| grad_output, |
| ctx.world_size, |
| ctx.process_group, |
| ctx.gather_dim, |
| ctx.scatter_dim, |
| ) |
| return ( |
| grad_output, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| def all_to_all( |
| input_: torch.Tensor, |
| scatter_dim: int = 2, |
| gather_dim: int = 1, |
| ): |
| return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) |
|
|
|
|
| class _AllGather(torch.autograd.Function): |
| """All-gather communication with autograd support. |
| |
| Args: |
| input_: input tensor |
| dim: dimension along which to concatenate |
| """ |
|
|
| @staticmethod |
| def forward(ctx, input_, dim): |
| ctx.dim = dim |
| world_size = nccl_info.sp_size |
| group = nccl_info.group |
| input_size = list(input_.size()) |
|
|
| ctx.input_size = input_size[dim] |
|
|
| tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
| input_ = input_.contiguous() |
| dist.all_gather(tensor_list, input_, group=group) |
|
|
| output = torch.cat(tensor_list, dim=dim) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| world_size = nccl_info.sp_size |
| rank = nccl_info.rank_within_group |
| dim = ctx.dim |
| input_size = ctx.input_size |
|
|
| sizes = [input_size] * world_size |
|
|
| grad_input_list = torch.split(grad_output, sizes, dim=dim) |
| grad_input = grad_input_list[rank] |
|
|
| return grad_input, None |
|
|
|
|
| def all_gather(input_: torch.Tensor, dim: int = 1): |
| """Performs an all-gather operation on the input tensor along the specified dimension. |
| |
| Args: |
| input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. |
| dim (int, optional): Dimension along which to concatenate. Defaults to 1. |
| |
| Returns: |
| torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. |
| """ |
| return _AllGather.apply(input_, dim) |
|
|
|
|
| def prepare_sequence_parallel_data( |
| encoder_hidden_states, pooled_prompt_embeds, text_ids, caption |
| ): |
| if nccl_info.sp_size == 1: |
| return ( |
| encoder_hidden_states, |
| pooled_prompt_embeds, |
| text_ids, |
| caption, |
| ) |
|
|
| def prepare( |
| encoder_hidden_states, pooled_prompt_embeds, text_ids, caption |
| ): |
| |
| encoder_hidden_states = all_to_all( |
| encoder_hidden_states, scatter_dim=1, gather_dim=0 |
| ) |
| |
| pooled_prompt_embeds = all_to_all( |
| pooled_prompt_embeds, scatter_dim=1, gather_dim=0 |
| ) |
| text_ids = all_to_all(text_ids, scatter_dim=1, gather_dim=0) |
| return ( |
| encoder_hidden_states, |
| pooled_prompt_embeds, |
| text_ids, |
| caption, |
| ) |
|
|
| sp_size = nccl_info.sp_size |
| |
| |
|
|
| ( |
| encoder_hidden_states, |
| pooled_prompt_embeds, |
| text_ids, |
| caption, |
| ) = prepare( |
| |
| encoder_hidden_states.repeat(1, sp_size, 1), |
| pooled_prompt_embeds.repeat(1, sp_size, 1, 1), |
| text_ids.repeat(1, sp_size), |
| caption, |
| ) |
|
|
| return encoder_hidden_states, pooled_prompt_embeds, text_ids, caption |
|
|
|
|
| def sp_parallel_dataloader_wrapper( |
| dataloader, device, train_batch_size, sp_size, train_sp_batch_size |
| ): |
| while True: |
| for data_item in dataloader: |
| encoder_hidden_states, pooled_prompt_embeds, text_ids, caption = data_item |
| |
| encoder_hidden_states = encoder_hidden_states.to(device) |
| pooled_prompt_embeds = pooled_prompt_embeds.to(device) |
| text_ids = text_ids.to(device) |
| |
| frame = 19 |
| if frame == 1: |
| yield encoder_hidden_states, pooled_prompt_embeds, text_ids, caption |
| else: |
| encoder_hidden_states, pooled_prompt_embeds, text_ids, caption = prepare_sequence_parallel_data( |
| encoder_hidden_states, pooled_prompt_embeds, text_ids, caption |
| ) |
| assert ( |
| train_batch_size * sp_size >= train_sp_batch_size |
| ), "train_batch_size * sp_size should be greater than train_sp_batch_size" |
| for iter in range(train_batch_size * sp_size // train_sp_batch_size): |
| st_idx = iter * train_sp_batch_size |
| ed_idx = (iter + 1) * train_sp_batch_size |
| encoder_hidden_states = encoder_hidden_states[st_idx:ed_idx] |
| pooled_prompt_embeds = pooled_prompt_embeds[st_idx:ed_idx] |
| text_ids = text_ids[st_idx:ed_idx] |
| yield ( |
| encoder_hidden_states, |
| pooled_prompt_embeds, |
| text_ids, |
| caption, |
| ) |
|
|
|
|