zaydzuhri commited on
Commit
0491866
·
verified ·
1 Parent(s): f80ddb2

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. fla/ops/based/naive.py +72 -0
  2. fla/ops/lightning_attn/chunk.py +74 -0
  3. fla/ops/linear_attn/chunk.py +65 -0
  4. fla/ops/nsa/naive.py +94 -0
  5. fla/ops/nsa/parallel.py +1435 -0
  6. fla/ops/retention/naive.py +15 -0
  7. profile_trace/iteration_11264/rank6_trace.json +0 -0
  8. profile_trace/iteration_1536/rank2_trace.json +0 -0
  9. profile_trace/iteration_18432/rank0_trace.json +0 -0
  10. profile_trace/iteration_23552/rank0_trace.json +0 -0
  11. profile_trace/iteration_23552/rank1_trace.json +0 -0
  12. profile_trace/iteration_23552/rank4_trace.json +0 -0
  13. profile_trace/iteration_23552/rank6_trace.json +0 -0
  14. profile_trace/iteration_2560/rank7_trace.json +0 -0
  15. profile_trace/iteration_27648/rank0_trace.json +0 -0
  16. profile_trace/iteration_27648/rank3_trace.json +0 -0
  17. profile_trace/iteration_27648/rank5_trace.json +0 -0
  18. profile_trace/iteration_27648/rank7_trace.json +0 -0
  19. profile_trace/iteration_29696/rank0_trace.json +0 -0
  20. profile_trace/iteration_29696/rank1_trace.json +0 -0
  21. profile_trace/iteration_29696/rank4_trace.json +0 -0
  22. profile_trace/iteration_30720/rank0_trace.json +0 -0
  23. profile_trace/iteration_30720/rank1_trace.json +0 -0
  24. profile_trace/iteration_30720/rank6_trace.json +0 -0
  25. profile_trace/iteration_30720/rank7_trace.json +0 -0
  26. profile_trace/iteration_31744/rank1_trace.json +0 -0
  27. profile_trace/iteration_36864/rank0_trace.json +0 -0
  28. profile_trace/iteration_36864/rank3_trace.json +0 -0
  29. profile_trace/iteration_37888/rank0_trace.json +0 -0
  30. profile_trace/iteration_37888/rank5_trace.json +0 -0
  31. torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc +0 -0
  32. torchtitan/components/__pycache__/optimizer.cpython-312.pyc +0 -0
  33. torchtitan/experiments/__pycache__/__init__.cpython-312.pyc +0 -0
  34. torchtitan/experiments/deepseek_v3/checkpoint.py +154 -0
  35. torchtitan/experiments/deepseek_v3/download.py +70 -0
  36. torchtitan/experiments/deepseek_v3/model.py +1325 -0
  37. torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +63 -0
  38. torchtitan/experiments/flux/model/model.py +177 -0
  39. torchtitan/experiments/flux/parallelize_flux.py +26 -0
  40. torchtitan/experiments/flux/tests/test_generate_image.py +252 -0
  41. torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py +885 -0
  42. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py +13 -0
  43. torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py +1304 -0
  44. torchtitan/experiments/llama4/__init__.py +70 -0
  45. torchtitan/experiments/llama4/infra/parallelize_llama.py +159 -0
  46. torchtitan/experiments/llama4/model/moe.py +228 -0
  47. torchtitan/experiments/llama4/train_configs/debug_model.toml +74 -0
  48. torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +65 -0
  49. torchtitan/models/__pycache__/attention.cpython-312.pyc +0 -0
  50. torchtitan/models/llama3/train_configs/llama3_405b.toml +63 -0
