| """Full MoSRAH sparse path for SHRAM. |
| |
| This module coordinates the routed sparse attention path used inside the SHRAM |
| hybrid attention layer. The underlying mechanics already live in verified |
| subunits. The responsibility here is to connect those subunits without |
| corrupting their bridge contracts. |
| |
| In particular, this path must preserve three architectural distinctions: |
| |
| - selected head indices are not routing probabilities |
| - packed position semantics are chosen before BEA, not inside it |
| - weighted reduction must consume the router's unbiased renormalized |
| probabilities after token-choice order has been restored |
| """ |
|
|
| import torch |
| from torch import nn |
|
|
| from .__cache__mosrah_cache import MoSRAHCache |
| from .configuration import ShramConfig |
| from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention |
| from .__attention__expert_packing import ( |
| pack_experts, |
| setup_packing, |
| unpack_experts, |
| ) |
| from .__attention__router import MoSRAHRouter |
| from .__attention__positions_converter import SparseMoSRAHPositions |
|
|
|
|
| class MoSRAHLayer(nn.Module): |
| """Full routed sparse attention path for SHRAM. |
| |
| The MoSRAH path consumes model-space hidden states together with |
| authoritative per-token positions and returns the model-space sparse-path |
| contribution, the router's load-balance loss, and the router's MaxVio |
| routing-imbalance scalar. |
| """ |
|
|
| def __init__(self, config: ShramConfig) -> None: |
| super().__init__() |
| self.num_experts = config.num_mosrah_heads |
|
|
| self.router = MoSRAHRouter(config) |
| self.positions = SparseMoSRAHPositions(config) |
| self.bea = BottleneckedEnsembleAttention(config) |
|
|
| def num_mosrah_parameters(self) -> int: |
| """Return the total number of trainable parameters in this MoSRAH layer.""" |
| return sum(p.numel() for p in self.parameters()) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_ids: torch.Tensor, |
| active_mask: torch.Tensor, |
| cache: MoSRAHCache | None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Run the full MoSRAH sparse path. |
| |
| Args: |
| hidden_states: Model-space hidden states x of shape (B, N, d). |
| position_ids: Authoritative per-token positions of shape (B, N). |
| active_mask: Current-chunk active mask of shape (B, N), where True |
| means the token is semantically live. Forwarded to the router |
| so dead tokens are excluded from routing statistics, and to |
| pack_experts so dead outer tokens do not become semantically |
| active packed entries. |
| cache: Optional layer-local MoSRAH cache. Pass None for uncached |
| execution and the layer-local cache instance for cached execution. |
| |
| Returns: |
| sparse_output: Model-space sparse-path output of shape (B, N, d). |
| load_balance_loss: Scalar router load-balance loss. |
| max_vio: Detached scalar routing-imbalance summary. Passed through |
| unchanged from the router; see MoSRAHRouter for semantics. |
| """ |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| selected_heads, routing_probs, load_balance_loss, max_vio = self.router( |
| hidden_states, active_mask |
| ) |
|
|
| flattened_selected_heads, permutation, inverse_permutation = setup_packing( |
| selected_heads |
| ) |
| packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts( |
| hidden_states=hidden_states, |
| position_ids=position_ids, |
| selected_heads=selected_heads, |
| num_experts=self.num_experts, |
| flattened_selected_heads=flattened_selected_heads, |
| permutation=permutation, |
| outer_active_mask=active_mask, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| bea_positions = self.positions( |
| packed_positions=packed_positions, |
| cache=cache, |
| ) |
| packed_outputs = self.bea( |
| packed_embeddings=packed_hidden_states, |
| position_ids=bea_positions, |
| active_mask=active_mask, |
| cache=cache, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| token_choice_outputs = unpack_experts( |
| expert_outputs=packed_outputs, |
| selected_heads=selected_heads, |
| unpacking_mask=unpacking_mask, |
| inverse_permutation=inverse_permutation, |
| ) |
| final_output = ( |
| token_choice_outputs * routing_probs.unsqueeze(-1) |
| ).sum(dim=2) |
|
|
| return final_output, load_balance_loss, max_vio |
|
|