PolarSparsity / HybridTensor /modules /SelectiveBlock.py
Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchvision.ops import StochasticDepth
try:
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
except ImportError:
layer_norm_fn, RMSNorm = None, None
class SelectBlock(nn.Module):
def __init__(
self,
dim,
mixer_cls=None,
mlp_cls=None,
mlp_router=None,
mha_router=None,
norm_cls=nn.LayerNorm,
dropout_cls=nn.Dropout,
prenorm=True,
resid_dropout1=0.0,
resid_dropout2=0.0,
drop_path1=0.0,
drop_path2=0.0,
fused_dropout_add_ln=False,
return_residual=False,
residual_in_fp32=False,
sequence_parallel=False,
mark_shared_params=False,
):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc.
If you want to do concurrency with CUDA graphs, your shapes must remain fixed
(batch_size, seq_len, etc.) across captures and replays. Also avoid any operations
that cause dynamic shape changes or memory allocations.
"""
super().__init__()
self.prenorm = prenorm
self.fused_dropout_add_ln = fused_dropout_add_ln
self.return_residual = return_residual
self.residual_in_fp32 = residual_in_fp32
if self.residual_in_fp32:
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
assert mixer_cls is not None and mlp_cls is not None, (
"mixer_cls and mlp_cls cannot be None in SelectBlock"
)
# MHA & MLP submodules
self.mixer = mixer_cls(dim)
self.dropout1 = dropout_cls(resid_dropout1)
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
self.norm1 = norm_cls(dim)
self.mlp = mlp_cls(dim)
self.total_neurons = self.mlp.fc1.weight.shape[0]
# Routers
if mlp_router is not None:
self.mlp_router = mlp_router(dim)
self.skip_attn_router = False
else:
self.mlp_router = None
self.skip_attn_router = True
if mha_router is not None:
self.mha_router = mha_router(dim)
else:
self.mha_router = None
if not isinstance(self.mlp, nn.Identity):
self.dropout2 = dropout_cls(resid_dropout2)
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
self.norm2 = norm_cls(dim)
if self.fused_dropout_add_ln:
assert layer_norm_fn is not None, "Triton layer_norm_fn not installed"
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout)
# Mark the norm parameters for sequence parallel / shared params if needed
if sequence_parallel:
for p in self.norm1.parameters():
p._sequence_parallel = True
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._sequence_parallel = True
if mark_shared_params:
for p in self.norm1.parameters():
p._shared_params = True
if hasattr(self, "norm2"):
for p in self.norm2.parameters():
p._shared_params = True
self.mlp_topk = None
self.skip_mlp_router = False
self.skip_attn_router = False
# We'll use an extra stream for concurrency
self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
# We'll record events to coordinate concurrency
self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.use_tensor_parallel = mark_shared_params
if self.use_tensor_parallel:
# save the stream and events in the mixer and mlp classes
self.mlp.router = self.mlp_router
self.mixer.router = self.mha_router
self.mlp_topk_layers = None # this will be a dictionary of layer_idx -> topk value
self.attn_topk_layers = None # this will be a dictionary of layer_idx -> topk value
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None):
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
if dropped.shape != residual.shape:
dropped = dropped.view(residual.shape)
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
if hidden_states.shape != residual.shape:
hidden_states = hidden_states.view(residual.shape)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm),
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
""" Single GPU Decode Forward
Args:
hidden_states (Tensor): _description_
residual (Optional[Tensor], optional): _description_. Defaults to None.
mixer_subset (_type_, optional): _description_. Defaults to None.
"""
curr_stream = torch.cuda.current_stream()
# We want to run MHA & mlp_router in parallel on different streams
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
self.main_stream.wait_stream(curr_stream)
self.sparse_stream.wait_stream(curr_stream)
main_stream = self.main_stream
# if mlp_topk > th * total_neurons, skip mlp router
# if self.mlp_topk > 0.8 * self.total_neurons:
# self.skip_mlp_router = True
# else:
# self.skip_mlp_router = False
# [Sparse stream] mlp_router
if not self.skip_mlp_router:
with torch.cuda.stream(self.sparse_stream):
index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
self.sparse_stream.record_event(self.mlp_event)
# [Main stream] MHA
with torch.cuda.stream(main_stream):
batch_head_idx = self.mha_router._select_heads(router_inputs)
hidden_states = self.mixer(
hidden_states,
batch_head_idx=batch_head_idx,
**mixer_kwargs
)
main_stream.record_event(self.mha_event)
# Now we unify after both are done, then do the next steps
with torch.cuda.stream(main_stream):
# Wait on router & MHA
curr_stream.wait_stream(main_stream)
main_stream.wait_event(self.mha_event)
# normal residual / layernorm
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(
residual.to(dtype=self.norm2.weight.dtype)
)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
if hidden_states.shape != residual.shape:
hidden_states = hidden_states.view(residual.shape)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm),
)
# hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
if self.skip_mlp_router:
hidden_states = self.mlp(hidden_states, index_vec=None)
else:
curr_stream.wait_stream(self.sparse_stream)
main_stream.wait_event(self.mlp_event)
hidden_states = self.mlp(hidden_states, index_vec=index_vec)
curr_stream.wait_stream(main_stream)
curr_stream.wait_stream(self.sparse_stream)
return hidden_states, residual
def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
"""
Tensor Parallel Decode Forward
"""
curr_stream = torch.cuda.current_stream()
self.sparse_stream.wait_stream(curr_stream)
# self.main_stream.wait_stream(curr_stream)
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
if self.mlp_topk > 0.8 * self.total_neurons:
self.skip_mlp_router = True
else:
self.skip_mlp_router = False
# attention router is synchronous
batch_head_idx = self.mha_router._select_heads(router_inputs)
# mlp router is asynchronous
if not self.skip_mlp_router:
with torch.cuda.stream(self.sparse_stream):
index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
self.sparse_stream.record_event(self.mlp_event)
hidden_states = self.mixer(hidden_states, **mixer_kwargs, batch_head_idx=batch_head_idx)
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
if dropped.shape != residual.shape:
dropped = dropped.view(residual.shape)
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
if hidden_states.shape != residual.shape:
hidden_states = hidden_states.view(residual.shape)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm),
)
# curr_stream.wait_stream(self.sparse_stream)
if self.skip_mlp_router:
hidden_states = self.mlp(hidden_states, index_vec=None)
else:
curr_stream.wait_event(self.mlp_event)
hidden_states = self.mlp(hidden_states, index_vec=index_vec)
return hidden_states, residual
def attn_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
"""
Decode Forward with Sparse Attention Router
"""
# We want to run MHA & mlp_router in parallel on different streams
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
batch_head_idx = self.mha_router._select_heads(router_inputs)
# print(f"hidden_states shape: {hidden_states.shape}")
# print(f"hidden states: {hidden_states}")
hidden_states = self.mixer(hidden_states, batch_head_idx=batch_head_idx, **mixer_kwargs)
# normal residual / layernorm
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(
residual.to(dtype=self.norm2.weight.dtype)
)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(hidden_states.shape[:-1], device=hidden_states.device, dtype=hidden_states.dtype,)
)
if hidden_states.shape != residual.shape:
hidden_states = hidden_states.view(residual.shape)
hidden_states, residual = layer_norm_fn(hidden_states, self.norm2.weight, self.norm2.bias, residual=residual,
eps=self.norm2.eps, dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm),)
# hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def mlp_sparse_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
""" Single GPU Decode Forward
Args:
hidden_states (Tensor): _description_
residual (Optional[Tensor], optional): _description_. Defaults to None.
mixer_subset (_type_, optional): _description_. Defaults to None.
"""
curr_stream = torch.cuda.current_stream()
# We want to run MHA & mlp_router in parallel on different streams
router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
self.main_stream.wait_stream(curr_stream)
self.sparse_stream.wait_stream(curr_stream)
main_stream = self.main_stream
# if mlp_topk > th * total_neurons, skip mlp router
if self.mlp_topk > 0.8 * self.total_neurons:
self.skip_mlp_router = True
else:
self.skip_mlp_router = False
# [Sparse stream] mlp_router
if not self.skip_mlp_router:
with torch.cuda.stream(self.sparse_stream):
index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
self.sparse_stream.record_event(self.mlp_event)
# [Main stream] MHA
with torch.cuda.stream(main_stream):
# batch_head_idx = self.mha_router._select_heads(router_inputs)
hidden_states = self.mixer(
hidden_states,
batch_head_idx=None,
**mixer_kwargs
)
main_stream.record_event(self.mha_event)
# Now we unify after both are done, then do the next steps
with torch.cuda.stream(main_stream):
# Wait on router & MHA
curr_stream.wait_stream(main_stream)
main_stream.wait_event(self.mha_event)
# normal residual / layernorm
if mixer_subset is not None:
residual = residual[:, mixer_subset]
if not isinstance(self.mlp, nn.Identity):
if not self.fused_dropout_add_ln:
dropped = self.drop_path2(self.dropout2(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm2(
residual.to(dtype=self.norm2.weight.dtype)
)
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
if self.drop_path2.p == 0 or not self.training:
rowscale2 = None
else:
rowscale2 = self.drop_path2(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
if hidden_states.shape != residual.shape:
hidden_states = hidden_states.view(residual.shape)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
eps=self.norm2.eps,
dropout_p=self.dropout2.p if self.training else 0.0,
rowscale=rowscale2,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm2, RMSNorm),
)
# hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
if self.skip_mlp_router:
hidden_states = self.mlp(hidden_states, index_vec=None)
else:
curr_stream.wait_stream(self.sparse_stream)
main_stream.wait_event(self.mlp_event)
hidden_states = self.mlp(hidden_states, index_vec=index_vec)
curr_stream.wait_stream(main_stream)
curr_stream.wait_stream(self.sparse_stream)
return hidden_states, residual
def forward(
self,
hidden_states: Tensor,
residual: Optional[Tensor] = None,
mixer_subset=None,
mixer_kwargs=None,
mlp_topk=None,
attn_topk=None,
):
"""
This forward pass includes concurrency logic in the decode branch.
If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be
inside the captured region so that the replay reproduces the parallel streams.
"""
# simulation values
if mlp_topk is not None:
self.mlp_topk = mlp_topk
if attn_topk is not None:
self.mha_router.topk = attn_topk
if mixer_kwargs is None:
mixer_kwargs = {"inference_params": None}
else:
# Ensure 'inference_params' key exists
if "inference_params" not in mixer_kwargs:
mixer_kwargs["inference_params"] = None
if self.prenorm:
# --- 1) Prenorm’s dropout/add/layernorm
if not self.fused_dropout_add_ln:
dropped = self.drop_path1(self.dropout1(hidden_states))
residual = (dropped + residual) if residual is not None else dropped
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
# fused dropout + add + layernorm
if self.drop_path1.p == 0 or not self.training:
rowscale1 = None
else:
rowscale1 = self.drop_path1(
torch.ones(
hidden_states.shape[:-1],
device=hidden_states.device,
dtype=hidden_states.dtype,
)
)
if residual is not None and hidden_states.shape != residual.shape:
hidden_states = hidden_states.view(residual.shape)
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm1.weight,
self.norm1.bias,
residual=residual,
eps=self.norm1.eps,
dropout_p=self.dropout1.p if self.training else 0.0,
rowscale=rowscale1,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
is_rms_norm=isinstance(self.norm1, RMSNorm),
)
if mixer_subset is not None:
mixer_kwargs["mixer_subset"] = mixer_subset
# Check if we are in the prefill or decode stage
prefill_stage = (
mixer_kwargs["inference_params"] is None
or mixer_kwargs["inference_params"].seqlen_offset == 0
)
if prefill_stage:
# --- 2) Prefill stage (no concurrency): just do normal forward
hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset)
else:
# --- 3) Decode stage:
if self.mlp_router is None:
# decode stage with only attention router, works with both single gpu and tensor parallel
hidden_states, residual = self.attn_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
else:
if not self.use_tensor_parallel:
if self.mha_router is None:
# decode stage with mlp routers (opt models and single gpu)
hidden_states, residual = self.mlp_sparse_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
else:
# decode stage with mlp and attention routers (opt models and single gpu)
hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
else:
# uses both mlp and attention routers in tensor parallel
hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
return hidden_states, residual
else:
# post-norm architecture not implemented here
raise NotImplementedError
# class SelectBlock(nn.Module):
# def __init__(
# self,
# dim,
# mixer_cls=None,
# mlp_cls=None,
# mlp_router=None,
# mha_router=None,
# norm_cls=nn.LayerNorm,
# dropout_cls=nn.Dropout,
# prenorm=True,
# resid_dropout1=0.0,
# resid_dropout2=0.0,
# drop_path1=0.0,
# drop_path2=0.0,
# fused_dropout_add_ln=False,
# return_residual=False,
# residual_in_fp32=False,
# sequence_parallel=False,
# mark_shared_params=False,
# ):
# """
# For prenorm=True, this Block has a slightly different structure compared to a regular
# prenorm Transformer block.
# The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
# Here we do: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, etc.
# If you want to do concurrency with CUDA graphs, your shapes must remain fixed
# (batch_size, seq_len, etc.) across captures and replays. Also avoid any operations
# that cause dynamic shape changes or memory allocations.
# """
# super().__init__()
# self.prenorm = prenorm
# self.fused_dropout_add_ln = fused_dropout_add_ln
# self.return_residual = return_residual
# self.residual_in_fp32 = residual_in_fp32
# if self.residual_in_fp32:
# assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
# assert mixer_cls is not None and mlp_cls is not None, (
# "mixer_cls and mlp_cls cannot be None in SelectBlock"
# )
# # MHA & MLP submodules
# self.mixer = mixer_cls(dim)
# self.dropout1 = dropout_cls(resid_dropout1)
# self.drop_path1 = StochasticDepth(drop_path1, mode="row")
# self.norm1 = norm_cls(dim)
# self.mlp = mlp_cls(dim)
# # Routers
# self.mlp_router = mlp_router(dim)
# self.mha_router = mha_router(dim)
# if not isinstance(self.mlp, nn.Identity):
# self.dropout2 = dropout_cls(resid_dropout2)
# self.drop_path2 = StochasticDepth(drop_path2, mode="row")
# self.norm2 = norm_cls(dim)
# if self.fused_dropout_add_ln:
# assert layer_norm_fn is not None, "Triton layer_norm_fn not installed"
# assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(self.dropout1, nn.Dropout)
# # Mark the norm parameters for sequence parallel / shared params if needed
# if sequence_parallel:
# for p in self.norm1.parameters():
# p._sequence_parallel = True
# if hasattr(self, "norm2"):
# for p in self.norm2.parameters():
# p._sequence_parallel = True
# if mark_shared_params:
# for p in self.norm1.parameters():
# p._shared_params = True
# if hasattr(self, "norm2"):
# for p in self.norm2.parameters():
# p._shared_params = True
# self.mlp_topk = None
# self.skip_mlp_router = False
# self.skip_attn_router = False
# # We'll use an extra stream for concurrency
# self.sparse_stream = torch.cuda.Stream(device="cuda", priority=0)
# self.main_stream = torch.cuda.Stream(device="cuda", priority=-5)
# # We'll record events to coordinate concurrency
# self.mha_event = torch.cuda.Event(enable_timing=False, blocking=False)
# self.mlp_event = torch.cuda.Event(enable_timing=False, blocking=False)
# self.use_tensor_parallel = mark_shared_params
# if self.use_tensor_parallel:
# # TODO: save the routers in the mixer and mlp classes
# # save the stream and events in the mixer and mlp classes
# pass
# def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
# return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
# def prefill_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_kwargs=None, mixer_subset=None):
# hidden_states = self.mixer(hidden_states, **mixer_kwargs)
# if mixer_subset is not None:
# residual = residual[:, mixer_subset]
# if not isinstance(self.mlp, nn.Identity):
# if not self.fused_dropout_add_ln:
# dropped = self.drop_path2(self.dropout2(hidden_states))
# if dropped.shape != residual.shape:
# dropped = dropped.view(residual.shape)
# residual = (dropped + residual) if residual is not None else dropped
# hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
# if self.residual_in_fp32:
# residual = residual.to(torch.float32)
# else:
# if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(
# torch.ones(
# hidden_states.shape[:-1],
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# )
# if hidden_states.shape != residual.shape:
# hidden_states = hidden_states.view(residual.shape)
# hidden_states, residual = layer_norm_fn(
# hidden_states,
# self.norm2.weight,
# self.norm2.bias,
# residual=residual,
# eps=self.norm2.eps,
# dropout_p=self.dropout2.p if self.training else 0.0,
# rowscale=rowscale2,
# prenorm=True,
# residual_in_fp32=self.residual_in_fp32,
# is_rms_norm=isinstance(self.norm2, RMSNorm),
# )
# hidden_states = self.mlp(hidden_states)
# return hidden_states, residual
# def decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
# """ Single GPU Decode Forward
# Args:
# hidden_states (Tensor): _description_
# residual (Optional[Tensor], optional): _description_. Defaults to None.
# mixer_subset (_type_, optional): _description_. Defaults to None.
# """
# curr_stream = torch.cuda.current_stream()
# # We want to run MHA & mlp_router in parallel on different streams
# router_inputs = hidden_states.squeeze(1) # shape (batch_size, dim)
# self.main_stream.wait_stream(curr_stream)
# self.sparse_stream.wait_stream(curr_stream)
# # We'll do MHA on the "main_stream" and the router on "sparse_stream"
# main_stream = self.main_stream
# # In a captured region, each 'with torch.cuda.stream(...)' block
# # is replayed in concurrency. The shape must remain consistent.
# # [Sparse stream] mlp_router
# if not self.skip_mlp_router:
# with torch.cuda.stream(self.sparse_stream): # <-- CHANGED
# # index_size, index_vec = self.mlp_router._select_neurons_cuda_safe(router_inputs) # need to fix this; make CUDA Graph safe
# # vec = self.mlp_router(router_inputs)
# index_vec = self.mlp_router._select_neurons_topk(router_inputs, topk = self.mlp_topk)
# self.sparse_stream.record_event(self.mlp_event)
# # [Main stream] MHA
# with torch.cuda.stream(main_stream): # <-- CHANGED
# batch_head_idx = self.mha_router._select_heads(router_inputs)
# hidden_states = self.mixer(
# hidden_states,
# batch_head_idx=batch_head_idx,
# # batch_head_idx=None,
# **mixer_kwargs
# )
# main_stream.record_event(self.mha_event)
# # Now we unify after both are done, then do the next steps
# with torch.cuda.stream(main_stream): # <-- CHANGED
# # Wait on router & MHA
# curr_stream.wait_stream(main_stream)
# main_stream.wait_event(self.mha_event)
# # normal residual / layernorm
# if mixer_subset is not None:
# residual = residual[:, mixer_subset]
# if not isinstance(self.mlp, nn.Identity):
# if not self.fused_dropout_add_ln:
# dropped = self.drop_path2(self.dropout2(hidden_states))
# residual = (dropped + residual) if residual is not None else dropped
# hidden_states = self.norm2(
# residual.to(dtype=self.norm2.weight.dtype)
# )
# if self.residual_in_fp32:
# residual = residual.to(torch.float32)
# else:
# if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(
# torch.ones(
# hidden_states.shape[:-1],
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# )
# if hidden_states.shape != residual.shape:
# hidden_states = hidden_states.view(residual.shape)
# hidden_states, residual = layer_norm_fn(
# hidden_states,
# self.norm2.weight,
# self.norm2.bias,
# residual=residual,
# eps=self.norm2.eps,
# dropout_p=self.dropout2.p if self.training else 0.0,
# rowscale=rowscale2,
# prenorm=True,
# residual_in_fp32=self.residual_in_fp32,
# is_rms_norm=isinstance(self.norm2, RMSNorm),
# )
# # Finally do MLP with the router's index vector
# curr_stream.wait_stream(self.sparse_stream)
# main_stream.wait_event(self.mlp_event)
# # hidden_states = self.mlp(hidden_states, index_vec=test_index_vec, index_size=test_index_size)
# if self.skip_mlp_router:
# hidden_states = self.mlp(hidden_states, index_vec=None)
# else:
# hidden_states = self.mlp(hidden_states, index_vec=index_vec)
# curr_stream.wait_stream(main_stream)
# curr_stream.wait_stream(self.sparse_stream)
# return hidden_states, residual
# def tp_decode_forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, mixer_subset=None, mixer_kwargs=None):
# """
# Tensor Parallel Decode Forward
# Args:
# hidden_states (Tensor): _description_
# residual (Optional[Tensor], optional): _description_. Defaults to None.
# mixer_subset (_type_, optional): _description_. Defaults to None.
# """
# # TODO: need to add routing
# hidden_states = self.mixer(hidden_states, **mixer_kwargs)
# if mixer_subset is not None:
# residual = residual[:, mixer_subset]
# if not isinstance(self.mlp, nn.Identity):
# if not self.fused_dropout_add_ln:
# dropped = self.drop_path2(self.dropout2(hidden_states))
# if dropped.shape != residual.shape:
# dropped = dropped.view(residual.shape)
# residual = (dropped + residual) if residual is not None else dropped
# hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
# if self.residual_in_fp32:
# residual = residual.to(torch.float32)
# else:
# if self.drop_path2.p == 0 or not self.training:
# rowscale2 = None
# else:
# rowscale2 = self.drop_path2(
# torch.ones(
# hidden_states.shape[:-1],
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# )
# if hidden_states.shape != residual.shape:
# hidden_states = hidden_states.view(residual.shape)
# hidden_states, residual = layer_norm_fn(
# hidden_states,
# self.norm2.weight,
# self.norm2.bias,
# residual=residual,
# eps=self.norm2.eps,
# dropout_p=self.dropout2.p if self.training else 0.0,
# rowscale=rowscale2,
# prenorm=True,
# residual_in_fp32=self.residual_in_fp32,
# is_rms_norm=isinstance(self.norm2, RMSNorm),
# )
# hidden_states = self.mlp(hidden_states)
# return hidden_states, residual
# def forward(
# self,
# hidden_states: Tensor,
# residual: Optional[Tensor] = None,
# mixer_subset=None,
# mixer_kwargs=None,
# ):
# """
# This forward pass includes concurrency logic in the decode branch.
# If you're capturing with a CUDA graph, the concurrency (two-stream usage) must be
# inside the captured region so that the replay reproduces the parallel streams.
# """
# if mixer_kwargs is None:
# mixer_kwargs = {"inference_params": None}
# else:
# # Ensure 'inference_params' key exists
# if "inference_params" not in mixer_kwargs:
# mixer_kwargs["inference_params"] = None
# if self.prenorm:
# # --- 1) Prenorm’s dropout/add/layernorm
# if not self.fused_dropout_add_ln:
# dropped = self.drop_path1(self.dropout1(hidden_states))
# residual = (dropped + residual) if residual is not None else dropped
# hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
# if self.residual_in_fp32:
# residual = residual.to(torch.float32)
# else:
# # fused dropout + add + layernorm
# if self.drop_path1.p == 0 or not self.training:
# rowscale1 = None
# else:
# rowscale1 = self.drop_path1(
# torch.ones(
# hidden_states.shape[:-1],
# device=hidden_states.device,
# dtype=hidden_states.dtype,
# )
# )
# if residual is not None and hidden_states.shape != residual.shape:
# hidden_states = hidden_states.view(residual.shape)
# hidden_states, residual = layer_norm_fn(
# hidden_states,
# self.norm1.weight,
# self.norm1.bias,
# residual=residual,
# eps=self.norm1.eps,
# dropout_p=self.dropout1.p if self.training else 0.0,
# rowscale=rowscale1,
# prenorm=True,
# residual_in_fp32=self.residual_in_fp32,
# is_rms_norm=isinstance(self.norm1, RMSNorm),
# )
# if mixer_subset is not None:
# mixer_kwargs["mixer_subset"] = mixer_subset
# # Check if we are in the prefill or decode stage
# prefill_stage = (
# mixer_kwargs["inference_params"] is None
# or mixer_kwargs["inference_params"].seqlen_offset == 0
# )
# if prefill_stage:
# # --- 2) Prefill stage (no concurrency): just do normal forward
# hidden_states, residual = self.prefill_forward(hidden_states, residual, mixer_kwargs, mixer_subset)
# else:
# # # --- 3) Decode stage:
# if not self.use_tensor_parallel:
# hidden_states, residual = self.decode_forward(hidden_states, residual, mixer_subset, mixer_kwargs)
# else:
# # routing is slightly different in tensor parallel; we overlap the router with allreduce
# hidden_states, residual = self.tp_decode_forward(hidden_states, residual, mixer_subset)
# return hidden_states, residual
# else:
# # post-norm architecture not implemented here
# raise NotImplementedError