| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Callable, ClassVar, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.nn.attention.flex_attention import ( |
| BlockMask, |
| create_block_mask, |
| flex_attention, |
| ) |
|
|
|
|
| class FlexAttention(torch.nn.Module): |
| |
| |
| flex_attn: ClassVar[Callable] = torch.compile( |
| flex_attention, mode="max-autotune-no-cudagraphs" |
| ) |
| compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) |
| used_attn_mask_types: ClassVar[set[str]] = set() |
| |
| |
| |
| |
| |
| block_masks: ClassVar[dict[str, BlockMask]] = {} |
|
|
| |
| attn_mask_type: str |
|
|
| def __init__(self, attn_mask_type: str) -> None: |
| super().__init__() |
| if attn_mask_type not in ["causal", "block_causal"]: |
| raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") |
| self.attn_mask_type = attn_mask_type |
| FlexAttention.used_attn_mask_types.add(attn_mask_type) |
|
|
| def forward( |
| self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| ) -> torch.Tensor: |
| block_mask = FlexAttention.block_masks[self.attn_mask_type] |
| return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) |
|
|
| @staticmethod |
| def _get_causal_mask_fn() -> Callable: |
| def causal_mask(b, h, q_idx, kv_idx): |
| return q_idx >= kv_idx |
|
|
| return causal_mask |
|
|
| @staticmethod |
| def _get_block_causal_mask_fn(batch: torch.Tensor, eos_id: int) -> Callable: |
| |
| mask = batch == eos_id |
| mask[:, -1] = True |
| acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) |
| seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) |
| seq_idx[:, 1:] = acc_mask[:, :-1] |
|
|
| def block_causal_mask(b, h, q_idx, kv_idx): |
| return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) |
|
|
| return block_causal_mask |
|
|
| @staticmethod |
| @torch.no_grad() |
| def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: |
| |
| for attn_mask_type in FlexAttention.used_attn_mask_types: |
| match attn_mask_type: |
| case "causal": |
| if FlexAttention.block_masks.get(attn_mask_type, None) is not None: |
| continue |
| |
| |
| batch_dimension = 1 |
| mask_fn = FlexAttention._get_causal_mask_fn() |
| case "block_causal": |
| if eos_id is None: |
| raise RuntimeError( |
| "eos_id must be provided for block_causal mask." |
| ) |
| batch_dimension = batch.shape[0] |
| mask_fn = FlexAttention._get_block_causal_mask_fn(batch, eos_id) |
| case _: |
| raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") |
|
|
| seq_len = batch.shape[1] |
| block_mask = FlexAttention.compiled_create_block_mask( |
| mask_fn, batch_dimension, None, seq_len, seq_len |
| ) |
| FlexAttention.block_masks[attn_mask_type] = block_mask |
|
|
|
|
| class ScaledDotProductAttention(torch.nn.Module): |
| def __init__(self, attn_mask_type: str) -> None: |
| super().__init__() |
| if attn_mask_type != "causal": |
| raise ValueError( |
| "TorchTitan with SDPA currently only supports causal mask." |
| ) |
|
|
| def forward( |
| self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| ) -> torch.Tensor: |
| return F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
|
|
| def build_attention(use_flex_attn: bool, attn_mask_type: str): |
| if use_flex_attn: |
| return FlexAttention(attn_mask_type) |
| else: |
| return ScaledDotProductAttention(attn_mask_type) |
|
|
|
|
| def init_attention_mask(batch: torch.Tensor, eos_id: Optional[int] = None) -> None: |
| FlexAttention.init_attention_mask(batch, eos_id) |
|
|