fla/ops/based/naive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from einops import rearrange
7
+
8
+
9
+ def naive_parallel_based(
10
+ q: torch.Tensor,
11
+ k: torch.Tensor,
12
+ v: torch.Tensor,
13
+ scale: Optional[float] = None,
14
+ use_norm: bool = True
15
+ ):
16
+ if scale is None:
17
+ scale = q.shape[-1] ** -0.5
18
+ q = q * scale
19
+ attn = q @ k.transpose(-2, -1)
20
+ attn = 1 + attn + 1/2 * (attn ** 2)
21
+ attn.masked_fill_(~torch.tril(torch.ones(
22
+ q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0)
23
+ o = attn @ v
24
+ if use_norm:
25
+ z = attn.sum(-1)
26
+ return o / (z[..., None] + 1e-6)
27
+ else:
28
+ return o
29
+
30
+
31
+ def naive_chunk_based(q, k, v, chunk_size=256):
32
+ q = q * (q.shape[-1] ** -0.5)
33
+ # compute normalizer.
34
+ k_cumsum = torch.cumsum(k, dim=-2)
35
+ kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3)
36
+ # first
37
+ z = (q * k_cumsum).sum(-1)
38
+ # second order
39
+ z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5
40
+ # zero-th order
41
+ z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :]
42
+
43
+ # compute o
44
+ # constant term
45
+ _o = v.cumsum(-2)
46
+
47
+ q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size)
48
+
49
+ k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
50
+ v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size)
51
+
52
+ intra_chunk_attn = q @ k.transpose(-2, -1)
53
+ intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2)
54
+ intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0)
55
+ o = intra_chunk_attn @ v
56
+
57
+ # quadractic term
58
+ kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v)
59
+ kv = kv.cumsum(2)
60
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
61
+
62
+ o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q)
63
+
64
+ # linear term
65
+ kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v)
66
+ kv = kv.cumsum(2)
67
+ kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2)
68
+ o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q)
69
+
70
+ o = rearrange(o, 'b h n c d -> b h (n c) d')
71
+ o = o + _o
72
+ return o / (z[..., None] + 1e-6)
fla/ops/lightning_attn/chunk.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.simple_gla.chunk import chunk_simple_gla
9
+
10
+
11
+ @torch.compiler.disable
12
+ def chunk_lightning_attn(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ v: torch.Tensor,
16
+ layer_idx: int,
17
+ num_layers: int,
18
+ scale: Optional[float] = None,
19
+ initial_state: Optional[torch.Tensor] = None,
20
+ output_final_state: bool = False,
21
+ cu_seqlens: Optional[torch.LongTensor] = None,
22
+ head_first: bool = True
23
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
24
+ r"""
25
+ Args:
26
+ q (torch.Tensor):
27
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
28
+ k (torch.Tensor):
29
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
30
+ v (torch.Tensor):
31
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
32
+ layer_idx (int):
33
+ The index of the current layer.
34
+ num_layers (int):
35
+ The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor.
36
+ scale (Optional[int]):
37
+ Scale factor for the attention scores.
38
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
39
+ initial_state (Optional[torch.Tensor]):
40
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
41
+ For equal-length input sequences, `N` equals the batch size `B`.
42
+ Default: `None`.
43
+ output_final_state (Optional[bool]):
44
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
45
+ cu_seqlens (torch.LongTensor):
46
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
47
+ consistent with the FlashAttention API.
48
+ head_first (Optional[bool]):
49
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
50
+ Default: `True`.
51
+
52
+ Returns:
53
+ o (torch.Tensor):
54
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
55
+ final_state (torch.Tensor):
56
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
57
+ """
58
+ H = q.shape[1] if head_first else q.shape[2]
59
+ s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float)
60
+ if head_first:
61
+ g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
62
+ else:
63
+ g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous()
64
+ return chunk_simple_gla(
65
+ q=q,
66
+ k=k,
67
+ v=v,
68
+ scale=scale,
69
+ g=g,
70
+ initial_state=initial_state,
71
+ output_final_state=output_final_state,
72
+ head_first=head_first,
73
+ cu_seqlens=cu_seqlens
74
+ )
fla/ops/linear_attn/chunk.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Yu Zhang, Songlin Yang
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from fla.ops.linear_attn.utils import normalize_output
9
+ from fla.ops.simple_gla import chunk_simple_gla
10
+
11
+
12
+ @torch.compiler.disable
13
+ def chunk_linear_attn(
14
+ q: torch.Tensor,
15
+ k: torch.Tensor,
16
+ v: torch.Tensor,
17
+ scale: Optional[float] = None,
18
+ initial_state: Optional[torch.Tensor] = None,
19
+ output_final_state: bool = False,
20
+ normalize: bool = True,
21
+ head_first: bool = True
22
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
23
+ r"""
24
+ Args:
25
+ q (torch.Tensor):
26
+ queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
27
+ k (torch.Tensor):
28
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`
29
+ v (torch.Tensor):
30
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
31
+ scale (Optional[int]):
32
+ Scale factor for the linear attention scores.
33
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
34
+ initial_state (Optional[torch.Tensor]):
35
+ Initial state of shape `[B, H, K, V]`. Default: `None`.
36
+ output_final_state (Optional[bool]):
37
+ Whether to output the final state of shape `[B, H, K, V]`. Default: `False`.
38
+ normalize (bool):
39
+ Whether to normalize the output. Default: `True`.
40
+ head_first (Optional[bool]):
41
+ Whether the inputs are in the head-first format. Default: `True`.
42
+
43
+ Returns:
44
+ o (torch.Tensor):
45
+ Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`
46
+ final_state (torch.Tensor):
47
+ Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`
48
+ """
49
+
50
+ if scale is None:
51
+ scale = k.shape[-1] ** -0.5
52
+
53
+ o, final_state = chunk_simple_gla(
54
+ q=q,
55
+ k=k,
56
+ v=v,
57
+ scale=scale,
58
+ g=None,
59
+ initial_state=initial_state,
60
+ output_final_state=output_final_state,
61
+ head_first=head_first
62
+ )
63
+ if normalize:
64
+ o = normalize_output(q * scale, k, o)
65
+ return o, final_state
fla/ops/nsa/naive.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ def naive_nsa(
11
+ q: torch.Tensor,
12
+ k: torch.Tensor,
13
+ v: torch.Tensor,
14
+ indices: torch.LongTensor,
15
+ block_size: int = 64,
16
+ scale: Optional[float] = None,
17
+ head_first: bool = False,
18
+ cu_seqlens: Optional[torch.LongTensor] = None
19
+ ) -> torch.Tensor:
20
+ r"""
21
+ Args:
22
+ q (torch.Tensor):
23
+ queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`.
24
+ k (torch.Tensor):
25
+ keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`.
26
+ GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
27
+ v (torch.Tensor):
28
+ values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`.
29
+ indices (torch.LongTensor):
30
+ Block indices of shape `[B, T, H, S]` if `head_first=True` else `[B, T, H, S]`.
31
+ `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
32
+ block_size (int):
33
+ Selected block size. Default: 64.
34
+ scale (Optional[int]):
35
+ Scale factor for attention scores.
36
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
37
+ head_first (Optional[bool]):
38
+ Whether the inputs are in the head-first format. Default: `False`.
39
+ cu_seqlens (torch.LongTensor):
40
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
41
+ consistent with the FlashAttention API.
42
+
43
+ Returns:
44
+ o (torch.Tensor):
45
+ Outputs of shape `[B, HQ, T, V]` if `head_first=True` else `[B, T, HQ, V]`.
46
+ """
47
+ if scale is None:
48
+ scale = k.shape[-1] ** -0.5
49
+ if cu_seqlens is not None:
50
+ if head_first:
51
+ raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
52
+ if head_first:
53
+ q, k, v, indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, indices))
54
+
55
+ dtype = q.dtype
56
+ G = q.shape[2] // k.shape[2]
57
+ BS = block_size
58
+ k, v, indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, indices))
59
+ q, k, v = map(lambda x: x.float(), (q, k, v))
60
+
61
+ o = torch.zeros_like(v)
62
+ varlen = True
63
+ if cu_seqlens is None:
64
+ varlen = False
65
+ B, T = q.shape[:2]
66
+ cu_seqlens = torch.cat([indices.new_tensor(range(0, B*T, T)), indices.new_tensor([B*T])])
67
+
68
+ for i in range(len(cu_seqlens) - 1):
69
+ if not varlen:
70
+ q_b, k_b, v_b, i_b = q[i], k[i], v[i], indices[i]
71
+ else:
72
+ T = cu_seqlens[i+1] - cu_seqlens[i]
73
+ q_b, k_b, v_b, i_b = map(lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], (q, k, v, indices))
74
+
75
+ i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
76
+ # [T, S*BS, HQ]
77
+ i_b = i_b.view(T, indices.shape[2], -1).transpose(1, 2)
78
+ for i_q in range(T):
79
+ # [HQ, D]
80
+ q_i = q_b[i_q] * scale
81
+ # [S*BS, HQ]
82
+ i_i = i_b[i_q]
83
+ # [S*BS, HQ, -1]
84
+ k_i, v_i = map(lambda x: x.gather(0, i_i.clamp(0, T-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
85
+ # [S*BS, HQ]
86
+ attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(i_i > i_q, float('-inf')).softmax(0)
87
+ if not varlen:
88
+ o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
89
+ else:
90
+ o[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
91
+
92
+ if head_first:
93
+ o = rearrange(o, 'b t h d -> b h t d')
94
+ return o.to(dtype)
fla/ops/nsa/parallel.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ import warnings
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+ import triton
9
+ import triton.language as tl
10
+ from einops import rearrange
11
+
12
+ from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices
13
+ from fla.ops.nsa.utils import _bitonic_merge
14
+ from fla.ops.utils import mean_pooling
15
+ from fla.ops.utils.op import exp, log
16
+ from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous
17
+
18
+ try:
19
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
20
+ except ImportError:
21
+ warnings.warn(
22
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
23
+ category=ImportWarning
24
+ )
25
+ flash_attn_func = None
26
+
27
+
28
+ @triton.heuristics({
29
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
30
+ })
31
+ @triton.autotune(
32
+ configs=[
33
+ triton.Config({}, num_warps=num_warps)
34
+ for num_warps in [1, 2, 4]
35
+ ],
36
+ key=['BS', 'BK', 'BV'],
37
+ )
38
+ @triton.jit
39
+ def parallel_nsa_compression_fwd_kernel(
40
+ q,
41
+ k,
42
+ v,
43
+ o,
44
+ lse,
45
+ scale,
46
+ offsets,
47
+ token_indices,
48
+ chunk_offsets,
49
+ T,
50
+ H: tl.constexpr,
51
+ HQ: tl.constexpr,
52
+ G: tl.constexpr,
53
+ K: tl.constexpr,
54
+ V: tl.constexpr,
55
+ BC: tl.constexpr,
56
+ BS: tl.constexpr,
57
+ BK: tl.constexpr,
58
+ BV: tl.constexpr,
59
+ USE_OFFSETS: tl.constexpr,
60
+ ):
61
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
62
+ i_b, i_h = i_bh // H, i_bh % H
63
+
64
+ if USE_OFFSETS:
65
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
66
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
67
+ T = eos - bos
68
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
69
+ else:
70
+ bos, eos = i_b * T, i_b * T + T
71
+ boc = i_b * tl.cdiv(T, BS)
72
+
73
+ p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
74
+
75
+ # the Q block is kept in the shared memory throughout the whole kernel
76
+ # [G, BK]
77
+ b_q = tl.load(p_q, boundary_check=(0, 1))
78
+ b_q = (b_q * scale).to(b_q.dtype)
79
+
80
+ # the number of compression representations in total
81
+ TC = tl.cdiv(T, BS)
82
+ # the number of compression representations required to iterate over
83
+ # incomplete compression blocks are not included
84
+ NC = (i_t + 1) // BS
85
+
86
+ p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
87
+ # [G, BV]
88
+ b_o = tl.zeros([G, BV], dtype=tl.float32)
89
+ # max scores for the current block
90
+ b_m = tl.full([G], float('-inf'), dtype=tl.float32)
91
+ # lse = log(acc) + m
92
+ b_acc = tl.zeros([G], dtype=tl.float32)
93
+
94
+ for i_c in range(0, NC, BC):
95
+ o_c = i_c + tl.arange(0, BC)
96
+
97
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
98
+ p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0))
99
+ # [BK, BC]
100
+ b_k = tl.load(p_k, boundary_check=(0, 1))
101
+ # [BC, BV]
102
+ b_v = tl.load(p_v, boundary_check=(0, 1))
103
+ # [G, BC]
104
+ b_s = tl.dot(b_q, b_k)
105
+ b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf'))
106
+
107
+ # [G]
108
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
109
+ b_r = exp(b_mp - b_m)
110
+ # [G, BC]
111
+ b_p = exp(b_s - b_m[:, None])
112
+ # [G]
113
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
114
+
115
+ # [G, BV]
116
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
117
+
118
+ b_mp = b_m
119
+ if NC == 0:
120
+ b_lse = tl.zeros([G], dtype=tl.float32)
121
+ else:
122
+ b_o = b_o / b_acc[:, None]
123
+ b_lse = b_m + log(b_acc)
124
+
125
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
126
+ if i_v == 0:
127
+ tl.store(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G), b_lse.to(lse.dtype.element_ty))
128
+
129
+
130
+ @triton.heuristics({
131
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
132
+ })
133
+ @triton.autotune(
134
+ configs=[
135
+ triton.Config({}, num_warps=num_warps)
136
+ for num_warps in [1, 2, 4]
137
+ ],
138
+ key=['BS', 'BK', 'BV'],
139
+ )
140
+ @triton.jit(do_not_specialize=['T'])
141
+ def parallel_nsa_compression_bwd_kernel_dq(
142
+ q,
143
+ k,
144
+ v,
145
+ lse,
146
+ delta,
147
+ do,
148
+ dq,
149
+ scale,
150
+ offsets,
151
+ token_indices,
152
+ chunk_offsets,
153
+ T,
154
+ B: tl.constexpr,
155
+ H: tl.constexpr,
156
+ HQ: tl.constexpr,
157
+ G: tl.constexpr,
158
+ K: tl.constexpr,
159
+ V: tl.constexpr,
160
+ BC: tl.constexpr,
161
+ BS: tl.constexpr,
162
+ BK: tl.constexpr,
163
+ BV: tl.constexpr,
164
+ USE_OFFSETS: tl.constexpr
165
+ ):
166
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
167
+ i_b, i_h = i_bh // H, i_bh % H
168
+
169
+ if USE_OFFSETS:
170
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
171
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
172
+ T = eos - bos
173
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
174
+ else:
175
+ bos, eos = i_b * T, i_b * T + T
176
+ boc = i_b * tl.cdiv(T, BS)
177
+
178
+ q += (bos + i_t) * HQ*K
179
+ do += (bos + i_t) * HQ*V
180
+ lse += (bos + i_t) * HQ
181
+ delta += (bos + i_t) * HQ
182
+ dq += (i_v * B * T + bos + i_t) * HQ*K
183
+
184
+ p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
185
+ p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
186
+
187
+ # [G, BK]
188
+ b_q = tl.load(p_q, boundary_check=(0, 1))
189
+ b_q = (b_q * scale).to(b_q.dtype)
190
+
191
+ p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
192
+ p_lse = lse + i_h * G + tl.arange(0, G)
193
+ p_delta = delta + i_h * G + tl.arange(0, G)
194
+
195
+ # the number of compression representations in total
196
+ TC = tl.cdiv(T, BS)
197
+ # the number of compression representations required to iterate over
198
+ # incomplete compression blocks are not included
199
+ NC = (i_t + 1) // BS
200
+
201
+ # [G, BV]
202
+ b_do = tl.load(p_do, boundary_check=(0, 1))
203
+ # [G]
204
+ b_lse = tl.load(p_lse)
205
+ b_delta = tl.load(p_delta)
206
+
207
+ # [G, BK]
208
+ b_dq = tl.zeros([G, BK], dtype=tl.float32)
209
+ for i_c in range(0, NC, BC):
210
+ o_c = i_c + tl.arange(0, BC)
211
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
212
+ p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1))
213
+ # [BK, BC]
214
+ b_k = tl.load(p_k, boundary_check=(0, 1))
215
+ # [BV, BC]
216
+ b_v = tl.load(p_v, boundary_check=(0, 1))
217
+
218
+ # [G, BC]
219
+ b_s = tl.dot(b_q, b_k)
220
+ b_p = exp(b_s - b_lse[:, None])
221
+ b_p = tl.where((o_c < NC)[None, :], b_p, 0)
222
+
223
+ # [G, BV] @ [BV, BC] -> [G, BC]
224
+ b_dp = tl.dot(b_do, b_v)
225
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
226
+ # [G, BC] @ [BC, BK] -> [G, BK]
227
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
228
+ b_dq *= scale
229
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
230
+
231
+
232
+ @triton.heuristics({
233
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
234
+ })
235
+ @triton.autotune(
236
+ configs=[
237
+ triton.Config({}, num_warps=num_warps)
238
+ for num_warps in [1, 2, 4]
239
+ ],
240
+ key=['BS', 'BK', 'BV'],
241
+ )
242
+ @triton.jit(do_not_specialize=['T'])
243
+ def parallel_nsa_compression_bwd_kernel_dkv(
244
+ q,
245
+ k,
246
+ v,
247
+ lse,
248
+ delta,
249
+ do,
250
+ dk,
251
+ dv,
252
+ offsets,
253
+ chunk_indices,
254
+ chunk_offsets,
255
+ scale,
256
+ T,
257
+ B: tl.constexpr,
258
+ H: tl.constexpr,
259
+ HQ: tl.constexpr,
260
+ G: tl.constexpr,
261
+ K: tl.constexpr,
262
+ V: tl.constexpr,
263
+ BC: tl.constexpr,
264
+ BS: tl.constexpr,
265
+ BK: tl.constexpr,
266
+ BV: tl.constexpr,
267
+ USE_OFFSETS: tl.constexpr
268
+ ):
269
+ i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
270
+ i_b, i_h = i_bh // H, i_bh % H
271
+
272
+ if USE_OFFSETS:
273
+ i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32)
274
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
275
+ T = eos - bos
276
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
277
+ else:
278
+ bos, eos = i_b * T, i_b * T + T
279
+ boc = i_b * tl.cdiv(T, BS)
280
+
281
+ # the number of compression representations in total
282
+ TC = tl.cdiv(T, BS)
283
+
284
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
285
+ p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
286
+ p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0))
287
+ p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0))
288
+
289
+ # [BC, BK]
290
+ b_k = tl.load(p_k, boundary_check=(0, 1))
291
+ b_dk = tl.zeros([BC, BK], dtype=tl.float32)
292
+ # [BC, BV]
293
+ b_v = tl.load(p_v, boundary_check=(0, 1))
294
+ b_dv = tl.zeros([BC, BV], dtype=tl.float32)
295
+
296
+ for i in range(i_c * BC * BS, T):
297
+ o_c = i_c * BC + tl.arange(0, BC)
298
+
299
+ p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
300
+ # [G, BK]
301
+ b_q = tl.load(p_q, boundary_check=(0, 1))
302
+ b_q = (b_q * scale).to(b_q.dtype)
303
+
304
+ p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
305
+ p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G)
306
+ p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G)
307
+ # [G, BV]
308
+ b_do = tl.load(p_do, boundary_check=(0, 1))
309
+ # [G]
310
+ b_lse = tl.load(p_lse)
311
+ b_delta = tl.load(p_delta)
312
+ # [BC, G]
313
+ b_s = tl.dot(b_k, tl.trans(b_q))
314
+ b_p = exp(b_s - b_lse[None, :])
315
+ b_p = tl.where((i >= max(0, (o_c + 1) * BS - 1))[:, None], b_p, 0)
316
+ # [BC, G] @ [G, BV] -> [BC, BV]
317
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
318
+ # [BC, BV] @ [BV, G] -> [BC, G]
319
+ b_dp = tl.dot(b_v, tl.trans(b_do))
320
+ # [BC, G]
321
+ b_ds = b_p * (b_dp - b_delta[None, :])
322
+ # [BC, G] @ [G, BK] -> [BC, BK]
323
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
324
+
325
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
326
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
327
+
328
+
329
+ @triton.heuristics({
330
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
331
+ })
332
+ @triton.autotune(
333
+ configs=[
334
+ triton.Config({}, num_warps=num_warps)
335
+ for num_warps in [1, 2, 4]
336
+ ],
337
+ key=['BS', 'BK'],
338
+ )
339
+ @triton.jit
340
+ def parallel_nsa_kernel_topk(
341
+ q,
342
+ k,
343
+ lse,
344
+ scale,
345
+ block_indices,
346
+ offsets,
347
+ token_indices,
348
+ chunk_offsets,
349
+ T,
350
+ H: tl.constexpr,
351
+ HQ: tl.constexpr,
352
+ G: tl.constexpr,
353
+ K: tl.constexpr,
354
+ S: tl.constexpr,
355
+ BC: tl.constexpr,
356
+ BS: tl.constexpr,
357
+ BK: tl.constexpr,
358
+ USE_OFFSETS: tl.constexpr,
359
+ ):
360
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
361
+ i_b, i_h = i_bh // H, i_bh % H
362
+
363
+ if USE_OFFSETS:
364
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
365
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
366
+ T = eos - bos
367
+ boc = tl.load(chunk_offsets + i_n).to(tl.int32)
368
+ else:
369
+ bos, eos = i_b * T, i_b * T + T
370
+ boc = i_b * tl.cdiv(T, BS)
371
+
372
+ p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
373
+
374
+ # the Q block is kept in the shared memory throughout the whole kernel
375
+ # [G, BK]
376
+ b_q = tl.load(p_q, boundary_check=(0, 1))
377
+ b_q = (b_q * scale).to(b_q.dtype)
378
+
379
+ # the number of compression representations in total
380
+ TC = tl.cdiv(T, BS)
381
+ # the number of compression representations required to iterate over
382
+ # incomplete compression blocks are not included
383
+ NC = (i_t + 1) // BS
384
+ ################################
385
+ # 1. lse computation
386
+ ################################
387
+ if lse is not None:
388
+ b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G))
389
+ else:
390
+ # max scores for the current block
391
+ b_m = tl.full([G], float('-inf'), dtype=tl.float32)
392
+ # lse = log(acc) + m
393
+ b_acc = tl.zeros([G], dtype=tl.float32)
394
+ for i_c in range(0, NC, BC):
395
+ o_c = i_c + tl.arange(0, BC)
396
+
397
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
398
+ # [BK, BC]
399
+ b_k = tl.load(p_k, boundary_check=(0, 1))
400
+
401
+ # [G, BC]
402
+ b_s = tl.dot(b_q, b_k)
403
+ b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf'))
404
+
405
+ # [G]
406
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
407
+ b_r = exp(b_mp - b_m)
408
+ # [G, BC]
409
+ b_p = exp(b_s - b_m[:, None])
410
+ # [G]
411
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
412
+
413
+ b_mp = b_m
414
+ if NC == 0:
415
+ b_lse = tl.zeros([G], dtype=tl.float32)
416
+ else:
417
+ b_lse = b_m + log(b_acc)
418
+
419
+ ################################
420
+ # 2. topk selection
421
+ ################################
422
+ # [BC]
423
+ b_i = tl.full([BC], -1, dtype=tl.float32)
424
+ o_i = tl.zeros([BC], dtype=tl.int32)
425
+ m_i = tl.arange(0, BC) < BC//2
426
+ for i_c in range(0, i_t // BS + 1, BC):
427
+ o_c = i_c + tl.arange(0, BC)
428
+
429
+ p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1))
430
+ # [BK, BC]
431
+ b_k = tl.load(p_k, boundary_check=(0, 1))
432
+ # [G, BC]
433
+ b_s = tl.dot(b_q, b_k)
434
+ b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf'))
435
+ # [G, BC]
436
+ b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), exp(b_s - b_lse[:, None]))
437
+ # the importance scores of the current block
438
+ # [BC]
439
+ b_i, b_ip = tl.sum(b_p, 0), b_i
440
+ o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i
441
+
442
+ n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0])
443
+ for i in tl.static_range(1, n_dims):
444
+ b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims)
445
+
446
+ if i_c != 0:
447
+ b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims)
448
+ b_i_new = b_ip * m_i + b_i * (1 - m_i)
449
+ o_i_new = o_ip * m_i + o_i * (1 - m_i)
450
+ b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims)
451
+ else:
452
+ b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims)
453
+
454
+ m_top = tl.arange(0, BC//S) == 0
455
+ b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0)
456
+
457
+ p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,))
458
+ tl.store(p_b, b_top.to(p_b.dtype.element_ty))
459
+
460
+
461
+ @triton.heuristics({
462
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
463
+ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
464
+ })
465
+ @triton.autotune(
466
+ configs=[
467
+ triton.Config({}, num_warps=num_warps)
468
+ for num_warps in [1, 2, 4]
469
+ ],
470
+ key=['BS', 'BK', 'BV'],
471
+ )
472
+ @triton.jit
473
+ def parallel_nsa_fwd_kernel(
474
+ q,
475
+ k,
476
+ v,
477
+ o,
478
+ lse,
479
+ scale,
480
+ block_indices,
481
+ block_counts,
482
+ offsets,
483
+ token_indices,
484
+ T,
485
+ H: tl.constexpr,
486
+ HQ: tl.constexpr,
487
+ G: tl.constexpr,
488
+ K: tl.constexpr,
489
+ V: tl.constexpr,
490
+ S: tl.constexpr,
491
+ BS: tl.constexpr,
492
+ BK: tl.constexpr,
493
+ BV: tl.constexpr,
494
+ USE_OFFSETS: tl.constexpr,
495
+ USE_BLOCK_COUNTS: tl.constexpr
496
+ ):
497
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
498
+ i_b, i_h = i_bh // H, i_bh % H
499
+
500
+ if USE_OFFSETS:
501
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
502
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
503
+ T = eos - bos
504
+ else:
505
+ bos, eos = i_b * T, i_b * T + T
506
+
507
+ k += (bos * H + i_h) * K
508
+ v += (bos * H + i_h) * V
509
+ block_indices += (bos + i_t) * H*S + i_h * S
510
+
511
+ if USE_BLOCK_COUNTS:
512
+ NS = tl.load(block_counts + (bos + i_t) * H + i_h)
513
+ else:
514
+ NS = S
515
+
516
+ p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
517
+ # the Q block is kept in the shared memory throughout the whole kernel
518
+ # [G, BK]
519
+ b_q = tl.load(p_q, boundary_check=(0, 1))
520
+ b_q = (b_q * scale).to(b_q.dtype)
521
+
522
+ p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
523
+ p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
524
+ # [G, BV]
525
+ b_o = tl.zeros([G, BV], dtype=tl.float32)
526
+
527
+ b_m = tl.full([G], float('-inf'), dtype=tl.float32)
528
+ b_acc = tl.zeros([G], dtype=tl.float32)
529
+ for i in range(NS):
530
+ i_s = tl.load(block_indices + i).to(tl.int32) * BS
531
+ if i_s <= i_t and i_s >= 0:
532
+ p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
533
+ p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
534
+ # [BK, BS]
535
+ b_k = tl.load(p_k, boundary_check=(0, 1))
536
+ # [BS, BV]
537
+ b_v = tl.load(p_v, boundary_check=(0, 1))
538
+ # [G, BS]
539
+ b_s = tl.dot(b_q, b_k)
540
+ b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf'))
541
+
542
+ # [G]
543
+ b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m
544
+ b_r = exp(b_mp - b_m)
545
+ # [G, BS]
546
+ b_p = exp(b_s - b_m[:, None])
547
+ # [G]
548
+ b_acc = b_acc * b_r + tl.sum(b_p, 1)
549
+ # [G, BV]
550
+ b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v)
551
+
552
+ b_mp = b_m
553
+ b_o = b_o / b_acc[:, None]
554
+ b_m += log(b_acc)
555
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
556
+ tl.store(p_lse, b_m.to(p_lse.dtype.element_ty))
557
+
558
+
559
+ @triton.heuristics({
560
+ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
561
+ })
562
+ @triton.jit
563
+ def parallel_nsa_kernel_mask(
564
+ block_indices,
565
+ block_counts,
566
+ block_mask,
567
+ T: tl.constexpr,
568
+ H: tl.constexpr,
569
+ S: tl.constexpr,
570
+ BS: tl.constexpr,
571
+ NS: tl.constexpr,
572
+ USE_BLOCK_COUNTS: tl.constexpr
573
+ ):
574
+ i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2)
575
+ i_h, i_s = i_hs // S, i_hs % S
576
+
577
+ b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s)
578
+ if USE_BLOCK_COUNTS:
579
+ b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h)
580
+ else:
581
+ b_m = b_i * BS <= i_t
582
+
583
+ if b_i < NS and b_i >= 0:
584
+ tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty))
585
+
586
+
587
+ @triton.jit
588
+ def parallel_nsa_bwd_kernel_preprocess(
589
+ o,
590
+ do,
591
+ delta,
592
+ B: tl.constexpr,
593
+ V: tl.constexpr
594
+ ):
595
+ i_n = tl.program_id(0)
596
+ o_d = tl.arange(0, B)
597
+ m_d = o_d < V
598
+
599
+ b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
600
+ b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
601
+ b_delta = tl.sum(b_o * b_do)
602
+
603
+ tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
604
+
605
+
606
+ @triton.heuristics({
607
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None,
608
+ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
609
+ })
610
+ @triton.autotune(
611
+ configs=[
612
+ triton.Config({}, num_warps=num_warps)
613
+ for num_warps in [1, 2, 4]
614
+ ],
615
+ key=['BS', 'BK', 'BV'],
616
+ )
617
+ @triton.jit(do_not_specialize=['T'])
618
+ def parallel_nsa_bwd_kernel_dq(
619
+ q,
620
+ k,
621
+ v,
622
+ lse,
623
+ delta,
624
+ do,
625
+ dq,
626
+ scale,
627
+ block_indices,
628
+ block_counts,
629
+ offsets,
630
+ token_indices,
631
+ T,
632
+ B: tl.constexpr,
633
+ H: tl.constexpr,
634
+ HQ: tl.constexpr,
635
+ G: tl.constexpr,
636
+ K: tl.constexpr,
637
+ V: tl.constexpr,
638
+ S: tl.constexpr,
639
+ BS: tl.constexpr,
640
+ BK: tl.constexpr,
641
+ BV: tl.constexpr,
642
+ USE_OFFSETS: tl.constexpr,
643
+ USE_BLOCK_COUNTS: tl.constexpr
644
+ ):
645
+ i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
646
+ i_b, i_h = i_bh // H, i_bh % H
647
+
648
+ if USE_OFFSETS:
649
+ i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
650
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
651
+ T = eos - bos
652
+ else:
653
+ bos, eos = i_b * T, i_b * T + T
654
+
655
+ q += (bos + i_t) * HQ*K
656
+ do += (bos + i_t) * HQ*V
657
+ lse += (bos + i_t) * HQ
658
+ delta += (bos + i_t) * HQ
659
+ dq += (i_v * B * T + bos + i_t) * HQ*K
660
+ block_indices += (bos + i_t) * H*S + i_h * S
661
+
662
+ if USE_BLOCK_COUNTS:
663
+ NS = tl.load(block_counts + (bos + i_t) * H + i_h)
664
+ else:
665
+ NS = S
666
+
667
+ k += (bos * H + i_h) * K
668
+ v += (bos * H + i_h) * V
669
+
670
+ p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
671
+ p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
672
+
673
+ # [G, BK]
674
+ b_q = tl.load(p_q, boundary_check=(0, 1))
675
+ b_q = (b_q * scale).to(b_q.dtype)
676
+
677
+ p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
678
+ p_lse = lse + i_h * G + tl.arange(0, G)
679
+ p_delta = delta + i_h * G + tl.arange(0, G)
680
+
681
+ # [G, BV]
682
+ b_do = tl.load(p_do, boundary_check=(0, 1))
683
+ # [G]
684
+ b_lse = tl.load(p_lse)
685
+ b_delta = tl.load(p_delta)
686
+
687
+ # [G, BK]
688
+ b_dq = tl.zeros([G, BK], dtype=tl.float32)
689
+ for i in range(NS):
690
+ i_s = tl.load(block_indices + i).to(tl.int32) * BS
691
+ if i_s <= i_t and i_s >= 0:
692
+ p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1))
693
+ p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1))
694
+ # [BK, BS]
695
+ b_k = tl.load(p_k, boundary_check=(0, 1))
696
+ # [BV, BS]
697
+ b_v = tl.load(p_v, boundary_check=(0, 1))
698
+
699
+ # [G, BS]
700
+ b_s = tl.dot(b_q, b_k)
701
+ b_p = exp(b_s - b_lse[:, None])
702
+ b_p = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p, 0)
703
+
704
+ # [G, BV] @ [BV, BS] -> [G, BS]
705
+ b_dp = tl.dot(b_do, b_v)
706
+ b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None])
707
+ # [G, BS] @ [BS, BK] -> [G, BK]
708
+ b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k))
709
+ b_dq *= scale
710
+
711
+ tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
712
+
713
+
714
+ @triton.heuristics({
715
+ 'USE_OFFSETS': lambda args: args['offsets'] is not None
716
+ })
717
+ @triton.autotune(
718
+ configs=[
719
+ triton.Config({}, num_warps=num_warps)
720
+ for num_warps in [1, 2, 4]
721
+ ],
722
+ key=['BS', 'BK', 'BV'],
723
+ )
724
+ @triton.jit(do_not_specialize=['T'])
725
+ def parallel_nsa_bwd_kernel_dkv(
726
+ q,
727
+ k,
728
+ v,
729
+ lse,
730
+ delta,
731
+ do,
732
+ dk,
733
+ dv,
734
+ block_mask,
735
+ offsets,
736
+ chunk_indices,
737
+ scale,
738
+ T,
739
+ B: tl.constexpr,
740
+ H: tl.constexpr,
741
+ HQ: tl.constexpr,
742
+ G: tl.constexpr,
743
+ K: tl.constexpr,
744
+ V: tl.constexpr,
745
+ M: tl.constexpr,
746
+ BS: tl.constexpr,
747
+ BK: tl.constexpr,
748
+ BV: tl.constexpr,
749
+ USE_OFFSETS: tl.constexpr
750
+ ):
751
+ i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
752
+ i_b, i_h = i_bh // H, i_bh % H
753
+
754
+ if USE_OFFSETS:
755
+ i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32)
756
+ bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
757
+ T = eos - bos
758
+ else:
759
+ bos, eos = i_b * T, i_b * T + T
760
+
761
+ p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
762
+ p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
763
+ p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
764
+ p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
765
+
766
+ # [BS, BK]
767
+ b_k = tl.load(p_k, boundary_check=(0, 1))
768
+ b_dk = tl.zeros([BS, BK], dtype=tl.float32)
769
+ # [BS, BV]
770
+ b_v = tl.load(p_v, boundary_check=(0, 1))
771
+ b_dv = tl.zeros([BS, BV], dtype=tl.float32)
772
+
773
+ for i in range(i_s * BS, T):
774
+ b_m = tl.load(block_mask + (bos + i) * H*M + i_h * M + i_s)
775
+ if b_m:
776
+ p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
777
+ # [G, BK]
778
+ b_q = tl.load(p_q, boundary_check=(0, 1))
779
+ b_q = (b_q * scale).to(b_q.dtype)
780
+
781
+ p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
782
+ p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G)
783
+ p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G)
784
+ # [G, BV]
785
+ b_do = tl.load(p_do, boundary_check=(0, 1))
786
+ # [G]
787
+ b_lse = tl.load(p_lse)
788
+ b_delta = tl.load(p_delta)
789
+ # [BS, G]
790
+ b_s = tl.dot(b_k, tl.trans(b_q))
791
+ b_p = exp(b_s - b_lse[None, :])
792
+ b_p = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p, 0)
793
+ # [BS, G] @ [G, BV] -> [BS, BV]
794
+ b_dv += tl.dot(b_p.to(b_do.dtype), b_do)
795
+ # [BS, BV] @ [BV, G] -> [BS, G]
796
+ b_dp = tl.dot(b_v, tl.trans(b_do))
797
+ # [BS, G]
798
+ b_ds = b_p * (b_dp - b_delta[None, :])
799
+ # [BS, G] @ [G, BK] -> [BS, BK]
800
+ b_dk += tl.dot(b_ds.to(b_q.dtype), b_q)
801
+
802
+ tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
803
+ tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
804
+
805
+
806
+ def parallel_nsa_compression_fwd(
807
+ q: torch.Tensor,
808
+ k: torch.Tensor,
809
+ v: torch.Tensor,
810
+ block_size: int,
811
+ scale: float,
812
+ offsets: Optional[torch.LongTensor] = None,
813
+ token_indices: Optional[torch.LongTensor] = None,
814
+ ):
815
+ B, T, HQ, K, V = *q.shape, v.shape[-1]
816
+ H = k.shape[2]
817
+ G = HQ // H
818
+ BC = BS = block_size
819
+ if check_shared_mem('hopper', q.device.index):
820
+ BK = min(256, triton.next_power_of_2(K))
821
+ BV = min(256, triton.next_power_of_2(V))
822
+ else:
823
+ BK = min(128, triton.next_power_of_2(K))
824
+ BV = min(128, triton.next_power_of_2(V))
825
+ NK = triton.cdiv(K, BK)
826
+ NV = triton.cdiv(V, BV)
827
+ assert NK == 1, "The key dimension can not be larger than 256"
828
+
829
+ chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None
830
+
831
+ grid = (T, NV, B * H)
832
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
833
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
834
+
835
+ parallel_nsa_compression_fwd_kernel[grid](
836
+ q=q,
837
+ k=k,
838
+ v=v,
839
+ o=o,
840
+ lse=lse,
841
+ scale=scale,
842
+ offsets=offsets,
843
+ token_indices=token_indices,
844
+ chunk_offsets=chunk_offsets,
845
+ T=T,
846
+ H=H,
847
+ HQ=HQ,
848
+ G=G,
849
+ K=K,
850
+ V=V,
851
+ BC=BC,
852
+ BS=BS,
853
+ BK=BK,
854
+ BV=BV,
855
+ )
856
+ return o, lse
857
+
858
+
859
+ def parallel_nsa_compression_bwd(
860
+ q: torch.Tensor,
861
+ k: torch.Tensor,
862
+ v: torch.Tensor,
863
+ o: torch.Tensor,
864
+ lse: torch.Tensor,
865
+ do: torch.Tensor,
866
+ block_size: int = 64,
867
+ scale: float = None,
868
+ offsets: Optional[torch.LongTensor] = None,
869
+ token_indices: Optional[torch.LongTensor] = None,
870
+ ):
871
+ B, T, HQ, K, V = *q.shape, v.shape[-1]
872
+ H = k.shape[2]
873
+ G = HQ // H
874
+ BC = BS = block_size
875
+ BK = triton.next_power_of_2(K)
876
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
877
+ NV = triton.cdiv(V, BV)
878
+ if offsets is not None:
879
+ lens = prepare_lens(offsets)
880
+ chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()])
881
+ chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets)
882
+ chunk_offsets = prepare_chunk_offsets(offsets, BS)
883
+ NC = len(chunk_indices)
884
+ else:
885
+ chunk_indices, chunk_offsets = None, None
886
+ NC = triton.cdiv(triton.cdiv(T, BS), BC)
887
+
888
+ delta = parallel_nsa_bwd_preprocess(o, do)
889
+
890
+ dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
891
+ grid = (T, NV, B * H)
892
+ parallel_nsa_compression_bwd_kernel_dq[grid](
893
+ q=q,
894
+ k=k,
895
+ v=v,
896
+ lse=lse,
897
+ delta=delta,
898
+ do=do,
899
+ dq=dq,
900
+ scale=scale,
901
+ offsets=offsets,
902
+ token_indices=token_indices,
903
+ chunk_offsets=chunk_offsets,
904
+ T=T,
905
+ B=B,
906
+ H=H,
907
+ HQ=HQ,
908
+ G=G,
909
+ K=K,
910
+ V=V,
911
+ BC=BC,
912
+ BS=BS,
913
+ BK=BK,
914
+ BV=BV
915
+ )
916
+ dq = dq.sum(0)
917
+
918
+ dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
919
+ dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
920
+
921
+ grid = (NV, NC, B * H)
922
+ parallel_nsa_compression_bwd_kernel_dkv[grid](
923
+ q=q,
924
+ k=k,
925
+ v=v,
926
+ lse=lse,
927
+ delta=delta,
928
+ do=do,
929
+ dk=dk,
930
+ dv=dv,
931
+ offsets=offsets,
932
+ chunk_indices=chunk_indices,
933
+ chunk_offsets=chunk_offsets,
934
+ scale=scale,
935
+ T=T,
936
+ B=B,
937
+ H=H,
938
+ HQ=HQ,
939
+ G=G,
940
+ K=K,
941
+ V=V,
942
+ BC=BC,
943
+ BS=BS,
944
+ BK=BK,
945
+ BV=BV
946
+ )
947
+ dk = dk.sum(0)
948
+ return dq, dk, dv
949
+
950
+
951
+ class ParallelNSACompressionFunction(torch.autograd.Function):
952
+
953
+ @staticmethod
954
+ @contiguous
955
+ @autocast_custom_fwd
956
+ def forward(
957
+ ctx,
958
+ q,
959
+ k,
960
+ v,
961
+ block_size,
962
+ scale,
963
+ offsets
964
+ ):
965
+ ctx.dtype = q.dtype
966
+
967
+ # 2-d sequence indices denoting the offsets of tokens in each sequence
968
+ # for example, if the passed `offsets` is [0, 2, 6],
969
+ # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
970
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
971
+ token_indices = prepare_token_indices(offsets) if offsets is not None else None
972
+
973
+ o, lse = parallel_nsa_compression_fwd(
974
+ q=q,
975
+ k=k,
976
+ v=v,
977
+ block_size=block_size,
978
+ scale=scale,
979
+ offsets=offsets,
980
+ token_indices=token_indices
981
+ )
982
+ ctx.save_for_backward(q, k, v, o, lse)
983
+ ctx.offsets = offsets
984
+ ctx.token_indices = token_indices
985
+ ctx.block_size = block_size
986
+ ctx.scale = scale
987
+ return o.to(q.dtype), lse
988
+
989
+ @staticmethod
990
+ @contiguous
991
+ @autocast_custom_bwd
992
+ def backward(ctx, do, *args):
993
+ q, k, v, o, lse = ctx.saved_tensors
994
+ dq, dk, dv = parallel_nsa_compression_bwd(
995
+ q=q,
996
+ k=k,
997
+ v=v,
998
+ o=o,
999
+ lse=lse,
1000
+ do=do,
1001
+ block_size=ctx.block_size,
1002
+ scale=ctx.scale,
1003
+ offsets=ctx.offsets,
1004
+ token_indices=ctx.token_indices
1005
+ )
1006
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None
1007
+
1008
+
1009
+ def parallel_nsa_topk(
1010
+ q: torch.Tensor,
1011
+ k: torch.Tensor,
1012
+ lse: torch.Tensor,
1013
+ block_counts: Union[torch.LongTensor, int],
1014
+ block_size: int = 64,
1015
+ scale: float = None,
1016
+ offsets: Optional[torch.LongTensor] = None,
1017
+ ) -> torch.LongTensor:
1018
+ B, T, HQ, K = q.shape
1019
+ H = k.shape[2]
1020
+ G = HQ // H
1021
+ S = block_counts if isinstance(block_counts, int) else block_counts.max().item()
1022
+ S = triton.next_power_of_2(S)
1023
+ # here we set BC = BS, but beware that they are actually decoupled
1024
+ BC = BS = block_size
1025
+ BK = triton.next_power_of_2(K)
1026
+
1027
+ block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device)
1028
+ token_indices = prepare_token_indices(offsets) if offsets is not None else None
1029
+ chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None
1030
+ grid = (T, B * H)
1031
+ parallel_nsa_kernel_topk[grid](
1032
+ q=q,
1033
+ k=k,
1034
+ lse=lse,
1035
+ scale=scale,
1036
+ block_indices=block_indices,
1037
+ offsets=offsets,
1038
+ token_indices=token_indices,
1039
+ chunk_offsets=chunk_offsets,
1040
+ T=T,
1041
+ H=H,
1042
+ HQ=HQ,
1043
+ G=G,
1044
+ K=K,
1045
+ S=S,
1046
+ BC=BC,
1047
+ BS=BS,
1048
+ BK=BK
1049
+ )
1050
+ return block_indices
1051
+
1052
+
1053
+ def parallel_nsa_fwd(
1054
+ q: torch.Tensor,
1055
+ k: torch.Tensor,
1056
+ v: torch.Tensor,
1057
+ block_indices: torch.LongTensor,
1058
+ block_counts: Union[torch.LongTensor, int],
1059
+ block_size: int,
1060
+ scale: float,
1061
+ offsets: Optional[torch.LongTensor] = None,
1062
+ token_indices: Optional[torch.LongTensor] = None,
1063
+ ):
1064
+ B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
1065
+ HQ = q.shape[2]
1066
+ G = HQ // H
1067
+ BS = block_size
1068
+ if check_shared_mem('hopper', q.device.index):
1069
+ BK = min(256, triton.next_power_of_2(K))
1070
+ BV = min(256, triton.next_power_of_2(V))
1071
+ else:
1072
+ BK = min(128, triton.next_power_of_2(K))
1073
+ BV = min(128, triton.next_power_of_2(V))
1074
+ NK = triton.cdiv(K, BK)
1075
+ NV = triton.cdiv(V, BV)
1076
+ assert NK == 1, "The key dimension can not be larger than 256"
1077
+
1078
+ grid = (T, NV, B * H)
1079
+ o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
1080
+ lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
1081
+
1082
+ parallel_nsa_fwd_kernel[grid](
1083
+ q=q,
1084
+ k=k,
1085
+ v=v,
1086
+ o=o,
1087
+ lse=lse,
1088
+ scale=scale,
1089
+ block_indices=block_indices,
1090
+ block_counts=block_counts,
1091
+ offsets=offsets,
1092
+ token_indices=token_indices,
1093
+ T=T,
1094
+ H=H,
1095
+ HQ=HQ,
1096
+ G=G,
1097
+ K=K,
1098
+ V=V,
1099
+ S=S,
1100
+ BS=BS,
1101
+ BK=BK,
1102
+ BV=BV,
1103
+ )
1104
+ return o, lse
1105
+
1106
+
1107
+ def parallel_nsa_block_mask(
1108
+ block_indices: torch.LongTensor,
1109
+ block_counts: Union[torch.LongTensor, int],
1110
+ offsets: torch.LongTensor,
1111
+ block_size: int,
1112
+ ):
1113
+ B, T, H, S = block_indices.shape
1114
+ BS = block_size
1115
+ if offsets is not None:
1116
+ NS = triton.cdiv(prepare_lens(offsets).max().item(), BS)
1117
+ else:
1118
+ NS = triton.cdiv(T, BS)
1119
+ block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device)
1120
+
1121
+ parallel_nsa_kernel_mask[(T, B, H*S)](
1122
+ block_indices=block_indices,
1123
+ block_counts=block_counts,
1124
+ block_mask=block_mask,
1125
+ T=T,
1126
+ H=H,
1127
+ S=S,
1128
+ BS=BS,
1129
+ NS=NS
1130
+ )
1131
+ return block_mask
1132
+
1133
+
1134
+ def parallel_nsa_bwd_preprocess(
1135
+ o: torch.Tensor,
1136
+ do: torch.Tensor
1137
+ ):
1138
+ V = o.shape[-1]
1139
+ delta = torch.empty_like(o[..., 0], dtype=torch.float32)
1140
+ parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)](
1141
+ o=o,
1142
+ do=do,
1143
+ delta=delta,
1144
+ B=triton.next_power_of_2(V),
1145
+ V=V,
1146
+ )
1147
+ return delta
1148
+
1149
+
1150
+ def parallel_nsa_bwd(
1151
+ q: torch.Tensor,
1152
+ k: torch.Tensor,
1153
+ v: torch.Tensor,
1154
+ o: torch.Tensor,
1155
+ lse: torch.Tensor,
1156
+ do: torch.Tensor,
1157
+ block_indices: torch.Tensor,
1158
+ block_counts: Union[torch.LongTensor, int],
1159
+ block_size: int = 64,
1160
+ scale: float = None,
1161
+ offsets: Optional[torch.LongTensor] = None,
1162
+ token_indices: Optional[torch.LongTensor] = None,
1163
+ ):
1164
+ B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
1165
+ HQ = q.shape[2]
1166
+ G = HQ // H
1167
+ BS = block_size
1168
+ BK = triton.next_power_of_2(K)
1169
+ BV = min(128, triton.next_power_of_2(v.shape[-1]))
1170
+ NV = triton.cdiv(V, BV)
1171
+
1172
+ delta = parallel_nsa_bwd_preprocess(o, do)
1173
+
1174
+ dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
1175
+ grid = (T, NV, B * H)
1176
+ parallel_nsa_bwd_kernel_dq[grid](
1177
+ q=q,
1178
+ k=k,
1179
+ v=v,
1180
+ lse=lse,
1181
+ delta=delta,
1182
+ do=do,
1183
+ dq=dq,
1184
+ block_indices=block_indices,
1185
+ block_counts=block_counts,
1186
+ offsets=offsets,
1187
+ token_indices=token_indices,
1188
+ scale=scale,
1189
+ T=T,
1190
+ B=B,
1191
+ H=H,
1192
+ HQ=HQ,
1193
+ G=G,
1194
+ K=K,
1195
+ V=V,
1196
+ S=S,
1197
+ BS=BS,
1198
+ BK=BK,
1199
+ BV=BV
1200
+ )
1201
+ dq = dq.sum(0)
1202
+
1203
+ if offsets is not None:
1204
+ chunk_indices = prepare_chunk_indices(offsets, BS)
1205
+ NS = len(chunk_indices)
1206
+ else:
1207
+ chunk_indices = None
1208
+ NS = triton.cdiv(T, BS)
1209
+
1210
+ # [B, T, H, M]
1211
+ block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size)
1212
+ dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
1213
+ dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
1214
+
1215
+ grid = (NV, NS, B * H)
1216
+ parallel_nsa_bwd_kernel_dkv[grid](
1217
+ q=q,
1218
+ k=k,
1219
+ v=v,
1220
+ lse=lse,
1221
+ delta=delta,
1222
+ do=do,
1223
+ dk=dk,
1224
+ dv=dv,
1225
+ block_mask=block_mask,
1226
+ offsets=offsets,
1227
+ chunk_indices=chunk_indices,
1228
+ scale=scale,
1229
+ T=T,
1230
+ B=B,
1231
+ H=H,
1232
+ HQ=HQ,
1233
+ G=G,
1234
+ K=K,
1235
+ V=V,
1236
+ M=block_mask.shape[-1],
1237
+ BS=BS,
1238
+ BK=BK,
1239
+ BV=BV
1240
+ )
1241
+ dk = dk.sum(0)
1242
+ return dq, dk, dv
1243
+
1244
+
1245
+ @torch.compile
1246
+ class ParallelNSAFunction(torch.autograd.Function):
1247
+
1248
+ @staticmethod
1249
+ @contiguous
1250
+ @autocast_custom_fwd
1251
+ def forward(ctx, q, k, v, block_indices, block_counts, block_size, scale, offsets):
1252
+ ctx.dtype = q.dtype
1253
+
1254
+ # 2-d sequence indices denoting the offsets of tokens in each sequence
1255
+ # for example, if the passed `offsets` is [0, 2, 6],
1256
+ # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
1257
+ # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
1258
+ token_indices = prepare_token_indices(offsets) if offsets is not None else None
1259
+
1260
+ o, lse = parallel_nsa_fwd(
1261
+ q=q,
1262
+ k=k,
1263
+ v=v,
1264
+ block_indices=block_indices,
1265
+ block_counts=block_counts,
1266
+ block_size=block_size,
1267
+ scale=scale,
1268
+ offsets=offsets,
1269
+ token_indices=token_indices
1270
+ )
1271
+ ctx.save_for_backward(q, k, v, o, lse)
1272
+ ctx.block_indices = block_indices
1273
+ ctx.block_counts = block_counts
1274
+ ctx.offsets = offsets
1275
+ ctx.token_indices = token_indices
1276
+ ctx.block_size = block_size
1277
+ ctx.scale = scale
1278
+ return o.to(q.dtype)
1279
+
1280
+ @staticmethod
1281
+ @contiguous
1282
+ @autocast_custom_bwd
1283
+ def backward(ctx, do):
1284
+ q, k, v, o, lse = ctx.saved_tensors
1285
+ dq, dk, dv = parallel_nsa_bwd(
1286
+ q=q,
1287
+ k=k,
1288
+ v=v,
1289
+ o=o,
1290
+ lse=lse,
1291
+ do=do,
1292
+ block_indices=ctx.block_indices,
1293
+ block_counts=ctx.block_counts,
1294
+ block_size=ctx.block_size,
1295
+ scale=ctx.scale,
1296
+ offsets=ctx.offsets,
1297
+ token_indices=ctx.token_indices
1298
+ )
1299
+ return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
1300
+
1301
+
1302
+ def parallel_nsa_compression(
1303
+ q: torch.Tensor,
1304
+ k: torch.Tensor,
1305
+ v: torch.Tensor,
1306
+ block_size: int = 64,
1307
+ scale: float = None,
1308
+ offsets: Optional[torch.LongTensor] = None
1309
+ ):
1310
+ return ParallelNSACompressionFunction.apply(
1311
+ q,
1312
+ k,
1313
+ v,
1314
+ block_size,
1315
+ scale,
1316
+ offsets
1317
+ )
1318
+
1319
+
1320
+ def parallel_nsa(
1321
+ q: torch.Tensor,
1322
+ k: torch.Tensor,
1323
+ v: torch.Tensor,
1324
+ g_cmp: torch.Tensor,
1325
+ g_slc: torch.Tensor,
1326
+ g_swa: torch.Tensor,
1327
+ block_indices: Optional[torch.LongTensor] = None,
1328
+ block_counts: Union[torch.LongTensor, int] = 16,
1329
+ block_size: int = 64,
1330
+ window_size: int = 0,
1331
+ scale: Optional[float] = None,
1332
+ cu_seqlens: Optional[torch.LongTensor] = None,
1333
+ head_first: bool = False
1334
+ ) -> torch.Tensor:
1335
+ r"""
1336
+ Args:
1337
+ q (torch.Tensor):
1338
+ queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
1339
+ k (torch.Tensor):
1340
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
1341
+ GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
1342
+ v (torch.Tensor):
1343
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
1344
+ g_cmp (torch.Tensor):
1345
+ Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1346
+ g_slc (torch.Tensor):
1347
+ Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1348
+ g_swa (torch.Tensor):
1349
+ Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
1350
+ block_indices (torch.LongTensor):
1351
+ Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
1352
+ `S` is the number of selected blocks for each query token, which is set to 16 in the paper.
1353
+ If `g_cmp` is provided, the passed `block_indices` will be ignored.
1354
+ block_counts (Optional[Union[torch.LongTensor, int]]):
1355
+ Number of selected blocks for each query.
1356
+ If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`,
1357
+ each query can select the same number of blocks.
1358
+ If not provided, it will default to 16.
1359
+ block_size (int):
1360
+ Selected block size. Default: 64.
1361
+ window_size (int):
1362
+ Sliding window size. Default: 0.
1363
+ scale (Optional[int]):
1364
+ Scale factor for attention scores.
1365
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
1366
+ head_first (Optional[bool]):
1367
+ Whether the inputs are in the head-first format. Default: `False`.
1368
+ cu_seqlens (torch.LongTensor):
1369
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
1370
+ consistent with the FlashAttention API.
1371
+
1372
+ Returns:
1373
+ o (torch.Tensor):
1374
+ Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
1375
+ """
1376
+ assert block_counts is not None, "block counts must be provided for selection"
1377
+ if scale is None:
1378
+ scale = k.shape[-1] ** -0.5
1379
+ if cu_seqlens is not None:
1380
+ assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
1381
+ if head_first:
1382
+ q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v))
1383
+ g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa))
1384
+ if not isinstance(block_counts, int):
1385
+ block_counts = rearrange(block_counts, 'b h t -> b t h')
1386
+ assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
1387
+
1388
+ k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens)
1389
+ o_cmp, lse_cmp = None, None
1390
+ if g_cmp is not None:
1391
+ o_cmp, lse_cmp = parallel_nsa_compression(
1392
+ q=q,
1393
+ k=k_cmp,
1394
+ v=v_cmp,
1395
+ block_size=block_size,
1396
+ scale=scale,
1397
+ offsets=cu_seqlens
1398
+ )
1399
+ if block_indices is not None:
1400
+ warnings.warn("`block_indices` will be ignored when `g_cmp` is provided")
1401
+ block_indices = parallel_nsa_topk(
1402
+ q=q,
1403
+ k=k_cmp,
1404
+ lse=lse_cmp,
1405
+ block_counts=block_counts,
1406
+ block_size=block_size,
1407
+ scale=scale,
1408
+ offsets=cu_seqlens
1409
+ )
1410
+ o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens)
1411
+ o = o_slc * g_slc.unsqueeze(-1)
1412
+ if o_cmp is not None:
1413
+ o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1))
1414
+ if window_size > 0:
1415
+ if cu_seqlens is not None:
1416
+ max_seqlen = q.shape[1]
1417
+ o_swa = flash_attn_varlen_func(
1418
+ q.squeeze(0), k.squeeze(0), v.squeeze(0),
1419
+ cu_seqlens_q=cu_seqlens,
1420
+ cu_seqlens_k=cu_seqlens,
1421
+ max_seqlen_q=max_seqlen,
1422
+ max_seqlen_k=max_seqlen,
1423
+ causal=True,
1424
+ window_size=(window_size-1, 0)
1425
+ ).unsqueeze(0)
1426
+ else:
1427
+ o_swa = flash_attn_func(
1428
+ q, k, v,
1429
+ causal=True,
1430
+ window_size=(window_size-1, 0)
1431
+ )
1432
+ o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1))
1433
+ if head_first:
1434
+ o = rearrange(o, 'b t h d -> b h t d')
1435
+ return o
fla/ops/retention/naive.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+
5
+
6
+ def naive_retention(q, k, v):
7
+ orig_type = q.dtype
8
+ q, k, v = q.float(), k.float(), v.float()
9
+ _, n_heads, seq_len, d_head = q.shape
10
+ s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
11
+ n = q.new_tensor(range(seq_len), dtype=torch.float)
12
+ n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
13
+ s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
14
+ o = torch.einsum('bhqk,bhkd->bhqd', s, v)
15
+ return o.to(orig_type)
profile_trace/iteration_11264/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_1536/rank2_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_18432/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_23552/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_23552/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_23552/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_23552/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_2560/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_27648/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_27648/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_27648/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_27648/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_29696/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_29696/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_29696/rank4_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_30720/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_30720/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_30720/rank6_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_30720/rank7_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_31744/rank1_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_36864/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_36864/rank3_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_37888/rank0_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
profile_trace/iteration_37888/rank5_trace.json ADDED
The diff for this file is too large to render. See raw diff
 
