| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import importlib |
| import numbers |
| from typing import Any, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torch.distributed import ProcessGroup |
|
|
| from .comm import get_ulysses_sequence_parallel_group |
| from .ulysses import all_to_all_tensor |
| from .utils import padding_tensor_for_seqeunce_parallel, unpadding_tensor_for_seqeunce_parallel |
|
|
|
|
| fused_layer_norm_cuda = None |
|
|
|
|
| def divide_qkv_linear_weight(weight: Tensor, dim: int): |
| return weight.chunk(3, dim=dim) |
|
|
|
|
| def divide_qkv_linear_bias(bias: Tensor, dim: int): |
| if bias is not None: |
| return bias.chunk(3, dim=dim) |
| else: |
| return None, None, None |
|
|
|
|
| class AsyncUlyssesQKVProjection(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: Any, |
| hidden_states: Tensor, |
| seq_dimension: int, |
| head_dimension: int, |
| q_weight: Tensor, |
| q_bias: Tensor, |
| k_weight: Tensor, |
| k_bias: Tensor, |
| v_weight: Tensor, |
| v_bias: Tensor, |
| norm_type: str, |
| norm_q_weight: Tensor, |
| norm_q_bias: Tensor, |
| norm_k_weight: Tensor, |
| norm_k_bias: Tensor, |
| normalized_shape: int, |
| eps: float, |
| unpadded_dim_size: int, |
| head_dim: int, |
| group: ProcessGroup, |
| ): |
| sp_group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
| |
| q = F.linear(hidden_states, q_weight, q_bias) |
|
|
| |
| q_res = all_to_all_tensor( |
| q, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True |
| ) |
|
|
| |
| k = F.linear(hidden_states, k_weight, k_bias) |
|
|
| |
| k_res = all_to_all_tensor( |
| k, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True |
| ) |
|
|
| |
| v = F.linear(hidden_states, v_weight, v_bias) |
|
|
| |
| v_res = all_to_all_tensor( |
| v, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True |
| ) |
|
|
| |
| q = q_res() |
| q = unpadding_tensor_for_seqeunce_parallel(q, seq_dimension, unpadded_dim_size) |
| q = q.reshape(list(q.shape[:-1]) + [-1, head_dim]).contiguous() |
|
|
| |
| k = k_res() |
| k = unpadding_tensor_for_seqeunce_parallel(k, seq_dimension, unpadded_dim_size) |
| k = k.reshape(list(k.shape[:-1]) + [-1, head_dim]).contiguous() |
|
|
| |
| if norm_type is not None: |
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| normalized_shape = torch.Size(normalized_shape) |
| global fused_layer_norm_cuda |
| if fused_layer_norm_cuda is None: |
| fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") |
| norm_q_weight = norm_q_weight.contiguous() |
| norm_k_weight = norm_k_weight.contiguous() |
| output_q, mean_q, invvar_q = None, None, None |
| output_k, mean_k, invvar_k = None, None, None |
| if norm_type == "rmsnorm": |
| output_q, invvar_q = fused_layer_norm_cuda.rms_forward_affine(q, normalized_shape, norm_q_weight, eps) |
| output_k, invvar_k = fused_layer_norm_cuda.rms_forward_affine(k, normalized_shape, norm_k_weight, eps) |
| elif norm_type == "layernorm": |
| output_q, mean_q, invvar_q = fused_layer_norm_cuda.forward_affine( |
| q, normalized_shape, norm_q_weight, norm_q_bias, eps |
| ) |
| output_k, mean_k, invvar_k = fused_layer_norm_cuda.forward_affine( |
| k, normalized_shape, norm_k_weight, norm_k_bias, eps |
| ) |
| else: |
| raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!") |
| else: |
| output_q = q |
| output_k = k |
| mean_q = None |
| mean_k = None |
| invvar_q = None |
| invvar_k = None |
|
|
| |
| v = v_res() |
| v = unpadding_tensor_for_seqeunce_parallel(v, seq_dimension, unpadded_dim_size) |
| v = v.reshape(list(v.shape[:-1]) + [-1, head_dim]).contiguous() |
|
|
| |
| ctx.sp_group = sp_group |
| ctx.head_dimension = head_dimension |
| ctx.seq_dimension = seq_dimension |
| ctx.norm_type = norm_type |
| ctx.normalized_shape = normalized_shape |
| ctx.eps = eps |
| ctx.save_for_backward( |
| hidden_states, |
| q_weight, |
| q_bias, |
| k_weight, |
| k_bias, |
| v_weight, |
| v_bias, |
| q, |
| norm_q_weight, |
| norm_q_bias, |
| mean_q, |
| invvar_q, |
| k, |
| norm_k_weight, |
| norm_k_bias, |
| mean_k, |
| invvar_k, |
| ) |
|
|
| return output_q, output_k, v |
|
|
| @staticmethod |
| def backward(ctx: Any, *grad_output: Tensor): |
| |
| sp_group = ctx.sp_group |
| seq_dimension = ctx.seq_dimension |
| head_dimension = ctx.head_dimension |
| norm_type = ctx.norm_type |
| normalized_shape = ctx.normalized_shape |
| eps = ctx.eps |
| ( |
| hidden_states, |
| q_weight, |
| q_bias, |
| k_weight, |
| k_bias, |
| v_weight, |
| v_bias, |
| q, |
| norm_q_weight, |
| norm_q_bias, |
| mean_q, |
| invvar_q, |
| k, |
| norm_k_weight, |
| norm_k_bias, |
| mean_k, |
| invvar_k, |
| ) = ctx.saved_tensors |
|
|
| |
| grad_hidden_states = None |
| grad_q_weight = None |
| grad_q_bias = None |
| grad_k_weight = None |
| grad_k_bias = None |
| grad_v_weight = None |
| grad_v_bias = None |
| grad_norm_q_weight = None |
| grad_norm_q_bias = None |
| grad_norm_k_weight = None |
| grad_norm_k_bias = None |
|
|
| |
| grad_v = grad_output[2].contiguous() |
| grad_v = grad_v.reshape(list(grad_v.shape[:-2]) + [-1]).contiguous() |
| grad_v = padding_tensor_for_seqeunce_parallel(grad_v, dim=seq_dimension) |
| grad_v_res = all_to_all_tensor( |
| grad_v, |
| scatter_dim=seq_dimension, |
| gather_dim=head_dimension, |
| group=sp_group, |
| async_op=True, |
| ) |
|
|
| |
| if norm_type is not None: |
| if norm_type == "rmsnorm": |
| grad_k, grad_norm_k_weight = fused_layer_norm_cuda.rms_backward_affine( |
| grad_output[1].contiguous(), |
| invvar_k, |
| k, |
| normalized_shape, |
| norm_k_weight, |
| eps, |
| False, |
| ) |
| grad_q, grad_norm_q_weight = fused_layer_norm_cuda.rms_backward_affine( |
| grad_output[0].contiguous(), |
| invvar_q, |
| q, |
| normalized_shape, |
| norm_q_weight, |
| eps, |
| False, |
| ) |
| elif norm_type == "layernorm": |
| grad_k, grad_norm_k_weight, grad_norm_k_bias = fused_layer_norm_cuda.backward_affine( |
| grad_output[1].contiguous(), |
| mean_k, |
| invvar_k, |
| k, |
| normalized_shape, |
| norm_k_weight, |
| norm_k_bias, |
| eps, |
| False, |
| ) |
| grad_q, grad_norm_q_weight, grad_norm_q_bias = fused_layer_norm_cuda.backward_affine( |
| grad_output[0].contiguous(), |
| mean_q, |
| invvar_q, |
| q, |
| normalized_shape, |
| norm_q_weight, |
| norm_q_bias, |
| eps, |
| False, |
| ) |
| else: |
| raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!") |
| else: |
| grad_k = grad_output[1].contiguous() |
| grad_q = grad_output[0].contiguous() |
| grad_norm_k_weight = None |
| grad_norm_q_weight = None |
|
|
| |
| grad_v = grad_v_res() |
|
|
| |
| grad_k = grad_k.reshape(list(grad_k.shape[:-2]) + [-1]).contiguous() |
| grad_k = padding_tensor_for_seqeunce_parallel(grad_k, dim=seq_dimension) |
| grad_k_res = all_to_all_tensor( |
| grad_k, |
| scatter_dim=seq_dimension, |
| gather_dim=head_dimension, |
| group=sp_group, |
| async_op=True, |
| ) |
|
|
| |
| grad_v_input = grad_v @ v_weight |
| grad_v_weight = grad_v.transpose(-1, -2) @ hidden_states |
| if v_bias is not None and ctx.needs_input_grad[7]: |
| grad_v_bias = grad_v.sum(0) |
|
|
| |
| grad_k = grad_k_res() |
|
|
| |
| grad_q = grad_q.reshape(list(grad_q.shape[:-2]) + [-1]).contiguous() |
| grad_q = padding_tensor_for_seqeunce_parallel(grad_q, dim=seq_dimension) |
| grad_q_res = all_to_all_tensor( |
| grad_q, |
| scatter_dim=seq_dimension, |
| gather_dim=head_dimension, |
| group=sp_group, |
| async_op=True, |
| ) |
|
|
| |
| grad_k_input = grad_k @ k_weight |
| grad_k_weight = grad_k.transpose(-1, -2) @ hidden_states |
| if k_bias is not None and ctx.needs_input_grad[5]: |
| grad_k_bias = grad_k.sum(0) |
|
|
| |
| grad_q = grad_q_res() |
|
|
| |
| grad_q_input = grad_q @ q_weight |
| grad_q_weight = grad_q.transpose(-1, -2) @ hidden_states |
| if q_bias is not None and ctx.needs_input_grad[3]: |
| grad_q_bias = grad_q.sum(0) |
|
|
| |
| grad_hidden_states = grad_q_input + grad_k_input + grad_v_input |
|
|
| return ( |
| grad_hidden_states, |
| None, |
| None, |
| grad_q_weight, |
| grad_q_bias, |
| grad_k_weight, |
| grad_k_bias, |
| grad_v_weight, |
| grad_v_bias, |
| None, |
| grad_norm_q_weight, |
| grad_norm_q_bias, |
| grad_norm_k_weight, |
| grad_norm_k_bias, |
| None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| class AsyncUlyssesOutputProjection(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: Any, |
| hidden_states: Tensor, |
| seq_dimension: int, |
| head_dimension: int, |
| proj_weight: Tensor, |
| proj_bias: Tensor, |
| unpadded_dim_size: int, |
| group: ProcessGroup, |
| ): |
| sp_group = get_ulysses_sequence_parallel_group() if group is None else group |
|
|
| |
| hidden_states = padding_tensor_for_seqeunce_parallel(hidden_states, seq_dimension) |
| hidden_states = all_to_all_tensor( |
| hidden_states, scatter_dim=seq_dimension, gather_dim=head_dimension, group=sp_group |
| ) |
| o = F.linear(hidden_states, proj_weight, proj_bias) |
|
|
| |
| ctx.sp_group = sp_group |
| ctx.head_dimension = head_dimension |
| ctx.seq_dimension = seq_dimension |
| ctx.unpadded_dim_size = unpadded_dim_size |
|
|
| ctx.save_for_backward( |
| hidden_states, |
| proj_weight, |
| proj_bias, |
| ) |
|
|
| return o |
|
|
| @staticmethod |
| def backward(ctx: Any, *grad_output: Tensor): |
| |
| sp_group = ctx.sp_group |
| head_dimension = ctx.head_dimension |
| seq_dimension = ctx.seq_dimension |
| unpadded_dim_size = ctx.unpadded_dim_size |
| ( |
| hidden_states, |
| proj_weight, |
| proj_bias, |
| ) = ctx.saved_tensors |
|
|
| |
| grad_o = None |
| grad_proj_weight = None |
| grad_proj_bias = None |
|
|
| |
| grad_o = grad_output[0] @ (proj_weight) |
|
|
| |
| grad_out_res = all_to_all_tensor( |
| grad_o, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True |
| ) |
|
|
| grad_proj_weight = grad_output[0].transpose(-1, -2) @ (hidden_states) |
| if proj_bias is not None and ctx.needs_input_grad[3]: |
| grad_proj_bias = grad_output[0].sum(0) |
|
|
| |
| grad_o = grad_out_res() |
| grad_o = unpadding_tensor_for_seqeunce_parallel(grad_o, seq_dimension, unpadded_dim_size) |
|
|
| return ( |
| grad_o, |
| None, |
| None, |
| grad_proj_weight, |
| grad_proj_bias, |
| None, |
| None, |
| ) |
|
|
|
|
| def async_ulysses_qkv_projection( |
| hidden_states: Tensor = None, |
| seq_dimension: int = None, |
| head_dimension: int = None, |
| q_weight: Tensor = None, |
| q_bias: Optional[Tensor] = None, |
| k_weight: Tensor = None, |
| k_bias: Optional[Tensor] = None, |
| v_weight: Tensor = None, |
| v_bias: Optional[Tensor] = None, |
| norm_type: str = None, |
| norm_q_weight: Optional[Tensor] = None, |
| norm_q_bias: Optional[Tensor] = None, |
| norm_k_weight: Optional[Tensor] = None, |
| norm_k_bias: Optional[Tensor] = None, |
| normalized_shape: Optional[int] = None, |
| eps: Optional[float] = None, |
| unpadded_dim_size: int = None, |
| head_dim: int = None, |
| group: Optional[ProcessGroup] = None, |
| ): |
| return AsyncUlyssesQKVProjection.apply( |
| hidden_states, |
| seq_dimension, |
| head_dimension, |
| q_weight, |
| q_bias, |
| k_weight, |
| k_bias, |
| v_weight, |
| v_bias, |
| norm_type, |
| norm_q_weight, |
| norm_q_bias, |
| norm_k_weight, |
| norm_k_bias, |
| normalized_shape, |
| eps, |
| unpadded_dim_size, |
| head_dim, |
| group, |
| ) |
|
|
|
|
| def async_ulysses_output_projection( |
| hidden_states: Optional[Tensor] = None, |
| seq_dimension: int = None, |
| head_dimension: int = None, |
| proj_weight: Optional[Tensor] = None, |
| proj_bias: Optional[Tensor] = None, |
| unpadded_dim_size: Optional[int] = None, |
| group: Optional[ProcessGroup] = None, |
| ): |
| return AsyncUlyssesOutputProjection.apply( |
| hidden_states, |
| seq_dimension, |
| head_dimension, |
| proj_weight, |
| proj_bias, |
| unpadded_dim_size, |
| group, |
| ) |
|
|