bazaar-research's picture
Upload folder using huggingface_hub
fb11af9 verified
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import numbers
from typing import Any, Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from .comm import get_ulysses_sequence_parallel_group
from .ulysses import all_to_all_tensor
from .utils import padding_tensor_for_seqeunce_parallel, unpadding_tensor_for_seqeunce_parallel
fused_layer_norm_cuda = None
def divide_qkv_linear_weight(weight: Tensor, dim: int):
return weight.chunk(3, dim=dim)
def divide_qkv_linear_bias(bias: Tensor, dim: int):
if bias is not None:
return bias.chunk(3, dim=dim)
else:
return None, None, None
class AsyncUlyssesQKVProjection(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
hidden_states: Tensor,
seq_dimension: int,
head_dimension: int,
q_weight: Tensor,
q_bias: Tensor,
k_weight: Tensor,
k_bias: Tensor,
v_weight: Tensor,
v_bias: Tensor,
norm_type: str,
norm_q_weight: Tensor,
norm_q_bias: Tensor,
norm_k_weight: Tensor,
norm_k_bias: Tensor,
normalized_shape: int,
eps: float,
unpadded_dim_size: int,
head_dim: int,
group: ProcessGroup,
):
sp_group = get_ulysses_sequence_parallel_group() if group is None else group
# q projection
q = F.linear(hidden_states, q_weight, q_bias)
# qk normalization (if needed)
if norm_type is not None:
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
global fused_layer_norm_cuda
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
norm_q_weight = norm_q_weight.contiguous()
output_q, mean_q, invvar_q = None, None, None
output_k, mean_k, invvar_k = None, None, None
if norm_type == "rmsnorm":
output_q, invvar_q = fused_layer_norm_cuda.rms_forward_affine(q, normalized_shape, norm_q_weight, eps)
elif norm_type == "layernorm":
output_q, mean_q, invvar_q = fused_layer_norm_cuda.forward_affine(
q, normalized_shape, norm_q_weight, norm_q_bias, eps
)
else:
raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
else:
output_q = q
mean_q = None
invvar_q = None
# q communication launch
output_q_res = all_to_all_tensor(
output_q, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
)
# k projection
k = F.linear(hidden_states, k_weight, k_bias)
# qk normalization (if needed)
if norm_type is not None:
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
normalized_shape = torch.Size(normalized_shape)
if fused_layer_norm_cuda is None:
fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
norm_k_weight = norm_k_weight.contiguous()
output_k, mean_k, invvar_k = None, None, None
if norm_type == "rmsnorm":
output_k, invvar_k = fused_layer_norm_cuda.rms_forward_affine(k, normalized_shape, norm_k_weight, eps)
elif norm_type == "layernorm":
output_k, mean_k, invvar_k = fused_layer_norm_cuda.forward_affine(
k, normalized_shape, norm_k_weight, norm_k_bias, eps
)
else:
raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
else:
output_k = k
mean_k = None
invvar_k = None
# k communication launch
output_k_res = all_to_all_tensor(
output_k, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
)
# v projection
v = F.linear(hidden_states, v_weight, v_bias)
# v communication launch
v_res = all_to_all_tensor(
v, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
)
# q communication collect
output_q = output_q_res()
output_q = unpadding_tensor_for_seqeunce_parallel(output_q, seq_dimension, unpadded_dim_size)
# k communication collect
output_k = output_k_res()
output_k = unpadding_tensor_for_seqeunce_parallel(output_k, seq_dimension, unpadded_dim_size)
# v communication collect
v = v_res()
v = unpadding_tensor_for_seqeunce_parallel(v, seq_dimension, unpadded_dim_size)
# save ctx for backward
ctx.sp_group = sp_group
ctx.head_dimension = head_dimension
ctx.seq_dimension = seq_dimension
ctx.norm_type = norm_type
ctx.normalized_shape = normalized_shape
ctx.eps = eps
ctx.save_for_backward(
hidden_states,
q_weight,
q_bias,
k_weight,
k_bias,
v_weight,
v_bias,
q,
norm_q_weight,
norm_q_bias,
mean_q,
invvar_q,
k,
norm_k_weight,
norm_k_bias,
mean_k,
invvar_k,
)
return output_q, output_k, v
@staticmethod
def backward(ctx: Any, *grad_output: Tensor):
# get ctx for backward
sp_group = ctx.sp_group
seq_dimension = ctx.seq_dimension
head_dimension = ctx.head_dimension
norm_type = ctx.norm_type
normalized_shape = ctx.normalized_shape
eps = ctx.eps
(
hidden_states,
q_weight,
q_bias,
k_weight,
k_bias,
v_weight,
v_bias,
q,
norm_q_weight,
norm_q_bias,
mean_q,
invvar_q,
k,
norm_k_weight,
norm_k_bias,
mean_k,
invvar_k,
) = ctx.saved_tensors
# initialize grads
grad_hidden_states = None
grad_q_weight = None
grad_q_bias = None
grad_k_weight = None
grad_k_bias = None
grad_v_weight = None
grad_v_bias = None
grad_norm_q_weight = None
grad_norm_q_bias = None
grad_norm_k_weight = None
grad_norm_k_bias = None
# v grad communication launch
grad_v = grad_output[2].contiguous()
grad_v = padding_tensor_for_seqeunce_parallel(grad_v, dim=seq_dimension)
grad_v_res = all_to_all_tensor(
grad_v,
scatter_dim=seq_dimension,
gather_dim=head_dimension,
group=sp_group,
async_op=True,
)
# v grad communication collect
grad_v = grad_v_res()
grad_k = grad_output[1].contiguous()
# k grad communication launch
grad_k = padding_tensor_for_seqeunce_parallel(grad_k, dim=seq_dimension)
grad_k_res = all_to_all_tensor(
grad_k,
scatter_dim=seq_dimension,
gather_dim=head_dimension,
group=sp_group,
async_op=True,
)
# v projection grad
grad_v_input = grad_v @ v_weight
grad_v_weight = grad_v.transpose(-1, -2) @ hidden_states
if v_bias is not None and ctx.needs_input_grad[7]:
grad_v_bias = grad_v.sum(0)
# qk normalization backward (if needed)
if norm_type is not None:
if norm_type == "rmsnorm":
grad_k, grad_norm_k_weight = fused_layer_norm_cuda.rms_backward_affine(
grad_k,
invvar_k,
k,
normalized_shape,
norm_k_weight,
eps,
False,
)
elif norm_type == "layernorm":
grad_k, grad_norm_k_weight, grad_norm_k_bias = fused_layer_norm_cuda.backward_affine(
grad_k,
mean_k,
invvar_k,
k,
normalized_shape,
norm_k_weight,
norm_k_bias,
eps,
False,
)
else:
raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
else:
grad_norm_k_weight = None
# k grad communication collect
grad_k = grad_k_res()
grad_q = grad_output[0].contiguous()
# q grad communication launch
grad_q = padding_tensor_for_seqeunce_parallel(grad_q, dim=seq_dimension)
grad_q_res = all_to_all_tensor(
grad_q,
scatter_dim=seq_dimension,
gather_dim=head_dimension,
group=sp_group,
async_op=True,
)
# k projection grad
grad_k_input = grad_k @ k_weight
grad_k_weight = grad_k.transpose(-1, -2) @ hidden_states
if k_bias is not None and ctx.needs_input_grad[5]:
grad_k_bias = grad_k.sum(0)
# q grad communication collect
grad_q = grad_q_res()
# qk normalization backward (if needed)
if norm_type is not None:
if norm_type == "rmsnorm":
grad_q, grad_norm_q_weight = fused_layer_norm_cuda.rms_backward_affine(
grad_q,
invvar_q,
q,
normalized_shape,
norm_q_weight,
eps,
False,
)
elif norm_type == "layernorm":
grad_q, grad_norm_q_weight, grad_norm_q_bias = fused_layer_norm_cuda.backward_affine(
grad_q,
mean_q,
invvar_q,
q,
normalized_shape,
norm_q_weight,
norm_q_bias,
eps,
False,
)
else:
raise NotImplementedError(f"{norm_type} is not supported in async-ulysses now!")
else:
grad_norm_q_weight = None
# q projection grad
grad_q_input = grad_q @ q_weight
grad_q_weight = grad_q.transpose(-1, -2) @ hidden_states
if q_bias is not None and ctx.needs_input_grad[3]:
grad_q_bias = grad_q.sum(0)
# grad
grad_hidden_states = grad_q_input + grad_k_input + grad_v_input
return (
grad_hidden_states,
None,
None,
grad_q_weight,
grad_q_bias,
grad_k_weight,
grad_k_bias,
grad_v_weight,
grad_v_bias,
None,
grad_norm_q_weight,
grad_norm_q_bias,
grad_norm_k_weight,
grad_norm_k_bias,
None,
None,
None,
None,
None,
)
class AsyncUlyssesOutputProjection(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
hidden_states: Tensor,
seq_dimension: int,
head_dimension: int,
proj_weight: Tensor,
proj_bias: Tensor,
unpadded_dim_size: int,
group: ProcessGroup,
):
sp_group = get_ulysses_sequence_parallel_group() if group is None else group
# out projection
hidden_states = padding_tensor_for_seqeunce_parallel(hidden_states, seq_dimension)
hidden_states = all_to_all_tensor(
hidden_states, scatter_dim=seq_dimension, gather_dim=head_dimension, group=sp_group
)
o = F.linear(hidden_states, proj_weight, proj_bias)
# save ctx for backward
ctx.sp_group = sp_group
ctx.head_dimension = head_dimension
ctx.seq_dimension = seq_dimension
ctx.unpadded_dim_size = unpadded_dim_size
ctx.save_for_backward(
hidden_states,
proj_weight,
proj_bias,
)
return o
@staticmethod
def backward(ctx: Any, *grad_output: Tensor):
# get ctx for backward
sp_group = ctx.sp_group
head_dimension = ctx.head_dimension
seq_dimension = ctx.seq_dimension
unpadded_dim_size = ctx.unpadded_dim_size
(
hidden_states,
proj_weight,
proj_bias,
) = ctx.saved_tensors
# initialize grads
grad_o = None
grad_proj_weight = None
grad_proj_bias = None
# output grad
grad_o = grad_output[0] @ (proj_weight)
# output grad communication launch
grad_out_res = all_to_all_tensor(
grad_o, scatter_dim=head_dimension, gather_dim=seq_dimension, group=sp_group, async_op=True
)
grad_proj_weight = grad_output[0].transpose(-1, -2) @ (hidden_states)
if proj_bias is not None and ctx.needs_input_grad[3]:
grad_proj_bias = grad_output[0].sum(0)
# output grad communication collect
grad_o = grad_out_res()
grad_o = unpadding_tensor_for_seqeunce_parallel(grad_o, seq_dimension, unpadded_dim_size)
return (
grad_o,
None,
None,
grad_proj_weight,
grad_proj_bias,
None,
None,
)
def async_ulysses_qkv_projection(
hidden_states: Tensor = None,
seq_dimension: int = None,
head_dimension: int = None,
q_weight: Tensor = None,
q_bias: Optional[Tensor] = None,
k_weight: Tensor = None,
k_bias: Optional[Tensor] = None,
v_weight: Tensor = None,
v_bias: Optional[Tensor] = None,
norm_type: str = None,
norm_q_weight: Optional[Tensor] = None,
norm_q_bias: Optional[Tensor] = None,
norm_k_weight: Optional[Tensor] = None,
norm_k_bias: Optional[Tensor] = None,
normalized_shape: Optional[int] = None,
eps: Optional[float] = None,
unpadded_dim_size: int = None,
head_dim: int = None,
group: Optional[ProcessGroup] = None,
):
return AsyncUlyssesQKVProjection.apply(
hidden_states,
seq_dimension,
head_dimension,
q_weight,
q_bias,
k_weight,
k_bias,
v_weight,
v_bias,
norm_type,
norm_q_weight,
norm_q_bias,
norm_k_weight,
norm_k_bias,
normalized_shape,
eps,
unpadded_dim_size,
head_dim,
group,
)
def async_ulysses_output_projection(
hidden_states: Optional[Tensor] = None,
seq_dimension: int = None,
head_dimension: int = None,
proj_weight: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
unpadded_dim_size: Optional[int] = None,
group: Optional[ProcessGroup] = None,
):
return AsyncUlyssesOutputProjection.apply(
hidden_states,
seq_dimension,
head_dimension,
proj_weight,
proj_bias,
unpadded_dim_size,
group,
)