Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
import torch
import torch.distributed as dist
import math
class ModulePlugin:
def __init__(self, module, module_id, global_state=None):
self.module = module
self.module_id = module_id
self.global_state = global_state
self.enable = True
self.implement_forward()
@property
def is_log_node(self):
return self.global_state.get('dist_controller').rank == 0 and self.module_id[1] == 0
@property
def t(self):
return self.global_state.get('timestep')
@property
def p(self):
return self.t / 1000
def implement_forward(self):
module = self.module
if not hasattr(module, "old_forward"):
module.old_forward = module.forward
self.new_forward = self.get_new_forward()
def forward(*args, **kwargs):
self.update_config() # update config
return self.new_forward(*args, **kwargs) if self.enable else self.old_forward(*args, **kwargs)
module.forward = forward
def set_enable(self, enable=True):
self.enable = enable
def get_new_forward(self):
raise NotImplementedError
def update_config(self, config:dict=None):
if config is None:
config = self.global_state.get('plugin_configs', {}).get(self.module_id[0], {})
for key, value in config.items():
setattr(self, key, value)
class GroupNormPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
def get_new_forward(self):
module = self.module
def new_forward(x):
shape = x.shape
N, C, G = shape[0], shape[1], module.num_groups
assert C % G == 0
x = x.reshape(N, G, -1)
mean = x.mean(-1, keepdim=True).to(torch.float32)
dist.all_reduce(mean)
mean = mean / dist.get_world_size()
var = ((x - mean.to(x.dtype)) ** 2).mean(-1, keepdim=True).to(torch.float32)
dist.all_reduce(var)
var = var / dist.get_world_size()
x = (x - mean.to(x.dtype)) / (var.to(x.dtype) + module.eps).sqrt()
x = x.view(shape)
new_shape = [1 for _ in shape]
new_shape[1] = -1
return x * module.weight.view(new_shape) + module.bias.view(new_shape)
return new_forward
class Conv3DSafeNewPligin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.kernel_size = getattr(module, 'kernel_size', (1, 1, 1))
if isinstance(self.kernel_size, int):
self.kernel_size = (self.kernel_size, self.kernel_size, self.kernel_size)
kernel_width = self.kernel_size[2]
d = kernel_width - 1
self.padding_left = d // 2
self.padding_right = d - self.padding_left
self.padding_flag = self.padding_left if d > 0 else 0
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h):
if self.padding_flag == 0:
return h
share_to_left = h[:, :, :, :self.padding_left].contiguous()
share_to_right = h[:, :, :, -self.padding_right:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=3)
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states, cache_x=None, *args, **kwargs):
if self.padding_flag == 0:
# print(f"padding=0, return old_forward")
return module.old_forward(hidden_states, cache_x, *args, **kwargs)
hidden_states = self.pad_context(hidden_states)
if cache_x is not None:
cache_x = self.pad_context(cache_x)
result = module.old_forward(hidden_states, cache_x, *args, **kwargs)
result = result[:,:,:,self.padding_left:-self.padding_right if self.padding_right > 0 else None]
return result
return new_forward
class Conv2DSafeNewPligin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.kernel_size = getattr(module, 'kernel_size', (1, 1))
self.stride = getattr(module, 'stride', (1, 1))
if isinstance(self.kernel_size, int):
self.kernel_size = (self.kernel_size, self.kernel_size)
kernel_height = self.kernel_size[0] # 卷积核的高度维度
d = kernel_height - 1 # 总padding量
self.padding_left = d // 2 # 上侧padding
self.padding_right = d - self.padding_left # 下侧padding
self.padding = self.padding_left if d > 0 else 0
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h):
if self.padding == 0:
return h
share_to_left = h[:, :, :self.padding_left].contiguous()
share_to_right = h[:, :, -self.padding_right:].contiguous()
if self.rank % 2:
# 1. the rank is odd, pad the left first
if self.rank:
# not the first rank, have left context
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# 2. then pad the right
if self.rank != dist.get_world_size() - 1:
# not the last rank, have right context
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
else:
# 1. the rank is even, pad the right first
if self.rank != dist.get_world_size() - 1:
padding_list = [torch.zeros_like(share_to_right) for _ in range(2)]
dist.all_gather(padding_list, share_to_right, group=self.adj_groups[self.rank])
right_context = padding_list[1].to(h.device, non_blocking=True)
else:
right_context = torch.zeros_like(share_to_right).to(h.device, non_blocking=True)
# 2. then pad the left
if self.rank:
padding_list = [torch.zeros_like(share_to_left) for _ in range(2)]
dist.all_gather(padding_list, share_to_left, group=self.adj_groups[self.rank-1])
left_context = padding_list[0].to(h.device, non_blocking=True)
else:
left_context = torch.zeros_like(share_to_left).to(h.device, non_blocking=True)
# torch.cuda.synchronize()
h_with_context = torch.cat([left_context, h, right_context], dim=2)
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor:
if self.padding == 0:
return module.old_forward(hidden_states)
hidden_states = self.pad_context(hidden_states)
hidden_states = module.old_forward(hidden_states)[:,:,self.padding_left:-self.padding_right if self.padding_right > 0 else None]
return hidden_states
return new_forward
class Conv2DSafeNewPliginStride2(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
super().__init__(module, module_id, global_state)
self.kernel_size = getattr(module, 'kernel_size', (1, 1))
self.stride = getattr(module, 'stride', (1, 1))
if isinstance(self.kernel_size, int):
self.kernel_size = (self.kernel_size, self.kernel_size)
kernel_height = self.kernel_size[0]
d = kernel_height - 1
self.padding_left = d // 2
self.padding_right = d - self.padding_left
self.padding = self.padding_left if d > 0 else 0
self.rank = dist.get_rank()
self.adj_groups = self.global_state.get('dist_controller').adj_groups
def pad_context(self, h):
if self.padding == 0:
return h
share_to_left = h[:, :, :self.padding_left].contiguous()
if self.rank < dist.get_world_size() - 1:
right_context = torch.zeros_like(share_to_left)
dist.recv(right_context, src=self.rank+1)
if self.rank >0:
dist.send(share_to_left, dst=self.rank-1)
# torch.cuda.synchronize()
if self.rank < dist.get_world_size() - 1:
h_with_context = torch.cat([h, right_context], dim=2)
else:
h_with_context = h
return h_with_context
def get_new_forward(self):
module = self.module
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor:
if self.padding == 0:
return module.old_forward(hidden_states)
hidden_states = hidden_states[:, :, :-1, :]
hidden_states = self.pad_context(hidden_states)
hidden_states = torch.nn.functional.pad(hidden_states,(0,0,0,1))
hidden_states = module.old_forward(hidden_states)#[:,:,self.padding_left:-self.padding_right if self.padding_right > 0 else None]
return hidden_states
return new_forward
class WanAttentionPlugin(ModulePlugin):
def __init__(self, module, module_id, global_state=None):
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
super().__init__(module, module_id, global_state)
def get_new_forward(self):
module = self.module
rank = self.rank
world_size = self.world_size
def new_forward(hidden_states: torch.Tensor) -> torch.Tensor:
gathered_tensors = [torch.zeros_like(hidden_states) for _ in range(world_size)]
dist.all_gather(gathered_tensors, hidden_states)
combined_tensor = torch.cat(gathered_tensors, dim=3)
forward_output = module.old_forward(combined_tensor)
chunk_sizes = [t.size(3) for t in gathered_tensors]
start_idx = sum(chunk_sizes[:rank])
end_idx = start_idx + chunk_sizes[rank]
local_output = forward_output[:, :, :, start_idx:end_idx].contiguous()
return local_output
return new_forward