|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import math |
|
|
|
|
|
|
|
|
class ModulePlugin: |
|
|
def __init__(self, module, module_id, global_state=None): |
|
|
self.module = module |
|
|
self.module_id = module_id |
|
|
self.global_state = global_state |
|
|
self.enable = True |
|
|
self.implement_forward() |
|
|
|
|
|
@property |
|
|
def is_log_node(self): |
|
|
return self.global_state.get('dist_controller').rank == 0 and self.module_id[1] == 0 |
|
|
|
|
|
@property |
|
|
def t(self): |
|
|
return self.global_state.get('timestep') |
|
|
|
|
|
@property |
|
|
def p(self): |
|
|
return self.t / 1000 |
|
|
|
|
|
def implement_forward(self): |
|
|
module = self.module |
|
|
if not hasattr(module, "old_forward"): |
|
|
module.old_forward = module.forward |
|
|
self.new_forward = self.get_new_forward() |
|
|
def forward(*args, **kwargs): |
|
|
self.update_config() |
|
|
return self.new_forward(*args, **kwargs) if self.enable else self.old_forward(*args, **kwargs) |
|
|
module.forward = forward |
|
|
|
|
|
def set_enable(self, enable=True): |
|
|
self.enable = enable |
|
|
|
|
|
def get_new_forward(self): |
|
|
raise NotImplementedError |
|
|
|
|
|
def update_config(self, config:dict=None): |
|
|
if config is None: |
|
|
config = self.global_state.get('plugin_configs', {}).get(self.module_id[0], {}) |
|
|
for key, value in config.items(): |
|
|
setattr(self, key, value) |
|
|
|
|
|
|
|
|
class GroupNormPlugin(ModulePlugin): |
|
|
def __init__(self, module, module_id, global_state=None): |
|
|
super().__init__(module, module_id, global_state) |
|
|
|
|
|
def get_new_forward(self): |
|
|
module = self.module |
|
|
|
|
|
def new_forward(x): |
|
|
shape = x.shape |
|
|
N, C, G = shape[0], shape[1], module.num_groups |
|
|
assert C % G == 0 |
|
|
|
|
|
x = x.reshape(N, G, -1) |
|
|
|
|
|
mean = x.mean(-1, keepdim=True).to(torch.float32) |
|
|
dist.all_reduce(mean) |
|
|
|
|
|
mean = mean / dist.get_world_size() |
|
|
|
|
|
var = ((x - mean.to(x.dtype)) ** 2).mean(-1, keepdim=True).to(torch.float32) |
|
|
|
|
|
dist.all_reduce(var) |
|
|
var = var / dist.get_world_size() |
|
|
|
|
|
x = (x - mean.to(x.dtype)) / (var.to(x.dtype) + module.eps).sqrt() |
|
|
x = x.view(shape) |
|
|
|
|
|
new_shape = [1 for _ in shape] |
|
|
new_shape[1] = -1 |
|
|
|
|
|
return x * module.weight.view(new_shape) + module.bias.view(new_shape) |
|
|
|
|
|
return new_forward |
|
|
|
|
|
|
|
|
class Conv3DSafeNewPligin(ModulePlugin): |
|
|
def __init__(self, module, module_id, global_state=None): |
|
|
super().__init__(module, module_id, global_state) |
|
|
|
|
|
self.kernel_size = getattr(module, 'kernel_size', (1, 1, 1)) |
|
|
|
|
|
if isinstance(self.kernel_size, int): |
|
|
self.kernel_size = (self.kernel_size, self.kernel_size, self.kernel_size) |
|
|
|
|
|
kernel_width = self.kernel_size[2] |
|
|
d = kernel_width - 1 |
|
|
self.padding_left = d // 2 |
|
|
self.padding_right = d - self.padding_left |
|
|
self.padding_flag = self.padding_left if d > 0 else 0 |
|
|
|
|
|
self.rank = dist.get_rank() |
|
|
self.adj_groups = self.global_state.get('dist_controller').adj_groups |
|
|
|
|
|
|
|
|
def pad_context(self, h): |
|
|
if self.padding_flag == 0: |
|
|
return h |
|
|
|
|
|
|
|
|
share_to_left = h[:, :, :, :self.padding_left].contiguous() |
|
|
share_to_right = h[:, :, :, -self.padding_right:].contiguous() |
|
|
|
|
|
if self.rank % 2: |
|
|
|
|
|
if self.rank: |
|
|
|
|
|
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) |
|
|
left_context = padding_list[0].to(h.device, non_blocking=True) |
|
|
else: |
|
|
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) |
|
|
|
|
|
if self.rank != dist.get_world_size() - 1: |
|
|
|
|
|
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) |
|
|
right_context = padding_list[1].to(h.device, non_blocking=True) |
|
|
else: |
|
|
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) |
|
|
else: |
|
|
|
|
|
if self.rank != dist.get_world_size() - 1: |
|
|
|
|
|
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) |
|
|
right_context = padding_list[1].to(h.device, non_blocking=True) |
|
|
else: |
|
|
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) |
|
|
|
|
|
if self.rank: |
|
|
|
|
|
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) |
|
|
left_context = padding_list[0].to(h.device, non_blocking=True) |
|
|
else: |
|
|
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) |
|
|
|
|
|
|
|
|
h_with_context = torch.cat([left_context, h, right_context], dim=3) |
|
|
return h_with_context |
|
|
|
|
|
def get_new_forward(self): |
|
|
module = self.module |
|
|
def new_forward(hidden_states, cache_x=None, *args, **kwargs): |
|
|
if self.padding_flag == 0: |
|
|
|
|
|
return module.old_forward(hidden_states, cache_x, *args, **kwargs) |
|
|
|
|
|
|
|
|
hidden_states = self.pad_context(hidden_states) |
|
|
if cache_x is not None: |
|
|
cache_x = self.pad_context(cache_x) |
|
|
|
|
|
result = module.old_forward(hidden_states, cache_x, *args, **kwargs) |
|
|
result = result[:,:,:,self.padding_left:-self.padding_right if self.padding_right > 0 else None] |
|
|
|
|
|
return result |
|
|
|
|
|
return new_forward |
|
|
|
|
|
class Conv2DSafeNewPligin(ModulePlugin): |
|
|
def __init__(self, module, module_id, global_state=None): |
|
|
super().__init__(module, module_id, global_state) |
|
|
|
|
|
self.kernel_size = getattr(module, 'kernel_size', (1, 1)) |
|
|
self.stride = getattr(module, 'stride', (1, 1)) |
|
|
|
|
|
if isinstance(self.kernel_size, int): |
|
|
self.kernel_size = (self.kernel_size, self.kernel_size) |
|
|
|
|
|
kernel_height = self.kernel_size[0] |
|
|
d = kernel_height - 1 |
|
|
self.padding_left = d // 2 |
|
|
self.padding_right = d - self.padding_left |
|
|
self.padding = self.padding_left if d > 0 else 0 |
|
|
self.rank = dist.get_rank() |
|
|
self.adj_groups = self.global_state.get('dist_controller').adj_groups |
|
|
|
|
|
def pad_context(self, h): |
|
|
if self.padding == 0: |
|
|
return h |
|
|
|
|
|
share_to_left = h[:, :, :self.padding_left].contiguous() |
|
|
share_to_right = h[:, :, -self.padding_right:].contiguous() |
|
|
if self.rank % 2: |
|
|
|
|
|
if self.rank: |
|
|
|
|
|
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) |
|
|
left_context = padding_list[0].to(h.device, non_blocking=True) |
|
|
else: |
|
|
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) |
|
|
|
|
|
if self.rank != dist.get_world_size() - 1: |
|
|
|
|
|
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) |
|
|
right_context = padding_list[1].to(h.device, non_blocking=True) |
|
|
else: |
|
|
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) |
|
|
else: |
|
|
|
|
|
if self.rank != dist.get_world_size() - 1: |
|
|
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank]) |
|
|
right_context = padding_list[1].to(h.device, non_blocking=True) |
|
|
else: |
|
|
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True) |
|
|
|
|
|
if self.rank: |
|
|
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)] |
|
|
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1]) |
|
|
left_context = padding_list[0].to(h.device, non_blocking=True) |
|
|
else: |
|
|
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True) |
|
|
|
|
|
|
|
|
h_with_context = torch.cat([left_context, h, right_context], dim=2) |
|
|
return h_with_context |
|
|
|
|
|
def get_new_forward(self): |
|
|
module = self.module |
|
|
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
if self.padding == 0: |
|
|
return module.old_forward(hidden_states) |
|
|
|
|
|
hidden_states = self.pad_context(hidden_states) |
|
|
hidden_states = module.old_forward(hidden_states)[:,:,self.padding_left:-self.padding_right if self.padding_right > 0 else None] |
|
|
return hidden_states |
|
|
|
|
|
return new_forward |
|
|
|
|
|
class Conv2DSafeNewPliginStride2(ModulePlugin): |
|
|
def __init__(self, module, module_id, global_state=None): |
|
|
super().__init__(module, module_id, global_state) |
|
|
|
|
|
self.kernel_size = getattr(module, 'kernel_size', (1, 1)) |
|
|
self.stride = getattr(module, 'stride', (1, 1)) |
|
|
|
|
|
if isinstance(self.kernel_size, int): |
|
|
self.kernel_size = (self.kernel_size, self.kernel_size) |
|
|
|
|
|
kernel_height = self.kernel_size[0] |
|
|
d = kernel_height - 1 |
|
|
self.padding_left = d // 2 |
|
|
self.padding_right = d - self.padding_left |
|
|
self.padding = self.padding_left if d > 0 else 0 |
|
|
self.rank = dist.get_rank() |
|
|
self.adj_groups = self.global_state.get('dist_controller').adj_groups |
|
|
|
|
|
def pad_context(self, h): |
|
|
if self.padding == 0: |
|
|
return h |
|
|
|
|
|
share_to_left = h[:, :, :self.padding_left].contiguous() |
|
|
|
|
|
if self.rank < dist.get_world_size() - 1: |
|
|
right_context = torch.zeros_like(share_to_left) |
|
|
|
|
|
dist.recv(right_context, src=self.rank+1) |
|
|
if self.rank >0: |
|
|
dist.send(share_to_left, dst=self.rank-1) |
|
|
|
|
|
if self.rank < dist.get_world_size() - 1: |
|
|
h_with_context = torch.cat([h, right_context], dim=2) |
|
|
else: |
|
|
h_with_context = h |
|
|
return h_with_context |
|
|
|
|
|
def get_new_forward(self): |
|
|
module = self.module |
|
|
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
if self.padding == 0: |
|
|
return module.old_forward(hidden_states) |
|
|
|
|
|
hidden_states = hidden_states[:, :, :-1, :] |
|
|
hidden_states = self.pad_context(hidden_states) |
|
|
hidden_states = torch.nn.functional.pad(hidden_states,(0,0,0,1)) |
|
|
hidden_states = module.old_forward(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
return new_forward |
|
|
|
|
|
class WanAttentionPlugin(ModulePlugin): |
|
|
def __init__(self, module, module_id, global_state=None): |
|
|
self.rank = dist.get_rank() |
|
|
self.world_size = dist.get_world_size() |
|
|
|
|
|
super().__init__(module, module_id, global_state) |
|
|
|
|
|
def get_new_forward(self): |
|
|
module = self.module |
|
|
rank = self.rank |
|
|
world_size = self.world_size |
|
|
|
|
|
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
gathered_tensors = [torch.zeros_like(hidden_states) for _ in range(world_size)] |
|
|
dist.all_gather(gathered_tensors, hidden_states) |
|
|
|
|
|
combined_tensor = torch.cat(gathered_tensors, dim=3) |
|
|
|
|
|
forward_output = module.old_forward(combined_tensor) |
|
|
|
|
|
chunk_sizes = [t.size(3) for t in gathered_tensors] |
|
|
|
|
|
start_idx = sum(chunk_sizes[:rank]) |
|
|
end_idx = start_idx + chunk_sizes[rank] |
|
|
|
|
|
local_output = forward_output[:, :, :, start_idx:end_idx].contiguous() |
|
|
|
|
|
return local_output |
|
|
|
|
|
return new_forward |
|
|
|