SHRAM / __attention__sliding_window_attention.py
smithblack-0's picture
Update architecture and tokenizer
a86502d verified
# src/shram/model/attention/sliding_window_attention.py
"""Local sliding-window attention path for SHRAM.
This file defines `SlidingWindowAttention`, the local short-range attention path
used inside the SHRAM hybrid layer.
In the masked-continuation variant, the local cache no longer returns a
semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
returns:
- the retained local window memory concatenated with the current chunk
- an aligned active mask over that returned frame
This module consumes that returned frame directly and constructs effective local
causal/window visibility from the mask. It does not own cache retention policy;
it owns only local attention semantics.
"""
import math
from typing import Any
import torch
import torch.nn as nn
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
from .configuration import ShramConfig
from .rope import RotaryEmbedding
class SlidingWindowAttention(nn.Module):
"""Causal local sliding-window attention for one SHRAM layer.
Args:
config: SHRAM config. Must expose `hidden_size`,
`num_sliding_window_heads`, `head_dim`, `window_size`,
`attention_dropout`, and `local_rope_theta`.
Raises:
NotImplementedError: If `attention_dropout != 0.0`.
"""
def __init__(self, config: ShramConfig) -> None:
super().__init__()
self.hidden_size = config.embedding_width
self.num_heads = config.num_sliding_window_heads
self.head_dim = config.head_dim
self.window_size = config.window_size
self.attention_dropout = config.attention_dropout
if self.attention_dropout != 0.0:
raise NotImplementedError(
"SlidingWindowAttention currently supports only "
"attention_dropout == 0.0."
)
self.inner_dim = self.num_heads * self.head_dim
# Standard MHA projections for the local path.
self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
# The local path always uses default-mode RoPE with its own theta.
self.rope = RotaryEmbedding(
mode="default",
head_dim=self.head_dim,
theta=config.local_rope_theta,
maximum_sequence_length=config.inference_sequence_length,
)
def forward(
self,
x: torch.Tensor,
position_ids: torch.Tensor,
active_mask: torch.Tensor,
cache: LocalSlidingWindowLayerCache | None = None,
) -> torch.Tensor:
"""Apply local causal sliding-window attention.
Args:
x: Input tensor of shape `(B, N, hidden_size)`.
position_ids: Position tensor of shape `(B, N)`.
active_mask: Current-chunk active mask of shape `(B, N)`, where
`True` means active.
cache: Optional `LocalSlidingWindowLayerCache`.
Returns:
Output tensor of shape `(B, N, hidden_size)`.
"""
batch_size, query_len, _ = x.shape
self._validate_position_shape(x, position_ids)
self._validate_active_mask_shape(x, active_mask)
# (B, N, H*D) -> (B, H, N, D)
q = self.q_proj(x).view(
batch_size,
query_len,
self.num_heads,
self.head_dim,
).transpose(1, 2)
k = self.k_proj(x).view(
batch_size,
query_len,
self.num_heads,
self.head_dim,
).transpose(1, 2)
v = self.v_proj(x).view(
batch_size,
query_len,
self.num_heads,
self.head_dim,
).transpose(1, 2)
q, k, attention_scaling = self.rope(q, k, position_ids)
# The cache returns the current-step visible local frame, not merely the
# retained next-step cache buffer.
if cache is not None:
k_full, v_full, full_active_mask, full_positions = cache.update(
k, v, active_mask, position_ids
)
else:
k_full, v_full, full_active_mask, full_positions = k, v, active_mask, position_ids
block_mask = self._make_block_mask(
active_mask=full_active_mask,
positions=full_positions,
batch_size=batch_size,
num_heads=self.num_heads,
query_len=query_len,
kv_len=k_full.shape[-2],
window_size=self.window_size,
device=x.device,
)
attn_output = flex_attention(
q,
k_full,
v_full,
block_mask=block_mask,
scale=attention_scaling / math.sqrt(self.head_dim),
)
# (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
attn_output = (
attn_output.transpose(1, 2)
.contiguous()
.view(batch_size, query_len, self.inner_dim)
)
return self.o_proj(attn_output)
def _validate_position_shape(
self,
x: torch.Tensor,
position_ids: torch.Tensor,
) -> None:
"""Validate the position tensor shape expected by local RoPE."""
if position_ids.shape != x.shape[:2]:
raise ValueError(
f"position_ids must have shape {tuple(x.shape[:2])}, "
f"got {tuple(position_ids.shape)}."
)
def _validate_active_mask_shape(
self,
x: torch.Tensor,
active_mask: torch.Tensor,
) -> None:
"""Validate the current-chunk active-mask contract."""
if active_mask.shape != x.shape[:2]:
raise ValueError(
f"active_mask must have shape {tuple(x.shape[:2])}, "
f"got {tuple(active_mask.shape)}."
)
if active_mask.dtype != torch.bool:
raise ValueError(
f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
)
def _make_block_mask(
self,
active_mask: torch.Tensor,
positions: torch.Tensor,
batch_size: int,
num_heads: int,
query_len: int,
kv_len: int,
window_size: int,
device: torch.device,
) -> Any:
"""Create the FlexAttention block mask for masked local continuation.
The returned local frame is chronological in raw buffer order; dead
positions may remain inside it. Liveness is carried by `active_mask`.
Causality and window distance are determined from `positions`, which
holds the absolute sequence position of every slot in the composite
frame. Using absolute positions rather than a cumsum over the active
mask eliminates the data-dependent computation that blocks torch.compile.
"""
query_offset = kv_len - query_len
def sliding_window_mask(
batch_idx: torch.Tensor,
head_idx: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
) -> torch.Tensor:
q_abs = query_offset + q_idx
query_is_active = active_mask[batch_idx, q_abs]
key_is_active = active_mask[batch_idx, kv_idx]
q_pos = positions[batch_idx, q_abs]
k_pos = positions[batch_idx, kv_idx]
is_causal = k_pos <= q_pos
in_window = (q_pos - k_pos) < window_size
return query_is_active & key_is_active & is_causal & in_window
return create_block_mask(
sliding_window_mask,
B=batch_size,
H=num_heads,
Q_LEN=query_len,
KV_LEN=kv_len,
device=device,
)