|
|
|
|
|
from typing import Optional |
|
|
from functools import partial |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from torch.distributed import ProcessGroup |
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
|
|
import fused_dense_lib as fused_dense_cuda |
|
|
|
|
|
from flash_attn.utils.distributed import reduce_scatter, all_reduce |
|
|
from einops import rearrange |
|
|
|
|
|
|
|
|
from HybridTensor.modules.references.fused_dense import ColumnParallelLinear, RowParallelLinear, fused_mlp_func |
|
|
from HybridTensor.modules.references.MLP import SelectiveMLPTriton |
|
|
from HybridTensor.utils.utils import arg_parser, sparse_index |
|
|
from HybridTensor.utils.profiling import cuda_profiler |
|
|
|
|
|
|
|
|
from HybridTensor.triton.gather_gemm_col import gather_matmul_col |
|
|
from HybridTensor.triton.gather_gemm_row import gather_matmul_row |
|
|
|
|
|
|
|
|
from HybridTensor.triton.heuristics.gather_gemm_col_h import gather_matmul_col as gather_matmul_col_h |
|
|
from HybridTensor.triton.heuristics.gather_gemm_row_h import gather_matmul_row as gather_matmul_row_h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def SelectiveMLPFunc(x, fc1_w, fc2_w, index_vec, bias1 = None, bias2 = None, activation='relu', use_heuristic=True): |
|
|
if use_heuristic: |
|
|
out = gather_matmul_col_h(x, fc1_w, index_vec, bias = bias1, activations=activation) |
|
|
out = gather_matmul_row_h(out, fc2_w, index_vec, bias = bias2) |
|
|
else: |
|
|
out = gather_matmul_col(x, fc1_w, index_vec, bias = bias1, activations=activation) |
|
|
out = gather_matmul_row(out, fc2_w, index_vec, bias = bias2) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MLPRouter(nn.Module): |
|
|
def __init__(self, embed_dim, low_rank_dim, out_dim, act_th, device=None, dtype=None): |
|
|
""" |
|
|
Initializes the MHARouter class. |
|
|
|
|
|
Args: |
|
|
embed_dim (int): Dimensionality of the input embeddings. |
|
|
low_rank_dim (int): Dimensionality of the intermediate layer. |
|
|
out_dim (int): Number of neurons. |
|
|
""" |
|
|
super(MLPRouter, self).__init__() |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
self.fc1 = nn.Linear(embed_dim, low_rank_dim, bias=False, **factory_kwargs) |
|
|
self.fc2 = nn.Linear(low_rank_dim, out_dim, bias=False, **factory_kwargs) |
|
|
self.act_th = act_th |
|
|
self.num_neurons = out_dim |
|
|
self.largest = self.num_neurons + 1 |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass of the MHARouter. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape (batch_size, embed_dim). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor of shape (batch_size, num_heads). |
|
|
""" |
|
|
x = self.fc1(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
def _select_neurons_topk(self, x, topk=None): |
|
|
neurons = self.forward(x) |
|
|
|
|
|
neurons_nonzero = torch.nn.ReLU()(neurons) |
|
|
_, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False) |
|
|
|
|
|
return index_vec |
|
|
|
|
|
def _select_neurons(self, x, th=None): |
|
|
''' |
|
|
Threshold based selection of neurons, not CG safe |
|
|
''' |
|
|
if th is None: |
|
|
th = self.act_th |
|
|
|
|
|
neurons = self.forward(x) |
|
|
activated = (neurons > th).sum(dim=0) |
|
|
index_vec = activated.nonzero().flatten() |
|
|
return index_vec |
|
|
|
|
|
def _select_neurons_cuda_safe(self, x, th=None): |
|
|
''' |
|
|
This function is used with threshold and is used for CG safe version of the code |
|
|
''' |
|
|
if th is None: |
|
|
th = self.act_th |
|
|
neurons = self.forward(x) |
|
|
activated = (neurons > th).sum(dim=0) |
|
|
|
|
|
indices = torch.arange(self.num_neurons, device=activated.device) |
|
|
selected = torch.where(activated > th, indices, torch.full_like(indices, self.largest)) |
|
|
|
|
|
index_vec, _ = torch.sort(selected) |
|
|
index_size = ((index_vec < self.largest).sum()).to(torch.int32) |
|
|
|
|
|
return index_size, index_vec |
|
|
|
|
|
|
|
|
|
|
|
class ParallelMLPRouter(nn.Module): |
|
|
""" |
|
|
Parallel Sparse Predictor for MHA layer. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
low_rank_dim, |
|
|
out_dim, |
|
|
act_th, |
|
|
process_group, |
|
|
sequence_parallel=False, |
|
|
device=None, |
|
|
dtype=None, |
|
|
): |
|
|
""" |
|
|
Initializes the ParallelMHARouter class. |
|
|
|
|
|
Args: |
|
|
embed_dim (int): Dimensionality of the input embeddings. |
|
|
low_rank_dim (int): Dimensionality of the intermediate layer. |
|
|
out_dim (int): Output dimensionality (typically number of neurons). |
|
|
process_group (torch.distributed.ProcessGroup): Process group for parallelism. |
|
|
sequence_parallel (bool, optional): Whether to use sequence parallelism. Defaults to False. |
|
|
device (torch.device, optional): Device to run the module on. Defaults to None. |
|
|
dtype (torch.dtype, optional): Data type of the module parameters. Defaults to None. |
|
|
""" |
|
|
super(ParallelMLPRouter, self).__init__() |
|
|
assert process_group is not None, "ParallelMHARouter requires a process group." |
|
|
|
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
self.process_group = process_group |
|
|
self.embed_dim = embed_dim |
|
|
self.act_th = act_th |
|
|
|
|
|
self.fc1 = nn.Linear( |
|
|
embed_dim, low_rank_dim, bias=False, **factory_kwargs |
|
|
) |
|
|
self.fc2 = ColumnParallelLinear( |
|
|
low_rank_dim, |
|
|
out_dim, |
|
|
process_group, |
|
|
bias=False, |
|
|
sequence_parallel=sequence_parallel, |
|
|
**factory_kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Forward pass of the ParallelMHARouter. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, embed_dim). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Output tensor of shape (batch_size, seq_len, out_dim). |
|
|
""" |
|
|
x = self.fc1(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
def _select_neurons(self, x, th=None): |
|
|
if th is None: |
|
|
th = self.act_th |
|
|
|
|
|
neurons = self.forward(x) |
|
|
activated = (neurons > th).sum(dim=0) |
|
|
index_vec = activated.nonzero().flatten() |
|
|
return index_vec |
|
|
|
|
|
def _select_neurons_topk(self, x, topk=None): |
|
|
neurons = self.forward(x) |
|
|
|
|
|
neurons_nonzero = torch.nn.ReLU()(neurons) |
|
|
|
|
|
|
|
|
_, index_vec = neurons_nonzero.sum(dim=0).topk(topk, dim=0, sorted=False) |
|
|
|
|
|
return index_vec |
|
|
|
|
|
class SelectiveMLP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
hidden_features=None, |
|
|
out_features=None, |
|
|
activation='relu', |
|
|
layer_idx=None, |
|
|
bias1=True, |
|
|
bias2=True, |
|
|
return_residual=False, |
|
|
checkpoint_lvl=0, |
|
|
use_heuristic=True, |
|
|
device=None, |
|
|
dtype=None, |
|
|
): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
out_features = out_features if out_features is not None else in_features |
|
|
hidden_features = hidden_features if hidden_features is not None else in_features * 4 |
|
|
self.return_residual = return_residual |
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) |
|
|
self.activation = activation |
|
|
self.activation_fn = nn.ReLU() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) |
|
|
|
|
|
self.fc2_weight_t = None |
|
|
self.use_heuristic = use_heuristic |
|
|
|
|
|
def _init_weights(self): |
|
|
|
|
|
self.fc2_weight_t = self.fc2.weight.t().contiguous() |
|
|
|
|
|
def forward(self, x, index_vec=None, index_size=None): |
|
|
|
|
|
if index_vec is not None: |
|
|
|
|
|
|
|
|
|
|
|
if self.fc2_weight_t is None: |
|
|
self.fc2_weight_t = self.fc2.weight.t().contiguous() |
|
|
|
|
|
|
|
|
self.fc2.weight = None |
|
|
del self.fc2._parameters['weight'] |
|
|
|
|
|
x = x.view(-1, x.size(-1)) |
|
|
|
|
|
y = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight, |
|
|
fc2_w = self.fc2_weight_t, index_vec = index_vec, |
|
|
bias1 = self.fc1.bias, bias2 = self.fc2.bias, |
|
|
activation=self.activation, use_heuristic=self.use_heuristic) |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
y = self.fc1(x) |
|
|
y = self.activation_fn(y) |
|
|
|
|
|
if self.fc2_weight_t is not None: |
|
|
y = torch.matmul(y, self.fc2_weight_t) |
|
|
else: |
|
|
y = self.fc2(y) |
|
|
|
|
|
return y if not self.return_residual else (y, x) |
|
|
|
|
|
class ParallelSelectiveMLP(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_features, |
|
|
hidden_features, |
|
|
out_features=None, |
|
|
activation="relu", |
|
|
layer_idx=None, |
|
|
process_group: ProcessGroup = None, |
|
|
bias1=True, |
|
|
bias2=True, |
|
|
return_residual=False, |
|
|
sequence_parallel=False, |
|
|
use_heuristic=True, |
|
|
checkpoint_lvl=0, |
|
|
heuristic="auto", |
|
|
device=None, |
|
|
dtype=None, |
|
|
): |
|
|
""" |
|
|
process_group is required. We're doing Tensor Parallel with sequence parallelism: |
|
|
we do an all_gather of x before doing the matmul, gelu, then matmul. |
|
|
Finally we do a reduce_scatter of the output. |
|
|
|
|
|
checkpoint_lvl (increasing lvl means slower but more memory saving): |
|
|
0: no recomputation in the bwd |
|
|
1: recompute gelu_out in the bwd |
|
|
2: recompute pre_act and gelu_out in the bwd |
|
|
heuristic: |
|
|
-1: don't fuse gemm + gelu (separate kernel) |
|
|
0..4: use this heuristic for the algo section in the fused gemm + gelu |
|
|
'auto': heuristic will be picked automatically: |
|
|
For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. |
|
|
For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. |
|
|
""" |
|
|
assert checkpoint_lvl in [0, 1, 2] |
|
|
assert activation in ["gelu_approx", "relu"] |
|
|
assert process_group is not None |
|
|
|
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
if out_features is None: |
|
|
out_features = in_features |
|
|
self.activation = activation |
|
|
self.process_group = process_group |
|
|
self.sequence_parallel = sequence_parallel |
|
|
self.checkpoint_lvl = checkpoint_lvl |
|
|
self.heuristic = heuristic |
|
|
self.fc1 = ColumnParallelLinear( |
|
|
in_features, hidden_features, process_group, bias=bias1, **factory_kwargs |
|
|
) |
|
|
self.fc2 = RowParallelLinear( |
|
|
hidden_features, out_features, process_group, bias=bias2, **factory_kwargs |
|
|
) |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.fc2_weight_t = self.register_buffer("fc2_weigth_t", None) |
|
|
self.return_residual = return_residual |
|
|
self.fc2_weight_t = None |
|
|
self.use_heuristic = use_heuristic |
|
|
self.reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce |
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self): |
|
|
|
|
|
self.fc2_weight_t = self.fc2.weight.t().contiguous() |
|
|
|
|
|
def forward(self, x, residual = None, index_vec = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if index_vec is not None: |
|
|
|
|
|
if self.fc2_weight_t is None: |
|
|
self.fc2_weight_t = self.fc2.weight.t().contiguous() |
|
|
|
|
|
x = x.view(-1, x.size(-1)) |
|
|
|
|
|
|
|
|
out = SelectiveMLPFunc(x = x, fc1_w = self.fc1.weight, |
|
|
fc2_w = self.fc2_weight_t, index_vec = index_vec, |
|
|
bias1 = self.fc1.bias, bias2 = self.fc2.bias, |
|
|
activation=self.activation, use_heuristic=self.use_heuristic) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
if self.heuristic == "auto": |
|
|
dtype = ( |
|
|
x.dtype |
|
|
if not torch.is_autocast_enabled() |
|
|
else torch.get_autocast_gpu_dtype() |
|
|
) |
|
|
if self.activation == "gelu_approx": |
|
|
cuda_ver = tuple(map(int, torch.version.cuda.split("."))) |
|
|
heuristic = ( |
|
|
0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) |
|
|
) |
|
|
else: |
|
|
heuristic = 0 |
|
|
else: |
|
|
heuristic = self.heuristic |
|
|
out = fused_mlp_func( |
|
|
x, |
|
|
self.fc1.weight, |
|
|
self.fc2.weight, |
|
|
self.fc1.bias, |
|
|
self.fc2.bias, |
|
|
activation=self.activation, |
|
|
save_pre_act=self.training, |
|
|
checkpoint_lvl=self.checkpoint_lvl, |
|
|
heuristic=heuristic, |
|
|
process_group=self.process_group, |
|
|
sequence_parallel=self.sequence_parallel, |
|
|
) |
|
|
|
|
|
if self.process_group.size() > 1: |
|
|
|
|
|
dist.all_reduce(out, op=dist.ReduceOp.SUM, group=self.process_group) |
|
|
|
|
|
return out if not self.return_residual else (out, x) |
|
|
|
|
|
|
|
|
def sp_forward(self, x, residual = None, index_vec = None): |
|
|
if self.heuristic == "auto": |
|
|
dtype = ( |
|
|
x.dtype |
|
|
if not torch.is_autocast_enabled() |
|
|
else torch.get_autocast_gpu_dtype() |
|
|
) |
|
|
if self.activation == "gelu_approx": |
|
|
cuda_ver = tuple(map(int, torch.version.cuda.split("."))) |
|
|
heuristic = ( |
|
|
0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) |
|
|
) |
|
|
else: |
|
|
heuristic = 0 |
|
|
else: |
|
|
heuristic = self.heuristic |
|
|
curr_stream = torch.cuda.current_stream() |
|
|
do_token_generation = x.size(1) == 1 |
|
|
|
|
|
|
|
|
|
|
|
if index_vec != None: |
|
|
assert x.size(1) == 1 |
|
|
|
|
|
if self.fc2_weight_t is None: |
|
|
self.fc2_weight_t = self.fc2.weight.t().contiguous() |
|
|
|
|
|
out = SelectiveMLPFunc( |
|
|
rearrange(x, "b 1 d -> b d"), |
|
|
self.fc1.weight, |
|
|
self.fc2_weight_t, |
|
|
index_vec, |
|
|
self.fc1.bias, |
|
|
self.fc2.bias, |
|
|
activation=self.activation, |
|
|
) |
|
|
out = rearrange(out, "b d -> b 1 d") |
|
|
else: |
|
|
out = fused_mlp_func( |
|
|
x, |
|
|
self.fc1.weight, |
|
|
self.fc2.weight, |
|
|
self.fc1.bias, |
|
|
self.fc2.bias, |
|
|
activation=self.activation, |
|
|
save_pre_act=self.training, |
|
|
checkpoint_lvl=self.checkpoint_lvl, |
|
|
heuristic=heuristic, |
|
|
process_group=self.process_group, |
|
|
sequence_parallel=self.sequence_parallel, |
|
|
) |
|
|
|
|
|
|
|
|
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce |
|
|
if self.sp_router: |
|
|
curr_stream.record_event(self.event_mlp) |
|
|
|
|
|
|
|
|
out = reduce_fn(out, self.process_group) |
|
|
|
|
|
|
|
|
if self.sp_router: |
|
|
with torch.cuda.stream(self.sp_stream): |
|
|
self.sp_stream.wait_event(self.event_mlp) |
|
|
if do_token_generation: |
|
|
mlp_logit = self.sp(rearrange(residual, "b 1 d -> b d")) |
|
|
self.sp_stream.record_event(self.event_mlp_sp) |
|
|
|
|
|
|
|
|
curr_stream.wait_event(self.event_mlp_sp) |
|
|
|
|
|
return out |
|
|
|
|
|
class SimpleMLP(nn.Module): |
|
|
def __init__(self, in_features, hidden_features, out_features, bias=False, activation="relu"): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) |
|
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) |
|
|
self.activation = activation |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = arg_parser() |
|
|
|
|
|
bias = True if args.bias > 0 else False |
|
|
x = torch.randn(args.batch_size, args.in_features, device="cuda", dtype=torch.float16) |
|
|
index_vec, _ = sparse_index(args.index_size, args.in_features*4) |
|
|
|
|
|
''' |
|
|
selective_mlp = SelectiveMLPTriton(args.in_features, args.hidden_features, bias=bias, device="cuda", dtype=torch.float16, activation="relu") |
|
|
|
|
|
out, mlp_time = cuda_profiler(selective_mlp, x, index_vec) |
|
|
|
|
|
out_col, col_time = cuda_profiler(gather_matmul_col, x, selective_mlp.fc1_w, index_vec, activations=selective_mlp.activation) |
|
|
out_row, row_time = cuda_profiler(gather_matmul_row, out_col, selective_mlp.fc2_w, index_vec) |
|
|
sum_time = col_time + row_time |
|
|
|
|
|
print(f"Index size {args.index_size}, Activated {args.index_size/(args.in_features * 4)*100}% neurons") |
|
|
|
|
|
print(f"Gather Col Time: {col_time} ms") |
|
|
print(f"Gather Row Time: {row_time} ms") |
|
|
# print(f"Sum Time: {sum_time} ms") |
|
|
|
|
|
print(f"SelectiveMLP Time: {mlp_time} ms") |
|
|
''' |
|
|
|
|
|
in_features = args.in_features |
|
|
hidden_features = in_features * 4 |
|
|
out_features = in_features |
|
|
device = torch.device("cuda") |
|
|
|
|
|
model = SelectiveMLP( |
|
|
in_features, hidden_features, out_features, device=device, dtype=torch.float16, activation="relu", use_heuristic=True |
|
|
).to(device) |
|
|
|
|
|
router = MLPRouter(in_features, 1024, hidden_features, act_th = 0.5, device=device, dtype=torch.float16).to(device) |
|
|
|
|
|
|
|
|
def warmup(): |
|
|
for _ in range(10): |
|
|
_ = model(x, index_vec) |
|
|
_ = model(x, None) |
|
|
_ = router._select_neurons_topk(x, args.index_size) |
|
|
|
|
|
warmup() |
|
|
|
|
|
|
|
|
_, router_time = cuda_profiler(router._select_neurons_topk, x, args.index_size) |
|
|
_, selective_time = cuda_profiler(model, x, index_vec) |
|
|
|
|
|
_, dense_time = cuda_profiler(model, x, None) |
|
|
|
|
|
print(f"Router time per run: {router_time:.6f} ms") |
|
|
print(f"SelectiveMLPFunc time per run: {selective_time:.6f} ms") |
|
|
print(f"Dense forward time per run: {dense_time:.6f} ms") |
|
|
print(f"Speedup: {dense_time / selective_time:.2f}x") |
|
|
router_selective_time = router_time + selective_time |
|
|
print(f"Router + SelectiveMLPFunc time per run: {router_selective_time:.6f} ms") |
|
|
print(f"Speedup: {dense_time / router_selective_time:.2f}x") |
|
|
|
|
|
|
|
|
|
|
|
print("\n=== CUDA Graph Tests ===") |
|
|
|
|
|
print("Testing CUDA Graph for Selective forward (with index_vec)...") |
|
|
static_x = x.clone() |
|
|
static_index_vec = index_vec.clone() |
|
|
|
|
|
static_out_sel = model(static_x, index_vec=static_index_vec) |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
capture_stream = torch.cuda.Stream() |
|
|
with torch.cuda.stream(capture_stream): |
|
|
g_sel = torch.cuda.CUDAGraph() |
|
|
g_sel.capture_begin() |
|
|
static_out_sel = model(static_x, index_vec=static_index_vec) |
|
|
g_sel.capture_end() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
g_sel.replay() |
|
|
torch.cuda.synchronize() |
|
|
cuda_sel_out = static_out_sel.clone() |
|
|
regular_sel_out = model(x, index_vec=index_vec) |
|
|
if torch.allclose(cuda_sel_out, regular_sel_out, atol=1e-3): |
|
|
print("Selective forward CUDA Graph output matches regular output") |
|
|
else: |
|
|
print("Selective forward CUDA Graph output does NOT match regular output") |
|
|
|
|
|
def replay_sel(): |
|
|
g_sel.replay() |
|
|
_, selective_time_cuda = cuda_profiler(replay_sel) |
|
|
print(f"Selective forward CUDA Graph time per run: {selective_time_cuda:.6f} ms") |