torchtitan/components/__pycache__/lr_scheduler.cpython-312.pyc ADDED
Binary file (7.71 kB). View file
 
torchtitan/components/__pycache__/optimizer.cpython-312.pyc ADDED
Binary file (14.5 kB). View file
 
torchtitan/experiments/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (252 Bytes). View file
 
torchtitan/experiments/deepseek_v3/checkpoint.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from typing import Dict, Optional, Set, Tuple
11
+
12
+ import torch
13
+ from safetensors import safe_open
14
+
15
+ from transformers.utils import cached_file
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
21
+
22
+
23
+ def read_weights_from_json(file_path: str) -> Optional[Dict[str, str]]:
24
+ try:
25
+ with open(file_path, "r") as file:
26
+ data = json.load(file)
27
+
28
+ if "weight_map" in data and isinstance(data["weight_map"], dict):
29
+ return data["weight_map"]
30
+ else:
31
+ logger.info("No 'weight_map' dictionary found in the JSON file.")
32
+ return None
33
+ except (json.JSONDecodeError, Exception) as e:
34
+ logger.info(f"An error occurred while reading the JSON file: {str(e)}")
35
+ return None
36
+
37
+
38
+ def get_hf_weight_map_and_path(
39
+ model_id: str,
40
+ ) -> Tuple[Dict[str, str], str]:
41
+ """Get the weight map for a given HF model id and also the cache path for loading the weights"""
42
+ try:
43
+ index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME)
44
+ except Exception as e:
45
+ logger.error(
46
+ f"Model `{model_id}` not found in HF cache. "
47
+ f"You can download the model using `python download.py {model_id}"
48
+ )
49
+ raise e
50
+
51
+ weight_map = read_weights_from_json(index_file)
52
+ weight_path = os.path.dirname(index_file)
53
+ logger.info(f"Loading weights from: {weight_path}")
54
+ return weight_map, weight_path
55
+
56
+
57
+ def get_needed_files(
58
+ state_dict: Dict[str, torch.Tensor], weight_map: Dict[str, str]
59
+ ) -> Set[str]:
60
+ needed_files = set()
61
+ for param in state_dict.keys():
62
+ file = weight_map.get(param)
63
+ if file:
64
+ needed_files.add(file)
65
+ elif param.endswith("weight"):
66
+ raise ValueError(
67
+ f"Parameter {param} not found in weight map, please check..."
68
+ )
69
+ logger.info(f"Needed files: {needed_files}")
70
+ return needed_files
71
+
72
+
73
+ def load_safetensor_file(
74
+ full_path: str, device: torch.device
75
+ ) -> Dict[str, torch.Tensor]:
76
+ tensors = {}
77
+ with safe_open(full_path, framework="pt", device=device) as f:
78
+ for k in f.keys():
79
+ tensors[k] = f.get_tensor(k)
80
+ logger.info(f"Loaded {len(tensors)} tensors from {full_path}")
81
+ return tensors
82
+
83
+
84
+ def load_safetensor_weights(
85
+ model: torch.nn.Module,
86
+ weight_map: Dict[str, str],
87
+ file_location: str,
88
+ device: torch.device,
89
+ ):
90
+ """
91
+ Load safetensor weights into a `nn.Module`.
92
+
93
+ Args:
94
+ model (Module): The PyTorch module to load weights into. It may be a
95
+ model chunk or a full model.
96
+ weight_map (Dict[str, str]): Mapping of model parameters to file names.
97
+ file_location (str): Directory containing the weight files.
98
+ device (torch.device): The device to load tensors onto.
99
+ """
100
+ model_state_dict = model.state_dict()
101
+ needed_files = get_needed_files(model_state_dict, weight_map)
102
+ updated_states: Set[str] = set()
103
+
104
+ for file in needed_files:
105
+ full_path = os.path.join(file_location, file)
106
+ try:
107
+ checkpoint = load_safetensor_file(full_path, "cpu")
108
+ except FileNotFoundError:
109
+ logger.error(f"File not found: {full_path}")
110
+ except Exception as e:
111
+ logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}")
112
+
113
+ matched_keys = set(checkpoint.keys()) & set(model_state_dict.keys())
114
+ for key in matched_keys:
115
+ # Check shape
116
+ if model_state_dict[key].shape != checkpoint[key].shape:
117
+ raise ValueError(
118
+ f"Shape mismatch for {key}: "
119
+ f"model needs {model_state_dict[key].shape}, but "
120
+ f"checkpoint has {checkpoint[key].shape}"
121
+ )
122
+ model_state_dict[key] = checkpoint[key].to(device)
123
+
124
+ updated_states.update(matched_keys)
125
+
126
+ missing_keys = set(model_state_dict.keys()) - updated_states
127
+ if missing_keys:
128
+ raise RuntimeError(
129
+ f"Partially updated state dict. Missing parameters: {missing_keys}"
130
+ )
131
+
132
+ model.load_state_dict(model_state_dict, strict=False, assign=True)
133
+ logger.info(f"Successfully loaded {len(updated_states)} weights into model")
134
+
135
+
136
+ def load_weights_from_hf(
137
+ model: torch.nn.Module,
138
+ distribution: str,
139
+ device: torch.device,
140
+ ):
141
+ """
142
+ Load the weights from Hugging Face format (index file + multiple safetensor
143
+ files), and fill into `model`. Model config is needed b/c we permute
144
+ wq and wk weights based on attn heads.
145
+ """
146
+
147
+ weight_map, weight_path = get_hf_weight_map_and_path(distribution)
148
+
149
+ load_safetensor_weights(
150
+ model,
151
+ weight_map,
152
+ weight_path,
153
+ device,
154
+ )
torchtitan/experiments/deepseek_v3/download.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Usage:
8
+ # Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path.
9
+ # python download.py {model_id} [custom_model_path]
10
+ # Examples:
11
+ # python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2
12
+ # python download.py custom "deepseek-ai/new-model" # Download a custom model path
13
+
14
+ # Available models:
15
+ # "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
16
+ # "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
17
+ # "v2": "deepseek-ai/DeepSeek-V2",
18
+ # "v3": "deepseek-ai/deepseek-v3",
19
+ # "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
20
+ # "custom": None, # Placeholder for custom models
21
+
22
+
23
+ import sys
24
+
25
+ from transformers import AutoModelForCausalLM
26
+
27
+
28
+ MODELS = {
29
+ "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat",
30
+ "v2-lite": "deepseek-ai/DeepSeek-V2-Lite",
31
+ "v2": "deepseek-ai/DeepSeek-V2",
32
+ "v3": "deepseek-ai/deepseek-v3",
33
+ "v3-0324": "deepseek-ai/DeepSeek-V3-0324",
34
+ "custom": None, # For custom (any) models
35
+ }
36
+
37
+
38
+ def print_usage():
39
+ print("Usage:")
40
+ print(" python download.py [model_version]")
41
+ print(" python download.py custom [custom_model_path]")
42
+ print("\nAvailable predefined models:")
43
+ for key, model in MODELS.items():
44
+ if key != "custom": # Skip the custom placeholder
45
+ print(f" {key}: {model}")
46
+ print("\nFor custom models:")
47
+ print(" custom: Specify your own model path")
48
+ print(' Example: python download.py custom "organization/model-name"')
49
+ sys.exit(1)
50
+
51
+
52
+ # Process command line arguments
53
+ if len(sys.argv) < 2 or sys.argv[1] not in MODELS:
54
+ print_usage()
55
+
56
+ if sys.argv[1] == "custom":
57
+ if len(sys.argv) != 3:
58
+ print("Error: Custom model requires a model path")
59
+ print_usage()
60
+ model_id = sys.argv[2]
61
+ print(f"Using custom model: {model_id}")
62
+ else:
63
+ model_id = MODELS[sys.argv[1]]
64
+ print(f"Downloading model: {model_id}")
65
+
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_id,
68
+ device_map="auto",
69
+ trust_remote_code=True,
70
+ )
torchtitan/experiments/deepseek_v3/model.py ADDED
@@ -0,0 +1,1325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This code is based on model definition of `deepseek-ai/DeepSeek-V3-Base` on
8
+ # Hugging Face Model Hub. Url:
9
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py
10
+ # https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/resolve/main/configuration_deepseek.py
11
+ #
12
+ # It has been modified from its original forms to accommodate naming convention
13
+ # and usage patterns of the TorchTitan project.
14
+
15
+ # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ """ PyTorch DeepSeek model."""
29
+ import math
30
+ from typing import Optional, Tuple
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+
35
+ import torch.distributed._symmetric_memory as symm_mem
36
+ import torch.nn.functional as F
37
+ import torch.utils.checkpoint
38
+
39
+ from attn_mask_utils import _prepare_4d_causal_attention_mask
40
+ from indices import generate_permute_indices
41
+ from model_config import ModelArgs
42
+ from symm_mem_recipes import OnDeviceAllToAllV
43
+ from torch import nn
44
+ from torch.distributed._functional_collectives import all_to_all_single_autograd
45
+
46
+ from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
47
+ ALIGN_SIZE_M,
48
+ grouped_gemm_forward,
49
+ )
50
+
51
+ # Get model parallel subgroup by name:
52
+ # e.g. "pp", "ep", None
53
+ def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup:
54
+ glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh()
55
+ return glob.get_group(dim_name)
56
+
57
+
58
+ class RMSNorm(nn.Module):
59
+ def __init__(self, hidden_size, eps=1e-6):
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ input_dtype = hidden_states.dtype
66
+ hidden_states = hidden_states.to(torch.float32)
67
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+ return self.weight * hidden_states.to(input_dtype)
70
+
71
+
72
+ class RotaryEmbedding(nn.Module):
73
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
74
+ super().__init__()
75
+
76
+ self.dim = dim
77
+ self.max_position_embeddings = max_position_embeddings
78
+ self.base = base
79
+ inv_freq = 1.0 / (
80
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
81
+ )
82
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
83
+
84
+ # Build here to make `torch.jit.trace` work.
85
+ self._set_cos_sin_cache(
86
+ seq_len=max_position_embeddings,
87
+ device=self.inv_freq.device,
88
+ dtype=torch.get_default_dtype(),
89
+ )
90
+
91
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
92
+ self.max_seq_len_cached = seq_len
93
+ t = torch.arange(
94
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
95
+ )
96
+
97
+ freqs = torch.outer(t, self.inv_freq.to(t.device))
98
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
99
+ emb = torch.cat((freqs, freqs), dim=-1)
100
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
101
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
102
+
103
+ def forward(self, x, seq_len=None):
104
+ # x: [bs, num_attention_heads, seq_len, head_size]
105
+ if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached:
106
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
107
+
108
+ return (
109
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
110
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
111
+ )
112
+
113
+
114
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
115
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
116
+
117
+ def __init__(
118
+ self,
119
+ dim,
120
+ max_position_embeddings=2048,
121
+ base=10000,
122
+ device=None,
123
+ scaling_factor=1.0,
124
+ ):
125
+ self.scaling_factor = scaling_factor
126
+ super().__init__(dim, max_position_embeddings, base, device)
127
+
128
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
129
+ self.max_seq_len_cached = seq_len
130
+ t = torch.arange(
131
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
132
+ )
133
+ t = t / self.scaling_factor
134
+
135
+ freqs = torch.outer(t, self.inv_freq)
136
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
137
+ emb = torch.cat((freqs, freqs), dim=-1)
138
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
139
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
140
+
141
+
142
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Deepseek
143
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
144
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
145
+
146
+ def __init__(
147
+ self,
148
+ dim,
149
+ max_position_embeddings=2048,
150
+ base=10000,
151
+ device=None,
152
+ scaling_factor=1.0,
153
+ ):
154
+ self.scaling_factor = scaling_factor
155
+ super().__init__(dim, max_position_embeddings, base, device)
156
+
157
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
158
+ self.max_seq_len_cached = seq_len
159
+
160
+ if seq_len > self.max_position_embeddings:
161
+ base = self.base * (
162
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
163
+ - (self.scaling_factor - 1)
164
+ ) ** (self.dim / (self.dim - 2))
165
+ inv_freq = 1.0 / (
166
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
167
+ )
168
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
169
+
170
+ t = torch.arange(
171
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
172
+ )
173
+
174
+ freqs = torch.outer(t, self.inv_freq)
175
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
176
+ emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
179
+
180
+
181
+ # Inverse dim formula to find dim based on number of rotations
182
+ def yarn_find_correction_dim(
183
+ num_rotations, dim, base=10000, max_position_embeddings=2048
184
+ ):
185
+ return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
186
+ 2 * math.log(base)
187
+ )
188
+
189
+
190
+ # Find dim range bounds based on rotations
191
+ def yarn_find_correction_range(
192
+ low_rot, high_rot, dim, base=10000, max_position_embeddings=2048
193
+ ):
194
+ low = math.floor(
195
+ yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
196
+ )
197
+ high = math.ceil(
198
+ yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
199
+ )
200
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
201
+
202
+
203
+ def yarn_get_mscale(scale=1, mscale=1):
204
+ if scale <= 1:
205
+ return 1.0
206
+ return 0.1 * mscale * math.log(scale) + 1.0
207
+
208
+
209
+ def yarn_linear_ramp_mask(min, max, dim):
210
+ if min == max:
211
+ max += 0.001 # Prevent singularity
212
+
213
+ linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
214
+ ramp_func = torch.clamp(linear_func, 0, 1)
215
+ return ramp_func
216
+
217
+
218
+ class YarnRotaryEmbedding(RotaryEmbedding):
219
+ def __init__(
220
+ self,
221
+ dim,
222
+ max_position_embeddings=2048,
223
+ base=10000,
224
+ device=None,
225
+ scaling_factor=1.0,
226
+ original_max_position_embeddings=4096,
227
+ beta_fast=32,
228
+ beta_slow=1,
229
+ mscale=1,
230
+ mscale_all_dim=0,
231
+ ):
232
+ self.scaling_factor = scaling_factor
233
+ self.original_max_position_embeddings = original_max_position_embeddings
234
+ self.beta_fast = beta_fast
235
+ self.beta_slow = beta_slow
236
+ self.mscale = mscale
237
+ self.mscale_all_dim = mscale_all_dim
238
+ super().__init__(dim, max_position_embeddings, base, device)
239
+
240
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
241
+ self.max_seq_len_cached = seq_len
242
+ dim = self.dim
243
+
244
+ freq_extra = 1.0 / (
245
+ self.base
246
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
247
+ )
248
+ freq_inter = 1.0 / (
249
+ self.scaling_factor
250
+ * self.base
251
+ ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
252
+ )
253
+
254
+ low, high = yarn_find_correction_range(
255
+ self.beta_fast,
256
+ self.beta_slow,
257
+ dim,
258
+ self.base,
259
+ self.original_max_position_embeddings,
260
+ )
261
+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
262
+ device=device, dtype=torch.float32
263
+ )
264
+ inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
265
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
266
+
267
+ t = torch.arange(seq_len, device=device, dtype=torch.float32)
268
+
269
+ freqs = torch.outer(t, inv_freq)
270
+
271
+ _mscale = float(
272
+ yarn_get_mscale(self.scaling_factor, self.mscale)
273
+ / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
274
+ )
275
+
276
+ emb = torch.cat((freqs, freqs), dim=-1)
277
+ self.register_buffer(
278
+ "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False
279
+ )
280
+ self.register_buffer(
281
+ "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False
282
+ )
283
+
284
+
285
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
286
+ def rotate_half(x):
287
+ """Rotates half the hidden dims of the input."""
288
+ x1 = x[..., : x.shape[-1] // 2]
289
+ x2 = x[..., x.shape[-1] // 2 :]
290
+ return torch.cat((-x2, x1), dim=-1)
291
+
292
+
293
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
294
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
295
+ """Applies Rotary Position Embedding to the query and key tensors.
296
+
297
+ Args:
298
+ q (`torch.Tensor`): The query tensor.
299
+ k (`torch.Tensor`): The key tensor.
300
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
301
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
302
+ position_ids (`torch.Tensor`):
303
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
304
+ used to pass offsetted position ids when working with a KV-cache.
305
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
306
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
307
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
308
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
309
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
310
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
311
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
312
+ Returns:
313
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
314
+ """
315
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
316
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
317
+
318
+ b, h, s, d = q.shape
319
+ q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
320
+
321
+ b, h, s, d = k.shape
322
+ k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
323
+
324
+ q_embed = (q * cos) + (rotate_half(q) * sin)
325
+ k_embed = (k * cos) + (rotate_half(k) * sin)
326
+ return q_embed, k_embed
327
+
328
+
329
+ class MLP(nn.Module):
330
+ act_fn = nn.SiLU()
331
+
332
+ def __init__(self, config, hidden_size=None, intermediate_size=None):
333
+ super().__init__()
334
+ self.config = config
335
+ self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
336
+ self.intermediate_size = (
337
+ config.intermediate_size if intermediate_size is None else intermediate_size
338
+ )
339
+
340
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
341
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
342
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
343
+
344
+ def forward(self, x):
345
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
346
+ return down_proj
347
+
348
+
349
+ class MoEGate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.config = config
353
+ self.top_k = config.num_experts_per_tok
354
+ self.n_routed_experts = config.n_routed_experts
355
+ self.routed_scaling_factor = config.routed_scaling_factor
356
+ self.scoring_func = config.scoring_func
357
+ self.seq_aux = config.seq_aux
358
+ self.topk_method = config.topk_method
359
+ self.n_group = config.n_group
360
+ self.topk_group = config.topk_group
361
+
362
+ # topk selection algorithm
363
+ self.norm_topk_prob = config.norm_topk_prob
364
+ self.gating_dim = config.hidden_size
365
+ self.weight = nn.Parameter(
366
+ torch.empty((self.n_routed_experts, self.gating_dim))
367
+ )
368
+ if self.topk_method == "noaux_tc":
369
+ self.e_score_correction_bias = nn.Parameter(
370
+ # Changed from torch.empty to torch.rand to avoid non-even
371
+ # distribution for runs without actual weigths
372
+ torch.rand((self.n_routed_experts))
373
+ )
374
+ self.reset_parameters()
375
+
376
+ def reset_parameters(self) -> None:
377
+ import torch.nn.init as init
378
+
379
+ init.kaiming_uniform_(self.weight, a=math.sqrt(5))
380
+
381
+ def forward(self, hidden_states):
382
+ bsz, seq_len, h = hidden_states.shape
383
+ # compute gating score
384
+ hidden_states = hidden_states.view(-1, h)
385
+ logits = F.linear(
386
+ hidden_states.type(torch.float32), self.weight.type(torch.float32), None
387
+ )
388
+ if self.scoring_func == "sigmoid":
389
+ scores = logits.sigmoid()
390
+ elif self.scoring_func == "softmax":
391
+ scores = logits.softmax(dim=-1, dtype=torch.float32)
392
+ else:
393
+ raise NotImplementedError(
394
+ f"insupportable scoring function for MoE gating: {self.scoring_func}"
395
+ )
396
+
397
+ # select top-k experts
398
+ if self.topk_method == "noaux_tc":
399
+ scores_for_choice = scores.view(
400
+ bsz * seq_len, -1
401
+ ) + self.e_score_correction_bias.unsqueeze(0)
402
+ group_scores = (
403
+ scores_for_choice.view(bsz * seq_len, self.n_group, -1)
404
+ .topk(2, dim=-1)[0]
405
+ .sum(dim=-1)
406
+ ) # [n, n_group]
407
+ group_idx = torch.topk(
408
+ group_scores, k=self.topk_group, dim=-1, sorted=False
409
+ )[
410
+ 1
411
+ ] # [n, top_k_group]
412
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
413
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
414
+ score_mask = (
415
+ group_mask.unsqueeze(-1)
416
+ .expand(
417
+ bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group
418
+ )
419
+ .reshape(bsz * seq_len, -1)
420
+ ) # [n, e]
421
+ tmp_scores = scores_for_choice.masked_fill(
422
+ ~score_mask.bool(), 0.0
423
+ ) # [n, e]
424
+ _, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False)
425
+ topk_weight = scores.gather(1, topk_idx)
426
+ elif self.topk_method == "greedy":
427
+ topk_weight, topk_idx = torch.topk(
428
+ scores, k=self.top_k, dim=-1, sorted=False
429
+ )
430
+ else:
431
+ raise NotImplementedError(
432
+ f"insupportable TopK function for MoE gating: {self.topk_method}"
433
+ )
434
+
435
+ # norm gate to sum 1
436
+ if self.top_k > 1 and self.norm_topk_prob:
437
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
438
+ topk_weight = topk_weight / denominator
439
+ topk_weight = (
440
+ topk_weight * self.routed_scaling_factor
441
+ ) # must multiply the scaling factor
442
+
443
+ return topk_idx, topk_weight
444
+
445
+
446
+ class MoE(nn.Module):
447
+ """
448
+ A mixed expert module containing shared experts.
449
+ """
450
+
451
+ # Class attributes:
452
+ # Two shuffle method supported:
453
+ # 1. "torch_all_to_all"
454
+ # 2. "symm_mem" (see `setup_symm_mem` below)
455
+ shuffle_method = "torch_all_to_all"
456
+
457
+ # Symmetric memory buffers shared by all MoE instances across layers
458
+ token_send_buf: Optional[torch.Tensor] = None
459
+ token_gather_buf: Optional[torch.Tensor] = None
460
+
461
+ def __init__(self, config):
462
+ super().__init__()
463
+ self.config = config
464
+ self.num_experts_per_tok = config.num_experts_per_tok
465
+
466
+ # ep_size is the number of ranks in expert dimension
467
+ if config.ep_size <= 1:
468
+ raise ValueError(
469
+ "For code simplicity, this model only supports distributed experts, "
470
+ "thus EP size must be > 1, please modify your model config"
471
+ )
472
+ self.ep_group = get_group("ep")
473
+ assert config.ep_size == self.ep_group.size()
474
+ self.ep_size = config.ep_size
475
+ self.ep_rank = self.ep_group.rank()
476
+ self.experts_per_rank = config.n_routed_experts // config.ep_size
477
+ # Use ModuleDict instead of ModuleList to preserve absoulte expert
478
+ # IDs while avoiding `None` experts. The absolute expert IDs match
479
+ # with checkpoint FQNs.
480
+ self.experts = nn.ModuleDict()
481
+ for i in range(self.experts_per_rank):
482
+ abs_expert_id = self.ep_rank * self.experts_per_rank + i
483
+ self.experts[str(abs_expert_id)] = MLP(
484
+ config, intermediate_size=config.moe_intermediate_size
485
+ )
486
+ self.gate = MoEGate(config)
487
+ if config.n_shared_experts is not None:
488
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
489
+ self.shared_experts = MLP(
490
+ config=config, intermediate_size=intermediate_size
491
+ )
492
+
493
+ def combine_experts(self, submod_name):
494
+ all_weights = []
495
+ for expert in self.experts.values():
496
+ lin = expert.get_submodule(submod_name)
497
+ all_weights.append(lin.weight)
498
+ lin.weight = None
499
+
500
+ concat_weight = torch.cat(all_weights)
501
+ self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
502
+
503
+ # This function is used to create a symm mem buffer for MoE's. It is for
504
+ # shuffling tokens fully "on-device", as compared to traditional torch
505
+ # all_to_all APIs which requrie a GPU-to-CPU sync of the splits. If a user
506
+ # calls this function, the `shuffle_method` would switch from
507
+ # `torch_all_to_all` to `symm_mem`.
508
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
509
+ # Switch shuffle method
510
+ self.shuffle_method = "symm_mem"
511
+
512
+ # Combine expert weights
513
+ print("Combining expert weights for Group GEMM")
514
+ self.combine_experts("gate_proj")
515
+ self.combine_experts("up_proj")
516
+ self.combine_experts("down_proj")
517
+
518
+ # Assuming worst case, 2x tokens are routed to one EP rank
519
+ overflow = 2
520
+ OnDeviceAllToAllV.max_output_len = (
521
+ self.config.max_seq_len * self.num_experts_per_tok * overflow
522
+ )
523
+
524
+ # Symmetric memory buffers are shared by all MoE instances across
525
+ # layers, we only need to initialize them once
526
+ if MoE.token_send_buf is not None:
527
+ return
528
+
529
+ # Input buffer for DP-to-EP shuffle
530
+ MoE.token_send_buf = symm_mem.empty(
531
+ self.config.max_seq_len
532
+ * self.num_experts_per_tok, # seq len * top k (flattened)
533
+ self.config.hidden_size, # hidden dim
534
+ dtype=dtype,
535
+ device=device,
536
+ )
537
+ # Input buffer for EP-to-DP shuffle
538
+ MoE.token_gather_buf = symm_mem.empty(
539
+ self.config.max_seq_len
540
+ * self.num_experts_per_tok # seq len * top k (flattened)
541
+ * overflow,
542
+ self.config.hidden_size, # hidden dim
543
+ dtype=dtype,
544
+ device=device,
545
+ )
546
+ print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
547
+
548
+ def get_send_buf(self):
549
+ # [Why detach?] During a first forward-backward step, the buffer would
550
+ # be included in a computational graph. In a second step, autograd will
551
+ # return an error saying "Trying to backward through the graph a second
552
+ # time (or directly access saved tensors more than once)". This is
553
+ # because the buffer is still in the graph, and autograd is trying to
554
+ # backward through the graph a second time. To avoid this, we detach the
555
+ # buffer from the graph. `detach()` returns a new tensor, which shares
556
+ # the same storage with the original one.
557
+ self.token_send_buf.grad = None
558
+ return self.token_send_buf.detach()
559
+
560
+ def get_gather_buf(self):
561
+ # See [Why detach?] in `get_send_buf`
562
+ self.token_gather_buf.grad = None
563
+ return self.token_gather_buf.detach()
564
+
565
+ def forward(self, hidden_states):
566
+ identity = hidden_states
567
+ orig_shape = hidden_states.shape
568
+ # for each token, select top-k experts, and compute the weight for each expert
569
+ topk_idx, topk_weight = self.gate(hidden_states)
570
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
571
+ if self.shuffle_method == "symm_mem":
572
+ y = self.moe_on_device(hidden_states, topk_idx, topk_weight)
573
+ else: # "torch_all_to_all"
574
+ y = self.moe_forward(hidden_states, topk_idx, topk_weight)
575
+
576
+ y = y.view(*orig_shape)
577
+ if self.config.n_shared_experts is not None:
578
+ y = y + self.shared_experts(identity)
579
+ return y
580
+
581
+ def moe_forward(self, x, topk_ids, topk_weight):
582
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
583
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
584
+ # Since this is an "aritificial" index creation (final outcome being
585
+ # `idxs`), we don't need gradients here.
586
+ with torch.no_grad():
587
+ # [seq_len, n_routed_experts]
588
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
589
+ # Fill 1 to the selected experts
590
+ cnts.scatter_(1, topk_ids, 1)
591
+ tokens_per_expert = cnts.sum(dim=0)
592
+ # Token indices for each expert
593
+ idxs = topk_ids.view(-1).argsort()
594
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
595
+
596
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
597
+ assert sorted_tokens.shape == sorted_tokens_shape
598
+
599
+ # This part exchange the information about the number of tokens send and
600
+ # received by each expert. We can understand this information as "side
601
+ # band", which is not part of the actual data. Thus no gradient is
602
+ # needed.
603
+ with torch.no_grad():
604
+ # Sum the tokens over local experts, then we get tokens per EP rank,
605
+ # which is the input splits
606
+ tokens_per_expert_group = tokens_per_expert.new_empty(
607
+ tokens_per_expert.shape[0]
608
+ )
609
+ dist.all_to_all_single(
610
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
611
+ )
612
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
613
+
614
+ # DP to EP token shuffle. This part needs gradient.
615
+ if self.shuffle_method == "symm_mem":
616
+ # Move input to the `token_send_buf` symm mem
617
+ token_send_buf = self.get_send_buf()
618
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
619
+ # Note: `out=` avoids copy, but it is not differentiable
620
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
621
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
622
+ token_send_buf,
623
+ input_splits,
624
+ self.ep_group,
625
+ )
626
+ with torch.no_grad():
627
+ # Received tokens from all other ranks. TODO: use mask instead
628
+ received = output_splits.sum()
629
+ # TODO: don't use `received`
630
+ gathered_tokens = token_gather_buf[:received]
631
+ else: # "torch_all_to_all"
632
+ # Prepare input ans output splits
633
+ with torch.no_grad():
634
+ output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(
635
+ dim=1
636
+ )
637
+ gathered_tokens = all_to_all_single_autograd(
638
+ sorted_tokens,
639
+ output_splits.tolist(),
640
+ input_splits.tolist(),
641
+ self.ep_group,
642
+ )
643
+
644
+ # This part prepares a 1D tensor with the same length as
645
+ # `gathered_tokens`. The 1D tensor is filled with local expert IDs which
646
+ # the tokens in `gathered_tokens` are headed for. This part doesn't need
647
+ # gradient.
648
+ with torch.no_grad():
649
+ gatherd_idxs = (
650
+ torch.arange(
651
+ tokens_per_expert_group.numel(),
652
+ device=tokens_per_expert_group.device,
653
+ )
654
+ % self.experts_per_rank
655
+ )
656
+ gatherd_idxs = gatherd_idxs.repeat_interleave(tokens_per_expert_group)
657
+
658
+ # Prepare buffer for tokens processed by experts
659
+ if self.shuffle_method == "symm_mem":
660
+ # Take necessary space from `token_gather_buf` symm mem because we are
661
+ # going to send them out after expert processing
662
+ processed_tokens = self.get_gather_buf()[: gathered_tokens.shape[0]]
663
+ else: # "torch_all_to_all"
664
+ processed_tokens = torch.empty_like(gathered_tokens)
665
+
666
+ # This part processes the tokens routed to the local experts.
667
+ # TODO: can we use group GEMM here?
668
+ for i, expert in enumerate(self.experts.values()):
669
+ processed_tokens[gatherd_idxs == i] = expert(
670
+ gathered_tokens[gatherd_idxs == i]
671
+ )
672
+
673
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
674
+ # The input/output splits are just a reverse of the previous shuffle.
675
+ if self.shuffle_method == "symm_mem":
676
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
677
+ processed_tokens,
678
+ output_splits,
679
+ self.ep_group,
680
+ )
681
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
682
+ else: # "torch_all_to_all"
683
+ returned_tokens = all_to_all_single_autograd(
684
+ processed_tokens,
685
+ input_splits.tolist(),
686
+ output_splits.tolist(),
687
+ self.ep_group,
688
+ )
689
+
690
+ output_tokens = torch.empty_like(returned_tokens)
691
+ output_tokens[idxs] = returned_tokens
692
+ final_out = (
693
+ output_tokens.view(*topk_ids.shape, -1)
694
+ .type(topk_weight.dtype)
695
+ .mul_(topk_weight.unsqueeze(dim=-1))
696
+ .sum(dim=1)
697
+ .type(returned_tokens.dtype)
698
+ )
699
+ return final_out
700
+
701
+ def moe_on_device(self, x, topk_ids, topk_weight):
702
+ # This part sorts the token indices so that tokens routed to the same expert reside consecutively.
703
+ # An implication is that tokens to the same "expert group" (i.e., device) are also consecutive.
704
+ # Since this is an "aritificial" index creation (final outcome being
705
+ # `idxs`), we don't need gradients here.
706
+ with torch.no_grad():
707
+ # [seq_len, n_routed_experts]
708
+ cnts = topk_ids.new_zeros((topk_ids.shape[0], self.config.n_routed_experts))
709
+ # Fill 1 to the selected experts
710
+ cnts.scatter_(1, topk_ids, 1)
711
+ tokens_per_expert = cnts.sum(dim=0)
712
+ # Token indices for each expert
713
+ idxs = topk_ids.view(-1).argsort()
714
+ sorted_tokens_shape = idxs.shape + x.shape[1:]
715
+
716
+ sorted_tokens = x[idxs // topk_ids.shape[1]]
717
+ assert sorted_tokens.shape == sorted_tokens_shape
718
+
719
+ # This part exchange the information about the number of tokens send and
720
+ # received by each expert. We can understand this information as "side
721
+ # band", which is not part of the actual data. Thus no gradient is
722
+ # needed.
723
+ with torch.no_grad():
724
+ # Sum the tokens over local experts, then we get tokens per EP rank,
725
+ # which is the input splits
726
+ tokens_per_expert_group = tokens_per_expert.new_empty(
727
+ tokens_per_expert.shape[0]
728
+ )
729
+ dist.all_to_all_single(
730
+ tokens_per_expert_group, tokens_per_expert, group=self.ep_group
731
+ )
732
+ input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
733
+
734
+ # Move input to the `token_send_buf` symm mem
735
+ token_send_buf = self.get_send_buf()
736
+ token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
737
+ # Note: `out=` avoids copy, but it is not differentiable
738
+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
739
+ token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
740
+ token_send_buf,
741
+ input_splits,
742
+ self.ep_group,
743
+ )
744
+
745
+ # We need to permute the received tokens so that tokens for the same expert are contiguous.
746
+ # This part prepares a 1D tensor `permuted_indices` for such permutation.
747
+ # This part doesn't need gradient.
748
+ with torch.no_grad():
749
+ permuted_indices, m_sizes = generate_permute_indices(
750
+ tokens_per_expert_group,
751
+ self.experts_per_rank,
752
+ self.ep_size,
753
+ token_gather_buf.shape[0],
754
+ ALIGN_SIZE_M,
755
+ )
756
+
757
+ # Permute the received tokens so that tokens for the same expert are contiguous.
758
+ contig_tokens = token_gather_buf[permuted_indices]
759
+
760
+ # Run the first grouped GEMM
761
+ w1 = self.get_parameter("gate_proj_weight")
762
+ gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
763
+
764
+ # Run the second grouped GEMM
765
+ w3 = self.get_parameter("up_proj_weight")
766
+ up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
767
+
768
+ # Apply activation
769
+ hidden_outputs = MLP.act_fn(gate_proj) * up_proj
770
+
771
+ # Run the third grouped GEMM
772
+ w2 = self.get_parameter("down_proj_weight")
773
+ hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
774
+
775
+ # Prepare buffer for tokens processed by experts
776
+ # Take necessary space from `token_gather_buf` symm mem because we are
777
+ # going to send them out after expert processing
778
+ processed_tokens = self.get_gather_buf()
779
+
780
+ # Move into Symmetric Memory for the return shuffle
781
+ processed_tokens[permuted_indices] = hidden_outputs
782
+
783
+ # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle.
784
+ # The input/output splits are just a reverse of the previous shuffle.
785
+ token_return_buf, _ = OnDeviceAllToAllV.apply(
786
+ processed_tokens,
787
+ output_splits,
788
+ self.ep_group,
789
+ )
790
+ returned_tokens = token_return_buf[: sorted_tokens_shape[0]]
791
+
792
+ output_tokens = torch.empty_like(returned_tokens)
793
+ output_tokens[idxs] = returned_tokens
794
+ final_out = (
795
+ output_tokens.view(*topk_ids.shape, -1)
796
+ .type(topk_weight.dtype)
797
+ .mul_(topk_weight.unsqueeze(dim=-1))
798
+ .sum(dim=1)
799
+ .type(returned_tokens.dtype)
800
+ )
801
+ return final_out
802
+
803
+
804
+ class Attention(nn.Module):
805
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
806
+
807
+ def __init__(self, config: ModelArgs, layer_idx: Optional[int] = None):
808
+ super().__init__()
809
+ self.config = config
810
+ self.layer_idx = layer_idx
811
+ self.attention_dropout = config.attention_dropout
812
+ self.hidden_size = config.hidden_size
813
+ self.num_heads = config.num_attention_heads
814
+
815
+ self.max_position_embeddings = config.max_position_embeddings
816
+ self.rope_theta = config.rope_theta
817
+ self.q_lora_rank = config.q_lora_rank
818
+ self.qk_rope_head_dim = config.qk_rope_head_dim
819
+ self.kv_lora_rank = config.kv_lora_rank
820
+ self.v_head_dim = config.v_head_dim
821
+ self.qk_nope_head_dim = config.qk_nope_head_dim
822
+ self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
823
+
824
+ self.is_causal = True
825
+
826
+ if self.q_lora_rank is None:
827
+ self.q_proj = nn.Linear(
828
+ self.hidden_size, self.num_heads * self.q_head_dim, bias=False
829
+ )
830
+ else:
831
+ self.q_a_proj = nn.Linear(
832
+ self.hidden_size, config.q_lora_rank, bias=config.attention_bias
833
+ )
834
+ self.q_a_layernorm = RMSNorm(config.q_lora_rank)
835
+ self.q_b_proj = nn.Linear(
836
+ config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False
837
+ )
838
+
839
+ self.kv_a_proj_with_mqa = nn.Linear(
840
+ self.hidden_size,
841
+ config.kv_lora_rank + config.qk_rope_head_dim,
842
+ bias=config.attention_bias,
843
+ )
844
+ self.kv_a_layernorm = RMSNorm(config.kv_lora_rank)
845
+ self.kv_b_proj = nn.Linear(
846
+ config.kv_lora_rank,
847
+ self.num_heads
848
+ * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),
849
+ bias=False,
850
+ )
851
+
852
+ self.o_proj = nn.Linear(
853
+ self.num_heads * self.v_head_dim,
854
+ self.hidden_size,
855
+ bias=config.attention_bias,
856
+ )
857
+ self._init_rope()
858
+
859
+ self.softmax_scale = self.q_head_dim ** (-0.5)
860
+ if self.config.rope_scaling is not None:
861
+ mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
862
+ scaling_factor = self.config.rope_scaling["factor"]
863
+ if mscale_all_dim:
864
+ mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
865
+ self.softmax_scale = self.softmax_scale * mscale * mscale
866
+
867
+ def _init_rope(self):
868
+ if self.config.rope_scaling is None:
869
+ self.rotary_emb = RotaryEmbedding(
870
+ self.qk_rope_head_dim,
871
+ max_position_embeddings=self.max_position_embeddings,
872
+ base=self.rope_theta,
873
+ )
874
+ else:
875
+ scaling_type = self.config.rope_scaling["type"]
876
+ scaling_factor = self.config.rope_scaling["factor"]
877
+ if scaling_type == "linear":
878
+ self.rotary_emb = LinearScalingRotaryEmbedding(
879
+ self.qk_rope_head_dim,
880
+ max_position_embeddings=self.max_position_embeddings,
881
+ scaling_factor=scaling_factor,
882
+ base=self.rope_theta,
883
+ )
884
+ elif scaling_type == "dynamic":
885
+ self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
886
+ self.qk_rope_head_dim,
887
+ max_position_embeddings=self.max_position_embeddings,
888
+ scaling_factor=scaling_factor,
889
+ base=self.rope_theta,
890
+ )
891
+ elif scaling_type == "yarn":
892
+ kwargs = {
893
+ key: self.config.rope_scaling[key]
894
+ for key in [
895
+ "original_max_position_embeddings",
896
+ "beta_fast",
897
+ "beta_slow",
898
+ "mscale",
899
+ "mscale_all_dim",
900
+ ]
901
+ if key in self.config.rope_scaling
902
+ }
903
+ self.rotary_emb = YarnRotaryEmbedding(
904
+ self.qk_rope_head_dim,
905
+ max_position_embeddings=self.max_position_embeddings,
906
+ scaling_factor=scaling_factor,
907
+ base=self.rope_theta,
908
+ **kwargs,
909
+ )
910
+ else:
911
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
912
+
913
+ def forward(
914
+ self,
915
+ hidden_states: torch.Tensor,
916
+ attention_mask: Optional[torch.Tensor] = None,
917
+ position_ids: Optional[torch.LongTensor] = None,
918
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
919
+ bsz, q_len, _ = hidden_states.size()
920
+
921
+ if self.q_lora_rank is None:
922
+ q = self.q_proj(hidden_states)
923
+ else:
924
+ q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
925
+ q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
926
+ q_nope, q_pe = torch.split(
927
+ q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
928
+ )
929
+
930
+ compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
931
+ compressed_kv, k_pe = torch.split(
932
+ compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
933
+ )
934
+ k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)
935
+ kv = (
936
+ self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
937
+ .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
938
+ .transpose(1, 2)
939
+ )
940
+
941
+ k_nope, value_states = torch.split(
942
+ kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
943
+ )
944
+ kv_seq_len = value_states.shape[-2]
945
+
946
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
947
+
948
+ q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
949
+
950
+ query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
951
+ query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
952
+ query_states[:, :, :, self.qk_nope_head_dim :] = q_pe
953
+
954
+ key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim)
955
+ key_states[:, :, :, : self.qk_nope_head_dim] = k_nope
956
+ key_states[:, :, :, self.qk_nope_head_dim :] = k_pe
957
+
958
+ if attention_mask is not None:
959
+ # Attention mask was made 4D because the `attn_weights` above is 4D.
960
+ # We probably can make this mask smarter if we want to pack sequences
961
+ # together, instead of using padding. This optimization can be used in
962
+ # inference. For training, if we want to pack sequences, data loader
963
+ # will pass in a mask containing such info.
964
+ attention_mask = _prepare_4d_causal_attention_mask(
965
+ attention_mask, # None, or user provided mask in 2D
966
+ (bsz, q_len),
967
+ hidden_states,
968
+ 0, # past_key_values_length, 0 when training
969
+ )
970
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
971
+ raise ValueError(
972
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
973
+ )
974
+
975
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
976
+ query=query_states,
977
+ key=key_states,
978
+ value=value_states,
979
+ attn_mask=attention_mask,
980
+ dropout_p=self.attention_dropout,
981
+ is_causal=attention_mask is None,
982
+ scale=self.softmax_scale,
983
+ )
984
+
985
+ attn_output = attn_output.transpose(1, 2).contiguous()
986
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
987
+ attn_output = self.o_proj(attn_output)
988
+
989
+ return attn_output
990
+
991
+
992
+ class DecoderLayer(nn.Module):
993
+ def __init__(self, config: ModelArgs, layer_idx: int):
994
+ super().__init__()
995
+ self.hidden_size = config.hidden_size
996
+
997
+ self.self_attn = Attention(config=config, layer_idx=layer_idx)
998
+
999
+ self.mlp = (
1000
+ MoE(config)
1001
+ if (
1002
+ config.n_routed_experts is not None
1003
+ and layer_idx >= config.first_k_dense_replace
1004
+ and layer_idx % config.moe_layer_freq == 0
1005
+ )
1006
+ else MLP(config)
1007
+ )
1008
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1009
+ self.post_attention_layernorm = RMSNorm(
1010
+ config.hidden_size, eps=config.rms_norm_eps
1011
+ )
1012
+
1013
+ def forward(
1014
+ self,
1015
+ hidden_states: torch.Tensor,
1016
+ attention_mask: Optional[torch.Tensor] = None,
1017
+ position_ids: Optional[torch.LongTensor] = None,
1018
+ ) -> torch.Tensor:
1019
+ """
1020
+ Args:
1021
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
1022
+ attention_mask (`torch.FloatTensor`, *optional*):
1023
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
1024
+ query_sequence_length, key_sequence_length)` if default attention is used.
1025
+ """
1026
+ residual = hidden_states
1027
+
1028
+ hidden_states = self.input_layernorm(hidden_states)
1029
+
1030
+ # Self Attention
1031
+ hidden_states = self.self_attn(
1032
+ hidden_states=hidden_states,
1033
+ attention_mask=attention_mask,
1034
+ position_ids=position_ids,
1035
+ )
1036
+ hidden_states = residual + hidden_states
1037
+
1038
+ # Fully Connected
1039
+ residual = hidden_states
1040
+ hidden_states = self.post_attention_layernorm(hidden_states)
1041
+ hidden_states = self.mlp(hidden_states)
1042
+ hidden_states = residual + hidden_states
1043
+
1044
+ return hidden_states
1045
+
1046
+
1047
+ Deepseek_INPUTS_DOCSTRING = r"""
1048
+ Args:
1049
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1050
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1051
+ it.
1052
+
1053
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1054
+ [`PreTrainedTokenizer.__call__`] for details.
1055
+
1056
+ [What are input IDs?](../glossary#input-ids)
1057
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1058
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1059
+
1060
+ - 1 for tokens that are **not masked**,
1061
+ - 0 for tokens that are **masked**.
1062
+
1063
+ [What are attention masks?](../glossary#attention-mask)
1064
+
1065
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1066
+ [`PreTrainedTokenizer.__call__`] for details.
1067
+
1068
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
1069
+ `past_key_values`).
1070
+
1071
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
1072
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
1073
+ information on the default strategy.
1074
+
1075
+ - 1 indicates the head is **not masked**,
1076
+ - 0 indicates the head is **masked**.
1077
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1078
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1079
+ config.n_positions - 1]`.
1080
+
1081
+ [What are position IDs?](../glossary#position-ids)
1082
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
1083
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
1084
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
1085
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
1086
+
1087
+ Two formats are allowed:
1088
+ - a [`~cache_utils.Cache`] instance;
1089
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1090
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
1091
+ cache format.
1092
+
1093
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
1094
+ legacy cache format will be returned.
1095
+
1096
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
1097
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
1098
+ of shape `(batch_size, sequence_length)`.
1099
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1100
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1101
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1102
+ model's internal embedding lookup matrix.
1103
+ use_cache (`bool`, *optional*):
1104
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1105
+ `past_key_values`).
1106
+ output_attentions (`bool`, *optional*):
1107
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1108
+ tensors for more detail.
1109
+ output_hidden_states (`bool`, *optional*):
1110
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1111
+ more detail.
1112
+ return_dict (`bool`, *optional*):
1113
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1114
+ """
1115
+
1116
+
1117
+ class DeepseekModel(torch.nn.Module):
1118
+ """
1119
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DecoderLayer`]
1120
+
1121
+ Args:
1122
+ config: ModelArgs
1123
+ """
1124
+
1125
+ def __init__(self, config: ModelArgs):
1126
+ super().__init__()
1127
+ self.config = config
1128
+ self.padding_idx = config.pad_token_id
1129
+ self.vocab_size = config.vocab_size
1130
+
1131
+ # Creating model parts related to my stage
1132
+ assert (
1133
+ config.stage_idx < config.num_stages
1134
+ ), f"Stage {config.stage_idx} is not in the model"
1135
+ print(f"Creating model stage {config.stage_idx} of {config.num_stages}")
1136
+
1137
+ self.embed_tokens = (
1138
+ nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1139
+ if config.stage_idx == 0
1140
+ else None
1141
+ )
1142
+
1143
+ self.layers = torch.nn.ModuleDict()
1144
+ division = config.num_hidden_layers // config.num_stages
1145
+ residual = config.num_hidden_layers % config.num_stages
1146
+ # Some earlier stages may have 1 more layer than latter stages because
1147
+ # the division may have residual; this is more even than giving the
1148
+ # entire residual to the last stage.
1149
+ layers_per_stage = [
1150
+ division + 1 if stage < residual else division
1151
+ for stage in range(config.num_stages)
1152
+ ]
1153
+ assert sum(layers_per_stage) == config.num_hidden_layers
1154
+ layer_id_start = sum(layers_per_stage[: config.stage_idx])
1155
+ layer_id_end = layer_id_start + layers_per_stage[config.stage_idx]
1156
+ for layer_id in range(layer_id_start, layer_id_end):
1157
+ self.layers[str(layer_id)] = DecoderLayer(config, layer_id)
1158
+
1159
+ self.norm = (
1160
+ RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1161
+ if config.stage_idx == config.num_stages - 1
1162
+ else None
1163
+ )
1164
+
1165
+ # Initialize weights and apply final processing
1166
+ self.apply(self._init_weights)
1167
+
1168
+ def _init_weights(self, module):
1169
+ std = self.config.initializer_range
1170
+ if isinstance(module, nn.Linear):
1171
+ module.weight.data.normal_(mean=0.0, std=std)
1172
+ if module.bias is not None:
1173
+ module.bias.data.zero_()
1174
+ elif isinstance(module, nn.Embedding):
1175
+ module.weight.data.normal_(mean=0.0, std=std)
1176
+ if module.padding_idx is not None:
1177
+ module.weight.data[module.padding_idx].zero_()
1178
+
1179
+ def forward(
1180
+ self,
1181
+ tokens: torch.Tensor,
1182
+ attention_mask: Optional[torch.Tensor] = None,
1183
+ position_ids: Optional[torch.LongTensor] = None,
1184
+ ) -> torch.Tensor:
1185
+ # Embedding
1186
+ hidden_states = (
1187
+ self.embed_tokens(tokens) if self.embed_tokens is not None else tokens
1188
+ )
1189
+
1190
+ # decoder layers
1191
+ for decoder_layer in self.layers.values():
1192
+ hidden_states = decoder_layer(
1193
+ hidden_states,
1194
+ attention_mask=attention_mask,
1195
+ position_ids=position_ids,
1196
+ )
1197
+
1198
+ hidden_states = (
1199
+ self.norm(hidden_states) if self.norm is not None else hidden_states
1200
+ )
1201
+ return hidden_states
1202
+
1203
+
1204
+ class DeepseekForCausalLM(torch.nn.Module):
1205
+ def __init__(self, config):
1206
+ super().__init__()
1207
+ self.model = DeepseekModel(config)
1208
+ self.lm_head = (
1209
+ nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1210
+ if config.stage_idx == config.num_stages - 1
1211
+ else None
1212
+ )
1213
+
1214
+ # Initialize weights and apply final processing
1215
+ # self.post_init()
1216
+
1217
+ def forward(
1218
+ self,
1219
+ tokens: torch.Tensor,
1220
+ attention_mask: Optional[torch.Tensor] = None,
1221
+ position_ids: Optional[torch.LongTensor] = None,
1222
+ ) -> Tuple:
1223
+ r"""
1224
+ Example:
1225
+
1226
+ ```python
1227
+ >>> from transformers import AutoTokenizer, DeepseekForCausalLM
1228
+
1229
+ >>> model = DeepseekForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
1230
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1231
+
1232
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1233
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1234
+
1235
+ >>> # Generate
1236
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1237
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1238
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1239
+ ```"""
1240
+ hidden_states = self.model(
1241
+ tokens,
1242
+ attention_mask=attention_mask,
1243
+ position_ids=position_ids,
1244
+ )
1245
+
1246
+ logits = (
1247
+ self.lm_head(hidden_states) if self.lm_head is not None else hidden_states
1248
+ )
1249
+ return logits
1250
+
1251
+ def prepare_inputs_for_generation(
1252
+ self,
1253
+ input_ids,
1254
+ past_key_values=None,
1255
+ attention_mask=None,
1256
+ **kwargs,
1257
+ ):
1258
+ if past_key_values is not None:
1259
+ # Assuming isinstance(past_key_values, Cache):
1260
+ cache_length = past_key_values.get_seq_length()
1261
+ past_length = past_key_values.seen_tokens
1262
+ max_cache_length = past_key_values.get_max_length()
1263
+
1264
+ # Keep only the unprocessed tokens:
1265
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1266
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1267
+ # input)
1268
+ if (
1269
+ attention_mask is not None
1270
+ and attention_mask.shape[1] > input_ids.shape[1]
1271
+ ):
1272
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1273
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1274
+ # input_ids based on the past_length.
1275
+ elif past_length < input_ids.shape[1]:
1276
+ input_ids = input_ids[:, past_length:]
1277
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1278
+
1279
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1280
+ if (
1281
+ max_cache_length is not None
1282
+ and attention_mask is not None
1283
+ and cache_length + input_ids.shape[1] > max_cache_length
1284
+ ):
1285
+ attention_mask = attention_mask[:, -max_cache_length:]
1286
+
1287
+ position_ids = kwargs.get("position_ids", None)
1288
+ if attention_mask is not None and position_ids is None:
1289
+ # create position_ids on the fly for batch generation
1290
+ position_ids = attention_mask.long().cumsum(-1) - 1
1291
+ position_ids.masked_fill_(attention_mask == 0, 1)
1292
+ if past_key_values:
1293
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1294
+
1295
+ model_inputs = {"input_ids": input_ids}
1296
+
1297
+ model_inputs.update(
1298
+ {
1299
+ "position_ids": position_ids,
1300
+ "past_key_values": past_key_values,
1301
+ "use_cache": kwargs.get("use_cache"),
1302
+ "attention_mask": attention_mask,
1303
+ }
1304
+ )
1305
+ return model_inputs
1306
+
1307
+ @staticmethod
1308
+ def _reorder_cache(past_key_values, beam_idx):
1309
+ reordered_past = ()
1310
+ for layer_past in past_key_values:
1311
+ reordered_past += (
1312
+ tuple(
1313
+ past_state.index_select(0, beam_idx.to(past_state.device))
1314
+ for past_state in layer_past
1315
+ ),
1316
+ )
1317
+ return reordered_past
1318
+
1319
+ # Setup Symmetric Memory for MoE token shuffle.
1320
+ # Supports inference currently.
1321
+ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
1322
+ for layer in self.model.layers.values():
1323
+ if not isinstance(layer.mlp, MoE):
1324
+ continue
1325
+ layer.mlp.setup_symm_mem(dtype, device)
torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+ @triton.jit
12
+ def get_tid():
13
+ return tl.inline_asm_elementwise(
14
+ """
15
+ mov.u32 $0, %tid.x;
16
+ mov.u32 $1, %tid.y;
17
+ mov.u32 $2, %tid.z;
18
+ """,
19
+ "=r,=r,=r",
20
+ [],
21
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
22
+ is_pure=True,
23
+ pack=1,
24
+ )
25
+
26
+
27
+ @triton.jit
28
+ def get_ntid():
29
+ return tl.inline_asm_elementwise(
30
+ """
31
+ mov.u32 $0, %ntid.x;
32
+ mov.u32 $1, %ntid.y;
33
+ mov.u32 $2, %ntid.z;
34
+ """,
35
+ "=r,=r,=r",
36
+ [],
37
+ dtype=(tl.uint32, tl.uint32, tl.uint32),
38
+ is_pure=True,
39
+ pack=1,
40
+ )
41
+
42
+
43
+ @triton.jit
44
+ def get_flat_tid():
45
+ tid_x, tid_y, tid_z = get_tid()
46
+ ntid_x, ntid_y, _ = get_ntid()
47
+ return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x
48
+
49
+
50
+ @triton.jit
51
+ def get_flat_bid():
52
+ return (
53
+ tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0)
54
+ + tl.program_id(1) * tl.num_programs(0)
55
+ + tl.program_id(0)
56
+ )
57
+
58
+
59
+ @triton.jit
60
+ def sync_threads():
61
+ tl.inline_asm_elementwise(
62
+ "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
63
+ )
torchtitan/experiments/flux/model/model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass, field
8
+
9
+ import torch
10
+
11
+ from torch import nn, Tensor
12
+ from torchtitan.components.tokenizer import Tokenizer
13
+ from torchtitan.config_manager import JobConfig
14
+
15
+ from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams
16
+ from torchtitan.experiments.flux.model.layers import (
17
+ DoubleStreamBlock,
18
+ EmbedND,
19
+ LastLayer,
20
+ MLPEmbedder,
21
+ SingleStreamBlock,
22
+ timestep_embedding,
23
+ )
24
+
25
+ from torchtitan.protocols.train_spec import BaseModelArgs, ModelProtocol
26
+ from torchtitan.tools.logging import logger
27
+
28
+
29
+ @dataclass
30
+ class FluxModelArgs(BaseModelArgs):
31
+ in_channels: int = 64
32
+ out_channels: int = 64
33
+ vec_in_dim: int = 768
34
+ context_in_dim: int = 512
35
+ hidden_size: int = 3072
36
+ mlp_ratio: float = 4.0
37
+ num_heads: int = 24
38
+ depth: int = 19
39
+ depth_single_blocks: int = 38
40
+ axes_dim: tuple = (16, 56, 56)
41
+ theta: int = 10_000
42
+ qkv_bias: bool = True
43
+ guidance_embed: bool = True
44
+ autoencoder_params: AutoEncoderParams = field(default_factory=AutoEncoderParams)
45
+
46
+ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
47
+ # context_in_dim is the same as the T5 embedding dimension
48
+ self.context_in_dim = job_config.encoder.max_t5_encoding_len
49
+
50
+ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
51
+ # TODO(jianiw): Add the number of flops for the autoencoder
52
+ nparams = sum(p.numel() for p in model.parameters())
53
+ logger.warning("FLUX model haven't implement get_nparams_and_flops() function")
54
+ return nparams, 1
55
+
56
+
57
+ class FluxModel(nn.Module, ModelProtocol):
58
+ """
59
+ Transformer model for flow matching on sequences.
60
+
61
+ Agrs:
62
+ model_args: FluxModelArgs.
63
+
64
+ Attributes:
65
+ model_args (TransformerModelArgs): Model configuration arguments.
66
+ """
67
+
68
+ def __init__(self, model_args: FluxModelArgs):
69
+ super().__init__()
70
+
71
+ self.model_args = model_args
72
+ self.in_channels = model_args.in_channels
73
+ self.out_channels = model_args.out_channels
74
+ if model_args.hidden_size % model_args.num_heads != 0:
75
+ raise ValueError(
76
+ f"Hidden size {model_args.hidden_size} must be divisible by num_heads {model_args.num_heads}"
77
+ )
78
+ pe_dim = model_args.hidden_size // model_args.num_heads
79
+ if sum(model_args.axes_dim) != pe_dim:
80
+ raise ValueError(
81
+ f"Got {model_args.axes_dim} but expected positional dim {pe_dim}"
82
+ )
83
+ self.hidden_size = model_args.hidden_size
84
+ self.num_heads = model_args.num_heads
85
+ self.pe_embedder = EmbedND(
86
+ dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim
87
+ )
88
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
89
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
90
+ self.vector_in = MLPEmbedder(model_args.vec_in_dim, self.hidden_size)
91
+ self.guidance_in = (
92
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
93
+ if model_args.guidance_embed
94
+ else nn.Identity()
95
+ )
96
+ self.txt_in = nn.Linear(model_args.context_in_dim, self.hidden_size)
97
+
98
+ self.double_blocks = nn.ModuleList(
99
+ [
100
+ DoubleStreamBlock(
101
+ self.hidden_size,
102
+ self.num_heads,
103
+ mlp_ratio=model_args.mlp_ratio,
104
+ qkv_bias=model_args.qkv_bias,
105
+ )
106
+ for _ in range(model_args.depth)
107
+ ]
108
+ )
109
+
110
+ self.single_blocks = nn.ModuleList(
111
+ [
112
+ SingleStreamBlock(
113
+ self.hidden_size, self.num_heads, mlp_ratio=model_args.mlp_ratio
114
+ )
115
+ for _ in range(model_args.depth_single_blocks)
116
+ ]
117
+ )
118
+
119
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
120
+
121
+ def init_weights(self, buffer_device=None):
122
+ # TODO(jianiw): replace placeholder with real weight init
123
+ for param in self.parameters():
124
+ param.data.uniform_(0, 0.1)
125
+
126
+ def forward(
127
+ self,
128
+ img: Tensor,
129
+ img_ids: Tensor,
130
+ txt: Tensor,
131
+ txt_ids: Tensor,
132
+ timesteps: Tensor,
133
+ y: Tensor,
134
+ guidance: Tensor | None = None,
135
+ ) -> Tensor:
136
+ if img.ndim != 3 or txt.ndim != 3:
137
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
138
+
139
+ # running on sequences img
140
+ img = self.img_in(img)
141
+ vec = self.time_in(timestep_embedding(timesteps, 256))
142
+ if self.model_args.guidance_embed:
143
+ if guidance is None:
144
+ raise ValueError(
145
+ "Didn't get guidance strength for guidance distilled model."
146
+ )
147
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
148
+ vec = vec + self.vector_in(y)
149
+ txt = self.txt_in(txt)
150
+
151
+ ids = torch.cat((txt_ids, img_ids), dim=1)
152
+ pe = self.pe_embedder(ids)
153
+
154
+ for block in self.double_blocks:
155
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
156
+
157
+ img = torch.cat((txt, img), 1)
158
+ for block in self.single_blocks:
159
+ img = block(img, vec=vec, pe=pe)
160
+ img = img[:, txt.shape[1] :, ...]
161
+
162
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
163
+ return img
164
+
165
+ @classmethod
166
+ def from_model_args(cls, model_args: FluxModelArgs) -> "FluxModel":
167
+ """
168
+ Initialize a Flux model from a FluxModelArgs object.
169
+
170
+ Args:
171
+ model_args (FluxModelArgs): Model configuration arguments.
172
+
173
+ Returns:
174
+ FluxModel: FluxModel model.
175
+
176
+ """
177
+ return cls(model_args)
torchtitan/experiments/flux/parallelize_flux.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # This file applies the PT-D parallelisms (except pipeline parallelism) and various
8
+ # training techniques (e.g. activation checkpointing and compile) to the Llama model.
9
+
10
+
11
+ import torch.nn as nn
12
+
13
+ from torch.distributed.device_mesh import DeviceMesh
14
+
15
+ from torchtitan.config_manager import JobConfig
16
+ from torchtitan.distributed import ParallelDims
17
+
18
+
19
+ def parallelize_flux(
20
+ model: nn.Module,
21
+ world_mesh: DeviceMesh,
22
+ parallel_dims: ParallelDims,
23
+ job_config: JobConfig,
24
+ ):
25
+ # TODO: Add model parallel strategy here
26
+ return model
torchtitan/experiments/flux/tests/test_generate_image.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import os
9
+ import time
10
+ from typing import Callable
11
+
12
+ import torch
13
+ from einops import rearrange
14
+
15
+ from PIL import ExifTags, Image
16
+
17
+ from torch import Tensor
18
+
19
+ from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer
20
+
21
+ from torchtitan.experiments.flux.model.autoencoder import (
22
+ AutoEncoder,
23
+ AutoEncoderParams,
24
+ load_ae,
25
+ )
26
+ from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder
27
+
28
+ from torchtitan.experiments.flux.model.model import FluxModel, FluxModelArgs
29
+ from torchtitan.experiments.flux.utils import (
30
+ create_position_encoding_for_latents,
31
+ generate_noise_latent,
32
+ pack_latents,
33
+ preprocess_flux_data,
34
+ unpack_latents,
35
+ )
36
+
37
+
38
+ def time_shift(mu: float, sigma: float, t: Tensor):
39
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
40
+
41
+
42
+ def get_lin_function(
43
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
44
+ ) -> Callable[[float], float]:
45
+ m = (y2 - y1) / (x2 - x1)
46
+ b = y1 - m * x1
47
+ return lambda x: m * x + b
48
+
49
+
50
+ def get_schedule(
51
+ num_steps: int,
52
+ image_seq_len: int,
53
+ base_shift: float = 0.5,
54
+ max_shift: float = 1.15,
55
+ shift: bool = True,
56
+ ) -> list[float]:
57
+ # extra step for zero
58
+ timesteps = torch.linspace(1, 0, num_steps + 1)
59
+
60
+ # shifting the schedule to favor high timesteps for higher signal images
61
+ if shift:
62
+ # estimate mu based on linear estimation between two points
63
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
64
+ timesteps = time_shift(mu, 1.0, timesteps)
65
+
66
+ return timesteps.tolist()
67
+
68
+
69
+ class TestGenerateImage:
70
+ def test_generate_image(self):
71
+ """
72
+ Run a forward pass of flux model to generate an image.
73
+ """
74
+ name = "flux-dev"
75
+ img_width = 512
76
+ img_height = 512
77
+ seed = None
78
+ prompt = (
79
+ "a photo of a forest with mist swirling around the tree trunks. The word "
80
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
81
+ )
82
+ device = "cuda"
83
+ num_steps = None
84
+ loop = False
85
+ guidance = 3.5
86
+ output_dir = "output"
87
+ add_sampling_metadata = True
88
+
89
+ prompt = prompt.split("|")
90
+ if len(prompt) == 1:
91
+ prompt = prompt[0]
92
+ additional_prompts = None
93
+ else:
94
+ additional_prompts = prompt[1:]
95
+ prompt = prompt[0]
96
+
97
+ assert not (
98
+ (additional_prompts is not None) and loop
99
+ ), "Do not provide additional prompts and set loop to True"
100
+
101
+ torch_device = torch.device(device)
102
+ if num_steps is None:
103
+ num_steps = 30
104
+
105
+ # allow for packing and conversion to latent space
106
+ img_height = 16 * (img_height // 16)
107
+ img_width = 16 * (img_width // 16)
108
+
109
+ # init all components
110
+ model = FluxModel(FluxModelArgs()).to(device=torch_device, dtype=torch.bfloat16)
111
+
112
+ ae = load_ae(
113
+ ckpt_path="assets/autoencoder/ae.safetensors",
114
+ autoencoder_params=AutoEncoderParams(),
115
+ device=torch_device,
116
+ dtype=torch.bfloat16,
117
+ )
118
+ clip_tokenizer = FluxTokenizer(
119
+ model_path="openai/clip-vit-large-patch14", max_length=77
120
+ )
121
+ t5_tokenizer = FluxTokenizer(model_path="google/t5-v1_1-small", max_length=512)
122
+ clip_encoder = FluxEmbedder(version="openai/clip-vit-large-patch14").to(
123
+ torch_device, dtype=torch.bfloat16
124
+ )
125
+ t5_encoder = FluxEmbedder(version="google/t5-v1_1-small").to(
126
+ torch_device, dtype=torch.bfloat16
127
+ )
128
+
129
+ rng = torch.Generator(device="cpu")
130
+
131
+ if seed is None:
132
+ seed = rng.seed()
133
+ print(f"Generating with seed {seed}:\n{prompt}")
134
+ t0 = time.perf_counter()
135
+ output_name = os.path.join(output_dir, f"img_{seed}.jpg")
136
+
137
+ # Tokenize the prompt, on CPU
138
+ clip_tokens = clip_tokenizer.encode(prompt)
139
+ t5_tokens = t5_tokenizer.encode(prompt)
140
+
141
+ batch = preprocess_flux_data(
142
+ device=torch_device,
143
+ dtype=torch.bfloat16,
144
+ autoencoder=None,
145
+ clip_encoder=clip_encoder,
146
+ t5_encoder=t5_encoder,
147
+ batch={
148
+ "clip_tokens": clip_tokens,
149
+ "t5_tokens": t5_tokens,
150
+ },
151
+ )
152
+
153
+ img = self._generate_images(
154
+ device=torch_device,
155
+ dtype=torch.bfloat16,
156
+ model=model,
157
+ decoder=ae,
158
+ img_width=img_width,
159
+ img_height=img_height,
160
+ denoising_steps=num_steps,
161
+ seed=seed,
162
+ clip_encodings=batch["clip_encodings"],
163
+ t5_encodings=batch["t5_encodings"],
164
+ guidance=guidance,
165
+ )
166
+
167
+ if torch.cuda.is_available():
168
+ torch.cuda.synchronize()
169
+ t1 = time.perf_counter()
170
+
171
+ print(f"Done in {t1 - t0:.1f}s.")
172
+
173
+ self._save_image(name, output_name, img, add_sampling_metadata, prompt)
174
+
175
+ def _generate_images(
176
+ self,
177
+ device: torch.device,
178
+ dtype: torch.dtype,
179
+ model: FluxModel,
180
+ decoder: AutoEncoder,
181
+ # image params:
182
+ img_width: int,
183
+ img_height: int,
184
+ # sampling params:
185
+ denoising_steps: int,
186
+ seed: int,
187
+ clip_encodings: torch.Tensor,
188
+ t5_encodings: torch.Tensor,
189
+ guidance: float = 4.0,
190
+ ):
191
+
192
+ bsz = clip_encodings.shape[0]
193
+ latents = generate_noise_latent(bsz, img_height, img_width, device, dtype, seed)
194
+ _, latent_channels, latent_height, latent_width = latents.shape
195
+
196
+ # create denoising schedule
197
+ timesteps = get_schedule(denoising_steps, latent_channels, shift=True)
198
+
199
+ # create positional encodings
200
+ POSITION_DIM = 3 # constant for Flux flow model
201
+ latent_pos_enc = create_position_encoding_for_latents(
202
+ bsz, latent_height, latent_width, POSITION_DIM
203
+ ).to(latents)
204
+ text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents)
205
+
206
+ # convert img-like latents into sequences of patches
207
+ latents = pack_latents(latents)
208
+
209
+ # this is ignored for schnell
210
+ guidance_vec = torch.full((bsz,), guidance, device=device, dtype=dtype)
211
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
212
+ t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device)
213
+ pred = model(
214
+ img=latents,
215
+ img_ids=latent_pos_enc,
216
+ txt=t5_encodings,
217
+ txt_ids=text_pos_enc,
218
+ y=clip_encodings,
219
+ timesteps=t_vec,
220
+ guidance=guidance_vec,
221
+ )
222
+
223
+ latents = latents + (t_prev - t_curr) * pred
224
+
225
+ # convert sequences of patches into img-like latents
226
+ latents = unpack_latents(latents, latent_height, latent_width)
227
+
228
+ img = decoder.decode(latents)
229
+ return img
230
+
231
+ def _save_image(
232
+ self,
233
+ name: str,
234
+ output_name: str,
235
+ x: torch.Tensor,
236
+ add_sampling_metadata: bool,
237
+ prompt: str,
238
+ ):
239
+ print(f"Saving {output_name}")
240
+ # bring into PIL format and save
241
+ x = x.clamp(-1, 1)
242
+ x = rearrange(x[0], "c h w -> h w c")
243
+
244
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
245
+
246
+ exif_data = Image.Exif()
247
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
248
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
249
+ exif_data[ExifTags.Base.Model] = name
250
+ if add_sampling_metadata:
251
+ exif_data[ExifTags.Base.ImageDescription] = prompt
252
+ img.save(output_name, exif=exif_data, quality=95, subsampling=0)
torchtitan/experiments/kernels/triton_mg_group_gemm/simpleMoE.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import math
10
+ import time
11
+
12
+ from typing import Dict, List, Tuple
13
+
14
+ # import numpy as np
15
+ import torch #
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import torch.optim as optim
19
+
20
+ # from torchao_pr.mg_grouped_gemm import mg_grouped_gemm
21
+
22
+ # Configure logging
23
+ logging.basicConfig(
24
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
25
+ )
26
+
27
+ # Try to import the optimized MG GEMM implementation
28
+ try:
29
+ from torchao_pr.mg_grouped_gemm import ( # grouped_gemm_backward,
30
+ grouped_gemm_forward,
31
+ )
32
+
33
+ has_mg_gemm = True
34
+ except ImportError:
35
+ logging.warning("MG GEMM implementation not found. Will use manual looping only.")
36
+ has_mg_gemm = False
37
+
38
+
39
+ class Router(nn.Module):
40
+ """
41
+ Router module that assigns tokens to experts.
42
+ """
43
+
44
+ def __init__(self, input_dim: int, num_experts: int, top_k: int = 2):
45
+ super().__init__()
46
+ self.input_dim = input_dim
47
+ self.num_experts = num_experts
48
+ self.top_k = top_k
49
+
50
+ # Routing layer
51
+ self.router = nn.Linear(input_dim, num_experts)
52
+
53
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
54
+ """
55
+ Route input tokens to experts.
56
+
57
+ Args:
58
+ x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
59
+
60
+ Returns:
61
+ Tuple containing:
62
+ - router_logits: Raw routing probabilities
63
+ - dispatch_tensor: One-hot tensor indicating expert assignment
64
+ - expert_indices: List of indices for each expert's tokens
65
+ """
66
+ batch_size, seq_len, _ = x.shape
67
+
68
+ # Flatten batch and sequence dimensions
69
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
70
+
71
+ # Compute routing probabilities
72
+ router_logits = self.router(x_flat) # (batch_size * seq_len, num_experts)
73
+
74
+ # Apply softmax to get probabilities
75
+ router_probs = F.softmax(router_logits, dim=-1)
76
+
77
+ # Get top-k experts for each token
78
+ top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
79
+
80
+ # Normalize top-k probabilities
81
+ top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
82
+
83
+ # Create dispatch tensor (one-hot representation of assignments)
84
+ dispatch_tensor = torch.zeros_like(router_probs)
85
+ token_indices = (
86
+ torch.arange(router_probs.size(0), device=router_probs.device)
87
+ .unsqueeze(1)
88
+ .expand(-1, self.top_k)
89
+ )
90
+ dispatch_tensor.scatter_(1, top_k_indices, top_k_probs) # .unsqueeze(-1))
91
+
92
+ # For each expert, get the indices of tokens routed to it
93
+ expert_indices = []
94
+ for expert_idx in range(self.num_experts):
95
+ # Get indices of tokens that have non-zero probability for this expert
96
+ indices = torch.nonzero(dispatch_tensor[:, expert_idx] > 0, as_tuple=True)[
97
+ 0
98
+ ]
99
+ expert_indices.append(indices)
100
+
101
+ return router_logits, dispatch_tensor, expert_indices
102
+
103
+
104
+ class Expert(nn.Module):
105
+ """
106
+ Individual expert module.
107
+ """
108
+
109
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
110
+ super().__init__()
111
+ self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False)
112
+ self.activation = nn.GELU()
113
+ self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False)
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ x = self.fc1(x)
117
+ x = self.activation(x)
118
+ x = self.fc2(x)
119
+ return x
120
+
121
+
122
+ class MixtureOfExperts(nn.Module):
123
+ """
124
+ Mixture of Experts layer with support for both manual looping and grouped GEMM.
125
+ """
126
+
127
+ def __init__(
128
+ self,
129
+ input_dim: int,
130
+ hidden_dim: int,
131
+ output_dim: int,
132
+ num_experts: int,
133
+ top_k: int = 2,
134
+ use_mg_gemm: bool = False,
135
+ ):
136
+ super().__init__()
137
+ self.input_dim = input_dim
138
+ self.hidden_dim = hidden_dim
139
+ self.output_dim = output_dim
140
+ self.num_experts = num_experts
141
+ self.top_k = top_k
142
+ self.use_mg_gemm = use_mg_gemm and has_mg_gemm
143
+
144
+ # Router
145
+ self.router = Router(input_dim, num_experts, top_k)
146
+
147
+ # Create expert modules
148
+ if self.use_mg_gemm:
149
+ # For MG GEMM, we need a single weight tensor for all experts
150
+ # First layer (input -> hidden)
151
+ self.expert_fc1_weight = nn.Parameter(
152
+ torch.randn(num_experts * hidden_dim, input_dim) / math.sqrt(input_dim)
153
+ )
154
+ # self.expert_fc1_bias = nn.Parameter(torch.zeros(num_experts * hidden_dim))
155
+
156
+ # Second layer (hidden -> output)
157
+ self.expert_fc2_weight = nn.Parameter(
158
+ torch.randn(num_experts * output_dim, hidden_dim)
159
+ / math.sqrt(hidden_dim)
160
+ )
161
+ # self.expert_fc2_bias = nn.Parameter(torch.zeros(num_experts * output_dim))
162
+ else:
163
+ # For manual looping, create separate experts
164
+ self.experts = nn.ModuleList(
165
+ [Expert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
166
+ )
167
+
168
+ def forward_manual_loop(self, x: torch.Tensor) -> torch.Tensor:
169
+ """
170
+ Forward pass using manual looping over experts.
171
+ """
172
+ batch_size, seq_len, _ = x.shape
173
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
174
+
175
+ # Get routing information
176
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
177
+
178
+ # Initialize output tensor
179
+ final_output = torch.zeros(
180
+ batch_size * seq_len, self.output_dim, device=x.device
181
+ )
182
+
183
+ # Process each expert
184
+ for expert_idx, indices in enumerate(expert_indices):
185
+ if indices.numel() > 0:
186
+ # Get tokens routed to this expert
187
+ expert_inputs = x_flat[indices] # (num_tokens_for_expert, input_dim)
188
+
189
+ # Process tokens through expert
190
+ expert_outputs = self.experts[expert_idx](
191
+ expert_inputs
192
+ ) # (num_tokens_for_expert, output_dim)
193
+
194
+ # Scale outputs by router probabilities
195
+ scaled_outputs = expert_outputs * dispatch_tensor[
196
+ indices, expert_idx
197
+ ].unsqueeze(1)
198
+
199
+ # Add to final output
200
+ final_output.index_add_(0, indices, scaled_outputs)
201
+
202
+ # Reshape back to original dimensions
203
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
204
+
205
+ return output, router_logits
206
+
207
+ def forward_mg_gemm(self, x: torch.Tensor) -> torch.Tensor:
208
+ batch_size, seq_len, _ = x.shape
209
+ x_flat = x.reshape(-1, self.input_dim) # (batch_size * seq_len, input_dim)
210
+ total_tokens = batch_size * seq_len
211
+
212
+ # Get routing information
213
+ router_logits, dispatch_tensor, expert_indices = self.router(x)
214
+
215
+ # Get token counts for each expert
216
+ token_counts = [indices.numel() for indices in expert_indices]
217
+ m_sizes = torch.tensor(token_counts, dtype=torch.int32, device=x.device)
218
+
219
+ print(f"Token counts per expert: {token_counts}")
220
+ print(f"m_sizes: {m_sizes}")
221
+
222
+ # Create the combined input tensor
223
+ combined_input = torch.zeros(sum(token_counts), self.input_dim, device=x.device)
224
+
225
+ start_idx = 0
226
+ for expert_idx, indices in enumerate(expert_indices):
227
+ if indices.numel() > 0:
228
+ end_idx = start_idx + indices.numel()
229
+ combined_input[start_idx:end_idx] = x_flat[indices]
230
+ start_idx = end_idx
231
+
232
+ print(f"combined_input shape: {combined_input.shape}")
233
+
234
+ # First layer: input -> hidden
235
+ fc1_weight_reshaped = self.expert_fc1_weight.reshape(
236
+ self.num_experts, self.hidden_dim, self.input_dim
237
+ )
238
+ fc1_weight_combined = fc1_weight_reshaped.reshape(-1, self.input_dim)
239
+
240
+ print(f"fc1_weight_combined shape: {fc1_weight_combined.shape}")
241
+
242
+ # Run the grouped GEMM
243
+ hidden_outputs = grouped_gemm_forward(
244
+ combined_input, fc1_weight_combined, m_sizes
245
+ )
246
+
247
+ print(f"hidden_outputs shape after first GEMM: {hidden_outputs.shape}")
248
+
249
+ # Apply activation
250
+ hidden_outputs = F.gelu(hidden_outputs)
251
+
252
+ print(f"hidden_outputs shape after activation: {hidden_outputs.shape}")
253
+
254
+ # Second layer: hidden -> output
255
+ # Reshape hidden_outputs to match expected dimensions
256
+ reshaped_hidden_outputs = []
257
+ start_idx = 0
258
+
259
+ for expert_idx, count in enumerate(token_counts):
260
+ if count > 0:
261
+ end_idx = start_idx + count
262
+ # Take this expert's outputs and reshape to [count, hidden_dim]
263
+ expert_output = hidden_outputs[
264
+ start_idx:end_idx,
265
+ expert_idx * self.hidden_dim : (expert_idx + 1) * self.hidden_dim,
266
+ ]
267
+ reshaped_hidden_outputs.append(expert_output)
268
+ start_idx = end_idx
269
+
270
+ # Concatenate all reshaped outputs
271
+ hidden_outputs = torch.cat(reshaped_hidden_outputs, dim=0)
272
+
273
+ # Reshape expert weights for second layer
274
+ fc2_weight_reshaped = self.expert_fc2_weight.reshape(
275
+ self.num_experts, self.output_dim, self.hidden_dim
276
+ )
277
+ fc2_weight_combined = fc2_weight_reshaped.reshape(-1, self.hidden_dim)
278
+
279
+ print(f"fc2_weight_combined shape: {fc2_weight_combined.shape}")
280
+
281
+ # Run the second grouped GEMM
282
+ expert_outputs_combined = grouped_gemm_forward(
283
+ hidden_outputs, fc2_weight_combined, m_sizes
284
+ )
285
+
286
+ # Initialize final output tensor with correct shape
287
+ final_output = torch.zeros(total_tokens, self.output_dim, device=x.device)
288
+
289
+ # Distribute the outputs back to the original token positions
290
+ start_idx = 0
291
+ for expert_idx, indices in enumerate(expert_indices):
292
+ if indices.numel() > 0:
293
+ end_idx = start_idx + indices.numel()
294
+ # Get this expert's outputs
295
+ expert_outputs = expert_outputs_combined[start_idx:end_idx]
296
+
297
+ print(
298
+ f"Expert {expert_idx} - indices shape: {indices.shape}, expert_outputs shape: {expert_outputs.shape}"
299
+ )
300
+
301
+ # Scale outputs by router probabilities
302
+ scaled_outputs = expert_outputs * dispatch_tensor[
303
+ indices, expert_idx
304
+ ].unsqueeze(1)
305
+
306
+ # Ensure dimensions match before using index_add_
307
+ if scaled_outputs.shape[1] != final_output.shape[1]:
308
+ # print(
309
+ # f"Reshaping: Dimension mismatch: scaled_outputs {scaled_outputs.shape}, final_output {final_output.shape}"
310
+ # )
311
+ # Reshape if needed - make sure output_dim is correct
312
+ scaled_outputs = scaled_outputs[:, : self.output_dim]
313
+
314
+ # Add to final output
315
+ final_output.index_add_(0, indices, scaled_outputs)
316
+
317
+ start_idx = end_idx
318
+
319
+ # Reshape back to original dimensions
320
+ output = final_output.reshape(batch_size, seq_len, self.output_dim)
321
+
322
+ return output, router_logits
323
+
324
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
325
+ if self.use_mg_gemm and has_mg_gemm:
326
+ return self.forward_mg_gemm(x)
327
+ else:
328
+ return self.forward_manual_loop(x)
329
+
330
+
331
+ class MoEModel(nn.Module):
332
+ """
333
+ Simple model using MoE layers.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ vocab_size: int,
339
+ embed_dim: int,
340
+ hidden_dim: int,
341
+ num_experts: int,
342
+ top_k: int = 2,
343
+ use_mg_gemm: bool = False,
344
+ ):
345
+ super().__init__()
346
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
347
+ self.moe_layer = MixtureOfExperts(
348
+ input_dim=embed_dim,
349
+ hidden_dim=hidden_dim,
350
+ output_dim=embed_dim,
351
+ num_experts=num_experts,
352
+ top_k=top_k,
353
+ use_mg_gemm=use_mg_gemm,
354
+ )
355
+ self.output_layer = nn.Linear(embed_dim, vocab_size)
356
+
357
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
358
+ # x shape: (batch_size, seq_len)
359
+ embedded = self.embedding(x) # (batch_size, seq_len, embed_dim)
360
+ moe_output, router_logits = self.moe_layer(
361
+ embedded
362
+ ) # (batch_size, seq_len, embed_dim)
363
+ logits = self.output_layer(moe_output) # (batch_size, seq_len, vocab_size)
364
+ return logits, router_logits
365
+
366
+
367
+ def compute_load_balancing_loss(
368
+ router_logits: torch.Tensor, num_experts: int
369
+ ) -> torch.Tensor:
370
+ """
371
+ Compute the load balancing loss for MoE training.
372
+
373
+ Args:
374
+ router_logits (torch.Tensor): Router logits of shape (batch_size * seq_len, num_experts)
375
+ num_experts (int): Number of experts
376
+
377
+ Returns:
378
+ torch.Tensor: Load balancing loss
379
+ """
380
+ # Get router probabilities
381
+ router_probs = F.softmax(
382
+ router_logits, dim=-1
383
+ ) # (batch_size * seq_len, num_experts)
384
+
385
+ # Compute fraction of tokens routed to each expert
386
+ # Sum across the batch dimension and normalize
387
+ router_probs_sum = router_probs.sum(dim=0) # (num_experts,)
388
+ router_probs_sum = router_probs_sum / router_probs_sum.sum()
389
+
390
+ # Compute the mean probability per expert
391
+ mean_prob = 1.0 / num_experts
392
+
393
+ # Compute the fraction of tokens routed to each expert
394
+ # The goal is to have uniform routing across experts
395
+ load_balancing_loss = num_experts * torch.sum(router_probs_sum * router_probs_sum)
396
+
397
+ return load_balancing_loss
398
+
399
+
400
+ def generate_sample_data(
401
+ batch_size: int, seq_len: int, vocab_size: int, device: str = "cuda"
402
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
403
+ """
404
+ Generate sample data for training.
405
+
406
+ Args:
407
+ batch_size (int): Batch size
408
+ seq_len (int): Sequence length
409
+ vocab_size (int): Vocabulary size
410
+ device (str): Device to use
411
+
412
+ Returns:
413
+ Tuple of input tokens and target tokens
414
+ """
415
+ # Generate random input tokens
416
+ inputs = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
417
+
418
+ # Generate random target tokens
419
+ targets = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
420
+
421
+ return inputs, targets
422
+
423
+
424
+ def train_epoch(
425
+ model: nn.Module,
426
+ optimizer: torch.optim.Optimizer,
427
+ batch_size: int,
428
+ seq_len: int,
429
+ vocab_size: int,
430
+ num_batches: int,
431
+ device: str,
432
+ load_balance_coef: float = 0.01,
433
+ ) -> Dict[str, float]:
434
+ """
435
+ Train the model for one epoch.
436
+
437
+ Args:
438
+ model (nn.Module): Model to train
439
+ optimizer (torch.optim.Optimizer): Optimizer
440
+ batch_size (int): Batch size
441
+ seq_len (int): Sequence length
442
+ vocab_size (int): Vocabulary size
443
+ num_batches (int): Number of batches per epoch
444
+ device (str): Device to use
445
+ load_balance_coef (float): Coefficient for load balancing loss
446
+
447
+ Returns:
448
+ Dict containing training metrics
449
+ """
450
+ model.train()
451
+ total_loss = 0.0
452
+ total_acc = 0.0
453
+ start_time = time.time()
454
+
455
+ for i in range(num_batches):
456
+ # Generate sample data
457
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
458
+
459
+ # Forward pass
460
+ optimizer.zero_grad()
461
+ logits, router_logits = model(inputs)
462
+
463
+ # Compute loss
464
+ # Reshape for cross entropy loss
465
+ logits_flat = logits.reshape(-1, vocab_size)
466
+ targets_flat = targets.reshape(-1)
467
+
468
+ # Cross entropy loss
469
+ ce_loss = F.cross_entropy(logits_flat, targets_flat)
470
+
471
+ # Load balancing loss
472
+ lb_loss = compute_load_balancing_loss(
473
+ router_logits, model.moe_layer.num_experts
474
+ )
475
+
476
+ # Combined loss
477
+ loss = ce_loss + load_balance_coef * lb_loss
478
+
479
+ # Backward pass
480
+ loss.backward()
481
+ optimizer.step()
482
+
483
+ # Compute accuracy
484
+ preds = logits_flat.argmax(dim=-1)
485
+ correct = (preds == targets_flat).float().sum()
486
+ acc = correct / (batch_size * seq_len)
487
+
488
+ # Accumulate metrics
489
+ total_loss += loss.item()
490
+ total_acc += acc.item()
491
+
492
+ # Log progress
493
+ if (i + 1) % 10 == 0:
494
+ logging.info(
495
+ f"Batch {i + 1}/{num_batches} | "
496
+ f"Loss: {loss.item():.4f} | "
497
+ f"CE Loss: {ce_loss.item():.4f} | "
498
+ f"LB Loss: {lb_loss.item():.4f} | "
499
+ f"Acc: {acc.item():.4f}"
500
+ )
501
+
502
+ # Compute average metrics
503
+ avg_loss = total_loss / num_batches
504
+ avg_acc = total_acc / num_batches
505
+ epoch_time = time.time() - start_time
506
+
507
+ return {"loss": avg_loss, "acc": avg_acc, "time": epoch_time}
508
+
509
+
510
+ def evaluate(
511
+ model: nn.Module,
512
+ batch_size: int,
513
+ seq_len: int,
514
+ vocab_size: int,
515
+ num_batches: int,
516
+ device: str,
517
+ ) -> Dict[str, float]:
518
+ """
519
+ Evaluate the model.
520
+
521
+ Args:
522
+ model (nn.Module): Model to evaluate
523
+ batch_size (int): Batch size
524
+ seq_len (int): Sequence length
525
+ vocab_size (int): Vocabulary size
526
+ num_batches (int): Number of batches for evaluation
527
+ device (str): Device to use
528
+
529
+ Returns:
530
+ Dict containing evaluation metrics
531
+ """
532
+ model.eval()
533
+ total_loss = 0.0
534
+ total_acc = 0.0
535
+
536
+ with torch.no_grad():
537
+ for i in range(num_batches):
538
+ # Generate sample data
539
+ inputs, targets = generate_sample_data(
540
+ batch_size, seq_len, vocab_size, device
541
+ )
542
+
543
+ # Forward pass
544
+ logits, router_logits = model(inputs)
545
+
546
+ # Compute loss
547
+ logits_flat = logits.reshape(-1, vocab_size)
548
+ targets_flat = targets.reshape(-1)
549
+
550
+ # Cross entropy loss
551
+ loss = F.cross_entropy(logits_flat, targets_flat)
552
+
553
+ # Compute accuracy
554
+ preds = logits_flat.argmax(dim=-1)
555
+ correct = (preds == targets_flat).float().sum()
556
+ acc = correct / (batch_size * seq_len)
557
+
558
+ # Accumulate metrics
559
+ total_loss += loss.item()
560
+ total_acc += acc.item()
561
+
562
+ # Compute average metrics
563
+ avg_loss = total_loss / num_batches
564
+ avg_acc = total_acc / num_batches
565
+
566
+ return {"loss": avg_loss, "acc": avg_acc}
567
+
568
+
569
+ def measure_performance(
570
+ model: nn.Module,
571
+ batch_size: int,
572
+ seq_len: int,
573
+ vocab_size: int,
574
+ num_batches: int,
575
+ device: str,
576
+ ) -> Dict[str, float]:
577
+ """
578
+ Measure forward and backward pass performance.
579
+
580
+ Args:
581
+ model (nn.Module): Model to evaluate
582
+ batch_size (int): Batch size
583
+ seq_len (int): Sequence length
584
+ vocab_size (int): Vocabulary size
585
+ num_batches (int): Number of batches for measurement
586
+ device (str): Device to use
587
+
588
+ Returns:
589
+ Dict containing performance metrics
590
+ """
591
+ model.train()
592
+
593
+ # Create dummy optimizer
594
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
595
+
596
+ # Warmup
597
+ for _ in range(5):
598
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
599
+ logits, router_logits = model(inputs)
600
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
601
+ loss.backward()
602
+ optimizer.zero_grad()
603
+
604
+ # Measure forward pass time
605
+ torch.cuda.synchronize()
606
+ forward_start = time.time()
607
+
608
+ for _ in range(num_batches):
609
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
610
+ with torch.no_grad():
611
+ logits, router_logits = model(inputs)
612
+
613
+ torch.cuda.synchronize()
614
+ forward_end = time.time()
615
+ forward_time = (forward_end - forward_start) / num_batches
616
+
617
+ # Measure backward pass time
618
+ torch.cuda.synchronize()
619
+ backward_start = time.time()
620
+
621
+ for _ in range(num_batches):
622
+ inputs, targets = generate_sample_data(batch_size, seq_len, vocab_size, device)
623
+ logits, router_logits = model(inputs)
624
+ loss = F.cross_entropy(logits.reshape(-1, vocab_size), targets.reshape(-1))
625
+ loss.backward()
626
+ optimizer.zero_grad()
627
+
628
+ torch.cuda.synchronize()
629
+ backward_end = time.time()
630
+ backward_time = (backward_end - backward_start) / num_batches
631
+
632
+ return {
633
+ "forward_time": forward_time * 1000, # Convert to ms
634
+ "backward_time": backward_time * 1000, # Convert to ms
635
+ "total_time": (forward_time + backward_time) * 1000, # Convert to ms
636
+ }
637
+
638
+
639
+ def compare_methods(args):
640
+ """
641
+ Compare manual looping and MG GEMM implementations.
642
+ """
643
+ device = torch.device(args.device)
644
+
645
+ # Create models
646
+ manual_model = MoEModel(
647
+ vocab_size=args.vocab_size,
648
+ embed_dim=args.embed_dim,
649
+ hidden_dim=args.hidden_dim,
650
+ num_experts=args.num_experts,
651
+ top_k=args.top_k,
652
+ use_mg_gemm=False,
653
+ ).to(device)
654
+
655
+ if has_mg_gemm:
656
+ mg_model = MoEModel(
657
+ vocab_size=args.vocab_size,
658
+ embed_dim=args.embed_dim,
659
+ hidden_dim=args.hidden_dim,
660
+ num_experts=args.num_experts,
661
+ top_k=args.top_k,
662
+ use_mg_gemm=True,
663
+ ).to(device)
664
+ else:
665
+ mg_model = None
666
+
667
+ # Measure performance
668
+ logging.info("Measuring performance of manual looping method...")
669
+ manual_perf = measure_performance(
670
+ manual_model,
671
+ args.batch_size,
672
+ args.seq_len,
673
+ args.vocab_size,
674
+ args.perf_batches,
675
+ device,
676
+ )
677
+
678
+ if mg_model is not None:
679
+ logging.info("Measuring performance of MG GEMM method...")
680
+ mg_perf = measure_performance(
681
+ mg_model,
682
+ args.batch_size,
683
+ args.seq_len,
684
+ args.vocab_size,
685
+ args.perf_batches,
686
+ device,
687
+ )
688
+ else:
689
+ mg_perf = {"forward_time": 0, "backward_time": 0, "total_time": 0}
690
+
691
+ # Log results
692
+ logging.info("\n===== Performance Comparison =====")
693
+ logging.info("Model Configuration:")
694
+ logging.info(f" - Batch Size: {args.batch_size}")
695
+ logging.info(f" - Sequence Length: {args.seq_len}")
696
+ logging.info(f" - Embed Dimension: {args.embed_dim}")
697
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
698
+ logging.info(f" - Number of Experts: {args.num_experts}")
699
+ logging.info(f" - Top-K: {args.top_k}")
700
+ logging.info("")
701
+
702
+ logging.info("Manual Looping Method:")
703
+ logging.info(f" - Forward Time: {manual_perf['forward_time']:.2f} ms")
704
+ logging.info(f" - Backward Time: {manual_perf['backward_time']:.2f} ms")
705
+ logging.info(f" - Total Time: {manual_perf['total_time']:.2f} ms")
706
+ logging.info("")
707
+
708
+ if mg_model is not None:
709
+ logging.info("MG GEMM Method:")
710
+ logging.info(f" - Forward Time: {mg_perf['forward_time']:.2f} ms")
711
+ logging.info(f" - Backward Time: {mg_perf['backward_time']:.2f} ms")
712
+ logging.info(f" - Total Time: {mg_perf['total_time']:.2f} ms")
713
+ logging.info("")
714
+
715
+ # Calculate speedup
716
+ forward_speedup = (
717
+ manual_perf["forward_time"] / mg_perf["forward_time"]
718
+ if mg_perf["forward_time"] > 0
719
+ else 0
720
+ )
721
+ backward_speedup = (
722
+ manual_perf["backward_time"] / mg_perf["backward_time"]
723
+ if mg_perf["backward_time"] > 0
724
+ else 0
725
+ )
726
+ total_speedup = (
727
+ manual_perf["total_time"] / mg_perf["total_time"]
728
+ if mg_perf["total_time"] > 0
729
+ else 0
730
+ )
731
+
732
+ logging.info("Speedup (MG GEMM vs Manual):")
733
+ logging.info(f" - Forward Speedup: {forward_speedup:.2f}x")
734
+ logging.info(f" - Backward Speedup: {backward_speedup:.2f}x")
735
+ logging.info(f" - Total Speedup: {total_speedup:.2f}x")
736
+ else:
737
+ logging.info("MG GEMM method not available.")
738
+
739
+
740
+ def train_model(args):
741
+ """
742
+ Train an MoE model.
743
+ """
744
+ device = torch.device(args.device)
745
+
746
+ # Create model
747
+ model = MoEModel(
748
+ vocab_size=args.vocab_size,
749
+ embed_dim=args.embed_dim,
750
+ hidden_dim=args.hidden_dim,
751
+ num_experts=args.num_experts,
752
+ top_k=args.top_k,
753
+ use_mg_gemm=args.use_mg_gemm and has_mg_gemm,
754
+ ).to(device)
755
+
756
+ # Create optimizer
757
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
758
+
759
+ # Log model information
760
+ logging.info("Model configuration:")
761
+ logging.info(f" - Vocabulary Size: {args.vocab_size}")
762
+ logging.info(f" - Embedding Dimension: {args.embed_dim}")
763
+ logging.info(f" - Hidden Dimension: {args.hidden_dim}")
764
+ logging.info(f" - Number of Experts: {args.num_experts}")
765
+ logging.info(f" - Top-K: {args.top_k}")
766
+ logging.info(f" - Using MG GEMM: {args.use_mg_gemm and has_mg_gemm}")
767
+
768
+ # Training loop
769
+ for epoch in range(args.epochs):
770
+ logging.info(f"\nEpoch {epoch + 1}/{args.epochs}")
771
+
772
+ # Train
773
+ train_metrics = train_epoch(
774
+ model=model,
775
+ optimizer=optimizer,
776
+ batch_size=args.batch_size,
777
+ seq_len=args.seq_len,
778
+ vocab_size=args.vocab_size,
779
+ num_batches=args.train_batches,
780
+ device=device,
781
+ load_balance_coef=args.load_balance_coef,
782
+ )
783
+
784
+ # Evaluate
785
+ eval_metrics = evaluate(
786
+ model=model,
787
+ batch_size=args.batch_size,
788
+ seq_len=args.seq_len,
789
+ vocab_size=args.vocab_size,
790
+ num_batches=args.eval_batches,
791
+ device=device,
792
+ )
793
+
794
+ # Log metrics
795
+ logging.info(
796
+ f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['acc']:.4f}"
797
+ )
798
+ logging.info(
799
+ f"Eval Loss: {eval_metrics['loss']:.4f} | Eval Acc: {eval_metrics['acc']:.4f}"
800
+ )
801
+ logging.info(f"Epoch Time: {train_metrics['time']:.2f} seconds")
802
+
803
+
804
+ if __name__ == "__main__":
805
+ parser = argparse.ArgumentParser(description="Train MoE model")
806
+
807
+ # Model parameters
808
+ parser.add_argument("--vocab_size", type=int, default=10000, help="Vocabulary size")
809
+ parser.add_argument(
810
+ "--embed_dim", type=int, default=512, help="Embedding dimension"
811
+ )
812
+ parser.add_argument(
813
+ "--hidden_dim", type=int, default=1024, help="Hidden dimension in experts"
814
+ )
815
+ parser.add_argument("--num_experts", type=int, default=8, help="Number of experts")
816
+ parser.add_argument(
817
+ "--top_k", type=int, default=2, help="Top-k experts to route to"
818
+ )
819
+
820
+ # Training parameters
821
+ parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
822
+ parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
823
+ parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
824
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
825
+ parser.add_argument(
826
+ "--train_batches",
827
+ type=int,
828
+ default=100,
829
+ help="Number of training batches per epoch",
830
+ )
831
+ parser.add_argument(
832
+ "--eval_batches", type=int, default=20, help="Number of evaluation batches"
833
+ )
834
+ parser.add_argument(
835
+ "--perf_batches",
836
+ type=int,
837
+ default=50,
838
+ help="Number of batches for performance testing",
839
+ )
840
+ parser.add_argument(
841
+ "--load_balance_coef",
842
+ type=float,
843
+ default=0.01,
844
+ help="Load balancing loss coefficient",
845
+ )
846
+
847
+ # Runtime parameters
848
+ parser.add_argument(
849
+ "--device",
850
+ type=str,
851
+ default="cuda" if torch.cuda.is_available() else "cpu",
852
+ help="Device to use (cuda or cpu)",
853
+ )
854
+ parser.add_argument(
855
+ "--use_mg_gemm",
856
+ action="store_true",
857
+ help="Use MG GEMM implementation if available",
858
+ )
859
+ parser.add_argument(
860
+ "--compare",
861
+ action="store_true",
862
+ help="Compare manual and MG GEMM implementations",
863
+ )
864
+ parser.add_argument("--train", action="store_true", help="Train the model")
865
+
866
+ args = parser.parse_args()
867
+
868
+ # Check for CUDA
869
+ if args.device == "cuda" and not torch.cuda.is_available():
870
+ logging.warning("CUDA not available, using CPU instead.")
871
+ args.device = "cpu"
872
+
873
+ # Log basic information
874
+ logging.info(f"PyTorch version: {torch.__version__}")
875
+ logging.info(f"Device: {args.device}")
876
+ logging.info(f"MG GEMM available: {has_mg_gemm}")
877
+
878
+ # Run the requested action
879
+ if args.compare:
880
+ compare_methods(args)
881
+ elif args.train:
882
+ train_model(args)
883
+ else:
884
+ # Default to comparison if no action specified
885
+ compare_methods(args)
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mg_grouped_gemm import grouped_gemm_forward
8
+ from .tma_autotuning import ALIGN_SIZE_M
9
+
10
+ __all__ = [
11
+ "grouped_gemm_forward",
12
+ "ALIGN_SIZE_M",
13
+ ]
torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/mg_grouped_gemm.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # credit - flat index forward kernel is derived from FBGemm:
8
+ # https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gemm/triton_gemm
9
+
10
+ # pyre-unsafe
11
+ import functools
12
+ import logging
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ import triton
21
+ import triton.language as tl
22
+ from triton import Config as TConfig
23
+
24
+ from triton.runtime import driver # @manual
25
+
26
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
27
+
28
+ from tma_autotuning import (
29
+ ALIGN_SIZE_M,
30
+ _NV_CONFIGS,
31
+ CudaUtils,
32
+ early_config_prune,
33
+ TmaDescriptorHelper,
34
+ )
35
+
36
+
37
+ # Configure logging
38
+ logging.basicConfig(
39
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
40
+ )
41
+
42
+ # ============== Start Triton Kernels ===============
43
+
44
+
45
+ @triton.autotune(
46
+ configs=_NV_CONFIGS,
47
+ key=["G", "M_BUCKET", "N", "K"],
48
+ prune_configs_by={"early_config_prune": early_config_prune},
49
+ )
50
+ @triton.jit
51
+ def _kernel_mg_forward_hopper(
52
+ a_desc_ptr,
53
+ b_desc_ptr,
54
+ c_ptr,
55
+ workspace,
56
+ m_sizes,
57
+ # problem sizes
58
+ G: tl.constexpr,
59
+ M_BUCKET: tl.constexpr,
60
+ N: tl.constexpr,
61
+ K: tl.constexpr,
62
+ # config
63
+ NUM_SMS: tl.constexpr,
64
+ TMA_SIZE: tl.constexpr,
65
+ USE_EPILOGUE_SUBTILING: tl.constexpr,
66
+ # tiles
67
+ BLOCK_SIZE_M: tl.constexpr,
68
+ BLOCK_SIZE_N: tl.constexpr,
69
+ BLOCK_SIZE_K: tl.constexpr,
70
+ ) -> None:
71
+ """
72
+ Flat index style forward kernel for Hopper.
73
+ For simplicity, we always use TMA Load and TMA Store
74
+ """
75
+ tbidx = tl.program_id(0) # thread block index
76
+
77
+ c_dtype = c_ptr.dtype.element_ty # output dtype
78
+
79
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE) # for TMA Store
80
+
81
+ M_end = 0
82
+ M_start = 0
83
+ processed_tiles = 0
84
+ # Size of individual weight matrix
85
+ n_size = N // G
86
+ n_start = 0
87
+
88
+ for g in range(G):
89
+ # Move down along groups
90
+ # reset to new M offset
91
+ M_start = M_end
92
+ m_size = tl.load(m_sizes + g)
93
+ M_end = M_start + m_size
94
+ n_start = n_size * g
95
+
96
+ if m_size > 0:
97
+ # Process this group
98
+
99
+ # Acquire hold on c_desc_ptr for TMA Store
100
+ tl.extra.cuda.experimental_device_tensormap_create2d(
101
+ desc_ptr=c_desc_ptr,
102
+ global_address=c_ptr + M_start * n_size,
103
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
104
+ global_size=[m_size, n_size],
105
+ element_ty=c_dtype,
106
+ )
107
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
108
+
109
+ # tiles for this group
110
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
111
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
112
+ group_num_tiles = num_m_tiles * num_n_tiles
113
+
114
+ while tbidx >= processed_tiles and tbidx < (
115
+ processed_tiles + group_num_tiles
116
+ ):
117
+ group_index = tbidx - processed_tiles
118
+
119
+ # columnwise
120
+ tile_m_index = group_index % num_m_tiles
121
+ tile_n_index = group_index // num_m_tiles
122
+
123
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
124
+
125
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
126
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
127
+ global_n_offset = (n_start + n_offset).to(tl.int32)
128
+
129
+ for k_offset in range(0, K, BLOCK_SIZE_K):
130
+ # input block [M,K]
131
+ a = tl._experimental_descriptor_load(
132
+ a_desc_ptr,
133
+ [m_offset, k_offset],
134
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
135
+ c_dtype,
136
+ )
137
+ # weight block [N, K]
138
+ b = tl._experimental_descriptor_load(
139
+ b_desc_ptr,
140
+ [global_n_offset, k_offset],
141
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
142
+ c_dtype,
143
+ )
144
+
145
+ accumulator += tl.dot(a, b.T)
146
+
147
+ # Store using TMA
148
+
149
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
150
+
151
+ if USE_EPILOGUE_SUBTILING:
152
+ acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2))
153
+ acc = tl.permute(acc, (0, 2, 1))
154
+ acc0, acc1 = tl.split(acc)
155
+ c0 = acc0.to(c_dtype)
156
+ tl._experimental_descriptor_store(
157
+ c_desc_ptr, c0, [m_offset, n_offset]
158
+ )
159
+ c1 = acc1.to(c_dtype)
160
+ tl._experimental_descriptor_store(
161
+ c_desc_ptr, c1, [m_offset, n_offset + BLOCK_SIZE_N // 2]
162
+ )
163
+ else:
164
+ tl._experimental_descriptor_store(
165
+ c_desc_ptr,
166
+ accumulator.to(c_dtype),
167
+ [m_offset, n_offset],
168
+ )
169
+ # move to next tile in group
170
+ tbidx += NUM_SMS
171
+ # Update the total tiles count for the next group
172
+ processed_tiles += group_num_tiles
173
+
174
+
175
+ @triton.autotune(
176
+ configs=_NV_CONFIGS,
177
+ key=["G", "M_BUCKET", "N", "K"],
178
+ prune_configs_by={"early_config_prune": early_config_prune},
179
+ )
180
+ @triton.jit
181
+ def _kernel_mg_forward_tma(
182
+ a_desc_ptr,
183
+ b_desc_ptr,
184
+ c_ptr,
185
+ workspace,
186
+ m_sizes,
187
+ a_scale_ptr,
188
+ b_scale_ptr,
189
+ # problem sizes
190
+ G: tl.constexpr,
191
+ M_BUCKET: tl.constexpr,
192
+ N: tl.constexpr,
193
+ K: tl.constexpr,
194
+ # config
195
+ NUM_SMS: tl.constexpr,
196
+ USE_TMA_LOAD: tl.constexpr,
197
+ USE_TMA_STORE: tl.constexpr,
198
+ TMA_SIZE: tl.constexpr,
199
+ USE_FP8: tl.constexpr,
200
+ # tiles
201
+ BLOCK_SIZE_M: tl.constexpr,
202
+ BLOCK_SIZE_N: tl.constexpr,
203
+ BLOCK_SIZE_K: tl.constexpr,
204
+ ) -> None:
205
+ """
206
+ Flat index style forward kernel.
207
+ For simplicity, we always use TMA Load and TMA Store
208
+ """
209
+ tbidx = tl.program_id(0) # thread block index
210
+
211
+ c_dtype = c_ptr.dtype.element_ty
212
+
213
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
214
+
215
+ M_end = 0
216
+ processed_tiles = 0
217
+
218
+ for g in range(G):
219
+ # Move down along groups
220
+ # reset to new M offset
221
+ M_start = M_end
222
+ m_size = tl.load(m_sizes + g)
223
+ M_end = M_start + m_size
224
+
225
+ if m_size > 0:
226
+ # Process this group
227
+ n_size = N
228
+
229
+ # TMA Store prep
230
+ tl.extra.cuda.experimental_device_tensormap_create2d(
231
+ desc_ptr=c_desc_ptr,
232
+ global_address=c_ptr + M_start * N,
233
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
234
+ global_size=[m_size, n_size],
235
+ element_ty=c_dtype,
236
+ )
237
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
238
+
239
+ # tiles for this group
240
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
241
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
242
+ group_num_tiles = num_m_tiles * num_n_tiles
243
+
244
+ while tbidx >= processed_tiles and tbidx < (
245
+ processed_tiles + group_num_tiles
246
+ ):
247
+ group_index = tbidx - processed_tiles
248
+
249
+ tile_m_index = group_index % num_m_tiles
250
+ tile_n_index = group_index // num_m_tiles
251
+
252
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
253
+
254
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
255
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
256
+
257
+ for k_offset in range(0, K, BLOCK_SIZE_K):
258
+ # input block [M,K]
259
+ a = tl._experimental_descriptor_load(
260
+ a_desc_ptr,
261
+ [m_offset, k_offset],
262
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
263
+ c_dtype,
264
+ )
265
+ # weight block [N, K]
266
+ b = tl._experimental_descriptor_load(
267
+ b_desc_ptr,
268
+ [n_offset, k_offset],
269
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
270
+ c_dtype,
271
+ )
272
+
273
+ accumulator += tl.dot(a, b.T)
274
+
275
+ # Store using TMA
276
+
277
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
278
+ # n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
279
+
280
+ tl._experimental_descriptor_store(
281
+ c_desc_ptr,
282
+ accumulator.to(c_dtype),
283
+ [m_offset, n_offset],
284
+ )
285
+
286
+ # Move to the next tile
287
+ tbidx += NUM_SMS
288
+ # Update the total tiles count for the next group
289
+ processed_tiles += group_num_tiles
290
+
291
+
292
+ @triton.autotune(
293
+ configs=_NV_CONFIGS,
294
+ key=["G", "M_BUCKET", "N", "K"],
295
+ prune_configs_by={"early_config_prune": early_config_prune},
296
+ )
297
+ @triton.jit
298
+ def _kernel_mg_forward_no_tma(
299
+ a_ptr,
300
+ b_ptr,
301
+ c_ptr,
302
+ workspace,
303
+ m_sizes,
304
+ # problem sizes
305
+ G: tl.constexpr,
306
+ M_BUCKET: tl.constexpr,
307
+ N: tl.constexpr,
308
+ K: tl.constexpr,
309
+ # config
310
+ NUM_SMS: tl.constexpr,
311
+ USE_TMA_LOAD: tl.constexpr,
312
+ USE_TMA_STORE: tl.constexpr,
313
+ TMA_SIZE: tl.constexpr,
314
+ # tiles
315
+ BLOCK_SIZE_M: tl.constexpr,
316
+ BLOCK_SIZE_N: tl.constexpr,
317
+ BLOCK_SIZE_K: tl.constexpr,
318
+ ) -> None:
319
+ """
320
+ Flat index style forward kernel.
321
+ For bc and Ampere, we never use TMA Load and TMA Store
322
+ """
323
+ tbidx = tl.program_id(0) # thread block index
324
+
325
+ c_dtype = c_ptr.dtype.element_ty
326
+ c_desc_ptr = None
327
+
328
+ M_end = 0
329
+ processed_tiles = 0
330
+
331
+ for g in range(G):
332
+ # Move down along groups
333
+ # reset to new M offset
334
+ M_start = M_end
335
+ m_size = tl.load(m_sizes + g)
336
+ M_end = M_start + m_size
337
+
338
+ if m_size > 0:
339
+ # Process this group
340
+ n_size = N
341
+
342
+ # tiles for this group
343
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
344
+ num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
345
+ group_num_tiles = num_m_tiles * num_n_tiles
346
+
347
+ while tbidx >= processed_tiles and tbidx < (
348
+ processed_tiles + group_num_tiles
349
+ ):
350
+ group_index = tbidx - processed_tiles
351
+
352
+ tile_m_index = group_index % num_m_tiles
353
+ tile_n_index = group_index // num_m_tiles
354
+
355
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
356
+
357
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
358
+ n_offset = (tile_n_index * BLOCK_SIZE_N).to(tl.int32)
359
+
360
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
361
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
362
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
363
+
364
+ a_ptrs = a_ptr + (M_start + offs_am[:, None]) * K + offs_k[None, :]
365
+ b_ptrs = b_ptr + (offs_bn[:, None]) * K + offs_k[None, :]
366
+
367
+ for k_offset in range(0, K, BLOCK_SIZE_K):
368
+ # Load with bounds checking
369
+ a = tl.load(a_ptrs, mask=offs_am[:, None] < m_size)
370
+ b = tl.load(b_ptrs, mask=offs_bn[:, None] < n_size)
371
+
372
+ # Main matmul
373
+ accumulator += tl.dot(a, b.T)
374
+
375
+ # Update pointers for next block
376
+ a_ptrs += BLOCK_SIZE_K
377
+ b_ptrs += BLOCK_SIZE_K
378
+
379
+ # Store without TMA
380
+ offs_am = tile_m_index * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
381
+ offs_bn = tile_n_index * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
382
+
383
+ c = accumulator.to(c_dtype)
384
+
385
+ tl.store(
386
+ c_ptr
387
+ + (M_start + offs_am[:, None]) * N # Row stride is N
388
+ + offs_bn[None, :], # Column offset
389
+ c,
390
+ mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
391
+ )
392
+ # Move to the next tile
393
+ tbidx += NUM_SMS
394
+ # Update the total tiles count for the next group
395
+ processed_tiles += group_num_tiles
396
+
397
+
398
+ """
399
+ Backward pass for grouped GEMM with Triton, where grouping is M*G
400
+ We compute gradients with respect to both input (`grad_x`) and weights (`grad_w`).
401
+ """
402
+
403
+
404
+ # ---- dx flat linear indexed ----
405
+ @triton.autotune(
406
+ configs=_NV_CONFIGS,
407
+ key=["G", "M_BUCKET", "N", "K"],
408
+ prune_configs_by={"early_config_prune": early_config_prune},
409
+ )
410
+ @triton.jit
411
+ def _kernel_mg_dx_tma(
412
+ grad_output_desc_ptr, # [MG, N]
413
+ w_desc_ptr, # [N, K]
414
+ grad_input_ptr, # output grad_x [MG, K]
415
+ workspace, # for TMA store
416
+ m_sizes, # group sizes [G]
417
+ # problem sizes
418
+ G: tl.constexpr,
419
+ M_BUCKET: tl.constexpr,
420
+ N: tl.constexpr,
421
+ K: tl.constexpr,
422
+ # config
423
+ NUM_SMS: tl.constexpr,
424
+ USE_TMA_LOAD: tl.constexpr,
425
+ USE_TMA_STORE: tl.constexpr,
426
+ TMA_SIZE: tl.constexpr,
427
+ # tiles
428
+ BLOCK_SIZE_M: tl.constexpr,
429
+ BLOCK_SIZE_N: tl.constexpr,
430
+ BLOCK_SIZE_K: tl.constexpr,
431
+ ) -> None:
432
+ """
433
+ TMA-optimized kernel for computing gradients with respect to input (dx).
434
+ For the forward pass Y = X @ W.T, the backward for input is:
435
+ grad_X = grad_Y @ W
436
+
437
+ This maps to [MG, N] @ [N, K] -> [MG, K]
438
+
439
+ Key differences from forward:
440
+ 1. W is used directly and not transposed
441
+ 2. The reduction dimension is now N (not K)
442
+ 3. Output is [M, K] instead of [M, N]
443
+ """
444
+ tbidx = tl.program_id(0) # thread block index
445
+
446
+ c_dtype = grad_input_ptr.dtype.element_ty
447
+ c_desc_ptr = workspace + (tbidx * TMA_SIZE)
448
+
449
+ M_end = 0
450
+ processed_tiles = 0
451
+
452
+ for g in range(G):
453
+ # Move down along groups - same as forward
454
+ M_start = M_end
455
+ m_size = tl.load(m_sizes + g)
456
+ M_end = M_start + m_size
457
+
458
+ if m_size > 0:
459
+ # Process this group
460
+ # tiles for this group - now producing [M, K] output
461
+ num_m_tiles = tl.cdiv(m_size, BLOCK_SIZE_M)
462
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
463
+ group_num_tiles = num_m_tiles * num_k_tiles
464
+
465
+ # TMA Store prep for [M, K] output
466
+ tl.extra.cuda.experimental_device_tensormap_create2d(
467
+ desc_ptr=c_desc_ptr,
468
+ global_address=grad_input_ptr + M_start * K,
469
+ load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
470
+ global_size=[m_size, K],
471
+ element_ty=c_dtype,
472
+ )
473
+ tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
474
+
475
+ while tbidx >= processed_tiles and tbidx < (
476
+ processed_tiles + group_num_tiles
477
+ ):
478
+ group_index = tbidx - processed_tiles
479
+
480
+ # Different tiling scheme for [M, K] output
481
+ tile_m_index = group_index % num_m_tiles
482
+ tile_k_index = group_index // num_m_tiles
483
+
484
+ # for grad_input block [M, K]
485
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
486
+
487
+ # Position in full matrix
488
+ m_offset = (M_start + (tile_m_index * BLOCK_SIZE_M)).to(tl.int32)
489
+ k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
490
+
491
+ # reduce along N dimension (instead of K in forward)
492
+ for n_offset in range(0, N, BLOCK_SIZE_N):
493
+ # grad_output block [M, N]
494
+ grad_output = tl._experimental_descriptor_load(
495
+ grad_output_desc_ptr,
496
+ [m_offset, n_offset],
497
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
498
+ c_dtype,
499
+ )
500
+
501
+ # weight block [N, K] - no transpose needed
502
+ w = tl._experimental_descriptor_load(
503
+ w_desc_ptr,
504
+ [n_offset, k_offset],
505
+ [BLOCK_SIZE_N, BLOCK_SIZE_K],
506
+ c_dtype,
507
+ )
508
+
509
+ # grad_x = grad_output @ w
510
+ # reducing along N dimension
511
+ accumulator += tl.dot(grad_output, w)
512
+
513
+ # Store using TMA
514
+ m_offset = (tile_m_index * BLOCK_SIZE_M).to(tl.int32)
515
+ # k_offset = (tile_k_index * BLOCK_SIZE_K).to(tl.int32)
516
+
517
+ tl._experimental_descriptor_store(
518
+ c_desc_ptr,
519
+ accumulator.to(c_dtype),
520
+ [m_offset, k_offset],
521
+ )
522
+
523
+ # Move to the next tile
524
+ tbidx += NUM_SMS
525
+
526
+ # Update the total tiles count for the next group
527
+ processed_tiles += group_num_tiles
528
+
529
+
530
+ # ---- dw flat linear indexed ----
531
+
532
+
533
+ @triton.autotune(
534
+ configs=_NV_CONFIGS,
535
+ key=["G", "M_BUCKET", "N", "K"],
536
+ prune_configs_by={"early_config_prune": early_config_prune},
537
+ )
538
+ @triton.jit
539
+ def _kernel_mg_dw_tma(
540
+ x_desc_ptr, # input descriptor [M_total, K]
541
+ grad_output_desc_ptr, # grad_output descriptor [M_total, N]
542
+ grad_weight_ptr, # output grad_w [N, K]
543
+ workspace, # workspace for TMA store
544
+ m_sizes, # group sizes [G]
545
+ # problem sizes
546
+ G: tl.constexpr,
547
+ M_BUCKET: tl.constexpr,
548
+ N: tl.constexpr,
549
+ K: tl.constexpr,
550
+ # config
551
+ NUM_SMS: tl.constexpr,
552
+ USE_TMA_LOAD: tl.constexpr,
553
+ USE_TMA_STORE: tl.constexpr,
554
+ TMA_SIZE: tl.constexpr,
555
+ # tiles
556
+ BLOCK_SIZE_N: tl.constexpr,
557
+ BLOCK_SIZE_K: tl.constexpr,
558
+ BLOCK_SIZE_M: tl.constexpr, # block size for reduction dimension
559
+ ) -> None:
560
+ """
561
+ Improved TMA-optimized kernel for computing gradients with respect to weights (dw).
562
+ Uses flat index structure similar to forward.
563
+
564
+ For the forward pass Y = X @ W.T,
565
+ the backward for weights is:
566
+ grad_W = grad_Y.T @ X
567
+
568
+ Where:
569
+ - grad_Y is [MG, N]
570
+ - X is [MG, K]
571
+ - grad_W is [N, K]
572
+ - we return [N,K]
573
+ """
574
+ # Get thread block index l
575
+ tbidx = tl.program_id(0)
576
+
577
+ # Get output data type
578
+ c_dtype = grad_weight_ptr.dtype.element_ty
579
+
580
+ # Calculate number of output tiles
581
+ num_n_tiles = tl.cdiv(N, BLOCK_SIZE_N)
582
+ num_k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
583
+ total_output_tiles = num_n_tiles * num_k_tiles
584
+
585
+ # Process tiles in strided manner across SMs
586
+ for tile_idx in range(tbidx, total_output_tiles, NUM_SMS):
587
+ # Calculate tile indices
588
+ tile_n_idx = tile_idx % num_n_tiles
589
+ tile_k_idx = tile_idx // num_n_tiles
590
+
591
+ # Calculate global offsets
592
+ n_offset = tile_n_idx * BLOCK_SIZE_N
593
+ k_offset = tile_k_idx * BLOCK_SIZE_K
594
+
595
+ # Initialize accumulator for this output tile [N, K]
596
+ accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)
597
+
598
+ # Process each group
599
+ M_end = 0
600
+ for g in range(G):
601
+ # Get group boundaries
602
+ M_start = M_end
603
+ m_size = tl.load(m_sizes + g)
604
+ M_end = M_start + m_size
605
+
606
+ # Only process if group is non-empty
607
+ if m_size > 0:
608
+ # Process this group in chunks along the M dimension
609
+ for m_offset in range(0, m_size, BLOCK_SIZE_M):
610
+ # Calculate actual block size (handling boundary)
611
+ m_block_size = tl.minimum(BLOCK_SIZE_M, m_size - m_offset)
612
+
613
+ # Only process if we have actual work to do
614
+ if m_block_size > 0:
615
+ # Global offset for this chunk
616
+ m_global_offset = M_start + m_offset
617
+
618
+ if USE_TMA_LOAD:
619
+ # Load input chunk [M_chunk, K] using TMA
620
+ x_block = tl._experimental_descriptor_load(
621
+ x_desc_ptr,
622
+ [m_global_offset, k_offset],
623
+ [BLOCK_SIZE_M, BLOCK_SIZE_K],
624
+ c_dtype,
625
+ )
626
+
627
+ # Load grad_output chunk [M_chunk, N] using TMA
628
+ grad_output_block = tl._experimental_descriptor_load(
629
+ grad_output_desc_ptr,
630
+ [m_global_offset, n_offset],
631
+ [BLOCK_SIZE_M, BLOCK_SIZE_N],
632
+ c_dtype,
633
+ )
634
+
635
+ # Apply masks for valid regions
636
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
637
+ m_mask = offs_m < m_block_size
638
+
639
+ # Zero out invalid elements
640
+ x_block = tl.where(m_mask[:, None], x_block, 0.0)
641
+ grad_output_block = tl.where(
642
+ m_mask[:, None], grad_output_block, 0.0
643
+ )
644
+ else:
645
+ # Manual load with bounds checking
646
+ offs_m = tl.arange(0, BLOCK_SIZE_M)
647
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
648
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
649
+
650
+ # Create masks
651
+ m_mask = offs_m < m_block_size
652
+ n_mask = offs_n < N - n_offset
653
+ k_mask = offs_k < K - k_offset
654
+
655
+ # Combined masks
656
+ mk_mask = m_mask[:, None] & k_mask[None, :]
657
+ mn_mask = m_mask[:, None] & n_mask[None, :]
658
+
659
+ # Global offsets for loading
660
+ m_global_offs = m_global_offset + offs_m
661
+
662
+ # Load x block [M_chunk, K]
663
+ x_block = tl.load(
664
+ x_desc_ptr
665
+ + m_global_offs[:, None] * K
666
+ + (k_offset + offs_k)[None, :],
667
+ mask=mk_mask,
668
+ other=0.0,
669
+ )
670
+
671
+ # Load grad_output block [M_chunk, N]
672
+ grad_output_block = tl.load(
673
+ grad_output_desc_ptr
674
+ + m_global_offs[:, None] * N
675
+ + (n_offset + offs_n)[None, :],
676
+ mask=mn_mask,
677
+ other=0.0,
678
+ )
679
+
680
+ # Compute partial contribution: grad_W += grad_Y.T @ X
681
+ # transpose grad_output for the matmul
682
+ contribution = tl.dot(
683
+ grad_output_block.to(tl.float32).T, # [N, M_chunk]
684
+ x_block.to(tl.float32), # [M_chunk, K]
685
+ )
686
+
687
+ # Accumulate
688
+ accumulator += contribution
689
+
690
+ # Store the result
691
+ if USE_TMA_STORE:
692
+ # Store using TMA
693
+ tl._experimental_descriptor_store(
694
+ workspace, # TMA store descriptor
695
+ accumulator.to(c_dtype),
696
+ [n_offset, k_offset],
697
+ )
698
+ else:
699
+ # Manual store with bounds checking
700
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
701
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
702
+
703
+ # Create masks for bounds checking
704
+ n_mask = offs_n < N - n_offset
705
+ k_mask = offs_k < K - k_offset
706
+ output_mask = n_mask[:, None] & k_mask[None, :]
707
+
708
+ # Store the result
709
+ tl.store(
710
+ grad_weight_ptr
711
+ + (n_offset + offs_n)[:, None] * K
712
+ + (k_offset + offs_k)[None, :],
713
+ accumulator.to(c_dtype),
714
+ mask=output_mask,
715
+ )
716
+
717
+
718
+ # ======== End Triton kernels ========
719
+
720
+ # ======== Triton wrapper functions ========
721
+
722
+ # ----- main forward pass wrapper -----
723
+
724
+
725
+ def grouped_gemm_forward(
726
+ x: torch.Tensor,
727
+ w: torch.Tensor,
728
+ m_sizes: torch.Tensor,
729
+ tma_size: int = 128,
730
+ ) -> torch.Tensor:
731
+ """
732
+ M*G style grouped GEMM with TMA and Float8 support.
733
+ # Removed for now - FP8 support is triggered by passing x_scale and w_scale tensors.
734
+
735
+ """
736
+ if not CudaUtils.verify_tma():
737
+ raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
738
+
739
+ G = m_sizes.shape[0]
740
+
741
+ assert x.is_contiguous()
742
+ assert w.is_contiguous()
743
+ assert m_sizes.is_contiguous()
744
+
745
+ # Total input size is now [M_total, K] where M_total is the sum of all group sizes
746
+ M_total, K = x.shape
747
+ N = w.shape[0] # N is now the same for all groups
748
+
749
+ assert K == w.shape[1], f"Input K ({K}) must match weight K ({w.shape[1]})"
750
+
751
+ # Verify that all group sizes are multiples of ALIGN_SIZE_M
752
+ # This check is commented out because it will involve a GPU-CPU sync
753
+ # assert torch.remainder(m_sizes, ALIGN_SIZE_M).max() == 0, "Group sizes must be a multiple of ALIGN_SIZE_M"
754
+
755
+ # Create output tensor with correct shape [M_total, N]
756
+ y = torch.empty((M_total, N // G), device=x.device, dtype=x.dtype)
757
+
758
+ if M_total == 0:
759
+ return y
760
+
761
+ NUM_SMS = CudaUtils.get_num_sms()
762
+ USE_TMA_LOAD = True
763
+ USE_TMA_STORE = True
764
+ USE_EPILOGUE_SUBTILING = False
765
+
766
+ # TMA descriptor helper
767
+ desc_helper = None
768
+ desc_x = x
769
+ desc_w = w
770
+ workspace = None
771
+
772
+ if USE_TMA_LOAD:
773
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
774
+ desc_helper.init_tma_descriptor("x")
775
+ desc_helper.init_tma_descriptor("w")
776
+ desc_x = desc_helper.get_tma_descriptor_kernel_param("x")
777
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
778
+
779
+ if USE_TMA_STORE:
780
+ workspace = torch.empty(
781
+ NUM_SMS * desc_helper.tma_size,
782
+ device=x.device,
783
+ dtype=torch.uint8,
784
+ )
785
+
786
+ def grid(META):
787
+ if USE_TMA_LOAD:
788
+ nonlocal desc_helper
789
+ desc_helper.fill_2d_tma_descriptor(
790
+ "x",
791
+ x.data_ptr(),
792
+ M_total,
793
+ K,
794
+ META["BLOCK_SIZE_M"],
795
+ META["BLOCK_SIZE_K"],
796
+ x.element_size(),
797
+ )
798
+
799
+ desc_helper.fill_2d_tma_descriptor(
800
+ "w",
801
+ w.data_ptr(),
802
+ N,
803
+ K,
804
+ META["BLOCK_SIZE_N"],
805
+ META["BLOCK_SIZE_K"],
806
+ w.element_size(),
807
+ )
808
+ return (NUM_SMS,)
809
+
810
+ M_BUCKET = triton.next_power_of_2(M_total)
811
+
812
+ _kernel_mg_forward_hopper[grid](
813
+ desc_x,
814
+ desc_w,
815
+ y,
816
+ workspace,
817
+ m_sizes,
818
+ G,
819
+ M_BUCKET,
820
+ N,
821
+ K,
822
+ NUM_SMS,
823
+ TMA_SIZE=tma_size,
824
+ USE_EPILOGUE_SUBTILING=USE_EPILOGUE_SUBTILING,
825
+ )
826
+
827
+ return y
828
+
829
+
830
+ # ======== Improved Backward =============
831
+ def grouped_gemm_backward(
832
+ grad_output: torch.Tensor,
833
+ x: torch.Tensor,
834
+ w: torch.Tensor,
835
+ m_sizes: torch.Tensor,
836
+ use_tma: bool = True,
837
+ tma_size: int = 128,
838
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
839
+ """
840
+ Unified backward pass for grouped GeMM with M*G grouping.
841
+ Uses optimized TMA-based implementations for both dx and dw when available.
842
+
843
+ Args:
844
+ grad_output: Gradient of output, shape [M_total, N]
845
+ x: Input tensor from forward pass, shape [M_total, K]
846
+ w: Weight tensor from forward pass, shape [N, K]
847
+ m_sizes: Group sizes tensor, shape [G]
848
+ use_tma: Whether to try using TMA acceleration (if available)
849
+ tma_size: Size of TMA descriptor in bytes
850
+
851
+
852
+ Returns:
853
+ Tuple of gradients with respect to x and w: (grad_x, grad_w)
854
+ """
855
+ logging.info("Starting unified grouped_gemm_backward")
856
+
857
+ # do this once, seems expensive
858
+ NUM_SMS = CudaUtils.get_num_sms()
859
+
860
+ # Basic validation
861
+ G = m_sizes.shape[0]
862
+ M_total, K_x = x.shape
863
+ M_grad, N = grad_output.shape
864
+ N_w, K_w = w.shape
865
+
866
+ # Check dimensions
867
+ if K_x != K_w:
868
+ raise ValueError(f"K dimension mismatch: x has K={K_x}, w has K={K_w}")
869
+ if M_total != M_grad:
870
+ raise ValueError(
871
+ f"M dimension mismatch: x has M={M_total}, grad_output has M={M_grad}"
872
+ )
873
+
874
+ # Check total M matches sum of group sizes
875
+ sum_m_sizes = m_sizes.sum().item()
876
+ if M_total != sum_m_sizes:
877
+ raise ValueError(
878
+ f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
879
+ )
880
+
881
+ # Make sure inputs are contiguous
882
+ grad_output = grad_output.contiguous()
883
+ x = x.contiguous()
884
+ w = w.contiguous()
885
+ m_sizes = m_sizes.contiguous()
886
+
887
+ # Check TMA support
888
+ can_use_tma = use_tma and CudaUtils.verify_tma()
889
+ if use_tma and not can_use_tma:
890
+ logging.info("TMA requested but not supported on this device")
891
+ use_tma = False
892
+
893
+ # Compute grad_x using flat linear implementation
894
+ try:
895
+ logging.info(f"Computing grad_x with flat linear kernel")
896
+
897
+ # Use TMA-optimized implementation
898
+ grad_x = grouped_gemm_dx_tma(
899
+ grad_output=grad_output,
900
+ w=w,
901
+ m_sizes=m_sizes,
902
+ num_sms=NUM_SMS,
903
+ tma_size=tma_size,
904
+ )
905
+
906
+ except Exception as e:
907
+ logging.error(f"Error in grad_x computation: {e}")
908
+ raise
909
+
910
+ # Compute grad_w using flat linear style implementation
911
+ try:
912
+ logging.info(f"Computing grad_w with flat linear kernel")
913
+
914
+ grad_w = grouped_gemm_dw_tma(
915
+ x, grad_output, m_sizes, num_sms=NUM_SMS, tma_size=tma_size
916
+ )
917
+ except Exception as e:
918
+ logging.error(f"Error in grad_w computation: {e}")
919
+ raise
920
+
921
+ return grad_x, grad_w
922
+
923
+
924
+ # ----- dx backward pass wrapper -----
925
+
926
+
927
+ def grouped_gemm_dx_tma(
928
+ grad_output: torch.Tensor,
929
+ w: torch.Tensor,
930
+ m_sizes: torch.Tensor,
931
+ num_sms: int = 132,
932
+ tma_size: int = 128,
933
+ ) -> torch.Tensor:
934
+ """
935
+ Optimized backward pass wrapper for computing gradient with respect to input (dx)
936
+ using TMA patterns similar to the forward pass.
937
+
938
+ Args:
939
+ grad_output: Gradient of output, shape [M_total, N]
940
+ w: Weight tensor, shape [N, K]
941
+ m_sizes: Group sizes tensor, shape [G]
942
+ tma_size: Size of TMA descriptor
943
+ # using_fp8: Whether to use FP8 quantization
944
+ # grad_output_scale: Scale for grad_output in FP8 mode
945
+ # w_scale: Scale for w in FP8 mode
946
+
947
+ Returns:
948
+ grad_x: Gradient with respect to x, shape [M_total, K]
949
+ """
950
+ """
951
+ Optimized backward pass for computing gradient with respect to input (dx)
952
+ using TMA patterns similar to the forward pass.
953
+
954
+ Args:
955
+ grad_output: Gradient of output, shape [M_total, N]
956
+ w: Weight tensor, shape [N, K]
957
+ m_sizes: Group sizes tensor, shape [G]
958
+ tma_size: Size of TMA descriptor
959
+ using_fp8: Whether to use FP8 quantization
960
+ # grad_output_scale: Scale for grad_output in FP8 mode
961
+ # w_scale: Scale for w in FP8 mode
962
+
963
+ Returns:
964
+ grad_x: Gradient with respect to x, shape [M_total, K]
965
+ """
966
+ if not CudaUtils.verify_tma():
967
+ raise NotImplementedError("Optimized dx computation requires TMA support")
968
+
969
+ G = m_sizes.shape[0]
970
+
971
+ assert grad_output.is_contiguous()
972
+ assert w.is_contiguous()
973
+ assert m_sizes.is_contiguous()
974
+
975
+ M_total, N_grad = grad_output.shape
976
+ N_w, K = w.shape
977
+
978
+ # Check dimensions
979
+ assert N_grad == N_w, f"Grad_output N ({N_grad}) must match weight N ({N_w})"
980
+
981
+ # Verify that the sum of m_sizes matches M_total
982
+ sum_m_sizes = m_sizes.sum().item()
983
+ assert (
984
+ M_total == sum_m_sizes
985
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
986
+
987
+ # Create output tensor (grad_x) with shape [M_total, K]
988
+ grad_x = torch.empty(
989
+ (M_total, K), device=grad_output.device, dtype=grad_output.dtype
990
+ )
991
+
992
+ NUM_SMS = num_sms # CudaUtils.get_num_sms()
993
+ USE_TMA_LOAD = True
994
+ USE_TMA_STORE = True
995
+
996
+ # Set up TMA descriptors
997
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
998
+ desc_helper.init_tma_descriptor("grad_output")
999
+ desc_helper.init_tma_descriptor("w")
1000
+ desc_grad_output = desc_helper.get_tma_descriptor_kernel_param("grad_output")
1001
+ desc_w = desc_helper.get_tma_descriptor_kernel_param("w")
1002
+
1003
+ # Allocate workspace for TMA store
1004
+ workspace = torch.empty(
1005
+ NUM_SMS * desc_helper.tma_size,
1006
+ device=grad_output.device,
1007
+ dtype=torch.uint8,
1008
+ )
1009
+
1010
+ def grid(META):
1011
+ # Fill TMA descriptors with appropriate dimensions
1012
+ desc_helper.fill_2d_tma_descriptor(
1013
+ "grad_output",
1014
+ grad_output.data_ptr(),
1015
+ M_total,
1016
+ N_grad,
1017
+ META["BLOCK_SIZE_M"],
1018
+ META["BLOCK_SIZE_N"],
1019
+ grad_output.element_size(),
1020
+ )
1021
+
1022
+ desc_helper.fill_2d_tma_descriptor(
1023
+ "w",
1024
+ w.data_ptr(),
1025
+ N_w,
1026
+ K,
1027
+ META["BLOCK_SIZE_N"],
1028
+ META["BLOCK_SIZE_K"],
1029
+ w.element_size(),
1030
+ )
1031
+ return (NUM_SMS,)
1032
+
1033
+ M_BUCKET = triton.next_power_of_2(M_total)
1034
+
1035
+ # Launch the flat linear kernel for computing grad_x
1036
+ _kernel_mg_dx_tma[grid](
1037
+ desc_grad_output,
1038
+ desc_w,
1039
+ grad_x,
1040
+ workspace,
1041
+ m_sizes,
1042
+ G,
1043
+ M_BUCKET,
1044
+ N_grad, # N dimension is now the reduction dimension
1045
+ K,
1046
+ NUM_SMS,
1047
+ USE_TMA_LOAD,
1048
+ USE_TMA_STORE,
1049
+ TMA_SIZE=tma_size,
1050
+ )
1051
+
1052
+ return grad_x
1053
+
1054
+
1055
+ # ======== dw wrapper function ==========
1056
+
1057
+
1058
+ def grouped_gemm_dw_tma(
1059
+ x: torch.Tensor,
1060
+ grad_output: torch.Tensor,
1061
+ m_sizes: torch.Tensor,
1062
+ num_sms: int = 132,
1063
+ tma_size: int = 128,
1064
+ ) -> torch.Tensor:
1065
+ """
1066
+ Optimized flat linear kernel computation of gradients with respect to weights (dw) using TMA.
1067
+ For the forward pass Y = X @ W.T, the backward for weights is:
1068
+ grad_W = grad_Y.T @ X
1069
+
1070
+ Args:
1071
+ x: Input tensor, shape [M_total, K]
1072
+ grad_output: Gradient of output, shape [M_total, N]
1073
+ m_sizes: Group sizes tensor, shape [G]
1074
+ tma_size: Size of TMA descriptor in bytes
1075
+
1076
+
1077
+ Returns:
1078
+ grad_w: Gradient with respect to weights, shape [N, K]
1079
+ """
1080
+ # Check TMA support
1081
+ has_tma_support = CudaUtils.verify_tma()
1082
+
1083
+ # Get group count
1084
+ G = m_sizes.shape[0]
1085
+
1086
+ # Ensure contiguous tensors
1087
+ x = x.contiguous()
1088
+ grad_output = grad_output.contiguous()
1089
+ m_sizes = m_sizes.contiguous()
1090
+
1091
+ # Get dimensions
1092
+ M_total, K_x = x.shape
1093
+ M_grad, N = grad_output.shape
1094
+
1095
+ # Check dimensions
1096
+ assert M_total == M_grad, f"x M ({M_total}) must match grad_output M ({M_grad})"
1097
+
1098
+ # Verify that the sum of m_sizes matches M_total
1099
+ sum_m_sizes = m_sizes.sum().item()
1100
+ assert (
1101
+ sum_m_sizes == M_total
1102
+ ), f"Sum of m_sizes ({sum_m_sizes}) must match M_total ({M_total})"
1103
+
1104
+ # Create output tensor (grad_w) with shape [N, K]
1105
+ grad_w = torch.zeros((N, K_x), device=x.device, dtype=x.dtype)
1106
+
1107
+ NUM_SMS = num_sms
1108
+
1109
+ # TODO - hardcoded for now...but should set TMA flags based on hardware support
1110
+ USE_TMA_LOAD = True # has_tma_support
1111
+ USE_TMA_STORE = True # has_tma_support
1112
+
1113
+ # Set up TMA descriptors or direct pointers
1114
+ if USE_TMA_LOAD or USE_TMA_STORE:
1115
+ desc_helper = TmaDescriptorHelper(tma_size=tma_size)
1116
+
1117
+ if USE_TMA_LOAD:
1118
+ desc_helper.init_tma_descriptor("x")
1119
+ desc_helper.init_tma_descriptor("grad_output")
1120
+ x_desc = desc_helper.get_tma_descriptor_kernel_param("x")
1121
+ grad_output_desc = desc_helper.get_tma_descriptor_kernel_param(
1122
+ "grad_output"
1123
+ )
1124
+ else:
1125
+ x_desc = x
1126
+ grad_output_desc = grad_output
1127
+
1128
+ if USE_TMA_STORE:
1129
+ desc_helper.init_tma_descriptor("grad_w")
1130
+ workspace = desc_helper.get_tma_descriptor_kernel_param("grad_w")
1131
+ else:
1132
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1133
+ else:
1134
+ # If not using TMA, just use the tensors directly
1135
+ x_desc = x
1136
+ grad_output_desc = grad_output
1137
+ workspace = torch.empty(1, device=x.device, dtype=torch.uint8)
1138
+
1139
+ # M_BUCKET for grid size
1140
+ M_BUCKET = triton.next_power_of_2(M_total)
1141
+
1142
+ # Define grid for kernel launch
1143
+ def grid(META):
1144
+ if USE_TMA_LOAD or USE_TMA_STORE:
1145
+
1146
+ if USE_TMA_LOAD:
1147
+ desc_helper.fill_2d_tma_descriptor(
1148
+ "x",
1149
+ x.data_ptr(),
1150
+ M_total,
1151
+ K_x,
1152
+ META["BLOCK_SIZE_M"],
1153
+ META["BLOCK_SIZE_K"],
1154
+ x.element_size(),
1155
+ )
1156
+
1157
+ desc_helper.fill_2d_tma_descriptor(
1158
+ "grad_output",
1159
+ grad_output.data_ptr(),
1160
+ M_total,
1161
+ N,
1162
+ META["BLOCK_SIZE_M"],
1163
+ META["BLOCK_SIZE_N"],
1164
+ grad_output.element_size(),
1165
+ )
1166
+
1167
+ if USE_TMA_STORE:
1168
+ desc_helper.fill_2d_tma_descriptor(
1169
+ "grad_w",
1170
+ grad_w.data_ptr(),
1171
+ N,
1172
+ K_x,
1173
+ META["BLOCK_SIZE_N"],
1174
+ META["BLOCK_SIZE_K"],
1175
+ grad_w.element_size(),
1176
+ )
1177
+
1178
+ # Return grid size - one block per SM for balanced work distribution
1179
+ return (NUM_SMS,)
1180
+
1181
+ # Launch the optimized kernel
1182
+ _kernel_mg_dw_tma[grid](
1183
+ x_desc,
1184
+ grad_output_desc,
1185
+ grad_w,
1186
+ workspace,
1187
+ m_sizes,
1188
+ G,
1189
+ M_BUCKET,
1190
+ N,
1191
+ K_x,
1192
+ NUM_SMS,
1193
+ USE_TMA_LOAD,
1194
+ USE_TMA_STORE,
1195
+ TMA_SIZE=tma_size,
1196
+ )
1197
+
1198
+ return grad_w
1199
+
1200
+
1201
+ # ======== End Backwards Wrapper Functions =============
1202
+
1203
+ # ======== PyTorch wrapper functions ========
1204
+
1205
+
1206
+ class GroupedGEMM_mg(torch.autograd.Function):
1207
+ """
1208
+ Autograd function for GroupedGEMM with M*G grouping.
1209
+ Supports both standard and FP8 quantized operations.
1210
+ """
1211
+
1212
+ @staticmethod
1213
+ def forward(ctx, x, w, m_sizes, use_tma=True, tma_size=128):
1214
+ """
1215
+ Forward pass of GroupedGEMM.
1216
+
1217
+ Args:
1218
+ x: Input tensor, shape [M_total, K]
1219
+ w: Weight tensor, shape [N, K]
1220
+ m_sizes: Tensor of shape [G] containing the size of each group
1221
+ use_tma: Whether to try using TMA acceleration (if available)
1222
+ tma_size: Size of TMA descriptor in bytes
1223
+ using_fp8: Whether to use FP8 quantization
1224
+
1225
+ Returns:
1226
+ Output tensor, shape [M_total, N]
1227
+ """
1228
+
1229
+ # Use regular forward without quantization
1230
+ output = grouped_gemm_forward(
1231
+ x=x, w=w, m_sizes=m_sizes, tma_size=tma_size, using_fp8=False
1232
+ )
1233
+
1234
+ # Save inputs and parameters for backward pass
1235
+ ctx.save_for_backward(x, w, m_sizes)
1236
+ ctx.use_tma = use_tma
1237
+ ctx.tma_size = tma_size
1238
+
1239
+ ctx.save_for_backward(x, w, m_sizes)
1240
+
1241
+ return output
1242
+
1243
+ @staticmethod
1244
+ def backward(ctx, grad_output):
1245
+ """
1246
+ Backward pass of M*G GroupedGEMM.
1247
+
1248
+ Args:
1249
+ grad_output: Gradient of output, shape [M_total, N]
1250
+
1251
+ Returns:
1252
+ Tuple of gradients:
1253
+ - grad_x: Gradient with respect to x, shape [M_total, K]
1254
+ - grad_w: Gradient with respect to w, shape [N, K]
1255
+ - None: Gradient with respect to m_sizes (not differentiable)
1256
+ - None: Gradient with respect to use_tma (not differentiable)
1257
+ - None: Gradient with respect to tma_size (not differentiable)
1258
+
1259
+ """
1260
+ # Retrieve saved tensors and parameters
1261
+
1262
+ x, w, m_sizes = ctx.saved_tensors
1263
+
1264
+ use_tma = ctx.use_tma
1265
+ tma_size = ctx.tma_size
1266
+
1267
+ # Compute gradients using the unified implementation
1268
+ grad_x, grad_w = grouped_gemm_backward(
1269
+ grad_output=grad_output,
1270
+ x=x,
1271
+ w=w,
1272
+ m_sizes=m_sizes,
1273
+ use_tma=use_tma,
1274
+ tma_size=tma_size,
1275
+ )
1276
+
1277
+ # Return gradients for all inputs (None for non-differentiable parameters)
1278
+ return grad_x, grad_w, None, None
1279
+
1280
+
1281
+ def mg_grouped_gemm(
1282
+ x: torch.Tensor,
1283
+ w: torch.Tensor,
1284
+ m_sizes: torch.Tensor,
1285
+ use_tma: bool = True,
1286
+ tma_size: int = 128,
1287
+ using_fp8: bool = False,
1288
+ ) -> torch.Tensor:
1289
+ """
1290
+ Unified differentiable grouped GEMM operation for M*G grouped GEMM.
1291
+ Supports both standard precision and FP8 quantized operations.
1292
+
1293
+ Args:
1294
+ x: Input tensor, shape [M_total, K]
1295
+ w: Weight tensor, shape [N, K]
1296
+ m_sizes: Tensor of shape [G] containing the size of each group
1297
+ use_tma: Whether to try using TMA acceleration (if available)
1298
+ tma_size: Size of TMA descriptor in bytes
1299
+ using_fp8: Whether to use FP8 quantization
1300
+
1301
+ Returns:
1302
+ Output tensor, shape [M_total, N]
1303
+ """
1304
+ return GroupedGEMM_mg.apply(x, w, m_sizes, use_tma, tma_size, using_fp8)
torchtitan/experiments/llama4/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from torchtitan.components.loss import build_cross_entropy_loss
8
+ from torchtitan.components.lr_scheduler import build_lr_schedulers
9
+ from torchtitan.components.optimizer import build_optimizers
10
+ from torchtitan.datasets.hf_datasets import build_hf_dataloader
11
+ from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
12
+ from torchtitan.models.llama3 import pipeline_llama
13
+ from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
14
+
15
+ from .infra.parallelize_llama import parallelize_llama
16
+ from .model.args import TransformerModelArgs
17
+ from .model.model import Transformer
18
+
19
+ __all__ = [
20
+ "TransformerModelArgs",
21
+ "Transformer",
22
+ "llama4_configs",
23
+ ]
24
+
25
+
26
+ llama4_configs = {
27
+ "debugmodel": TransformerModelArgs(
28
+ dim=256,
29
+ n_layers=8,
30
+ n_heads=16,
31
+ rope_theta=500000,
32
+ ),
33
+ "17bx16e": TransformerModelArgs(
34
+ dim=5120,
35
+ n_layers=48,
36
+ n_heads=40,
37
+ n_kv_heads=8,
38
+ ffn_dim_multiplier=1.2,
39
+ multiple_of=2048,
40
+ rope_theta=500000,
41
+ num_experts=16,
42
+ interleave_moe_layer_step=1,
43
+ ),
44
+ "17bx128e": TransformerModelArgs(
45
+ dim=5120,
46
+ n_layers=48,
47
+ n_heads=40,
48
+ n_kv_heads=8,
49
+ ffn_dim_multiplier=1.2,
50
+ multiple_of=2048,
51
+ rope_theta=500000,
52
+ num_experts=128,
53
+ ),
54
+ }
55
+
56
+
57
+ register_train_spec(
58
+ TrainSpec(
59
+ name="llama4",
60
+ cls=Transformer,
61
+ config=llama4_configs,
62
+ parallelize_fn=parallelize_llama,
63
+ pipelining_fn=pipeline_llama,
64
+ build_optimizers_fn=build_optimizers,
65
+ build_lr_schedulers_fn=build_lr_schedulers,
66
+ build_dataloader_fn=build_hf_dataloader,
67
+ build_tokenizer_fn=build_tiktoken_tokenizer,
68
+ build_loss_fn=build_cross_entropy_loss,
69
+ )
70
+ )
torchtitan/experiments/llama4/infra/parallelize_llama.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.distributed.device_mesh import DeviceMesh
11
+
12
+ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
13
+ from torchtitan.distributed import ParallelDims
14
+
15
+ from torchtitan.models.llama3.parallelize_llama import (
16
+ apply_ac,
17
+ apply_compile,
18
+ apply_ddp,
19
+ apply_fsdp,
20
+ apply_tp,
21
+ )
22
+ from torchtitan.tools.logging import logger
23
+
24
+
25
+ def parallelize_llama(
26
+ model: nn.Module,
27
+ world_mesh: DeviceMesh,
28
+ parallel_dims: ParallelDims,
29
+ job_config: JobConfig,
30
+ ):
31
+ """
32
+ Apply tensor parallelism, activation checkpointing, torch.compile, and data
33
+ parallelism to the model.
34
+
35
+ NOTE: The passed-in model preferably should be on meta device. Otherwise,
36
+ the model must fit on GPU or CPU memory.
37
+ """
38
+
39
+ if parallel_dims.tp_enabled:
40
+ if (
41
+ job_config.parallelism.enable_async_tensor_parallel
42
+ and not job_config.training.compile
43
+ ):
44
+ raise RuntimeError("Async TP requires --training.compile")
45
+
46
+ enable_float8_linear = "float8" in job_config.model.converters
47
+ float8_is_rowwise = job_config.float8.recipe_name in (
48
+ "rowwise",
49
+ "rowwise_with_gw_hp",
50
+ )
51
+
52
+ # For now, float8 all-gather with TP is only supported for tensorwise
53
+ # float8 scaling recipes. For rowwise recipes, we use regular TP and
54
+ # all-gather happens in high precision.
55
+ enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
56
+
57
+ apply_tp(
58
+ model,
59
+ world_mesh["tp"],
60
+ loss_parallel=parallel_dims.loss_parallel_enabled,
61
+ enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
62
+ enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
63
+ )
64
+
65
+ apply_moe_tp(model, world_mesh["tp"])
66
+
67
+ if job_config.activation_checkpoint.mode != "none":
68
+ if (
69
+ job_config.activation_checkpoint.mode == "selective"
70
+ and job_config.model.use_flex_attn
71
+ ):
72
+ raise ValueError(
73
+ "FlexAttention is not compatible with selective AC yet. "
74
+ "See https://github.com/pytorch/pytorch/issues/147879"
75
+ )
76
+ apply_ac(model, job_config.activation_checkpoint)
77
+
78
+ # turn on per-TransformerBlock compile after AC wrapping and before FSDP
79
+ if job_config.training.compile:
80
+ apply_compile(model)
81
+
82
+ # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
83
+ torch._dynamo.config.capture_scalar_outputs = True
84
+
85
+ if (
86
+ parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled
87
+ ): # apply FSDP or HSDP, potentially with Context Parallel
88
+ if parallel_dims.dp_replicate_enabled:
89
+ dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
90
+ else:
91
+ dp_mesh_dim_names = ("dp_shard_cp",)
92
+
93
+ apply_fsdp(
94
+ model,
95
+ world_mesh[tuple(dp_mesh_dim_names)],
96
+ param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
97
+ reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
98
+ pp_enabled=parallel_dims.pp_enabled,
99
+ cpu_offload=job_config.training.enable_cpu_offload,
100
+ reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
101
+ )
102
+
103
+ if parallel_dims.dp_replicate_enabled:
104
+ logger.info("Applied HSDP to the model")
105
+ else:
106
+ logger.info("Applied FSDP to the model")
107
+
108
+ if parallel_dims.cp_enabled:
109
+ logger.info("Applied Context Parallel to the model")
110
+
111
+ if job_config.training.enable_cpu_offload:
112
+ logger.info("Applied CPU Offloading to the model")
113
+ elif parallel_dims.dp_replicate_enabled:
114
+ if world_mesh.ndim > 1:
115
+ raise RuntimeError("DDP has not supported > 1D parallelism")
116
+ apply_ddp(
117
+ model,
118
+ world_mesh,
119
+ enable_compile=job_config.training.compile,
120
+ enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
121
+ )
122
+
123
+ return model
124
+
125
+
126
+ def apply_moe_tp(
127
+ model: nn.Module,
128
+ tp_mesh: DeviceMesh,
129
+ ):
130
+ from torch.distributed.tensor import Partial, Replicate, Shard
131
+ from torch.distributed.tensor.parallel import (
132
+ parallelize_module,
133
+ PrepareModuleInputOutput,
134
+ )
135
+
136
+ from .expert_parallel import NoParallel, TensorParallel
137
+
138
+ for _, transformer_block in model.layers.items():
139
+ moe_layer_plan = {
140
+ # input / output sharding on the seqlen dim
141
+ # all-gather for input, reduce-scatter for output
142
+ "moe": PrepareModuleInputOutput(
143
+ input_layouts=(Shard(1),),
144
+ desired_input_layouts=(Replicate(),),
145
+ use_local_input=True,
146
+ output_layouts=(Partial(),),
147
+ desired_output_layouts=(Shard(1),),
148
+ ),
149
+ # replicate computation for the router
150
+ "moe.router.gate": NoParallel(),
151
+ # input Replicate, output Partial
152
+ "moe.experts": TensorParallel(),
153
+ "moe.shared_expert": TensorParallel(),
154
+ }
155
+ parallelize_module(
156
+ module=transformer_block,
157
+ device_mesh=tp_mesh,
158
+ parallelize_plan=moe_layer_plan,
159
+ )
torchtitan/experiments/llama4/model/moe.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from .args import TransformerModelArgs
12
+
13
+
14
+ class GroupedExperts(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ hidden_dim: int,
19
+ num_experts: int,
20
+ ):
21
+ super().__init__()
22
+ self.num_experts = num_experts
23
+ self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
24
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
25
+ self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26
+
27
+ def forward(
28
+ self,
29
+ x: torch.Tensor,
30
+ num_local_tokens_per_expert: torch.Tensor | None = None,
31
+ ) -> torch.Tensor:
32
+ if num_local_tokens_per_expert is not None:
33
+ # a tuple of tensors indexed by experts
34
+ # each with shape (tokens_per_expert(varying), dim)
35
+ x = torch.split(
36
+ x,
37
+ split_size_or_sections=num_local_tokens_per_expert.tolist(),
38
+ dim=0,
39
+ )
40
+ out_experts_splits = []
41
+ for expert_idx, x_expert in enumerate(x):
42
+ w1, w2, w3 = (
43
+ self.w1[expert_idx],
44
+ self.w2[expert_idx],
45
+ self.w3[expert_idx],
46
+ )
47
+ h = F.silu(torch.matmul(x_expert, w1))
48
+ h = h * torch.matmul(x_expert, w3)
49
+ h = torch.matmul(h, w2)
50
+ # h shape (tokens_per_expert(varying), dim)
51
+ out_experts_splits.append(h)
52
+ out = torch.cat(out_experts_splits, dim=0)
53
+
54
+ # TODO:optimize with GroupedGEMM
55
+ # https://github.com/pytorch/pytorch/pull/150374
56
+ # _gouped_mm requires shapes to be multiple of 8
57
+ # offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
58
+ # h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
59
+ # h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
60
+ # out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
61
+ else:
62
+ # x shape (num_experts, tokens_per_expert, dim)
63
+ h = F.silu(torch.bmm(x, self.w1))
64
+ h = h * torch.bmm(x, self.w3)
65
+ # out shape (num_experts, tokens_per_expert, dim)
66
+ out = torch.bmm(h, self.w2)
67
+ return out
68
+
69
+ def init_weights(self, init_std: float):
70
+ nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
71
+ nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
72
+ nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
73
+
74
+
75
+ class TokenChoiceTopKRouter(nn.Module):
76
+ """This class implements token-choice routing. In token-choice top-K routing, each token is
77
+ routed to top K experts based on the router scores.
78
+
79
+ Args:
80
+ gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts).
81
+ dim (int): Dimension of input tokens.
82
+ num_experts (int): Number of experts in each moe layer.
83
+ top_k (int): Number of experts each token will be routed to in token-choice routing.
84
+ use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ dim: int,
90
+ num_experts: int,
91
+ top_k: int,
92
+ use_sigmoid: bool = False,
93
+ ):
94
+ super().__init__()
95
+ self.gate = nn.Linear(dim, num_experts, bias=False)
96
+ self.num_experts = num_experts
97
+ self.top_k = top_k
98
+ self.use_sigmoid = use_sigmoid
99
+
100
+ def forward(
101
+ self, x: torch.Tensor
102
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
103
+ """
104
+ Args:
105
+ x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``.
106
+
107
+ Returns:
108
+ routed_input (torch.Tensor):
109
+ Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``.
110
+ token_indices (torch.Tensor):
111
+ Token indices for routed_input with shape ``(bs*slen*top_k,)``.
112
+ num_local_tokens_per_expert (torch.Tensor):
113
+ Number of tokens assigned to each expert with shape ``(num_experts,)``.
114
+ """
115
+ # scores shape (bs*slen, num_experts)
116
+ scores = self.gate(x)
117
+
118
+ # By default, sigmoid or softmax is performed in float32 to avoid loss explosion
119
+ if self.use_sigmoid:
120
+ scores = torch.sigmoid(scores.to(torch.float32)).to(x.dtype)
121
+ else:
122
+ scores = F.softmax(scores.to(torch.float32), dim=1).to(x.dtype)
123
+
124
+ # top scores shape (bs*slen, top_k)
125
+ top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=1)
126
+ # top_scores /= top_scores.sum(dim=-1, keep_dim=True).to(x.dtype)
127
+
128
+ # group tokens together by expert indices from 0 to num_experts and pass that to experts forward
129
+ num_local_tokens_per_expert = torch.histc(
130
+ selected_experts_indices.view(-1),
131
+ bins=self.num_experts,
132
+ min=0,
133
+ max=self.num_experts,
134
+ )
135
+ # token_indices_experts_sorted shape (bs*slen*top_k,)
136
+ token_indices_experts_sorted = torch.argsort(
137
+ selected_experts_indices.view(-1), stable=True
138
+ )
139
+ top_scores = top_scores.view(-1)[token_indices_experts_sorted]
140
+ token_indices_experts_sorted = token_indices_experts_sorted // self.top_k
141
+
142
+ return top_scores, token_indices_experts_sorted, num_local_tokens_per_expert
143
+
144
+ def init_weights(self, init_std: float):
145
+ nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
146
+
147
+
148
+ # TODO: implement load balancing auxiliary loss for token-choice routing
149
+ class MoE(nn.Module):
150
+ def __init__(self, model_args: TransformerModelArgs):
151
+ super().__init__()
152
+ dim = model_args.dim
153
+ hidden_dim = 4 * model_args.dim
154
+ ffn_dim_multiplier = model_args.ffn_dim_multiplier
155
+ hidden_dim = int(2 * hidden_dim / 3)
156
+ if ffn_dim_multiplier is not None:
157
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
158
+
159
+ num_experts = model_args.num_experts
160
+
161
+ hidden_dim_denom = 1
162
+ if model_args.auto_scale_hidden_dim:
163
+ hidden_dim_denom = model_args.top_k + int(model_args.use_shared_expert)
164
+
165
+ if model_args.auto_scale_hidden_dim:
166
+ hidden_dim = int(hidden_dim / hidden_dim_denom)
167
+ hidden_dim += -hidden_dim % model_args.multiple_of
168
+
169
+ self.experts = GroupedExperts(
170
+ dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
171
+ )
172
+ self.router = TokenChoiceTopKRouter(
173
+ dim=dim, num_experts=num_experts, top_k=model_args.top_k
174
+ )
175
+ self.shared_expert = (
176
+ GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
177
+ if model_args.use_shared_expert
178
+ else None
179
+ )
180
+
181
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
182
+ """
183
+ Args:
184
+ x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``.
185
+
186
+ Returns:
187
+ out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
188
+ """
189
+ bs, slen, dim = x.shape
190
+ # top_scores and selected_indices shape (bs*slen*top_k,)
191
+ # num_local_tokens_per_expert shape (num_experts,)
192
+ (
193
+ top_scores,
194
+ token_indices,
195
+ num_local_tokens_per_expert,
196
+ ) = self.router(x.reshape(bs * slen, dim))
197
+
198
+ # shape (bs*slen*top_k, dim)
199
+ token_indices = token_indices.reshape(-1, 1).expand(-1, dim)
200
+
201
+ # shape (bs*slen*top_k, dim)
202
+ routed_input = torch.gather(
203
+ x.view(-1, dim),
204
+ dim=0,
205
+ index=token_indices,
206
+ )
207
+ routed_input = routed_input * top_scores.reshape(-1, 1)
208
+
209
+ # shape (bs*slen*top_k, dim)
210
+ routed_output = self.experts(routed_input, num_local_tokens_per_expert)
211
+
212
+ # shared expert
213
+ if self.shared_expert is not None:
214
+ out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
215
+ bs * slen, dim
216
+ )
217
+ else:
218
+ out = torch.zeros_like(x.reshape(bs * slen, dim))
219
+
220
+ out = out.scatter_add(dim=0, index=token_indices, src=routed_output)
221
+ out = out.reshape(bs, slen, dim)
222
+ return out
223
+
224
+ def init_weights(self, init_std: float):
225
+ self.experts.init_weights(init_std)
226
+ self.router.init_weights(init_std)
227
+ if self.shared_expert is not None:
228
+ self.shared_expert.init_weights(init_std)
torchtitan/experiments/llama4/train_configs/debug_model.toml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [job]
2
+ dump_folder = "./outputs"
3
+ description = "Llama 4 debug training"
4
+ print_args = false
5
+ use_for_integration_test = true
6
+
7
+ [profiling]
8
+ enable_profiling = false
9
+ save_traces_folder = "profile_trace"
10
+ profile_freq = 10
11
+ enable_memory_snapshot = false
12
+ save_memory_snapshot_folder = "memory_snapshot"
13
+
14
+ [metrics]
15
+ log_freq = 1
16
+ disable_color_printing = false
17
+ enable_tensorboard = false
18
+ save_tb_folder = "tb"
19
+ enable_wandb = false
20
+
21
+ [model]
22
+ name = "llama4"
23
+ flavor = "debugmodel"
24
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
25
+ # test tokenizer.model, for debug purpose only
26
+ tokenizer_path = "./tests/assets/test_tiktoken.model"
27
+ # converters = "float8"
28
+ use_flex_attn = false
29
+ attn_mask_type = "causal" # causal / block_causal
30
+
31
+ [optimizer]
32
+ name = "AdamW"
33
+ lr = 4e-3
34
+ eps = 1e-15
35
+
36
+ [lr_scheduler]
37
+ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
38
+ decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
39
+ decay_type = "linear"
40
+ lr_min = 0.1
41
+
42
+ [training]
43
+ batch_size = 8
44
+ seq_len = 2048
45
+ max_norm = 1.0 # grad norm clipping
46
+ steps = 10
47
+ compile = false
48
+ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
49
+
50
+ [parallelism]
51
+ data_parallel_replicate_degree = 1
52
+ data_parallel_shard_degree = -1
53
+ fsdp_reshard_after_forward = "default" # default / never / always
54
+ tensor_parallel_degree = 1
55
+ enable_async_tensor_parallel = false
56
+ pipeline_parallel_degree = 1
57
+ context_parallel_degree = 1
58
+
59
+ [checkpoint]
60
+ enable_checkpoint = false
61
+ folder = "checkpoint"
62
+ interval = 10
63
+ model_weights_only = false
64
+ export_dtype = "float32"
65
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
66
+
67
+ [activation_checkpoint]
68
+ mode = 'none' # ['none', 'selective', 'full']
69
+ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
70
+
71
+ [float8]
72
+ enable_fsdp_float8_all_gather = false
73
+ precompute_float8_dynamic_scale_for_fsdp = false
74
+ filter_fqns = "output,router.gate"
torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: this toml config is still under development
2
+
3
+ [job]
4
+ dump_folder = "./outputs"
5
+ description = "Llama 4 Maverick 17Bx128E training"
6
+
7
+ [profiling]
8
+ enable_profiling = false
9
+ save_traces_folder = "profile_trace"
10
+ profile_freq = 100
11
+
12
+ [metrics]
13
+ log_freq = 10
14
+ enable_tensorboard = false
15
+ save_tb_folder = "tb"
16
+
17
+ [model]
18
+ name = "llama4"
19
+ flavor = "17bx128e"
20
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
21
+ tokenizer_path = "./assets/tokenizer/tokenizer.model"
22
+ # converters = "float8"
23
+
24
+ [optimizer]
25
+ name = "AdamW"
26
+ lr = 4e-3
27
+ eps = 1e-15
28
+
29
+ [lr_scheduler]
30
+ warmup_steps = 600
31
+ lr_min = 0.1
32
+
33
+ [training]
34
+ batch_size = 1
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = false
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8
45
+ enable_async_tensor_parallel = false
46
+ pipeline_parallel_degree = 4
47
+ # pipeline_parallel_schedule = "interleaved1f1b"
48
+ # pipeline_parallel_microbatches = 2
49
+ context_parallel_degree = 1
50
+
51
+ [checkpoint]
52
+ enable_checkpoint = false
53
+ folder = "checkpoint"
54
+ interval = 500
55
+ model_weights_only = false
56
+ export_dtype = "float32"
57
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
58
+
59
+ [activation_checkpoint]
60
+ mode = 'full' # ['none', 'selective', 'full']
61
+
62
+ [float8]
63
+ enable_fsdp_float8_all_gather = false
64
+ precompute_float8_dynamic_scale_for_fsdp = false
65
+ filter_fqns = "output,router.gate"
torchtitan/models/__pycache__/attention.cpython-312.pyc ADDED
Binary file (6.33 kB). View file
 
