Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
# python -m HybridTensor.modules.SelectiveMLP --batch_size 8 --index_size 512
from typing import Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
import torch.distributed as dist
# import fused_dense_cuda # from apex
import fused_dense_lib as fused_dense_cuda
from flash_attn.utils.distributed import reduce_scatter, all_reduce
from einops import rearrange
# from HybridTensor.modules.MLP import SelectiveMLPFunc
from HybridTensor.modules.references.fused_dense import ColumnParallelLinear, RowParallelLinear, fused_mlp_func
from HybridTensor.modules.references.MLP import SelectiveMLPTriton
from HybridTensor.utils.utils import arg_parser, sparse_index
from HybridTensor.utils.profiling import cuda_profiler
# compiles the kernels for the first time, takes time
from HybridTensor.triton.gather_gemm_col import gather_matmul_col
from HybridTensor.triton.gather_gemm_row import gather_matmul_row
# needs to be compiled before running
from HybridTensor.triton.heuristics.gather_gemm_col_h import gather_matmul_col as gather_matmul_col_h
from HybridTensor.triton.heuristics.gather_gemm_row_h import gather_matmul_row as gather_matmul_row_h
# from HybridTensor.triton.cg_safe.gather_gemm_col_cg import gather_matmul_col
# from HybridTensor.triton.cg_safe.gather_gemm_row_cg import gather_matmul_row
def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu', use_heuristic=True):
if use_heuristic:
out = gather_matmul_col_h(x, fc1_w, index_vec, bias = bias1, activations=activation)
out = gather_matmul_row_h(out, fc2_w, index_vec, bias = bias2)
else:
out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation)
out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2)
return out
# cg safe version
# def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, index_size, bias1 = None, bias2 = None, activation='relu', use_heuristic=True):
# out = gather_matmul_col(x, fc1_w, index_vec, index_size, bias = bias1, activations=activation)
# out = gather_matmul_row(out, fc2_w, index_vec, index_size, bias = bias2)
# return out
class MLPRouter(nn.Module):
def __init__(self, embed_dim, low_rank_dim, out_dim, act_th, device=None, dtype=None):
"""
Initializes the MHARouter class.
Args:
embed_dim (int): Dimensionality of the input embeddings.
low_rank_dim (int): Dimensionality of the intermediate layer.
out_dim (int): Number of neurons.
"""
super(MLPRouter, self).__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.fc1 = nn.Linear(embed_dim, low_rank_dim, bias=False, **factory_kwargs)
self.fc2 = nn.Linear(low_rank_dim, out_dim, bias=False, **factory_kwargs)
self.act_th = act_th
self.num_neurons = out_dim
self.largest = self.num_neurons + 1
def forward(self, x):
"""
Forward pass of the MHARouter.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, embed_dim).
Returns:
torch.Tensor: Output tensor of shape (batch_size, num_heads).
"""
x = self.fc1(x)
x = self.fc2(x)
return x
def _select_neurons_topk(self, x, topk=None):
neurons = self.forward(x)
neurons_nonzero = torch.nn.ReLU()(neurons)
_, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False)
# index_vec, _ = index_vec.sort()
return index_vec
def _select_neurons(self, x, th=None):
'''
Threshold based selection of neurons, not CG safe
'''
if th is None:
th = self.act_th
neurons = self.forward(x)
activated = (neurons > th).sum(dim=0)
index_vec = activated.nonzero().flatten()
return index_vec
def _select_neurons_cuda_safe(self, x, th=None):
'''
This function is used with threshold and is used for CG safe version of the code
'''
if th is None:
th = self.act_th
neurons = self.forward(x)
activated = (neurons > th).sum(dim=0)
indices = torch.arange(self.num_neurons, device=activated.device)
selected = torch.where(activated > th, indices, torch.full_like(indices, self.largest))
index_vec, _ = torch.sort(selected)
index_size = ((index_vec < self.largest).sum()).to(torch.int32)
return index_size, index_vec
class ParallelMLPRouter(nn.Module):
"""
Parallel Sparse Predictor for MHA layer.
"""
def __init__(
self,
embed_dim,
low_rank_dim,
out_dim,
act_th,
process_group,
sequence_parallel=False,
device=None,
dtype=None,
):
"""
Initializes the ParallelMHARouter class.
Args:
embed_dim (int): Dimensionality of the input embeddings.
low_rank_dim (int): Dimensionality of the intermediate layer.
out_dim (int): Output dimensionality (typically number of neurons).
process_group (torch.distributed.ProcessGroup): Process group for parallelism.
sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False.
device (torch.device, optional): Device to run the module on. Defaults to None.
dtype (torch.dtype, optional): Data type of the module parameters. Defaults to None.
"""
super(ParallelMLPRouter, self).__init__()
assert process_group is not None, "ParallelMHARouter requires a process group."
factory_kwargs = {"device": device, "dtype": dtype}
self.process_group = process_group
self.embed_dim = embed_dim
self.act_th = act_th
self.fc1 = nn.Linear(
embed_dim, low_rank_dim, bias=False, **factory_kwargs
)
self.fc2 = ColumnParallelLinear(
low_rank_dim,
out_dim,
process_group,
bias=False,
sequence_parallel=sequence_parallel,
**factory_kwargs,
)
# def _select_neurons(self, neurons, th=None):
# if th is None:
# th = self.act_th
# activated = (neurons > th).sum(dim=0)
# index_vec = activated.nonzero().flatten()
# return index_vec
def forward(self, x):
"""
Forward pass of the ParallelMHARouter.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embed_dim).
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, out_dim).
"""
x = self.fc1(x)
x = self.fc2(x)
return x
def _select_neurons(self, x, th=None):
if th is None:
th = self.act_th
neurons = self.forward(x)
activated = (neurons > th).sum(dim=0)
index_vec = activated.nonzero().flatten()
return index_vec
def _select_neurons_topk(self, x, topk=None):
neurons = self.forward(x)
neurons_nonzero = torch.nn.ReLU()(neurons) #.squeeze(1)
# print(f"neurons_nonzero shape: {neurons_nonzero.shape}")
# print(f"Top k neurons: {topk}")
_, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False)
# index_vec, _ = index_vec.sort()
return index_vec
class SelectiveMLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
activation='relu',
layer_idx=None,
bias1=True,
bias2=True,
return_residual=False,
checkpoint_lvl=0,
use_heuristic=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features if out_features is not None else in_features
hidden_features = hidden_features if hidden_features is not None else in_features * 4
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.activation = activation
self.activation_fn = nn.ReLU()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
# self.fc2_weight_t = self.fc2.weight.t().contiguous()
self.fc2_weight_t = None
self.use_heuristic = use_heuristic
def _init_weights(self):
# if weights are updated, we need to update the transpose
self.fc2_weight_t = self.fc2.weight.t().contiguous()
def forward(self, x, index_vec=None, index_size=None):
if index_vec is not None:
# sparse forward,
# update on first run
if self.fc2_weight_t is None:
self.fc2_weight_t = self.fc2.weight.t().contiguous()
# Remove the original parameter to free memory.
self.fc2.weight = None
del self.fc2._parameters['weight']
x = x.view(-1, x.size(-1))
# x = x.squeeze(1)
y = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight,
fc2_w = self.fc2_weight_t, index_vec = index_vec,
bias1 = self.fc1.bias, bias2 = self.fc2.bias,
activation=self.activation, use_heuristic=self.use_heuristic)
else:
# dense forward
y = self.fc1(x)
y = self.activation_fn(y)
if self.fc2_weight_t is not None:
y = torch.matmul(y, self.fc2_weight_t)
else:
y = self.fc2(y)
return y if not self.return_residual else (y, x)
class ParallelSelectiveMLP(nn.Module):
def __init__(
self,
in_features,
hidden_features,
out_features=None,
activation="relu",
layer_idx=None,
process_group: ProcessGroup = None,
bias1=True,
bias2=True,
return_residual=False,
sequence_parallel=False,
use_heuristic=True,
checkpoint_lvl=0,
heuristic="auto",
device=None,
dtype=None,
):
"""
process_group is required. We're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul, gelu, then matmul.
Finally we do a reduce_scatter of the output.
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute pre_act and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
'auto': heuristic will be picked automatically:
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
"""
assert checkpoint_lvl in [0, 1, 2]
assert activation in ["gelu_approx", "relu"]
assert process_group is not None
# assert sp_kwargs != None, "sparse predictor parameters are not passed in."
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if out_features is None:
out_features = in_features
self.activation = activation
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.fc1 = ColumnParallelLinear(
in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
)
self.fc2 = RowParallelLinear(
hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
)
self.layer_idx = layer_idx
self.fc2_weight_t = self.register_buffer("fc2_weigth_t", None)
self.return_residual = return_residual
self.fc2_weight_t = None
self.use_heuristic = use_heuristic
self.reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
# self._init_weights()
def _init_weights(self):
# ffn2 weights needs to be in row major format to select from rows
self.fc2_weight_t = self.fc2.weight.t().contiguous()
def forward(self, x, residual = None, index_vec = None):
# do_token_generation = x.size(1) == 1
# index_vec = None
# with torch.cuda.stream(self.curr_stream):
if index_vec is not None:
# assert x.size(1) == 1
if self.fc2_weight_t is None:
self.fc2_weight_t = self.fc2.weight.t().contiguous()
x = x.view(-1, x.size(-1))
# x = rearrange(x, "b 1 d -> b d") # slightly more expensive to use rearrange
out = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight,
fc2_w = self.fc2_weight_t, index_vec = index_vec,
bias1 = self.fc1.bias, bias2 = self.fc2.bias,
activation=self.activation, use_heuristic=self.use_heuristic)
# out = rearrange(out, "b d -> b 1 d")
# out = out.view(-1, 1, out.size(-1))
else: # normal mlp
if self.heuristic == "auto":
dtype = (
x.dtype
if not torch.is_autocast_enabled()
else torch.get_autocast_gpu_dtype()
)
if self.activation == "gelu_approx":
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
heuristic = (
0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
)
else:
heuristic = 0
else:
heuristic = self.heuristic
out = fused_mlp_func(
x,
self.fc1.weight,
self.fc2.weight,
self.fc1.bias,
self.fc2.bias,
activation=self.activation,
save_pre_act=self.training,
checkpoint_lvl=self.checkpoint_lvl,
heuristic=heuristic,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
)
if self.process_group.size() > 1:
# out = self.reduce_fn(out, self.process_group) # has some overhead,
dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.process_group)
return out if not self.return_residual else (out, x)
# return out
def sp_forward(self, x, residual = None, index_vec = None):
if self.heuristic == "auto":
dtype = (
x.dtype
if not torch.is_autocast_enabled()
else torch.get_autocast_gpu_dtype()
)
if self.activation == "gelu_approx":
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
heuristic = (
0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
)
else:
heuristic = 0
else:
heuristic = self.heuristic
curr_stream = torch.cuda.current_stream()
do_token_generation = x.size(1) == 1
# mlp_logit = None
# with torch.cuda.stream(self.curr_stream):
if index_vec != None:
assert x.size(1) == 1
if self.fc2_weight_t is None:
self.fc2_weight_t = self.fc2.weight.t().contiguous()
out = SelectiveMLPFunc(
rearrange(x, "b 1 d -> b d"),
self.fc1.weight,
self.fc2_weight_t,
index_vec,
self.fc1.bias,
self.fc2.bias,
activation=self.activation,
)
out = rearrange(out, "b d -> b 1 d")
else:
out = fused_mlp_func(
x,
self.fc1.weight,
self.fc2.weight,
self.fc1.bias,
self.fc2.bias,
activation=self.activation,
save_pre_act=self.training,
checkpoint_lvl=self.checkpoint_lvl,
heuristic=heuristic,
process_group=self.process_group,
sequence_parallel=self.sequence_parallel,
)
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
if self.sp_router:
curr_stream.record_event(self.event_mlp)
# handle = torch.distributed.all_reduce(out, op=torch.distributed.ReduceOp.SUM, group=self.process_group, async_op=True)
out = reduce_fn(out, self.process_group)
if self.sp_router:
with torch.cuda.stream(self.sp_stream):
self.sp_stream.wait_event(self.event_mlp)
if do_token_generation:
mlp_logit = self.sp(rearrange(residual, "b 1 d -> b d"))
self.sp_stream.record_event(self.event_mlp_sp)
# check this again, we might not have to synchronize here, we can synchronize in the next layer
curr_stream.wait_event(self.event_mlp_sp)
return out
class SimpleMLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features, bias=False, activation="relu"):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.activation = activation
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
if __name__ == "__main__":
args = arg_parser()
bias = True if args.bias > 0 else False
x = torch.randn(args.batch_size, args.in_features, device="cuda", dtype=torch.float16)
index_vec, _ = sparse_index(args.index_size, args.in_features*4)
'''
selective_mlp = SelectiveMLPTriton(args.in_features, args.hidden_features, bias=bias, device="cuda", dtype=torch.float16, activation="relu")
out, mlp_time = cuda_profiler(selective_mlp, x, index_vec)
out_col, col_time = cuda_profiler(gather_matmul_col, x, selective_mlp.fc1_w, index_vec, activations=selective_mlp.activation)
out_row, row_time = cuda_profiler(gather_matmul_row, out_col, selective_mlp.fc2_w, index_vec)
sum_time = col_time + row_time
print(f"Index size {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons")
print(f"Gather Col Time: {col_time} ms")
print(f"Gather Row Time: {row_time} ms")
# print(f"Sum Time: {sum_time} ms")
print(f"SelectiveMLP Time: {mlp_time} ms")
'''
in_features = args.in_features
hidden_features = in_features * 4
out_features = in_features
device = torch.device("cuda")
model = SelectiveMLP(
in_features, hidden_features, out_features, device=device, dtype=torch.float16, activation="relu", use_heuristic=True
).to(device)
router = MLPRouter(in_features, 1024, hidden_features, act_th = 0.5, device=device, dtype=torch.float16).to(device)
# Warm-up GPU
def warmup():
for _ in range(10):
_ = model(x, index_vec)
_ = model(x, None)
_ = router._select_neurons_topk(x, args.index_size)
warmup()
# Measure SelectiveMLPFunc speed
_, router_time = cuda_profiler(router._select_neurons_topk, x, args.index_size)
_, selective_time = cuda_profiler(model, x, index_vec)
# Measure dense forward speed
_, dense_time = cuda_profiler(model, x, None)
print(f"Router time per run: {router_time:.6f} ms")
print(f"SelectiveMLPFunc time per run: {selective_time:.6f} ms")
print(f"Dense forward time per run: {dense_time:.6f} ms")
print(f"Speedup: {dense_time / selective_time:.2f}x")
router_selective_time = router_time + selective_time
print(f"Router + SelectiveMLPFunc time per run: {router_selective_time:.6f} ms")
print(f"Speedup: {dense_time / router_selective_time:.2f}x")
############################################
# CUDA Graph capture tests for the MLP model
############################################
print("\n=== CUDA Graph Tests ===")
# --- Selective forward (sparse mode) ---
print("Testing CUDA Graph for Selective forward (with index_vec)...")
static_x = x.clone()
static_index_vec = index_vec.clone()
# Warm-up run to allocate memory
static_out_sel = model(static_x, index_vec=static_index_vec)
torch.cuda.synchronize()
# Capture on a non-default stream
capture_stream = torch.cuda.Stream()
with torch.cuda.stream(capture_stream):
g_sel = torch.cuda.CUDAGraph()
g_sel.capture_begin()
static_out_sel = model(static_x, index_vec=static_index_vec)
g_sel.capture_end()
torch.cuda.synchronize()
# Replay and check accuracy
g_sel.replay()
torch.cuda.synchronize()
cuda_sel_out = static_out_sel.clone()
regular_sel_out = model(x, index_vec=index_vec)
if torch.allclose(cuda_sel_out, regular_sel_out, atol=1e-3):
print("Selective forward CUDA Graph output matches regular output")
else:
print("Selective forward CUDA Graph output does NOT match regular output")
def replay_sel():
g_sel.replay()
_, selective_time_cuda = cuda_profiler(replay_sel)
print(f"Selective forward CUDA Graph time per run: {selective_time_cuda:.6f} ms")