|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Tuple |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import functools |
|
|
import torch.distributed as dist |
|
|
from torch import Tensor |
|
|
|
|
|
from ..utils.parallel_states import nccl_info, get_teacher_student_parallel_state |
|
|
|
|
|
|
|
|
def broadcast(input_: torch.Tensor): |
|
|
src = nccl_info.group_id * nccl_info.sp_size |
|
|
dist.broadcast(input_, src=src, group=nccl_info.group) |
|
|
|
|
|
def broadcast_within_ts_unit(input_): |
|
|
src = nccl_info.ts_unit_group_id * nccl_info.ts_unit_size |
|
|
dist.broadcast(input_, src=src, group=nccl_info.ts_unit_group) |
|
|
|
|
|
def broadcast_global(input_: torch.Tensor): |
|
|
dist.broadcast(input_, src=0, group=None) |
|
|
|
|
|
def broadcast_dict(input_: dict): |
|
|
src = nccl_info.group_id * nccl_info.sp_size |
|
|
for k, v in input_.items(): |
|
|
if isinstance(input_[k], torch.Tensor): |
|
|
dist.broadcast(input_[k], src=src, group=nccl_info.group) |
|
|
|
|
|
def broadcast_dict_within_ts_unit(input_: dict): |
|
|
src = nccl_info.ts_unit_group_id * nccl_info.ts_unit_size |
|
|
for k, v in input_.items(): |
|
|
if isinstance(input_[k], torch.Tensor): |
|
|
dist.broadcast(input_[k], src=src, group=nccl_info.ts_unit_group) |
|
|
|
|
|
def _all_to_all_4D( |
|
|
input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None |
|
|
) -> torch.tensor: |
|
|
""" |
|
|
all-to-all for QKV |
|
|
|
|
|
Args: |
|
|
input (torch.tensor): a tensor sharded along dim scatter dim |
|
|
scatter_idx (int): default 1 |
|
|
gather_idx (int): default 2 |
|
|
group : torch process group |
|
|
|
|
|
Returns: |
|
|
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) |
|
|
""" |
|
|
assert ( |
|
|
input.dim() == 4 |
|
|
), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}" |
|
|
|
|
|
seq_world_size = dist.get_world_size(group) |
|
|
|
|
|
if scatter_idx == 2 and gather_idx == 1: |
|
|
|
|
|
bs, shard_seqlen, hc, hs = input.shape |
|
|
seqlen = shard_seqlen * seq_world_size |
|
|
shard_hc = hc // seq_world_size |
|
|
|
|
|
|
|
|
|
|
|
input_t = ( |
|
|
input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs) |
|
|
.transpose(0, 2) |
|
|
.contiguous() |
|
|
) |
|
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
|
|
|
|
if seq_world_size > 1: |
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
torch.cuda.synchronize() |
|
|
else: |
|
|
output = input_t |
|
|
|
|
|
output = output.reshape(seqlen, bs, shard_hc, hs) |
|
|
|
|
|
|
|
|
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs) |
|
|
|
|
|
return output |
|
|
|
|
|
elif scatter_idx == 1 and gather_idx == 2: |
|
|
|
|
|
bs, seqlen, shard_hc, hs = input.shape |
|
|
hc = shard_hc * seq_world_size |
|
|
shard_seqlen = seqlen // seq_world_size |
|
|
seq_world_size = dist.get_world_size(group) |
|
|
|
|
|
|
|
|
|
|
|
input_t = ( |
|
|
input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs) |
|
|
.transpose(0, 3) |
|
|
.transpose(0, 1) |
|
|
.contiguous() |
|
|
.reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs) |
|
|
) |
|
|
|
|
|
output = torch.empty_like(input_t) |
|
|
|
|
|
|
|
|
if seq_world_size > 1: |
|
|
dist.all_to_all_single(output, input_t, group=group) |
|
|
torch.cuda.synchronize() |
|
|
else: |
|
|
output = input_t |
|
|
|
|
|
|
|
|
output = output.reshape(hc, shard_seqlen, bs, hs) |
|
|
|
|
|
|
|
|
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs) |
|
|
|
|
|
return output |
|
|
else: |
|
|
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2") |
|
|
|
|
|
|
|
|
class SeqAllToAll4D(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx: Any, |
|
|
group: dist.ProcessGroup, |
|
|
input: Tensor, |
|
|
scatter_idx: int, |
|
|
gather_idx: int, |
|
|
) -> Tensor: |
|
|
ctx.group = group |
|
|
ctx.scatter_idx = scatter_idx |
|
|
ctx.gather_idx = gather_idx |
|
|
|
|
|
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: |
|
|
return ( |
|
|
None, |
|
|
SeqAllToAll4D.apply( |
|
|
ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx |
|
|
), |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
def all_to_all_4D( |
|
|
input_: torch.Tensor, |
|
|
scatter_dim: int = 2, |
|
|
gather_dim: int = 1, |
|
|
): |
|
|
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim) |
|
|
|
|
|
|
|
|
def _all_to_all( |
|
|
input_: torch.Tensor, |
|
|
world_size: int, |
|
|
group: dist.ProcessGroup, |
|
|
scatter_dim: int, |
|
|
gather_dim: int, |
|
|
): |
|
|
input_list = [ |
|
|
t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim) |
|
|
] |
|
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
|
dist.all_to_all(output_list, input_list, group=group) |
|
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
|
|
|
class _AllToAll(torch.autograd.Function): |
|
|
"""All-to-all communication. |
|
|
|
|
|
Args: |
|
|
input_: input matrix |
|
|
process_group: communication group |
|
|
scatter_dim: scatter dimension |
|
|
gather_dim: gather dimension |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
|
|
ctx.process_group = process_group |
|
|
ctx.scatter_dim = scatter_dim |
|
|
ctx.gather_dim = gather_dim |
|
|
ctx.world_size = dist.get_world_size(process_group) |
|
|
output = _all_to_all( |
|
|
input_, ctx.world_size, process_group, scatter_dim, gather_dim |
|
|
) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
grad_output = _all_to_all( |
|
|
grad_output, |
|
|
ctx.world_size, |
|
|
ctx.process_group, |
|
|
ctx.gather_dim, |
|
|
ctx.scatter_dim, |
|
|
) |
|
|
return ( |
|
|
grad_output, |
|
|
None, |
|
|
None, |
|
|
None, |
|
|
) |
|
|
|
|
|
|
|
|
def all_to_all( |
|
|
input_: torch.Tensor, |
|
|
scatter_dim: int = 2, |
|
|
gather_dim: int = 1, |
|
|
): |
|
|
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim) |
|
|
|
|
|
|
|
|
class _AllGather(torch.autograd.Function): |
|
|
"""All-gather communication with autograd support. |
|
|
|
|
|
Args: |
|
|
input_: input tensor |
|
|
dim: dimension along which to concatenate |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_, dim): |
|
|
ctx.dim = dim |
|
|
world_size = nccl_info.sp_size |
|
|
group = nccl_info.group |
|
|
input_size = list(input_.size()) |
|
|
|
|
|
ctx.input_size = input_size[dim] |
|
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
input_ = input_.contiguous() |
|
|
dist.all_gather(tensor_list, input_, group=group) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=dim) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
world_size = nccl_info.sp_size |
|
|
rank = nccl_info.rank_within_group |
|
|
dim = ctx.dim |
|
|
input_size = ctx.input_size |
|
|
|
|
|
sizes = [input_size] * world_size |
|
|
|
|
|
grad_input_list = torch.split(grad_output, sizes, dim=dim) |
|
|
grad_input = grad_input_list[rank] |
|
|
|
|
|
return grad_input, None |
|
|
|
|
|
|
|
|
def all_gather(input_: torch.Tensor, dim: int = 1): |
|
|
"""Performs an all-gather operation on the input tensor along the specified dimension. |
|
|
|
|
|
Args: |
|
|
input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. |
|
|
dim (int, optional): Dimension along which to concatenate. Defaults to 1. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. |
|
|
""" |
|
|
return _AllGather.apply(input_, dim) |
|
|
|
|
|
class _AllGather_TeacherStudent(torch.autograd.Function): |
|
|
"""All-gather communication with autograd support. |
|
|
|
|
|
Args: |
|
|
input_: input tensor |
|
|
dim: dimension along which to concatenate |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_, dim): |
|
|
ctx.dim = dim |
|
|
world_size = nccl_info.ts_unit_size |
|
|
group = nccl_info.ts_unit_group |
|
|
input_size = list(input_.size()) |
|
|
|
|
|
ctx.input_size = input_size[dim] |
|
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
|
input_ = input_.contiguous() |
|
|
dist.all_gather(tensor_list, input_, group=group) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=dim) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
world_size = nccl_info.ts_unit_size |
|
|
rank = nccl_info.rank_within_ts_unit_group |
|
|
dim = ctx.dim |
|
|
input_size = ctx.input_size |
|
|
|
|
|
sizes = [input_size] * world_size |
|
|
grad_input_list = torch.split(grad_output, sizes, dim=dim) |
|
|
grad_input = grad_input_list[rank] |
|
|
return grad_input, None |
|
|
|
|
|
def all_gather_ts(input_: torch.Tensor, dim: int = 1): |
|
|
"""Performs an all-gather operation on the input tensor along the specified dimension. |
|
|
|
|
|
Args: |
|
|
input_ (torch.Tensor): Input tensor of shape [B, H, S, D]. |
|
|
dim (int, optional): Dimension along which to concatenate. Defaults to 1. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'. |
|
|
""" |
|
|
return _AllGather_TeacherStudent.apply(input_, dim) |
|
|
|
|
|
|
|
|
def prepare_sequence_parallel_data_wanx( |
|
|
hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition |
|
|
): |
|
|
if nccl_info.sp_size == 1: |
|
|
return ( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
uncond_text_states, |
|
|
image_embeds, |
|
|
latents_condition, |
|
|
) |
|
|
|
|
|
def prepare(hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition): |
|
|
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) |
|
|
encoder_hidden_states = all_to_all( |
|
|
encoder_hidden_states, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
uncond_text_states = all_to_all( |
|
|
uncond_text_states, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
image_embeds = all_to_all(image_embeds, scatter_dim=1, gather_dim=0) |
|
|
latents_condition = all_to_all(latents_condition, scatter_dim=2, gather_dim=0) |
|
|
|
|
|
return ( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
uncond_text_states, |
|
|
image_embeds, |
|
|
latents_condition, |
|
|
) |
|
|
|
|
|
sp_size = nccl_info.sp_size |
|
|
frame = hidden_states.shape[2] |
|
|
assert frame % sp_size == 0, "frame should be a multiple of sp_size" |
|
|
|
|
|
( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
uncond_text_states, |
|
|
image_embeds, |
|
|
latents_condition, |
|
|
) = prepare( |
|
|
hidden_states, |
|
|
encoder_hidden_states.repeat(1, sp_size, 1), |
|
|
uncond_text_states.repeat(1, sp_size, 1), |
|
|
image_embeds.repeat(1, sp_size, 1), |
|
|
latents_condition, |
|
|
) |
|
|
|
|
|
return hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition |
|
|
|
|
|
|
|
|
def sp_parallel_dataloader_wrapper_wanx( |
|
|
dataloader, device, train_batch_size, sp_size, train_sp_batch_size |
|
|
): |
|
|
while True: |
|
|
for data_item in dataloader: |
|
|
latents, text_states, uncond_text_states, image_embeds, latents_condition = data_item |
|
|
latents = latents.to(device) |
|
|
text_states = text_states.to(device) |
|
|
uncond_text_states = uncond_text_states.to(device) |
|
|
image_embeds = image_embeds.to(device) |
|
|
latents_condition = latents_condition.to(device) |
|
|
frame = latents.shape[2] |
|
|
if frame == 1: |
|
|
yield latents, text_states, uncond_text_states, image_embeds, latents_condition |
|
|
else: |
|
|
latents, text_states, uncond_text_states, image_embeds, latents_condition = ( |
|
|
prepare_sequence_parallel_data_wanx( |
|
|
latents, text_states, uncond_text_states, image_embeds, latents_condition |
|
|
) |
|
|
) |
|
|
assert ( |
|
|
train_batch_size * sp_size >= train_sp_batch_size |
|
|
), "train_batch_size * sp_size should be greater than train_sp_batch_size" |
|
|
for iter in range(train_batch_size * sp_size // train_sp_batch_size): |
|
|
st_idx = iter * train_sp_batch_size |
|
|
ed_idx = (iter + 1) * train_sp_batch_size |
|
|
yield ( |
|
|
latents[st_idx:ed_idx], |
|
|
text_states[st_idx:ed_idx], |
|
|
uncond_text_states[st_idx:ed_idx], |
|
|
image_embeds[st_idx:ed_idx], |
|
|
latents_condition[st_idx:ed_idx], |
|
|
) |
|
|
|
|
|
def prepare_sequence_parallel_data_wanx_dpo( |
|
|
hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition,latents_lose |
|
|
): |
|
|
if nccl_info.sp_size == 1: |
|
|
return ( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
uncond_text_states, |
|
|
image_embeds, |
|
|
latents_condition, |
|
|
latents_lose, |
|
|
) |
|
|
|
|
|
def prepare(hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition, latents_lose): |
|
|
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) |
|
|
latents_lose = all_to_all(latents_lose, scatter_dim=2, gather_dim=0) |
|
|
encoder_hidden_states = all_to_all( |
|
|
encoder_hidden_states, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
uncond_text_states = all_to_all( |
|
|
uncond_text_states, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
image_embeds = all_to_all(image_embeds, scatter_dim=1, gather_dim=0) |
|
|
latents_condition = all_to_all(latents_condition, scatter_dim=2, gather_dim=0) |
|
|
|
|
|
return ( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
uncond_text_states, |
|
|
image_embeds, |
|
|
latents_condition, |
|
|
latents_lose, |
|
|
) |
|
|
|
|
|
sp_size = nccl_info.sp_size |
|
|
frame = hidden_states.shape[2] |
|
|
assert frame % sp_size == 0, "frame should be a multiple of sp_size" |
|
|
|
|
|
( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
uncond_text_states, |
|
|
image_embeds, |
|
|
latents_condition,latents_lose, |
|
|
) = prepare( |
|
|
hidden_states, |
|
|
encoder_hidden_states.repeat(1, sp_size, 1), |
|
|
uncond_text_states.repeat(1, sp_size, 1), |
|
|
image_embeds.repeat(1, sp_size, 1), |
|
|
latents_condition, |
|
|
latents_lose |
|
|
) |
|
|
|
|
|
return hidden_states, encoder_hidden_states, uncond_text_states, image_embeds, latents_condition,latents_lose |
|
|
|
|
|
def sp_parallel_dataloader_wrapper_wanx_dpo( |
|
|
dataloader, device, train_batch_size, sp_size, train_sp_batch_size |
|
|
): |
|
|
while True: |
|
|
for data_item in dataloader: |
|
|
latents, text_states, uncond_text_states, image_embeds, latents_condition,latent_lose = data_item |
|
|
latents = latents.to(device) |
|
|
latents_lose = latents.to(device) |
|
|
text_states = text_states.to(device) |
|
|
uncond_text_states = uncond_text_states.to(device) |
|
|
image_embeds = image_embeds.to(device) |
|
|
latents_condition = latents_condition.to(device) |
|
|
frame = latents.shape[2] |
|
|
if frame == 1: |
|
|
yield latents, text_states, uncond_text_states, image_embeds, latents_condition,latent_lose |
|
|
else: |
|
|
latents, text_states, uncond_text_states, image_embeds, latents_condition, latents_lose = ( |
|
|
prepare_sequence_parallel_data_wanx_dpo( |
|
|
latents, text_states, uncond_text_states, image_embeds, latents_condition,latent_lose |
|
|
) |
|
|
) |
|
|
assert ( |
|
|
train_batch_size * sp_size >= train_sp_batch_size |
|
|
), "train_batch_size * sp_size should be greater than train_sp_batch_size" |
|
|
for iter in range(train_batch_size * sp_size // train_sp_batch_size): |
|
|
st_idx = iter * train_sp_batch_size |
|
|
ed_idx = (iter + 1) * train_sp_batch_size |
|
|
yield ( |
|
|
latents[st_idx:ed_idx], |
|
|
text_states[st_idx:ed_idx], |
|
|
uncond_text_states[st_idx:ed_idx], |
|
|
image_embeds[st_idx:ed_idx], |
|
|
latents_condition[st_idx:ed_idx], |
|
|
latents_lose[st_idx:ed_idx], |
|
|
) |
|
|
|
|
|
def prepare_sequence_parallel_data_ltx( |
|
|
hidden_states, encoder_hidden_states, text_mask, uncond_text_states, uncond_text_mask |
|
|
): |
|
|
if nccl_info.sp_size == 1: |
|
|
return ( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
text_mask, |
|
|
uncond_text_states, |
|
|
uncond_text_mask, |
|
|
) |
|
|
|
|
|
def prepare(hidden_states, encoder_hidden_states, text_mask, uncond_text_states, uncond_text_mask): |
|
|
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0) |
|
|
encoder_hidden_states = all_to_all( |
|
|
encoder_hidden_states, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
text_mask = all_to_all(text_mask, scatter_dim=1, gather_dim=0) |
|
|
uncond_text_states = all_to_all( |
|
|
uncond_text_states, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
uncond_text_mask = all_to_all( |
|
|
uncond_text_mask, scatter_dim=1, gather_dim=0 |
|
|
) |
|
|
|
|
|
return ( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
text_mask, |
|
|
uncond_text_states, |
|
|
uncond_text_mask, |
|
|
) |
|
|
|
|
|
sp_size = nccl_info.sp_size |
|
|
frame = hidden_states.shape[2] |
|
|
assert frame % sp_size == 0, "frame should be a multiple of sp_size" |
|
|
|
|
|
( |
|
|
hidden_states, |
|
|
encoder_hidden_states, |
|
|
text_mask, |
|
|
uncond_text_states, |
|
|
uncond_text_mask, |
|
|
) = prepare( |
|
|
hidden_states, |
|
|
encoder_hidden_states.repeat(1, sp_size, 1), |
|
|
text_mask.repeat(1, sp_size), |
|
|
uncond_text_states.repeat(1, sp_size, 1), |
|
|
uncond_text_mask.repeat(1, sp_size) |
|
|
) |
|
|
|
|
|
return hidden_states, encoder_hidden_states, text_mask, uncond_text_states, uncond_text_mask |
|
|
|
|
|
|
|
|
def sp_parallel_dataloader_wrapper_ltx( |
|
|
dataloader, device, train_batch_size, sp_size, train_sp_batch_size |
|
|
): |
|
|
while True: |
|
|
for data_item in dataloader: |
|
|
latents, text_states, text_mask, uncond_text_states, uncond_text_mask = data_item |
|
|
latents = latents.to(device) |
|
|
text_states = text_states.to(device) |
|
|
text_mask = text_mask.to(device) |
|
|
uncond_text_states = uncond_text_states.to(device) |
|
|
uncond_text_mask = uncond_text_mask.to(device) |
|
|
frame = latents.shape[2] |
|
|
if frame == 1: |
|
|
yield latents, text_states, text_mask, uncond_text_states, uncond_text_mask |
|
|
else: |
|
|
latents, text_states, text_mask, uncond_text_states, uncond_text_mask = ( |
|
|
prepare_sequence_parallel_data_ltx( |
|
|
latents, text_states, text_mask, uncond_text_states, uncond_text_mask |
|
|
) |
|
|
) |
|
|
assert ( |
|
|
train_batch_size * sp_size >= train_sp_batch_size |
|
|
), "train_batch_size * sp_size should be greater than train_sp_batch_size" |
|
|
for iter in range(train_batch_size * sp_size // train_sp_batch_size): |
|
|
st_idx = iter * train_sp_batch_size |
|
|
ed_idx = (iter + 1) * train_sp_batch_size |
|
|
yield ( |
|
|
latents[st_idx:ed_idx], |
|
|
text_states[st_idx:ed_idx], |
|
|
text_mask[st_idx:ed_idx], |
|
|
uncond_text_states[st_idx:ed_idx], |
|
|
uncond_text_mask[st_idx:ed_idx], |
|
|
) |
|
|
|
|
|
|
|
|
def parallelize_model(model): |
|
|
original_forward = model.forward |
|
|
|
|
|
@functools.wraps(model.__class__.forward) |
|
|
def new_forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
timestep: torch.LongTensor, |
|
|
text_states: torch.Tensor, |
|
|
text_states_2: torch.Tensor, |
|
|
encoder_attention_mask: torch.Tensor, |
|
|
output_features=False, |
|
|
output_features_stride=8, |
|
|
attention_kwargs=None, |
|
|
freqs_cos=None, |
|
|
freqs_sin=None, |
|
|
return_dict=False, |
|
|
guidance=None, |
|
|
): |
|
|
x = hidden_states |
|
|
sp_size = nccl_info.sp_size |
|
|
sp_rank = nccl_info.rank_within_group |
|
|
|
|
|
if x.shape[-2] // 2 % sp_size == 0: |
|
|
|
|
|
split_dim = -2 |
|
|
elif x.shape[-1] // 2 % sp_size == 0: |
|
|
|
|
|
split_dim = -1 |
|
|
else: |
|
|
raise ValueError(f"Cannot split video sequence into ulysses_degree ({sp_size}) parts evenly") |
|
|
|
|
|
_, _, ot, oh, ow = x.shape |
|
|
tt, th, tw = ( |
|
|
ot // self.patch_size[0], |
|
|
oh // self.patch_size[1], |
|
|
ow // self.patch_size[2], |
|
|
) |
|
|
freqs_cos, freqs_sin = self.get_rotary_pos_embed((tt, th, tw)) |
|
|
|
|
|
temporal_size, h, w = x.shape[2], x.shape[3] // 2, x.shape[4] // 2 |
|
|
|
|
|
x = torch.chunk(x, sp_size,dim=split_dim)[sp_rank] |
|
|
|
|
|
dim_thw = freqs_cos.shape[-1] |
|
|
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) |
|
|
freqs_cos = torch.chunk(freqs_cos, sp_size,dim=split_dim - 1)[sp_rank] |
|
|
freqs_cos = freqs_cos.reshape(-1, dim_thw) |
|
|
dim_thw = freqs_sin.shape[-1] |
|
|
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) |
|
|
freqs_sin = torch.chunk(freqs_sin, sp_size,dim=split_dim - 1)[sp_rank] |
|
|
freqs_sin = freqs_sin.reshape(-1, dim_thw) |
|
|
|
|
|
output = original_forward( |
|
|
x, |
|
|
timestep, |
|
|
text_states, |
|
|
text_states_2, |
|
|
encoder_attention_mask, |
|
|
output_features, |
|
|
output_features_stride, |
|
|
attention_kwargs, |
|
|
freqs_cos, |
|
|
freqs_sin, |
|
|
return_dict, |
|
|
guidance, |
|
|
) |
|
|
|
|
|
return_dict = not isinstance(output, tuple) |
|
|
shape = (tt, th, tw) |
|
|
if return_dict: |
|
|
assert not output_features, "output_feature is not compatible with return_dict" |
|
|
sample = output["x"] |
|
|
sample = all_gather(sample, dim=split_dim) |
|
|
output["x"] = sample |
|
|
else: |
|
|
sample = output[0] |
|
|
sample = all_gather(sample, dim=split_dim) |
|
|
if output_features: |
|
|
features_list = output[1] |
|
|
features_list = all_gather(features_list, dim=split_dim) |
|
|
else: |
|
|
features_list = None |
|
|
|
|
|
output = (sample, features_list, shape) |
|
|
return output |
|
|
|
|
|
new_forward = new_forward.__get__(model) |
|
|
model.forward = new_forward |
|
|
|
|
|
|
|
|
def all_reduce_tensor_item(item): |
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
|
item = item.detach().clone() |
|
|
dist.all_reduce(item, op=dist.ReduceOp.SUM) |
|
|
item = item / nccl_info.ts_group_size if get_teacher_student_parallel_state() else item / world_size |
|
|
return item |
|
|
|
|
|
def broadcast_item(item, idx): |
|
|
item_list = [item] |
|
|
dist.broadcast_object_list(item_list, src=idx) |
|
|
return item_list[0] |
|
|
|