Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from typing import Any, Tuple
import os
import torch
import functools
import torch.distributed as dist
from torch import Tensor
from ..utils.parallel_states import nccl_info, get_teacher_student_parallel_state
def broadcast(input_: torch.Tensor):
src = nccl_info.group_id * nccl_info.sp_size
dist.broadcast(input_, src=src, group=nccl_info.group)
def broadcast_within_ts_unit(input_):
src = nccl_info.ts_unit_group_id * nccl_info.ts_unit_size
dist.broadcast(input_, src=src, group=nccl_info.ts_unit_group)
def broadcast_global(input_: torch.Tensor):
dist.broadcast(input_, src=0, group=None)
def broadcast_dict(input_: dict):
src = nccl_info.group_id * nccl_info.sp_size
for k, v in input_.items():
if isinstance(input_[k], torch.Tensor):
dist.broadcast(input_[k], src=src, group=nccl_info.group)
def broadcast_dict_within_ts_unit(input_: dict):
src = nccl_info.ts_unit_group_id * nccl_info.ts_unit_size
for k, v in input_.items():
if isinstance(input_[k], torch.Tensor):
dist.broadcast(input_[k], src=src, group=nccl_info.ts_unit_group)
def _all_to_all_4D(
input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None
) -> torch.tensor:
"""
all-to-all for QKV
Args:
input (torch.tensor): a tensor sharded along dim scatter dim
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
"""
assert (
input.dim() == 4
), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
seq_world_size = dist.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
bs, shard_seqlen, hc, hs = input.shape
seqlen = shard_seqlen * seq_world_size
shard_hc = hc // seq_world_size
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
input_t = (
input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs)
.transpose(0, 2)
.contiguous()
)
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(seqlen, bs, shard_hc, hs)
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
return output
elif scatter_idx == 1 and gather_idx == 2:
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
bs, seqlen, shard_hc, hs = input.shape
hc = shard_hc * seq_world_size
shard_seqlen = seqlen // seq_world_size
seq_world_size = dist.get_world_size(group)
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
input_t = (
input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs)
.transpose(0, 3)
.transpose(0, 1)
.contiguous()
.reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs)
)
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
output = output.reshape(hc, shard_seqlen, bs, hs)
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
return output
else:
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
input: Tensor,
scatter_idx: int,
gather_idx: int,
) -> Tensor:
ctx.group = group
ctx.scatter_idx = scatter_idx
ctx.gather_idx = gather_idx
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
return (
None,
SeqAllToAll4D.apply(
ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx
),
None,
None,
)
def all_to_all_4D(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
input_list = [
t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)
]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
"""All-to-all communication.
Args:
input_: input matrix
process_group: communication group
scatter_dim: scatter dimension
gather_dim: gather dimension
"""
@staticmethod
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = dist.get_world_size(process_group)
output = _all_to_all(
input_, ctx.world_size, process_group, scatter_dim, gather_dim
)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
class _AllGather(torch.autograd.Function):
"""All-gather communication with autograd support.
Args:
input_: input tensor
dim: dimension along which to concatenate
"""
@staticmethod
def forward(ctx, input_, dim):
ctx.dim = dim
world_size = nccl_info.sp_size
group = nccl_info.group
input_size = list(input_.size())
ctx.input_size = input_size[dim]
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
input_ = input_.contiguous()
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim)
return output
@staticmethod
def backward(ctx, grad_output):
world_size = nccl_info.sp_size
rank = nccl_info.rank_within_group
dim = ctx.dim
input_size = ctx.input_size
sizes = [input_size] * world_size
grad_input_list = torch.split(grad_output, sizes, dim=dim)
grad_input = grad_input_list[rank]
return grad_input, None
def all_gather(input_: torch.Tensor, dim: int = 1):
"""Performs an all-gather operation on the input tensor along the specified dimension.
Args:
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
"""
return _AllGather.apply(input_, dim)
class _AllGather_TeacherStudent(torch.autograd.Function):
"""All-gather communication with autograd support.
Args:
input_: input tensor
dim: dimension along which to concatenate
"""
@staticmethod
def forward(ctx, input_, dim):
ctx.dim = dim
world_size = nccl_info.ts_unit_size
group = nccl_info.ts_unit_group
input_size = list(input_.size())
ctx.input_size = input_size[dim]
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
input_ = input_.contiguous()
dist.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=dim)
return output
@staticmethod
def backward(ctx, grad_output):
world_size = nccl_info.ts_unit_size
rank = nccl_info.rank_within_ts_unit_group
dim = ctx.dim
input_size = ctx.input_size
sizes = [input_size] * world_size
grad_input_list = torch.split(grad_output, sizes, dim=dim)
grad_input = grad_input_list[rank]
return grad_input, None
def all_gather_ts(input_: torch.Tensor, dim: int = 1):
"""Performs an all-gather operation on the input tensor along the specified dimension.
Args:
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
Returns:
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
"""
return _AllGather_TeacherStudent.apply(input_, dim)
def prepare_sequence_parallel_data_wanx(
hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition
):
if nccl_info.sp_size == 1:
return (
hidden_states,
encoder_hidden_states,
uncond_text_states,
image_embeds,
latents_condition,
)
def prepare(hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition):
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
encoder_hidden_states = all_to_all(
encoder_hidden_states, scatter_dim=1, gather_dim=0
)
uncond_text_states = all_to_all(
uncond_text_states, scatter_dim=1, gather_dim=0
)
image_embeds = all_to_all(image_embeds, scatter_dim=1, gather_dim=0)
latents_condition = all_to_all(latents_condition, scatter_dim=2, gather_dim=0)
return (
hidden_states,
encoder_hidden_states,
uncond_text_states,
image_embeds,
latents_condition,
)
sp_size = nccl_info.sp_size
frame = hidden_states.shape[2]
assert frame % sp_size == 0, "frame should be a multiple of sp_size"
(
hidden_states,
encoder_hidden_states,
uncond_text_states,
image_embeds,
latents_condition,
) = prepare(
hidden_states,
encoder_hidden_states.repeat(1, sp_size, 1),
uncond_text_states.repeat(1, sp_size, 1),
image_embeds.repeat(1, sp_size, 1),
latents_condition,
)
return hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition
def sp_parallel_dataloader_wrapper_wanx(
dataloader, device, train_batch_size, sp_size, train_sp_batch_size
):
while True:
for data_item in dataloader:
latents, text_states, uncond_text_states, image_embeds, latents_condition = data_item
latents = latents.to(device)
text_states = text_states.to(device)
uncond_text_states = uncond_text_states.to(device)
image_embeds = image_embeds.to(device)
latents_condition = latents_condition.to(device)
frame = latents.shape[2]
if frame == 1:
yield latents, text_states, uncond_text_states, image_embeds, latents_condition
else:
latents, text_states, uncond_text_states, image_embeds, latents_condition = (
prepare_sequence_parallel_data_wanx(
latents, text_states, uncond_text_states, image_embeds, latents_condition
)
)
assert (
train_batch_size * sp_size >= train_sp_batch_size
), "train_batch_size * sp_size should be greater than train_sp_batch_size"
for iter in range(train_batch_size * sp_size // train_sp_batch_size):
st_idx = iter * train_sp_batch_size
ed_idx = (iter + 1) * train_sp_batch_size
yield (
latents[st_idx:ed_idx],
text_states[st_idx:ed_idx],
uncond_text_states[st_idx:ed_idx],
image_embeds[st_idx:ed_idx],
latents_condition[st_idx:ed_idx],
)
def prepare_sequence_parallel_data_wanx_dpo(
hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition,latents_lose
):
if nccl_info.sp_size == 1:
return (
hidden_states,
encoder_hidden_states,
uncond_text_states,
image_embeds,
latents_condition,
latents_lose,
)
def prepare(hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition, latents_lose):
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
latents_lose = all_to_all(latents_lose, scatter_dim=2, gather_dim=0)
encoder_hidden_states = all_to_all(
encoder_hidden_states, scatter_dim=1, gather_dim=0
)
uncond_text_states = all_to_all(
uncond_text_states, scatter_dim=1, gather_dim=0
)
image_embeds = all_to_all(image_embeds, scatter_dim=1, gather_dim=0)
latents_condition = all_to_all(latents_condition, scatter_dim=2, gather_dim=0)
return (
hidden_states,
encoder_hidden_states,
uncond_text_states,
image_embeds,
latents_condition,
latents_lose,
)
sp_size = nccl_info.sp_size
frame = hidden_states.shape[2]
assert frame % sp_size == 0, "frame should be a multiple of sp_size"
(
hidden_states,
encoder_hidden_states,
uncond_text_states,
image_embeds,
latents_condition,latents_lose,
) = prepare(
hidden_states,
encoder_hidden_states.repeat(1, sp_size, 1),
uncond_text_states.repeat(1, sp_size, 1),
image_embeds.repeat(1, sp_size, 1),
latents_condition,
latents_lose
)
return hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition,latents_lose
def sp_parallel_dataloader_wrapper_wanx_dpo(
dataloader, device, train_batch_size, sp_size, train_sp_batch_size
):
while True:
for data_item in dataloader:
latents, text_states, uncond_text_states, image_embeds, latents_condition,latent_lose = data_item
latents = latents.to(device)
latents_lose = latents.to(device)
text_states = text_states.to(device)
uncond_text_states = uncond_text_states.to(device)
image_embeds = image_embeds.to(device)
latents_condition = latents_condition.to(device)
frame = latents.shape[2]
if frame == 1:
yield latents, text_states, uncond_text_states, image_embeds, latents_condition,latent_lose
else:
latents, text_states, uncond_text_states, image_embeds, latents_condition, latents_lose = (
prepare_sequence_parallel_data_wanx_dpo(
latents, text_states, uncond_text_states, image_embeds, latents_condition,latent_lose
)
)
assert (
train_batch_size * sp_size >= train_sp_batch_size
), "train_batch_size * sp_size should be greater than train_sp_batch_size"
for iter in range(train_batch_size * sp_size // train_sp_batch_size):
st_idx = iter * train_sp_batch_size
ed_idx = (iter + 1) * train_sp_batch_size
yield (
latents[st_idx:ed_idx],
text_states[st_idx:ed_idx],
uncond_text_states[st_idx:ed_idx],
image_embeds[st_idx:ed_idx],
latents_condition[st_idx:ed_idx],
latents_lose[st_idx:ed_idx],
)
def prepare_sequence_parallel_data_ltx(
hidden_states, encoder_hidden_states, text_mask, uncond_text_states, uncond_text_mask
):
if nccl_info.sp_size == 1:
return (
hidden_states,
encoder_hidden_states,
text_mask,
uncond_text_states,
uncond_text_mask,
)
def prepare(hidden_states, encoder_hidden_states, text_mask, uncond_text_states, uncond_text_mask):
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
encoder_hidden_states = all_to_all(
encoder_hidden_states, scatter_dim=1, gather_dim=0
)
text_mask = all_to_all(text_mask, scatter_dim=1, gather_dim=0)
uncond_text_states = all_to_all(
uncond_text_states, scatter_dim=1, gather_dim=0
)
uncond_text_mask = all_to_all(
uncond_text_mask, scatter_dim=1, gather_dim=0
)
return (
hidden_states,
encoder_hidden_states,
text_mask,
uncond_text_states,
uncond_text_mask,
)
sp_size = nccl_info.sp_size
frame = hidden_states.shape[2]
assert frame % sp_size == 0, "frame should be a multiple of sp_size"
(
hidden_states,
encoder_hidden_states,
text_mask,
uncond_text_states,
uncond_text_mask,
) = prepare(
hidden_states,
encoder_hidden_states.repeat(1, sp_size, 1),
text_mask.repeat(1, sp_size),
uncond_text_states.repeat(1, sp_size, 1),
uncond_text_mask.repeat(1, sp_size)
)
return hidden_states, encoder_hidden_states, text_mask, uncond_text_states, uncond_text_mask
def sp_parallel_dataloader_wrapper_ltx(
dataloader, device, train_batch_size, sp_size, train_sp_batch_size
):
while True:
for data_item in dataloader:
latents, text_states, text_mask, uncond_text_states, uncond_text_mask = data_item
latents = latents.to(device)
text_states = text_states.to(device)
text_mask = text_mask.to(device)
uncond_text_states = uncond_text_states.to(device)
uncond_text_mask = uncond_text_mask.to(device)
frame = latents.shape[2]
if frame == 1:
yield latents, text_states, text_mask, uncond_text_states, uncond_text_mask
else:
latents, text_states, text_mask, uncond_text_states, uncond_text_mask = (
prepare_sequence_parallel_data_ltx(
latents, text_states, text_mask, uncond_text_states, uncond_text_mask
)
)
assert (
train_batch_size * sp_size >= train_sp_batch_size
), "train_batch_size * sp_size should be greater than train_sp_batch_size"
for iter in range(train_batch_size * sp_size // train_sp_batch_size):
st_idx = iter * train_sp_batch_size
ed_idx = (iter + 1) * train_sp_batch_size
yield (
latents[st_idx:ed_idx],
text_states[st_idx:ed_idx],
text_mask[st_idx:ed_idx],
uncond_text_states[st_idx:ed_idx],
uncond_text_mask[st_idx:ed_idx],
)
def parallelize_model(model):
original_forward = model.forward
@functools.wraps(model.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
text_states: torch.Tensor,
text_states_2: torch.Tensor,
encoder_attention_mask: torch.Tensor,
output_features=False,
output_features_stride=8,
attention_kwargs=None,
freqs_cos=None,
freqs_sin=None,
return_dict=False,
guidance=None,
):
x = hidden_states
sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
if x.shape[-2] // 2 % sp_size == 0:
# try to split x by height
split_dim = -2
elif x.shape[-1] // 2 % sp_size == 0:
# try to split x by width
split_dim = -1
else:
raise ValueError(f"Cannot split video sequence into ulysses_degree ({sp_size}) parts evenly")
_, _, ot, oh, ow = x.shape
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
freqs_cos, freqs_sin = self.get_rotary_pos_embed((tt, th, tw))
# patch sizes for the temporal, height, and width dimensions are 1, 2, and 2.
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2
x = torch.chunk(x, sp_size,dim=split_dim)[sp_rank]
dim_thw = freqs_cos.shape[-1]
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw)
freqs_cos = torch.chunk(freqs_cos, sp_size,dim=split_dim - 1)[sp_rank]
freqs_cos = freqs_cos.reshape(-1, dim_thw)
dim_thw = freqs_sin.shape[-1]
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw)
freqs_sin = torch.chunk(freqs_sin, sp_size,dim=split_dim - 1)[sp_rank]
freqs_sin = freqs_sin.reshape(-1, dim_thw)
output = original_forward(
x,
timestep,
text_states,
text_states_2,
encoder_attention_mask,
output_features,
output_features_stride,
attention_kwargs,
freqs_cos,
freqs_sin,
return_dict,
guidance,
)
return_dict = not isinstance(output, tuple)
shape = (tt, th, tw)
if return_dict:
assert not output_features, "output_feature is not compatible with return_dict"
sample = output["x"]
sample = all_gather(sample, dim=split_dim)
output["x"] = sample
else:
sample = output[0]
sample = all_gather(sample, dim=split_dim)
if output_features:
features_list = output[1]
features_list = all_gather(features_list, dim=split_dim)
else:
features_list = None
output = (sample, features_list, shape)
return output
new_forward = new_forward.__get__(model)
model.forward = new_forward
def all_reduce_tensor_item(item):
world_size = int(os.environ["WORLD_SIZE"])
item = item.detach().clone()
dist.all_reduce(item, op=dist.ReduceOp.SUM)
item = item / nccl_info.ts_group_size if get_teacher_student_parallel_state() else item / world_size
return item
def broadcast_item(item, idx):
item_list = [item]
dist.broadcast_object_list(item_list, src=idx)
return item_list[0]