torchtitan/models/llama3/train_configs/llama3_405b.toml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torchtitan Config.toml
2
+ # NOTE: this toml config is a preset for 128 H100 GPUs.
3
+
4
+ [job]
5
+ dump_folder = "./outputs"
6
+ description = "Llama 3 405B training"
7
+
8
+ [profiling]
9
+ enable_profiling = true
10
+ save_traces_folder = "profile_trace"
11
+ profile_freq = 100
12
+
13
+ [metrics]
14
+ log_freq = 10
15
+ enable_tensorboard = true
16
+ save_tb_folder = "tb"
17
+
18
+ [model]
19
+ name = "llama3"
20
+ flavor = "405B"
21
+ norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm
22
+ tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
23
+ converters = "float8"
24
+
25
+ [optimizer]
26
+ name = "AdamW"
27
+ lr = 8e-5
28
+ eps = 1e-8
29
+
30
+ [lr_scheduler]
31
+ warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
32
+
33
+ [training]
34
+ batch_size = 2
35
+ seq_len = 8192
36
+ max_norm = 1.0 # grad norm clipping
37
+ steps = 3000
38
+ compile = true
39
+ dataset = "c4"
40
+
41
+ [parallelism]
42
+ data_parallel_replicate_degree = 1
43
+ data_parallel_shard_degree = -1
44
+ tensor_parallel_degree = 8 # 8-way TP
45
+ enable_async_tensor_parallel = true
46
+ pipeline_parallel_degree = 1
47
+ context_parallel_degree = 1
48
+
49
+ [checkpoint]
50
+ enable_checkpoint = false
51
+ folder = "checkpoint"
52
+ interval = 500
53
+ model_weights_only = false
54
+ export_dtype = "float32"
55
+ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
56
+
57
+ [activation_checkpoint]
58
+ mode = 'full' # ['none', 'selective', 'full']
59
+
60
+ [float8]
61
+ enable_fsdp_float8_all_gather = true
62
+ precompute_float8_dynamic_scale_for_fsdp = true
63
+ filter_fqns = "output"