danieldk HF Staff commited on
Commit
f25859d
·
verified ·
1 Parent(s): d2d3257

Build uploaded using `kernels`.

Browse files
build/torch-cuda/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .triton_a100 import kernel_a100
2
+ from .triton_h100 import kernel_h100
3
+ from .triton_b200 import kernel_b200
4
+ from .trimul_mi300 import kernel_mi300
5
+ from .trimul_global import kernel_global
6
+
7
+ __all__ = ["kernel_a100", "kernel_h100", "kernel_b200", "kernel_mi300", "kernel_global"]
build/torch-cuda/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._trimul_gpumode_176b4e4
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_trimul_gpumode_176b4e4::{op_name}"
build/torch-cuda/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-cuda/task.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Type definitions for TriMul task.
3
+
4
+ Input: Tuple of (input_tensor, mask, weights, config)
5
+ - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
6
+ - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
7
+ - weights: Dictionary containing model weights
8
+ - config: Dictionary containing model configuration parameters
9
+
10
+ Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
11
+ """
12
+
13
+ import torch
14
+ from typing import Tuple, Dict, Any
15
+
16
+ # Input type: (input_tensor, mask, weights, config)
17
+ input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
18
+
19
+ # Output type: output tensor
20
+ output_t = torch.Tensor
build/torch-cuda/trimul_global.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from utils import make_match_reference, DisableCuDNNTF32
2
+ from .task import input_t, output_t
3
+
4
+ import torch
5
+ from torch import nn, einsum
6
+ import math
7
+ import os
8
+ import requests
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
14
+ # in PyTorch 1.12 and later.
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+
17
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
18
+ torch.backends.cudnn.allow_tf32 = True
19
+
20
+ # Set allocator for TMA descriptors (required for on-device TMA)
21
+ def alloc_fn(size: int, alignment: int, stream=None):
22
+ return torch.empty(size, device="cuda", dtype=torch.int8)
23
+
24
+ triton.set_allocator(alloc_fn)
25
+
26
+ # os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
27
+ # os.environ['MLIR_ENABLE_DIAGNOSTICS'] = 'warnings,remarks'
28
+
29
+ # Reference code in PyTorch
30
+ class TriMul(nn.Module):
31
+ # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
32
+ def __init__(
33
+ self,
34
+ dim: int,
35
+ hidden_dim: int,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm = nn.LayerNorm(dim)
40
+
41
+ self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
42
+ self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
43
+
44
+ self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
45
+ self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
46
+ self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
47
+
48
+ self.to_out_norm = nn.LayerNorm(hidden_dim)
49
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
50
+
51
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ x: [bs, seq_len, seq_len, dim]
54
+ mask: [bs, seq_len, seq_len]
55
+
56
+ Returns:
57
+ output: [bs, seq_len, seq_len, dim]
58
+ """
59
+ batch_size, seq_len, _, dim = x.shape
60
+
61
+ x = self.norm(x)
62
+
63
+ left = self.left_proj(x)
64
+ right = self.right_proj(x)
65
+
66
+ mask = mask.unsqueeze(-1)
67
+ left = left * mask
68
+ right = right * mask
69
+
70
+ left_gate = self.left_gate(x).sigmoid()
71
+ right_gate = self.right_gate(x).sigmoid()
72
+ out_gate = self.out_gate(x).sigmoid()
73
+
74
+ left = left * left_gate
75
+ right = right * right_gate
76
+
77
+ out = einsum('... i k d, ... j k d -> ... i j d', left, right)
78
+ # This einsum is the same as the following:
79
+ # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
80
+
81
+ # # Compute using nested loops
82
+ # for b in range(batch_size):
83
+ # for i in range(seq_len):
84
+ # for j in range(seq_len):
85
+ # # Compute each output element
86
+ # for k in range(seq_len):
87
+ # out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
88
+
89
+ out = self.to_out_norm(out)
90
+ out = out * out_gate
91
+ return self.to_out(out)
92
+
93
+ @triton.jit
94
+ def triton_sigmoid(x):
95
+ """
96
+ Compute sigmoid function: 1 / (1 + exp(-x))
97
+ """
98
+ return 1.0 / (1.0 + tl.exp(-x))
99
+
100
+ def two_mm_kernel_configs_wrapper():
101
+ if torch.cuda.get_device_capability() == (12, 0):
102
+ def two_mm_kernel_configs():
103
+ configs = []
104
+ for BLOCK_M in [16, 32]:
105
+ for BLOCK_N in [16, 32, 64]:
106
+ for BLOCK_K in [16, 32, 64]:
107
+ for num_stages in [2, 3]:
108
+ configs.append(triton.Config({
109
+ 'BLOCK_M': BLOCK_M,
110
+ 'BLOCK_N': BLOCK_N,
111
+ 'BLOCK_K': BLOCK_K,
112
+ 'GROUP_SIZE_M': 8
113
+ }, num_stages=num_stages, num_warps=8))
114
+ return configs
115
+
116
+ elif torch.cuda.get_device_capability()[0] == 9:
117
+ def get_optimal_two_mm_config_h100(B, seq_len, dim):
118
+ configs = {
119
+ (1, 128, 128): (128, 64, 128, 2, 8),
120
+ (1, 128, 256): (128, 64, 128, 2, 8),
121
+ (1, 128, 384): (128, 64, 64, 3, 8),
122
+ (1, 128, 512): (128, 64, 64, 3, 8),
123
+ (1, 128, 768): (128, 64, 64, 3, 8),
124
+ (1, 128, 1024): (128, 64, 64, 3, 8),
125
+ (1, 256, 128): (128, 64, 128, 2, 8),
126
+ (1, 256, 256): (128, 64, 128, 2, 8),
127
+ (1, 256, 384): (128, 64, 64, 3, 8),
128
+ (1, 256, 512): (128, 64, 64, 3, 8),
129
+ (1, 256, 768): (128, 64, 64, 3, 8),
130
+ (1, 256, 1024): (128, 64, 64, 3, 8),
131
+ (1, 512, 128): (128, 64, 128, 2, 8),
132
+ (1, 512, 256): (128, 64, 128, 2, 8),
133
+ (1, 512, 384): (128, 64, 128, 2, 8),
134
+ (1, 512, 512): (128, 64, 128, 2, 8),
135
+ (1, 512, 768): (128, 64, 64, 3, 8),
136
+ (1, 512, 1024): (128, 64, 64, 3, 8),
137
+ (1, 1024, 128): (128, 64, 128, 2, 8),
138
+ (1, 1024, 256): (128, 64, 64, 2, 8),
139
+ (1, 1024, 384): (128, 64, 128, 2, 8),
140
+ (1, 1024, 512): (128, 64, 128, 2, 8),
141
+ (1, 1024, 768): (128, 64, 128, 2, 8),
142
+ (1, 1024, 1024): (128, 64, 128, 2, 8),
143
+ (2, 128, 128): (128, 64, 128, 2, 8),
144
+ (2, 128, 256): (128, 64, 128, 2, 8),
145
+ (2, 128, 384): (128, 64, 64, 3, 8),
146
+ (2, 128, 512): (128, 64, 64, 3, 8),
147
+ (2, 128, 768): (128, 64, 64, 3, 8),
148
+ (2, 128, 1024): (128, 64, 64, 3, 8),
149
+ (2, 256, 128): (128, 64, 128, 2, 8),
150
+ (2, 256, 256): (128, 64, 128, 2, 8),
151
+ (2, 256, 384): (128, 64, 128, 2, 8),
152
+ (2, 256, 512): (128, 64, 128, 2, 8),
153
+ (2, 256, 768): (128, 64, 64, 3, 8),
154
+ (2, 256, 1024): (128, 64, 64, 3, 8),
155
+ (2, 512, 128): (128, 64, 128, 2, 8),
156
+ (2, 512, 256): (128, 64, 128, 2, 8),
157
+ (2, 512, 384): (128, 64, 128, 2, 8),
158
+ (2, 512, 512): (128, 64, 128, 2, 8),
159
+ (2, 512, 768): (128, 64, 128, 2, 8),
160
+ (2, 512, 1024): (128, 64, 128, 2, 8),
161
+ (2, 1024, 128): (128, 64, 128, 2, 8),
162
+ (2, 1024, 256): (128, 64, 128, 2, 8),
163
+ (2, 1024, 384): (128, 64, 128, 2, 8),
164
+ (2, 1024, 512): (128, 64, 128, 2, 8),
165
+ (2, 1024, 768): (128, 64, 128, 2, 8),
166
+ (2, 1024, 1024): (128, 64, 128, 2, 8),
167
+ }
168
+ return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
169
+
170
+ def two_mm_kernel_configs():
171
+ # This function is kept for compatibility but will be overridden for H100
172
+ return [
173
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
174
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
175
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
176
+ ]
177
+
178
+ elif torch.cuda.get_device_capability()[0] == 10 and False:
179
+ def get_optimal_two_mm_config(B, seq_len, dim):
180
+ configs = {
181
+ (1, 128, 128): (64, 128, 64, 2, 8),
182
+ (1, 128, 256): (128, 64, 128, 2, 8),
183
+ (1, 128, 384): (128, 64, 128, 2, 8),
184
+ (1, 128, 512): (128, 64, 128, 2, 8),
185
+ (1, 128, 768): (128, 64, 64, 3, 8),
186
+ (1, 128, 1024): (128, 64, 64, 3, 8),
187
+ (1, 256, 128): (128, 64, 128, 2, 8),
188
+ (1, 256, 256): (128, 64, 128, 2, 8),
189
+ (1, 256, 384): (128, 64, 128, 2, 8),
190
+ (1, 256, 512): (128, 64, 64, 3, 8),
191
+ (1, 256, 768): (128, 64, 64, 3, 8),
192
+ (1, 256, 1024): (128, 64, 64, 3, 8),
193
+ (1, 512, 128): (128, 64, 128, 2, 8),
194
+ (1, 512, 256): (128, 64, 128, 2, 8),
195
+ (1, 512, 384): (128, 64, 128, 2, 8),
196
+ (1, 512, 512): (128, 64, 128, 2, 8),
197
+ (1, 512, 768): (128, 64, 64, 3, 8),
198
+ (1, 512, 1024): (128, 64, 64, 3, 8),
199
+ (1, 1024, 128): (128, 64, 128, 2, 8),
200
+ (1, 1024, 256): (128, 64, 128, 2, 8),
201
+ (1, 1024, 384): (128, 64, 128, 2, 8),
202
+ (1, 1024, 512): (128, 64, 128, 2, 8),
203
+ (1, 1024, 768): (128, 64, 64, 3, 8),
204
+ (1, 1024, 1024): (128, 64, 64, 3, 8),
205
+ (2, 128, 128): (128, 64, 128, 2, 8),
206
+ (2, 128, 256): (128, 64, 128, 2, 8),
207
+ (2, 128, 384): (128, 64, 128, 2, 8),
208
+ (2, 128, 512): (128, 64, 64, 3, 8),
209
+ (2, 128, 768): (128, 64, 64, 3, 8),
210
+ (2, 128, 1024): (128, 64, 64, 3, 8),
211
+ (2, 256, 128): (128, 64, 128, 2, 8),
212
+ (2, 256, 256): (128, 64, 128, 2, 8),
213
+ (2, 256, 384): (128, 64, 128, 2, 8),
214
+ (2, 256, 512): (128, 64, 64, 3, 8),
215
+ (2, 256, 768): (128, 64, 64, 3, 8),
216
+ (2, 256, 1024): (128, 64, 64, 3, 8),
217
+ (2, 512, 128): (128, 64, 128, 2, 8),
218
+ (2, 512, 256): (128, 64, 128, 2, 8),
219
+ (2, 512, 384): (128, 64, 128, 2, 8),
220
+ (2, 512, 512): (128, 64, 128, 2, 8),
221
+ (2, 512, 768): (128, 64, 64, 3, 8),
222
+ (2, 512, 1024): (128, 64, 64, 3, 8),
223
+ (2, 1024, 128): (128, 64, 128, 2, 8),
224
+ (2, 1024, 256): (128, 64, 128, 2, 8),
225
+ (2, 1024, 384): (128, 64, 128, 2, 8),
226
+ (2, 1024, 512): (128, 64, 128, 2, 8),
227
+ (2, 1024, 768): (128, 64, 64, 3, 8),
228
+ (2, 1024, 1024): (128, 64, 64, 3, 8),
229
+ }
230
+ return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
231
+
232
+ def two_mm_kernel_configs():
233
+ # This function is kept for compatibility but will be overridden
234
+ return [
235
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
236
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
237
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
238
+ ]
239
+ elif torch.cuda.get_device_capability()[0] == 8:
240
+ # A100
241
+ def two_mm_kernel_configs():
242
+ configs = []
243
+ for BLOCK_M in [64]:
244
+ for BLOCK_N in [64, 128]:
245
+ for BLOCK_K in [16]:
246
+ for num_stages in [3, 4]:
247
+ for num_warps in [4, 8]:
248
+ configs.append(triton.Config({
249
+ 'BLOCK_M': BLOCK_M,
250
+ 'BLOCK_N': BLOCK_N,
251
+ 'BLOCK_K': BLOCK_K,
252
+ 'GROUP_SIZE_M': 8
253
+ }, num_stages=num_stages, num_warps=num_warps))
254
+ return configs
255
+ else:
256
+ def two_mm_kernel_configs():
257
+ configs = []
258
+ for BLOCK_M in [64, 128]:
259
+ for BLOCK_N in [64, 128]:
260
+ for BLOCK_K in [64, 128]:
261
+ for num_stages in [2, 3]:
262
+ configs.append(triton.Config({
263
+ 'BLOCK_M': BLOCK_M,
264
+ 'BLOCK_N': BLOCK_N,
265
+ 'BLOCK_K': BLOCK_K,
266
+ 'GROUP_SIZE_M': 8
267
+ }, num_stages=num_stages, num_warps=8))
268
+ return configs
269
+
270
+ return two_mm_kernel_configs
271
+
272
+ def two_mm_kernel_wrapper():
273
+ if torch.cuda.get_device_capability()[0] == 8:
274
+ @triton.jit
275
+ def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
276
+ # Persistent kernel using standard tl.load operations
277
+ start_pid = tl.program_id(axis=0)
278
+ num_pid_m = tl.cdiv(M, BLOCK_M)
279
+ num_pid_n = tl.cdiv(N, BLOCK_N)
280
+ k_tiles = tl.cdiv(K, BLOCK_K)
281
+ num_tiles = num_pid_m * num_pid_n
282
+
283
+ # tile_id_c is used in the epilogue to break the dependency between
284
+ # the prologue and the epilogue
285
+ tile_id_c = start_pid - NUM_SMS
286
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
287
+
288
+ # Persistent loop over tiles
289
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
290
+ # Calculate PID for this tile using improved swizzling
291
+ group_id = tile_id // num_pid_in_group
292
+ first_pid_m = group_id * GROUP_SIZE_M
293
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
294
+ pid_m = first_pid_m + (tile_id % group_size_m)
295
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
296
+
297
+ # Calculate block offsets
298
+ offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
299
+ offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
300
+ offs_k = tl.arange(0, BLOCK_K)
301
+
302
+ # Initialize accumulators for all outputs
303
+ accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
304
+ accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
305
+ accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
306
+ accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
307
+ accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
308
+
309
+ # Main computation loop over K dimension
310
+ for ki in range(k_tiles):
311
+ k_start = ki * BLOCK_K
312
+ k_offsets = k_start + offs_k
313
+
314
+ # Create pointers for A matrix (2D flattened view)
315
+ a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3
316
+ a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K)
317
+
318
+ # Create pointers for B matrices [N, K] layout
319
+ b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
320
+ b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
321
+ b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
322
+ b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
323
+ b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
324
+ b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K)
325
+
326
+ # Load blocks from A and all weight matrices using standard tl.load
327
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
328
+ b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0)
329
+ b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0)
330
+ b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0)
331
+ b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0)
332
+ b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0)
333
+
334
+ # Perform matrix multiplications using TF32
335
+ accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
336
+ accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
337
+ accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
338
+ accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
339
+ accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
340
+
341
+ # Store results using separate tile_id_c for epilogue
342
+ tile_id_c += NUM_SMS
343
+ group_id = tile_id_c // num_pid_in_group
344
+ first_pid_m = group_id * GROUP_SIZE_M
345
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346
+ pid_m = first_pid_m + (tile_id_c % group_size_m)
347
+ pid_n = (tile_id_c % num_pid_in_group) // group_size_m
348
+
349
+ # Calculate output offsets and pointers
350
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
351
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
352
+
353
+ # Create masks for bounds checking
354
+ d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
355
+
356
+ # Calculate pointer addresses using 4D strides
357
+ stride_cm = stride_c2 # Stride to next element in flattened M dimension
358
+ stride_cn = stride_c3 # N is the innermost dimension
359
+
360
+ # For D tensor: use separate D strides
361
+ stride_dm = stride_d2 # Stride to next element in flattened M dimension
362
+ stride_dn = stride_d3 # N is the innermost dimension
363
+
364
+ off_c_batch = offs_cm // (seq_len * seq_len)
365
+ off_c_sl1 = (offs_cm // seq_len) % seq_len
366
+ off_c_sl2 = offs_cm % seq_len
367
+ off_c_dim = offs_cn
368
+
369
+ c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
370
+ c_mask = d_mask
371
+
372
+ c1_ptrs = c1_ptr + c_offsets
373
+ c2_ptrs = c2_ptr + c_offsets
374
+ d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
375
+
376
+ mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
377
+
378
+ # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
379
+ mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
380
+ # Apply masking only to left_proj and right_proj results (C1, C2)
381
+ accumulator1 = tl.where(mask_2d, accumulator1, 0)
382
+ accumulator2 = tl.where(mask_2d, accumulator2, 0)
383
+
384
+ # Apply sigmoid to gate values
385
+ left_gate_sigmoid = triton_sigmoid(accumulator3)
386
+ right_gate_sigmoid = triton_sigmoid(accumulator4)
387
+ accumulator_d = triton_sigmoid(accumulator_d)
388
+
389
+ # Apply elementwise multiplication with gated values
390
+ # C1 = left * left_gate, C2 = right * right_gate
391
+ accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
392
+ accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
393
+
394
+ # Convert to appropriate output dtype and store with normal tl.store
395
+ c1 = accumulator1.to(c1_ptr.dtype.element_ty)
396
+ c2 = accumulator2.to(c2_ptr.dtype.element_ty)
397
+ d = accumulator_d.to(d_ptr.dtype.element_ty)
398
+
399
+ tl.store(c1_ptrs, c1, mask=c_mask)
400
+ tl.store(c2_ptrs, c2, mask=c_mask)
401
+ tl.store(d_ptrs, d, mask=d_mask)
402
+ else:
403
+ @triton.jit
404
+ def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
405
+ # Persistent kernel using on-device TMA descriptors
406
+ start_pid = tl.program_id(axis=0)
407
+ num_pid_m = tl.cdiv(M, BLOCK_M)
408
+ num_pid_n = tl.cdiv(N, BLOCK_N)
409
+ k_tiles = tl.cdiv(K, BLOCK_K)
410
+ num_tiles = num_pid_m * num_pid_n
411
+
412
+ # Create on-device TMA descriptors
413
+ a_desc = tl._experimental_make_tensor_descriptor(
414
+ a_ptr,
415
+ shape=[M, K],
416
+ strides=[stride_a2, stride_a3],
417
+ block_shape=[BLOCK_M, BLOCK_K],
418
+ )
419
+ b1_desc = tl._experimental_make_tensor_descriptor(
420
+ b1_ptr,
421
+ shape=[N, K],
422
+ strides=[stride_bn, stride_bk],
423
+ block_shape=[BLOCK_N, BLOCK_K],
424
+ )
425
+ b2_desc = tl._experimental_make_tensor_descriptor(
426
+ b2_ptr,
427
+ shape=[N, K],
428
+ strides=[stride_bn, stride_bk],
429
+ block_shape=[BLOCK_N, BLOCK_K],
430
+ )
431
+ b3_desc = tl._experimental_make_tensor_descriptor(
432
+ b3_ptr,
433
+ shape=[N, K],
434
+ strides=[stride_bn, stride_bk],
435
+ block_shape=[BLOCK_N, BLOCK_K],
436
+ )
437
+ b4_desc = tl._experimental_make_tensor_descriptor(
438
+ b4_ptr,
439
+ shape=[N, K],
440
+ strides=[stride_bn, stride_bk],
441
+ block_shape=[BLOCK_N, BLOCK_K],
442
+ )
443
+ b5_desc = tl._experimental_make_tensor_descriptor(
444
+ b5_ptr,
445
+ shape=[N, K],
446
+ strides=[stride_bn, stride_bk],
447
+ block_shape=[BLOCK_N, BLOCK_K],
448
+ )
449
+
450
+ # tile_id_c is used in the epilogue to break the dependency between
451
+ # the prologue and the epilogue
452
+ tile_id_c = start_pid - NUM_SMS
453
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
454
+
455
+ # Persistent loop over tiles
456
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
457
+ # Calculate PID for this tile using improved swizzling
458
+ group_id = tile_id // num_pid_in_group
459
+ first_pid_m = group_id * GROUP_SIZE_M
460
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
461
+ pid_m = first_pid_m + (tile_id % group_size_m)
462
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
463
+
464
+ # Calculate block offsets
465
+ offs_am = pid_m * BLOCK_M
466
+ offs_bn = pid_n * BLOCK_N
467
+
468
+ # Initialize accumulators for all outputs
469
+ accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
470
+ accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
471
+ accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
472
+ accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
473
+ accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
474
+
475
+ # Main computation loop over K dimension
476
+ for ki in range(k_tiles):
477
+ offs_k = ki * BLOCK_K
478
+ # Load blocks from A and all weight matrices using on-device TMA
479
+ a = a_desc.load([offs_am, offs_k])
480
+ b1 = b1_desc.load([offs_bn, offs_k])
481
+ b2 = b2_desc.load([offs_bn, offs_k])
482
+ b3 = b3_desc.load([offs_bn, offs_k])
483
+ b4 = b4_desc.load([offs_bn, offs_k])
484
+ b5 = b5_desc.load([offs_bn, offs_k])
485
+
486
+ # Perform matrix multiplications using TF32
487
+ accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
488
+ accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
489
+ accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
490
+ accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
491
+ accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
492
+
493
+ # Store results using separate tile_id_c for epilogue
494
+ tile_id_c += NUM_SMS
495
+ group_id = tile_id_c // num_pid_in_group
496
+ first_pid_m = group_id * GROUP_SIZE_M
497
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
498
+ pid_m = first_pid_m + (tile_id_c % group_size_m)
499
+ pid_n = (tile_id_c % num_pid_in_group) // group_size_m
500
+
501
+ # Calculate output offsets and pointers
502
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
503
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
504
+
505
+ # Create masks for bounds checking
506
+ d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
507
+
508
+ # Calculate pointer addresses using 4D strides
509
+ # For C tensors: compute effective 2D strides from 4D strides
510
+ # Output tensor is [B, I, J, N], flattened to [M, N] where M = B*I*J
511
+ stride_cm = stride_c2 # Stride to next element in flattened M dimension
512
+ stride_cn = stride_c3 # N is the innermost dimension
513
+
514
+ # For D tensor: use separate D strides
515
+ stride_dm = stride_d2 # Stride to next element in flattened M dimension
516
+ stride_dn = stride_d3 # N is the innermost dimension
517
+
518
+ off_c_batch = offs_cm // (seq_len * seq_len)
519
+ off_c_sl1 = (offs_cm // seq_len) % seq_len
520
+ off_c_sl2 = offs_cm % seq_len
521
+ off_c_dim = offs_cn
522
+
523
+ # TODO update the mask_c so we don't IMA
524
+ c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
525
+ # c_offsets = offs_cm[:, None] * stride_c2 + offs_cn[None, :] * stride_c3
526
+ c_mask = d_mask
527
+
528
+ c1_ptrs = c1_ptr + c_offsets
529
+ c2_ptrs = c2_ptr + c_offsets
530
+ # c1_ptrs = c1_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
531
+ # c2_ptrs = c2_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
532
+ d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
533
+
534
+ mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
535
+
536
+ # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
537
+ mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
538
+ # Apply masking only to left_proj and right_proj results (C1, C2)
539
+ accumulator1 = tl.where(mask_2d, accumulator1, 0)
540
+ accumulator2 = tl.where(mask_2d, accumulator2, 0)
541
+
542
+ # Apply sigmoid to gate values
543
+ left_gate_sigmoid = triton_sigmoid(accumulator3)
544
+ right_gate_sigmoid = triton_sigmoid(accumulator4)
545
+ accumulator_d = triton_sigmoid(accumulator_d)
546
+
547
+ # Apply elementwise multiplication with gated values
548
+ # C1 = left * left_gate, C2 = right * right_gate
549
+ accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
550
+ accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
551
+
552
+ # Convert to appropriate output dtype and store with normal tl.store
553
+ c1 = accumulator1.to(c1_ptr.dtype.element_ty)
554
+ c2 = accumulator2.to(c2_ptr.dtype.element_ty)
555
+ d = accumulator_d.to(d_ptr.dtype.element_ty)
556
+
557
+ tl.store(c1_ptrs, c1, mask=c_mask)
558
+ tl.store(c2_ptrs, c2, mask=c_mask)
559
+ tl.store(d_ptrs, d, mask=d_mask)
560
+
561
+
562
+ if torch.cuda.get_device_capability()[0] not in [9, 10.2]:
563
+ two_mm_kernel = triton.autotune(
564
+ (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"]
565
+ )(two_mm_kernel)
566
+
567
+ return two_mm_kernel
568
+
569
+
570
+ def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask):
571
+ """
572
+ Persistent matrix multiplication for all weight matrices using on-device TMA descriptors.
573
+
574
+ Args:
575
+ A: [..., K] tensor (arbitrary leading dimensions)
576
+ left_proj: [N, K] matrix (will be transposed)
577
+ right_proj: [N, K] matrix (will be transposed)
578
+ left_gate: [N, K] left gate weight matrix
579
+ right_gate: [N, K] right gate weight matrix
580
+ out_gate: [N, K] output gate weight matrix
581
+ mask: mask tensor
582
+
583
+ Returns:
584
+ (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A
585
+ C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked)
586
+ C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked)
587
+ D = sigmoid(A @ out_gate.T) (unmasked)
588
+ """
589
+ # Check constraints
590
+ assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions"
591
+ assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes"
592
+
593
+ # Assert that all weight matrices have the same strides (same [N, K] shape)
594
+ assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \
595
+ "All weight matrices must have identical strides"
596
+
597
+ # Get dimensions
598
+ original_shape = A.shape[:-1] # All dimensions except the last
599
+ K = A.shape[-1]
600
+ N = left_proj.shape[0]
601
+ B, seq_len, _, _ = A.shape
602
+ dtype = A.dtype
603
+
604
+ # Flatten A to 2D for kernel processing
605
+ A_2d = A.view(-1, K) # [M, K] where M is product of all leading dims
606
+ M = A_2d.shape[0]
607
+
608
+ # Get number of streaming multiprocessors
609
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
610
+
611
+ # Launch persistent kernel with limited number of blocks
612
+ grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
613
+
614
+ # Get original 4D strides for A and output tensors
615
+ A_strides = A.stride() # (stride_0, stride_1, stride_2, stride_3)
616
+
617
+ # Create output tensors with proper 4D shape to get correct strides
618
+ output_shape = original_shape + (N,)
619
+ # C1 = torch.empty(output_shape, device=A.device, dtype=dtype)
620
+ # C2 = torch.empty(output_shape, device=A.device, dtype=dtype)
621
+ C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
622
+ C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
623
+ D = torch.empty(output_shape, device=A.device, dtype=torch.float16)
624
+
625
+ C_strides = C1.stride() # (stride_0, stride_1, stride_2, stride_3)
626
+ D_strides = D.stride() # (stride_0, stride_1, stride_2, stride_3)
627
+
628
+ # Use optimal configuration for B200/H100 or fallback to autotuning for other GPUs
629
+ if torch.cuda.get_device_capability()[0] == 10:
630
+ # Get optimal configuration for B200
631
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
632
+ grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
633
+
634
+ two_mm_kernel_wrapper()[(grid_size,)](
635
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
636
+ C1, C2, D, mask,
637
+ M, N, K,
638
+ *A_strides, # 4D strides for A
639
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
640
+ *C_strides, # 4D strides for C
641
+ seq_len,
642
+ *D_strides, # 4D strides for D
643
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
644
+ num_stages=num_stages, num_warps=num_warps
645
+ )
646
+ elif torch.cuda.get_device_capability()[0] == 9:
647
+ # Get optimal configuration for H100
648
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
649
+ grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
650
+
651
+ two_mm_kernel_wrapper()[(grid_size,)](
652
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
653
+ C1, C2, D, mask,
654
+ M, N, K,
655
+ *A_strides, # 4D strides for A
656
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
657
+ *C_strides, # 4D strides for C
658
+ seq_len,
659
+ *D_strides, # 4D strides for D
660
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
661
+ num_stages=num_stages, num_warps=num_warps
662
+ )
663
+ else:
664
+ # Use autotuning for other GPUs
665
+ two_mm_kernel_wrapper()[grid](
666
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
667
+ C1, C2, D, mask,
668
+ M, N, K,
669
+ *A_strides, # 4D strides for A
670
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
671
+ *C_strides, # 4D strides for C
672
+ seq_len,
673
+ *D_strides, # 4D strides for D
674
+ NUM_SMS=NUM_SMS
675
+ )
676
+
677
+ return C1, C2, D
678
+
679
+
680
+ def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand):
681
+ ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype))
682
+ out = ln * mul_operand
683
+ return out
684
+
685
+ '''
686
+ @triton.autotune(
687
+ [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)],
688
+ key=["R", "C"]
689
+ )
690
+ '''
691
+ @triton.jit
692
+ def layernorm_kernel_first(
693
+ X,
694
+ Y,
695
+ Weight,
696
+ Bias,
697
+ R,
698
+ C, # aka "dim"
699
+ eps,
700
+ ROW_BLOCK_SIZE: tl.constexpr,
701
+ BLOCK_SIZE: tl.constexpr,
702
+ ):
703
+ row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
704
+ cols = tl.arange(0, BLOCK_SIZE)
705
+
706
+ mask_row = row < R
707
+ mask_col = cols < C
708
+
709
+ # Simple indexing for contiguous data
710
+ x = tl.load(
711
+ X + row[:, None] * C + cols[None, :],
712
+ mask=mask_row[:, None] & mask_col[None, :],
713
+ other=0.0
714
+ ).to(tl.float32)
715
+
716
+ weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
717
+ bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
718
+
719
+ mean = tl.sum(x, axis=1) / C
720
+ diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
721
+ var = tl.sum(diff * diff, axis=1) / C
722
+ rstd = 1 / tl.sqrt(var + eps)
723
+
724
+ y_hat = (x - mean[:, None]) * rstd[:, None]
725
+ y = y_hat * weight[None, :] + bias[None, :]
726
+
727
+ tl.store(
728
+ Y + row[:, None] * C + cols[None, :],
729
+ y,
730
+ mask=mask_row[:, None] & mask_col[None, :]
731
+ )
732
+
733
+
734
+ def get_optimal_config_ln(dim):
735
+ config = None
736
+ if torch.cuda.get_device_capability()[0] == 9:
737
+ if (dim <= 256):
738
+ config = (16, 1)
739
+ elif dim <= 512:
740
+ config = (16, 2)
741
+ elif dim <= 1024:
742
+ config = (16, 4)
743
+
744
+ if not config:
745
+ config = (16, 4)
746
+ return config
747
+
748
+
749
+ def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None):
750
+ B, seq_len, seq_len2, dim = x.shape
751
+ assert(seq_len == seq_len2)
752
+
753
+ R = B * seq_len * seq_len
754
+ C = dim
755
+
756
+ out = torch.empty_like(x, dtype=torch.float16)
757
+
758
+ if not num_warps or not ROW_BLOCK_SIZE:
759
+ ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim)
760
+
761
+ BLOCK_SIZE = triton.next_power_of_2(C)
762
+ assert(BLOCK_SIZE <= 1024)
763
+
764
+ def grid(meta):
765
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
766
+
767
+ layernorm_kernel_first[grid](
768
+ x, out, weight, bias,
769
+ R, C, eps,
770
+ ROW_BLOCK_SIZE=ROW_BLOCK_SIZE,
771
+ BLOCK_SIZE=BLOCK_SIZE,
772
+ num_warps=num_warps,
773
+ num_stages=3
774
+ )
775
+
776
+ return out
777
+
778
+ '''
779
+ def triton_layernorm_first(x, weight, bias, eps=1e-5):
780
+ B, seq_len, seq_len2, dim = x.shape
781
+ assert(seq_len == seq_len2)
782
+
783
+ R = B * seq_len * seq_len
784
+ C = dim
785
+
786
+ out = torch.empty_like(x)
787
+
788
+ BLOCK_SIZE = triton.next_power_of_2(C)
789
+ assert(BLOCK_SIZE <= 1024)
790
+
791
+ def grid(meta):
792
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
793
+
794
+ layernorm_kernel_first[grid](
795
+ x, out, weight, bias,
796
+ R, C, eps,
797
+ BLOCK_SIZE=BLOCK_SIZE
798
+ )
799
+
800
+ return out
801
+ '''
802
+
803
+
804
+ @triton.autotune(
805
+ [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)],
806
+ key=[]
807
+ )
808
+ @triton.jit
809
+ def layernorm_kernel_eltwise(
810
+ X,
811
+ Y,
812
+ Weight,
813
+ Bias,
814
+ OutGate,
815
+ seq_len,
816
+ stride_batch,
817
+ stride_dim,
818
+ R,
819
+ C, # aka "dim"
820
+ eps,
821
+ ROW_BLOCK_SIZE: tl.constexpr,
822
+ BLOCK_SIZE: tl.constexpr,
823
+ ):
824
+ row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
825
+ cols = tl.arange(0, BLOCK_SIZE)
826
+
827
+ # Calculate base pointer for this batch of rows
828
+ tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0)
829
+ # batch_offset = (row // (stride_seq1 // stride_dim)) * stride_batch
830
+ batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len)
831
+ seqs_off = row % (seq_len * seq_len) # TODO is this going to prevent vectorization
832
+
833
+ off_r = batch * stride_batch + seqs_off
834
+ off_c = cols * stride_dim
835
+
836
+ mask_row = row < R
837
+ mask_col = cols < C
838
+
839
+ out_gate = tl.load(
840
+ OutGate + row[:, None] * C + cols[None, :],
841
+ mask = mask_row[:, None] & mask_col[None, :],
842
+ )
843
+
844
+ x = tl.load(
845
+ X + off_r[:, None] + off_c[None, :],
846
+ mask=mask_row[:, None] & mask_col[None, :],
847
+ other=0.0
848
+ ).to(tl.float32)
849
+
850
+ weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
851
+ bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
852
+
853
+ mean = tl.sum(x, axis=1) / C
854
+ diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
855
+ var = tl.sum(diff * diff, axis=1) / C
856
+ rstd = 1 / tl.sqrt(var + eps)
857
+
858
+ y_hat = (x - mean[:, None]) * rstd[:, None]
859
+ y = y_hat * weight[None, :] + bias[None, :]
860
+
861
+ tl.store(
862
+ Y + row[:, None] * C + cols[None, :],
863
+ y * out_gate,
864
+ mask=mask_row[:, None] & mask_col[None, :]
865
+ )
866
+
867
+
868
+ def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5):
869
+ B, seq_len, seq_len2, dim = x.shape
870
+ assert(seq_len == seq_len2)
871
+ R = B * seq_len * seq_len
872
+ assert(x.stride(3) == seq_len*seq_len)
873
+ assert(out_gate.is_contiguous())
874
+ C = dim
875
+
876
+ out = torch.empty_like(out_gate, dtype=torch.float32)
877
+
878
+ BLOCK_SIZE = triton.next_power_of_2(C)
879
+ assert(BLOCK_SIZE == 128)
880
+
881
+ def grid(meta):
882
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
883
+
884
+ layernorm_kernel_eltwise[grid](
885
+ x, out, weight, bias, out_gate,
886
+ seq_len,
887
+ x.stride(0), x.stride(3),
888
+ R, C, eps,
889
+ BLOCK_SIZE=BLOCK_SIZE
890
+ )
891
+
892
+ return out
893
+
894
+
895
+ def kernel_global(data: input_t) -> output_t:
896
+ """
897
+ Reference implementation of TriMul using PyTorch.
898
+
899
+ Args:
900
+ data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
901
+ - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
902
+ - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
903
+ - weights: Dictionary containing model weights
904
+ - config: Dictionary containing model configuration parameters
905
+ """
906
+ input_tensor, mask, weights, config = data
907
+
908
+ left_proj_weight = weights["left_proj.weight"].to(torch.float16)
909
+ right_proj_weight = weights["right_proj.weight"].to(torch.float16)
910
+ left_gate_weight = weights["left_gate.weight"].to(torch.float16)
911
+ right_gate_weight = weights["right_gate.weight"].to(torch.float16)
912
+ out_gate_weight = weights["out_gate.weight"].to(torch.float16)
913
+
914
+ hidden_dim = config["hidden_dim"]
915
+ # trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
916
+
917
+ x = input_tensor
918
+
919
+ batch_size, seq_len, _, dim = x.shape
920
+
921
+ x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias'])
922
+ # x = torch.nn.functional.layer_norm(x, (dim,), eps=1e-5, weight=weights['norm.weight'], bias=weights['norm.bias'])
923
+
924
+ left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask)
925
+ # left = torch.nn.functional.linear(x, weights['left_proj.weight'].to(torch.float16))
926
+ # right = torch.nn.functional.linear(x, weights['right_proj.weight'].to(torch.float16))
927
+
928
+ # left = left * mask.unsqueeze(-1)
929
+ # right = right * mask.unsqueeze(-1)
930
+
931
+ '''
932
+ left = left.to(torch.float32)
933
+ right = right.to(torch.float32)
934
+ x = x.to(torch.float32)
935
+
936
+ left_gate = left_gate.sigmoid()
937
+ right_gate = right_gate.sigmoid()
938
+ out_gate = out_gate.sigmoid()
939
+ '''
940
+
941
+ # Elementwise multiplication now handled in kernel
942
+ # left = left * left_gate
943
+ # right = right * right_gate
944
+
945
+ # out = einsum('... i k d, ... j k d -> ... i j d', left, right)
946
+ out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1]))
947
+ out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1)
948
+
949
+ # out = torch.compile(second_layernorm_mul, dynamic=False)(out, hidden_dim, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
950
+ out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
951
+ # out = torch.nn.functional.layer_norm(out, (hidden_dim,), eps=1e-5, weight=weights['to_out_norm.weight'].to(out.dtype), bias=weights['to_out_norm.bias'].to(out.dtype))
952
+ # out = out * out_gate
953
+ return torch.nn.functional.linear(out, weights['to_out.weight'])
954
+
955
+ '''
956
+ # Fill in the given weights of the model
957
+ trimul.norm.weight = nn.Parameter(weights['norm.weight'])
958
+ trimul.norm.bias = nn.Parameter(weights['norm.bias'])
959
+ trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
960
+ trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
961
+ trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
962
+ trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
963
+ trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
964
+ trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
965
+ trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
966
+ trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
967
+
968
+ output = trimul(input_tensor, mask)
969
+
970
+ return output
971
+ '''
build/torch-cuda/trimul_gpumode/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-cuda/trimul_mi300.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
12
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
13
+
14
+ # Configurations with larger block sizes for better data reuse
15
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
16
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=2),
17
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
18
+
19
+ # Configurations with deeper K dimension
20
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
21
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
22
+
23
+ # More extreme configurations to test the limits
24
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
25
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=2),
26
+
27
+ # Configurations with fewer warps
28
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
29
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=2),
30
+
31
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
32
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
33
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
34
+ ],
35
+ key=['M', 'N', 'K'],
36
+ )
37
+ @triton.jit
38
+ def fused_ln_dual_matmul_kernel(
39
+ # Pointers (9)
40
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
41
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
42
+ # Metadata (5)
43
+ M, H, K, s1, s2,
44
+ # Strides (16)
45
+ stride_x_m, stride_x_k,
46
+ stride_w4_k, stride_w4_n,
47
+ stride_wog_k, stride_wog_n,
48
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
49
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
50
+ stride_og_m, stride_og_h,
51
+ stride_mask_m, stride_mask_h,
52
+ # Constexpr (from decorator and kwargs)
53
+ LN_EPS: tl.constexpr,
54
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
55
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
56
+ ):
57
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
58
+ pid = tl.program_id(axis=0)
59
+ N_4way = 4 * H
60
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
62
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
63
+ group_id = pid // num_pid_in_group
64
+ first_pid_m = group_id * GROUP_SIZE_M
65
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
66
+ pid_m = first_pid_m + (pid % group_size_m)
67
+ pid_n = (pid % num_pid_in_group) // group_size_m
68
+
69
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
70
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
+ m_mask = offs_m < M
72
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
73
+
74
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
75
+ for k_offset in range(0, K, BLOCK_SIZE_K):
76
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
77
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
78
+ k_mask = (k_offset + k_chunk_offs) < K
79
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
80
+ mean += tl.sum(x_chunk, axis=1)
81
+ mean /= K
82
+
83
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
84
+ for k_offset in range(0, K, BLOCK_SIZE_K):
85
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
86
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
87
+ k_mask = (k_offset + k_chunk_offs) < K
88
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
89
+ x_centered = x_chunk - mean[:, None]
90
+ var += tl.sum(x_centered * x_centered, axis=1)
91
+ var /= K
92
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
93
+
94
+ # --- Matmul Loop 1: For the 4-Way Projections ---
95
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
96
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
97
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
98
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
99
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
+
101
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
102
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
103
+ k_block_start = k * BLOCK_SIZE_K;
104
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
105
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
106
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
107
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
108
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
109
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
110
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
111
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
112
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
113
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
114
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
115
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
117
+
118
+ #Some threads should calclate out_gate
119
+ if pid_n * BLOCK_SIZE_N < H:
120
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
121
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
122
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
123
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
124
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
125
+
126
+ if pid_n * BLOCK_SIZE_N < H:
127
+ og_out = tl.sigmoid(accumulator_og)
128
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
129
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
130
+ tl.store(outg_ptrs, og_out, mask=og_mask)
131
+
132
+ # --- Fusion Logic for 4-Way Part ---
133
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
134
+ role_idx = tl.arange(0, 4)[None, None, :]
135
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
136
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
137
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
138
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
139
+
140
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
141
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
142
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
143
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
144
+
145
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
146
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
147
+
148
+ s1s2 = s1 * s2
149
+ offs_b = offs_m // s1s2
150
+ offs_s1 = (offs_m % s1s2) // s2
151
+ offs_s2 = offs_m % s2
152
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
153
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
154
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
155
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
156
+
157
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
158
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
159
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
160
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
161
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
162
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
163
+
164
+ @triton.autotune(
165
+ configs=[
166
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
167
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
168
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
169
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
170
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
171
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
172
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
173
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
174
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
175
+ ],
176
+ key=['s1', 's2', 'H'],
177
+ )
178
+ @triton.jit
179
+ def bmm_coalesced_kernel(
180
+ # Pointers
181
+ Left_ptr, Right_ptr, Out_ptr,
182
+ # Dimensions
183
+ bs, s1, s2, H,
184
+ # Strides
185
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
186
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
187
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
188
+ # Kernel parameters
189
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
190
+ GROUP_SIZE_M: tl.constexpr,
191
+ ):
192
+ # Grid and program IDs
193
+ pid = tl.program_id(axis=0)
194
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
195
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
196
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
197
+ group_id = pid // num_pid_in_group
198
+ first_pid_m = group_id * GROUP_SIZE_M
199
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
200
+ pid_m = first_pid_m + (pid % group_size_m)
201
+ pid_n = (pid % num_pid_in_group) // group_size_m
202
+
203
+ pid_bh = tl.program_id(axis=1)
204
+ pid_b = pid_bh // H
205
+ pid_h = pid_bh % H
206
+
207
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
208
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
209
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
210
+
211
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
212
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
213
+
214
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
215
+
216
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
217
+ k_start = k * BLOCK_SIZE_K
218
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
219
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
220
+
221
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
222
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
223
+
224
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
225
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
226
+
227
+ accumulator += tl.dot(a, b)
228
+
229
+ # --- Coalesced Write ---
230
+ # Write to a standard (bs, H, s1, s1) layout
231
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
232
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
233
+
234
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
235
+ tl.store(out_ptrs, accumulator, mask=c_mask)
236
+
237
+ @triton.autotune(
238
+ configs=[
239
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
240
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
241
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
243
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
244
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
245
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
247
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
248
+ ],
249
+ key=['H', 'D'],
250
+ )
251
+ @triton.jit
252
+ def fused_final_kernel(
253
+ # Pointers
254
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
255
+ # Metadata
256
+ M, H, D, s1, # M_gate = bs*s1*s2
257
+ # Strides
258
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
259
+ stride_gate_m, stride_gate_h,
260
+ stride_proj_d, stride_proj_h,
261
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
262
+ # Constants
263
+ LN_EPS: tl.constexpr,
264
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
265
+ GROUP_SIZE_M: tl.constexpr,
266
+ ):
267
+ # --- Grid and PID Setup for Matmul ---
268
+ pid = tl.program_id(axis=0)
269
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
270
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
271
+
272
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
273
+ group_id = pid // num_pid_in_group
274
+ first_pid_m = group_id * GROUP_SIZE_M
275
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
276
+ pid_m = first_pid_m + (pid % group_size_m)
277
+ pid_n = (pid % num_pid_in_group) // group_size_m
278
+
279
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
280
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
281
+ m_mask = offs_m < M
282
+
283
+ # Decompose M back to (b, r, c) for reordering lookups
284
+ s1s1 = s1 * s1
285
+ b = offs_m // s1s1
286
+ r = (offs_m % s1s1) // s1
287
+ c = offs_m % s1
288
+
289
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
291
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
292
+
293
+ for k_offset in range(0, H, BLOCK_SIZE_K):
294
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
295
+ k_mask = offs_k < H
296
+
297
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
298
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
299
+
300
+ # Accumulate sum and sum of squares in one pass
301
+ sum_x += tl.sum(in_chunk, axis=1)
302
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
303
+
304
+ # Finalize statistics
305
+ mean = sum_x / H
306
+ var = (sum_x2 / H) - (mean * mean)
307
+ rstd = tl.math.rsqrt(var + LN_EPS)
308
+
309
+ # --- Pass 3: Fused Gating and Matmul ---
310
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
311
+ for k_offset in range(0, H, BLOCK_SIZE_K):
312
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
313
+ k_mask = offs_k < H
314
+
315
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
316
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
317
+ a_norm = (a - mean[:, None]) * rstd[:, None]
318
+
319
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
320
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
321
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
322
+
323
+ proj_ptrs = ProjW_ptr + \
324
+ offs_n[None, :] * stride_proj_d + \
325
+ offs_k[:, None] * stride_proj_h
326
+
327
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
328
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
329
+ a_gated = a_norm * gate
330
+
331
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
332
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
333
+
334
+ # --- Store Final Output ---
335
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
336
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
337
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
338
+
339
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
340
+
341
+ def compiledtrimul_fused_interleaved(
342
+ x: torch.Tensor,
343
+ mask_mh: torch.Tensor,
344
+ norm_weight: torch.Tensor,
345
+ norm_bias: torch.Tensor,
346
+ W_4way: torch.Tensor, # Use the new weight matrices
347
+ W_og: torch.Tensor,
348
+ to_out_norm_weight: torch.Tensor,
349
+ to_out_norm_bias: torch.Tensor,
350
+ to_out_weight: torch.Tensor,
351
+ h: int,
352
+ ):
353
+ bs, s1, s2, d = x.shape
354
+ M, K, H = bs * s1 * s2, x.shape[-1], h
355
+ x_flat = x.view(M, K)
356
+
357
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
358
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
359
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
360
+
361
+ # The grid is launched for the larger 4*H problem
362
+ N_4way = 4 * H
363
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
364
+ fused_ln_dual_matmul_kernel[grid](
365
+ # Pointers (9)
366
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
367
+ left_final, right_final_t, og_mh,
368
+ # Metadata (5) - M, H, K, s1, s2
369
+ M, H, K, s1, s2,
370
+ # Strides (16)
371
+ x_flat.stride(0), x_flat.stride(1),
372
+ W_4way.stride(0), W_4way.stride(1),
373
+ W_og.stride(0), W_og.stride(1),
374
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
375
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
376
+ og_mh.stride(0), og_mh.stride(1),
377
+ mask_mh.stride(0), mask_mh.stride(1),
378
+ # Constexpr (1)
379
+ LN_EPS=1e-5
380
+ )
381
+
382
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
383
+
384
+ grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
385
+ bmm_coalesced_kernel[grid_bmm](
386
+ left_final, right_final_t, bmm_out_tmp,
387
+ bs, s1, s2, H,
388
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
389
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
390
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
391
+ )
392
+
393
+ # --- Kernel 3: Fully Fused Final Stage ---
394
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
395
+
396
+ grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
397
+ fused_final_kernel[grid_final](
398
+ # Pointers
399
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
400
+ # Metadata
401
+ M, H, d, s1,
402
+ # Strides
403
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
404
+ og_mh.stride(0), og_mh.stride(1),
405
+ to_out_weight.stride(0), to_out_weight.stride(1), # Use strides of the corrected tensor
406
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
407
+ # Constants
408
+ LN_EPS=1e-5,
409
+ )
410
+
411
+ return final_out
412
+
413
+ def pack_w_4way_efficient(weights):
414
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
415
+ WL = weights['left_proj.weight']
416
+ WLG = weights['left_gate.weight']
417
+ WR = weights['right_proj.weight']
418
+ WRG = weights['right_gate.weight']
419
+ H, K = WL.shape
420
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
421
+ ws = ws.contiguous().view(4 * H, K)
422
+ return ws.t().to(torch.float16)
423
+
424
+ def get_w_og(weights):
425
+ """ Gets the transposed [K, H] out_gate weight matrix. """
426
+ WOG = weights['out_gate.weight']
427
+ return WOG.t().to(torch.float16)
428
+
429
+ def compiledtrimul(
430
+ x: torch.Tensor,
431
+ mask: torch.Tensor,
432
+ norm_weight: torch.Tensor,
433
+ norm_bias: torch.Tensor,
434
+ w_concat: torch.Tensor,
435
+ to_out_norm_weight: torch.Tensor,
436
+ to_out_norm_bias: torch.Tensor,
437
+ to_out_weight: torch.Tensor,
438
+ h: int
439
+ ) -> torch.Tensor:
440
+ """
441
+ A barebones, compiled PyTorch function for the TriMul logic.
442
+ """
443
+ bs, s1, s2, d = x.shape
444
+
445
+ # Initial LayerNorm
446
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
447
+ # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
448
+ all_projections = torch.mm(x_norm, w_concat)
449
+
450
+ # Split back into individual projections
451
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
452
+
453
+ # Apply mask and gates
454
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
455
+ left = left * mask_expanded * torch.sigmoid(lg)
456
+ right = right * mask_expanded * torch.sigmoid(rg)
457
+ out_gate = torch.sigmoid(og)
458
+
459
+ # Reshape for einsum
460
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
461
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
462
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
463
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
464
+
465
+ # Apply layer norm and final gating
466
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
467
+ gated = normed * out_gate
468
+
469
+ # Final projection
470
+ final_out_flat = gated @ to_out_weight.t()
471
+ final_out = final_out_flat.view(bs, s1, s2, d)
472
+
473
+ return final_out
474
+
475
+ def small_kernel_pt_path(data):
476
+ input_tensor, mask, weights, config = data
477
+ w_concat = torch.cat([
478
+ weights['left_proj.weight'],
479
+ weights['right_proj.weight'],
480
+ weights['left_gate.weight'],
481
+ weights['right_gate.weight'],
482
+ weights['out_gate.weight']
483
+ ], dim=0).t().contiguous().to(torch.float16)
484
+ # Call the compiled function with prepared weights
485
+ output = compiledtrimul(
486
+ x=input_tensor.to(torch.float32),
487
+ mask=mask.unsqueeze(-1),
488
+ norm_weight=weights['norm.weight'].to(torch.float32),
489
+ norm_bias=weights['norm.bias'].to(torch.float32),
490
+ w_concat=w_concat,
491
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
492
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
493
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
494
+ h=config["hidden_dim"]
495
+ )
496
+ return output
497
+
498
+ def kernel_mi300(data):
499
+ input_tensor, mask, weights, config = data
500
+ bs, s1, s2, d = input_tensor.shape
501
+
502
+ if s1 < 100:
503
+ return small_kernel_pt_path(data)
504
+
505
+ H = config["hidden_dim"]
506
+
507
+ W_4way = pack_w_4way_efficient(weights)
508
+ W_og = get_w_og(weights)
509
+
510
+ M = bs * s1 * s2
511
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
512
+
513
+ return compiledtrimul_fused_interleaved(
514
+ x=input_tensor.to(torch.float32),
515
+ mask_mh=mask_mh,
516
+ norm_weight=weights['norm.weight'].to(torch.float32),
517
+ norm_bias=weights['norm.bias'].to(torch.float32),
518
+ W_4way=W_4way, # Pass the new 4-way matrix
519
+ W_og=W_og, # Pass the new out_gate matrix
520
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
521
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
522
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
523
+ h=H,
524
+ )
build/torch-cuda/triton_a100.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ # Set PyTorch flags for performance
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
9
+
10
+ @triton.jit
11
+ def fused_ln_dual_matmul_kernel(
12
+ # Pointers (9)
13
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
14
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
15
+ # Metadata (5)
16
+ M, H, K, s1, s2,
17
+ # Strides (16)
18
+ stride_x_m, stride_x_k,
19
+ stride_w4_k, stride_w4_n,
20
+ stride_wog_k, stride_wog_n,
21
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
22
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
23
+ stride_og_m, stride_og_h,
24
+ stride_mask_m, stride_mask_h,
25
+ # Constexpr (now passed as arguments from the host)
26
+ LN_EPS: tl.constexpr,
27
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
29
+ ):
30
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
31
+ pid = tl.program_id(axis=0)
32
+ N_4way = 4 * H
33
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
35
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
+ group_id = pid // num_pid_in_group
37
+ first_pid_m = group_id * GROUP_SIZE_M
38
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
+ pid_m = first_pid_m + (pid % group_size_m)
40
+ pid_n = (pid % num_pid_in_group) // group_size_m
41
+
42
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
43
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44
+ m_mask = offs_m < M
45
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
46
+
47
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
48
+ for k_offset in range(0, K, BLOCK_SIZE_K):
49
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
50
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
51
+ k_mask = (k_offset + k_chunk_offs) < K
52
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
53
+ mean += tl.sum(x_chunk, axis=1)
54
+ mean /= K
55
+
56
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
57
+ for k_offset in range(0, K, BLOCK_SIZE_K):
58
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
59
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
60
+ k_mask = (k_offset + k_chunk_offs) < K
61
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
62
+ x_centered = x_chunk - mean[:, None]
63
+ var += tl.sum(x_centered * x_centered, axis=1)
64
+ var /= K
65
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
66
+
67
+ # --- Matmul Loop 1: For the 4-Way Projections ---
68
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
69
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
70
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
71
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+
74
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76
+ k_block_start = k * BLOCK_SIZE_K;
77
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
78
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
79
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
80
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
81
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
82
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
83
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
84
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
86
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
87
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
88
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
89
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
90
+
91
+ if pid_n * BLOCK_SIZE_N < H:
92
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
97
+
98
+ if pid_n * BLOCK_SIZE_N < H:
99
+ og_out = tl.sigmoid(accumulator_og)
100
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
+ tl.store(outg_ptrs, og_out, mask=og_mask)
103
+
104
+ # --- Fusion Logic for 4-Way Part ---
105
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
+ role_idx = tl.arange(0, 4)[None, None, :]
107
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
+
112
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
+
117
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
+
120
+ s1s2 = s1 * s2
121
+ offs_b = offs_m // s1s2
122
+ offs_s1 = (offs_m % s1s2) // s2
123
+ offs_s2 = offs_m % s2
124
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
+
129
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
+
136
+ @triton.jit
137
+ def bmm_coalesced_kernel(
138
+ # Pointers
139
+ Left_ptr, Right_ptr, Out_ptr,
140
+ # Dimensions
141
+ bs, s1, s2, H,
142
+ # Strides
143
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
+ # Kernel parameters
147
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
+ GROUP_SIZE_M: tl.constexpr,
149
+ ):
150
+ pid = tl.program_id(axis=0)
151
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
152
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
153
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
154
+ group_id = pid // num_pid_in_group
155
+ first_pid_m = group_id * GROUP_SIZE_M
156
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
157
+ pid_m = first_pid_m + (pid % group_size_m)
158
+ pid_n = (pid % num_pid_in_group) // group_size_m
159
+
160
+ pid_bh = tl.program_id(axis=1)
161
+ pid_b = pid_bh // H
162
+ pid_h = pid_bh % H
163
+
164
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
165
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
166
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
167
+
168
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
169
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
170
+
171
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
172
+
173
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
174
+ k_start = k * BLOCK_SIZE_K
175
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
176
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
177
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
178
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
179
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
180
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
181
+ accumulator += tl.dot(a, b)
182
+
183
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
184
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
185
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
186
+ tl.store(out_ptrs, accumulator, mask=c_mask)
187
+
188
+ @triton.jit
189
+ def fused_final_kernel(
190
+ # Pointers
191
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
192
+ # Metadata
193
+ M, H, D, s1,
194
+ # Strides
195
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
196
+ stride_gate_m, stride_gate_h,
197
+ stride_proj_d, stride_proj_h,
198
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
199
+ # Constants
200
+ LN_EPS: tl.constexpr,
201
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
202
+ GROUP_SIZE_M: tl.constexpr,
203
+ ):
204
+ pid = tl.program_id(axis=0)
205
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
207
+
208
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
+ group_id = pid // num_pid_in_group
210
+ first_pid_m = group_id * GROUP_SIZE_M
211
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
+ pid_m = first_pid_m + (pid % group_size_m)
213
+ pid_n = (pid % num_pid_in_group) // group_size_m
214
+
215
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
216
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ m_mask = offs_m < M
218
+
219
+ s1s1 = s1 * s1
220
+ b = offs_m // s1s1
221
+ r = (offs_m % s1s1) // s1
222
+ c = offs_m % s1
223
+
224
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
225
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
226
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
227
+
228
+ for k_offset in range(0, H, BLOCK_SIZE_K):
229
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
230
+ k_mask = offs_k < H
231
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
232
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
233
+ sum_x += tl.sum(in_chunk, axis=1)
234
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
235
+
236
+ mean = sum_x / H
237
+ var = (sum_x2 / H) - (mean * mean)
238
+ rstd = tl.math.rsqrt(var + LN_EPS)
239
+
240
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
241
+ for k_offset in range(0, H, BLOCK_SIZE_K):
242
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
243
+ k_mask = offs_k < H
244
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
245
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
246
+ a_norm = (a - mean[:, None]) * rstd[:, None]
247
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
248
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
249
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
250
+ proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
251
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
252
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
253
+ a_gated = a_norm * gate
254
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
255
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
256
+
257
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
258
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
259
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
260
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
261
+
262
+ def compiledtrimul_fused_interleaved_final(
263
+ x: torch.Tensor,
264
+ mask_mh: torch.Tensor,
265
+ norm_weight: torch.Tensor,
266
+ norm_bias: torch.Tensor,
267
+ W_4way: torch.Tensor,
268
+ W_og: torch.Tensor,
269
+ to_out_norm_weight: torch.Tensor,
270
+ to_out_norm_bias: torch.Tensor,
271
+ to_out_weight: torch.Tensor,
272
+ h: int,
273
+ ):
274
+ bs, s1, s2, d = x.shape
275
+ M, K, H = bs * s1 * s2, x.shape[-1], h
276
+ x_flat = x.view(M, K)
277
+
278
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
279
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
280
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
281
+
282
+ # --- Kernel 1: Fused LN + Dual Matmul ---
283
+ N_4way = 4 * H
284
+ # Hardcoded A100 best config: M128-N128-K32-GM8-HC32-W8-S2
285
+ config_k1 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
286
+ grid_k1 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
287
+
288
+ fused_ln_dual_matmul_kernel[grid_k1](
289
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
290
+ left_final, right_final_t, og_mh,
291
+ M, H, K, s1, s2,
292
+ x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
293
+ W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
294
+ left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
295
+ right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
296
+ mask_mh.stride(0), mask_mh.stride(1),
297
+ LN_EPS=1e-5, **config_k1, num_warps=8, num_stages=2
298
+ )
299
+
300
+ # --- Kernel 2: Batched Matrix Multiplication ---
301
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
302
+ # Hardcoded A100 best config: M128-N64-K32-GM8-W4-S3
303
+ config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
304
+ grid_k2 = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
305
+
306
+ bmm_coalesced_kernel[grid_k2](
307
+ left_final, right_final_t, bmm_out_tmp,
308
+ bs, s1, s2, H,
309
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
310
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
311
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
312
+ **config_k2, num_warps=4, num_stages=3
313
+ )
314
+
315
+ # --- Kernel 3: Fully Fused Final Stage ---
316
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
317
+ # Hardcoded A100 best config: M32-N128-K32-GM8-W4-S3
318
+ config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
319
+ grid_k3 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
320
+
321
+ fused_final_kernel[grid_k3](
322
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
323
+ M, H, d, s1,
324
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
325
+ og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
326
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
327
+ LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
328
+ )
329
+ return final_out
330
+
331
+ def pack_w_4way_efficient(weights):
332
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
333
+ WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
334
+ H, K = WL.shape
335
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
336
+ return ws.t().to(torch.float16)
337
+
338
+ def get_w_og(weights):
339
+ """ Gets the transposed [K, H] out_gate weight matrix. """
340
+ return weights['out_gate.weight'].t().to(torch.float16)
341
+
342
+ @torch.compile()
343
+ def compiledtrimul(
344
+ x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
345
+ w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
346
+ to_out_weight: torch.Tensor, h: int
347
+ ) -> torch.Tensor:
348
+ bs, s1, s2, d = x.shape
349
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
350
+ all_projections = torch.mm(x_norm, w_concat)
351
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
352
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
353
+ left = left * mask_expanded * torch.sigmoid(lg)
354
+ right = right * mask_expanded * torch.sigmoid(rg)
355
+ out_gate = torch.sigmoid(og)
356
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
357
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
358
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
359
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
360
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
361
+ gated = normed * out_gate
362
+ final_out_flat = gated @ to_out_weight.t()
363
+ return final_out_flat.view(bs, s1, s1, d)
364
+
365
+ def small_kernel_pt_path(data):
366
+ input_tensor, mask, weights, config = data
367
+ w_concat = torch.cat([
368
+ weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
369
+ weights['right_gate.weight'], weights['out_gate.weight']
370
+ ], dim=0).t().contiguous().to(torch.float16)
371
+ return compiledtrimul(
372
+ x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
373
+ norm_weight=weights['norm.weight'].to(torch.float32),
374
+ norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
375
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
376
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
377
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
378
+ h=config["hidden_dim"]
379
+ )
380
+
381
+ def kernel_a100(data):
382
+ input_tensor, mask, weights, config = data
383
+ bs, s1, s2, d = input_tensor.shape
384
+
385
+ if s1 < 512: # Adjusted threshold based on observed BMM configs
386
+ return small_kernel_pt_path(data)
387
+
388
+ H = config["hidden_dim"]
389
+ W_4way = pack_w_4way_efficient(weights)
390
+ W_og = get_w_og(weights)
391
+ M = bs * s1 * s2
392
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
393
+
394
+ return compiledtrimul_fused_interleaved_final(
395
+ x=input_tensor.to(torch.float32),
396
+ mask_mh=mask_mh,
397
+ norm_weight=weights['norm.weight'].to(torch.float32),
398
+ norm_bias=weights['norm.bias'].to(torch.float32),
399
+ W_4way=W_4way,
400
+ W_og=W_og,
401
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
402
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
403
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
404
+ h=H,
405
+ )
build/torch-cuda/triton_b200.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.jit
10
+ def fused_ln_dual_matmul_kernel(
11
+ # Pointers (9)
12
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
13
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
14
+ # Metadata (5)
15
+ M, H, K, s1, s2,
16
+ # Strides (16)
17
+ stride_x_m, stride_x_k,
18
+ stride_w4_k, stride_w4_n,
19
+ stride_wog_k, stride_wog_n,
20
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
21
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
22
+ stride_og_m, stride_og_h,
23
+ stride_mask_m, stride_mask_h,
24
+ # Constexpr (now passed as arguments from the host)
25
+ LN_EPS: tl.constexpr,
26
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
27
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
28
+ ):
29
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
30
+ pid = tl.program_id(axis=0)
31
+ N_4way = 4 * H
32
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
34
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
35
+ group_id = pid // num_pid_in_group
36
+ first_pid_m = group_id * GROUP_SIZE_M
37
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38
+ pid_m = first_pid_m + (pid % group_size_m)
39
+ pid_n = (pid % num_pid_in_group) // group_size_m
40
+
41
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
42
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
43
+ m_mask = offs_m < M
44
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
45
+
46
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
47
+ for k_offset in range(0, K, BLOCK_SIZE_K):
48
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
49
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
50
+ k_mask = (k_offset + k_chunk_offs) < K
51
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
52
+ mean += tl.sum(x_chunk, axis=1)
53
+ mean /= K
54
+
55
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
56
+ for k_offset in range(0, K, BLOCK_SIZE_K):
57
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
58
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
59
+ k_mask = (k_offset + k_chunk_offs) < K
60
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
61
+ x_centered = x_chunk - mean[:, None]
62
+ var += tl.sum(x_centered * x_centered, axis=1)
63
+ var /= K
64
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
65
+
66
+ # --- Matmul Loop 1: For the 4-Way Projections ---
67
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
69
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
70
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
71
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
+
73
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
74
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75
+ k_block_start = k * BLOCK_SIZE_K;
76
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
77
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
78
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
79
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
80
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
81
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
82
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
83
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
84
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
86
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
87
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
88
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
89
+
90
+ #Some threads should calclate out_gate
91
+ if pid_n * BLOCK_SIZE_N < H:
92
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
97
+
98
+ if pid_n * BLOCK_SIZE_N < H:
99
+ og_out = tl.sigmoid(accumulator_og)
100
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
+ tl.store(outg_ptrs, og_out, mask=og_mask)
103
+
104
+ # --- Fusion Logic for 4-Way Part ---
105
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
+ role_idx = tl.arange(0, 4)[None, None, :]
107
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
+
112
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
+
117
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
+
120
+ s1s2 = s1 * s2
121
+ offs_b = offs_m // s1s2
122
+ offs_s1 = (offs_m % s1s2) // s2
123
+ offs_s2 = offs_m % s2
124
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
+
129
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
+
136
+ @triton.jit
137
+ def bmm_coalesced_kernel(
138
+ # Pointers
139
+ Left_ptr, Right_ptr, Out_ptr,
140
+ # Dimensions
141
+ bs, s1, s2, H,
142
+ # Strides
143
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
+ # Kernel parameters
147
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
+ GROUP_SIZE_M: tl.constexpr,
149
+ ):
150
+ # Grid and program IDs
151
+ pid = tl.program_id(axis=0)
152
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
153
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
154
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
155
+ group_id = pid // num_pid_in_group
156
+ first_pid_m = group_id * GROUP_SIZE_M
157
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
158
+ pid_m = first_pid_m + (pid % group_size_m)
159
+ pid_n = (pid % num_pid_in_group) // group_size_m
160
+
161
+ pid_bh = tl.program_id(axis=1)
162
+ pid_b = pid_bh // H
163
+ pid_h = pid_bh % H
164
+
165
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
166
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
167
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
168
+
169
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
170
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
171
+
172
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173
+
174
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
175
+ k_start = k * BLOCK_SIZE_K
176
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
177
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
178
+
179
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
180
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
181
+
182
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
183
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
184
+
185
+ accumulator += tl.dot(a, b)
186
+
187
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
188
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
189
+
190
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
191
+ tl.store(out_ptrs, accumulator, mask=c_mask)
192
+
193
+ @triton.jit
194
+ def fused_final_kernel(
195
+ # Pointers
196
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
197
+ # Metadata
198
+ M, H, D, s1,
199
+ # Strides
200
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
201
+ stride_gate_m, stride_gate_h,
202
+ stride_proj_d, stride_proj_h,
203
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
204
+ # Constants
205
+ LN_EPS: tl.constexpr,
206
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
207
+ GROUP_SIZE_M: tl.constexpr,
208
+ ):
209
+ pid = tl.program_id(axis=0)
210
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
211
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
212
+
213
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
214
+ group_id = pid // num_pid_in_group
215
+ first_pid_m = group_id * GROUP_SIZE_M
216
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217
+ pid_m = first_pid_m + (pid % group_size_m)
218
+ pid_n = (pid % num_pid_in_group) // group_size_m
219
+
220
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
221
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
222
+ m_mask = offs_m < M
223
+
224
+ s1s1 = s1 * s1
225
+ b = offs_m // s1s1
226
+ r = (offs_m % s1s1) // s1
227
+ c = offs_m % s1
228
+
229
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
230
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
231
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
232
+
233
+ for k_offset in range(0, H, BLOCK_SIZE_K):
234
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
235
+ k_mask = offs_k < H
236
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
237
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
238
+ sum_x += tl.sum(in_chunk, axis=1)
239
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
240
+
241
+ mean = sum_x / H
242
+ var = (sum_x2 / H) - (mean * mean)
243
+ rstd = tl.math.rsqrt(var + LN_EPS)
244
+
245
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
246
+ for k_offset in range(0, H, BLOCK_SIZE_K):
247
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
248
+ k_mask = offs_k < H
249
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
250
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
251
+ a_norm = (a - mean[:, None]) * rstd[:, None]
252
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
253
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
254
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
255
+ proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
256
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
257
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
258
+ a_gated = a_norm * gate
259
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
260
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
261
+
262
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
263
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
264
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
265
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
266
+
267
+ def compiledtrimul_fused_interleaved_final(
268
+ x: torch.Tensor,
269
+ mask_mh: torch.Tensor,
270
+ norm_weight: torch.Tensor,
271
+ norm_bias: torch.Tensor,
272
+ W_4way: torch.Tensor,
273
+ W_og: torch.Tensor,
274
+ to_out_norm_weight: torch.Tensor,
275
+ to_out_norm_bias: torch.Tensor,
276
+ to_out_weight: torch.Tensor,
277
+ h: int,
278
+ ):
279
+ bs, s1, s2, d = x.shape
280
+ M, K, H = bs * s1 * s2, x.shape[-1], h
281
+ x_flat = x.view(M, K)
282
+
283
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
284
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
285
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
286
+
287
+ # --- Kernel 1: Fused LN + Dual Matmul ---
288
+ # The grid is launched for the larger 4*H problem
289
+ N_4way = 4 * H
290
+ # Hardcoded best config from logs: M64-N128-K64-GM8-HC32-W4-S2
291
+ config_k1 = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
292
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
293
+
294
+ fused_ln_dual_matmul_kernel[grid](
295
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
296
+ left_final, right_final_t, og_mh,
297
+ M, H, K, s1, s2,
298
+ x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
299
+ W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
300
+ left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
301
+ right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
302
+ mask_mh.stride(0), mask_mh.stride(1),
303
+ LN_EPS=1e-5, **config_k1, num_warps=4, num_stages=2
304
+ )
305
+
306
+ # --- Kernel 2: Batched Matrix Multiplication ---
307
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
308
+ # Hardcoded best config from logs: M128-N128-K32-GM8-W8-S3
309
+ config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
310
+ grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
311
+
312
+ bmm_coalesced_kernel[grid_bmm](
313
+ left_final, right_final_t, bmm_out_tmp,
314
+ bs, s1, s2, H,
315
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
316
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
317
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
318
+ **config_k2, num_warps=8, num_stages=3
319
+ )
320
+
321
+ # --- Kernel 3: Fully Fused Final Stage ---
322
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
323
+ # Hardcoded best config from logs: M32-N128-K32-GM8-W4-S3
324
+ config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
325
+ grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
326
+
327
+ fused_final_kernel[grid_final](
328
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
329
+ M, H, d, s1,
330
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
331
+ og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
332
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
333
+ LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
334
+ )
335
+ return final_out
336
+
337
+ def pack_w_4way_efficient(weights):
338
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
339
+ WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
340
+ H, K = WL.shape
341
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
342
+ return ws.t().to(torch.float16)
343
+
344
+ def get_w_og(weights):
345
+ """ Gets the transposed [K, H] out_gate weight matrix. """
346
+ return weights['out_gate.weight'].t().to(torch.float16)
347
+
348
+ @torch.compile()
349
+ def compiledtrimul(
350
+ x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
351
+ w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
352
+ to_out_weight: torch.Tensor, h: int
353
+ ) -> torch.Tensor:
354
+ bs, s1, s2, d = x.shape
355
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
356
+ all_projections = torch.mm(x_norm, w_concat)
357
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
358
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
359
+ left = left * mask_expanded * torch.sigmoid(lg)
360
+ right = right * mask_expanded * torch.sigmoid(rg)
361
+ out_gate = torch.sigmoid(og)
362
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
363
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
364
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
365
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
366
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
367
+ gated = normed * out_gate
368
+ final_out_flat = gated @ to_out_weight.t()
369
+ return final_out_flat.view(bs, s1, s1, d)
370
+
371
+ def small_kernel_pt_path(data):
372
+ input_tensor, mask, weights, config = data
373
+ w_concat = torch.cat([
374
+ weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
375
+ weights['right_gate.weight'], weights['out_gate.weight']
376
+ ], dim=0).t().contiguous().to(torch.float16)
377
+ return compiledtrimul(
378
+ x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
379
+ norm_weight=weights['norm.weight'].to(torch.float32),
380
+ norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
381
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
382
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
383
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
384
+ h=config["hidden_dim"]
385
+ )
386
+
387
+ def kernel_b200(data):
388
+ input_tensor, mask, weights, config = data
389
+ bs, s1, s2, d = input_tensor.shape
390
+
391
+ if s1 < 800:
392
+ return small_kernel_pt_path(data)
393
+
394
+ H = config["hidden_dim"]
395
+ W_4way = pack_w_4way_efficient(weights)
396
+ W_og = get_w_og(weights)
397
+ M = bs * s1 * s2
398
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
399
+
400
+ return compiledtrimul_fused_interleaved_final(
401
+ x=input_tensor.to(torch.float32),
402
+ mask_mh=mask_mh,
403
+ norm_weight=weights['norm.weight'].to(torch.float32),
404
+ norm_bias=weights['norm.bias'].to(torch.float32),
405
+ W_4way=W_4way,
406
+ W_og=W_og,
407
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
408
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
409
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
410
+ h=H,
411
+ )
build/torch-cuda/triton_h100.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
12
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
13
+
14
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
15
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
16
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
17
+
18
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=4),
19
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
20
+
21
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=5),
22
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=5),
23
+
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
25
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=4),
26
+ ],
27
+ key=['M', 'N', 'K'],
28
+ )
29
+ @triton.jit
30
+ def fused_ln_dual_matmul_kernel(
31
+ # Pointers (9)
32
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
33
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
34
+ # Metadata (5)
35
+ M, H, K, s1, s2,
36
+ # Strides (16)
37
+ stride_x_m, stride_x_k,
38
+ stride_w4_k, stride_w4_n,
39
+ stride_wog_k, stride_wog_n,
40
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
41
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
42
+ stride_og_m, stride_og_h,
43
+ stride_mask_m, stride_mask_h,
44
+ # Constexpr (from decorator and kwargs)
45
+ LN_EPS: tl.constexpr,
46
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
47
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
48
+ ):
49
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
50
+ pid = tl.program_id(axis=0)
51
+ N_4way = 4 * H
52
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
54
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
55
+ group_id = pid // num_pid_in_group
56
+ first_pid_m = group_id * GROUP_SIZE_M
57
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
+ pid_m = first_pid_m + (pid % group_size_m)
59
+ pid_n = (pid % num_pid_in_group) // group_size_m
60
+
61
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
62
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63
+ m_mask = offs_m < M
64
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
65
+
66
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
67
+ for k_offset in range(0, K, BLOCK_SIZE_K):
68
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
69
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
70
+ k_mask = (k_offset + k_chunk_offs) < K
71
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
72
+ mean += tl.sum(x_chunk, axis=1)
73
+ mean /= K
74
+
75
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
76
+ for k_offset in range(0, K, BLOCK_SIZE_K):
77
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
78
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
79
+ k_mask = (k_offset + k_chunk_offs) < K
80
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
81
+ x_centered = x_chunk - mean[:, None]
82
+ var += tl.sum(x_centered * x_centered, axis=1)
83
+ var /= K
84
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
85
+
86
+ # --- Matmul Loop 1: For the 4-Way Projections ---
87
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
89
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
90
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
91
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92
+
93
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
94
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
95
+ k_block_start = k * BLOCK_SIZE_K;
96
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
97
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
98
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
99
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
100
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
101
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
102
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
103
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
104
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
105
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
106
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
107
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
108
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
109
+
110
+ #Some threads should calclate out_gate
111
+ if pid_n * BLOCK_SIZE_N < H:
112
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
113
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
114
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
115
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
117
+
118
+ if pid_n * BLOCK_SIZE_N < H:
119
+ og_out = tl.sigmoid(accumulator_og)
120
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
121
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
122
+ tl.store(outg_ptrs, og_out, mask=og_mask)
123
+
124
+ # --- Fusion Logic for 4-Way Part ---
125
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
126
+ role_idx = tl.arange(0, 4)[None, None, :]
127
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
128
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
129
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
130
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
131
+
132
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
133
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
134
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
135
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
136
+
137
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
138
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
139
+
140
+ s1s2 = s1 * s2
141
+ offs_b = offs_m // s1s2
142
+ offs_s1 = (offs_m % s1s2) // s2
143
+ offs_s2 = offs_m % s2
144
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
145
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
146
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
147
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
148
+
149
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
150
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
151
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
152
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
153
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
154
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
155
+
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
159
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
160
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
161
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
162
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
163
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
164
+ ],
165
+ key=['s1', 's2', 'H'],
166
+ )
167
+ @triton.jit
168
+ def bmm_coalesced_kernel(
169
+ # Pointers
170
+ Left_ptr, Right_ptr, Out_ptr,
171
+ # Dimensions
172
+ bs, s1, s2, H,
173
+ # Strides
174
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
175
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
176
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
177
+ # Kernel parameters
178
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
179
+ GROUP_SIZE_M: tl.constexpr,
180
+ ):
181
+ # Grid and program IDs
182
+ pid = tl.program_id(axis=0)
183
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
184
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
185
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
186
+ group_id = pid // num_pid_in_group
187
+ first_pid_m = group_id * GROUP_SIZE_M
188
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
189
+ pid_m = first_pid_m + (pid % group_size_m)
190
+ pid_n = (pid % num_pid_in_group) // group_size_m
191
+
192
+ pid_bh = tl.program_id(axis=1)
193
+ pid_b = pid_bh // H
194
+ pid_h = pid_bh % H
195
+
196
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
199
+
200
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
201
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
202
+
203
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
204
+
205
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
206
+ k_start = k * BLOCK_SIZE_K
207
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
208
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
209
+
210
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
211
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
212
+
213
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
214
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
215
+
216
+ accumulator += tl.dot(a, b)
217
+
218
+ # --- Coalesced Write ---
219
+ # Write to a standard (bs, H, s1, s1) layout
220
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
221
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
222
+
223
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
224
+ tl.store(out_ptrs, accumulator, mask=c_mask)
225
+
226
+ @torch.compile
227
+ def torch_pt2(left_final, right_final_t, bs, s1, s2, d, h, to_out_norm_weight, to_out_norm_bias, og_mh, to_out_weight):
228
+ bmm_out = torch.matmul(left_final, right_final_t)
229
+ out_einsum_flat = bmm_out.permute(0, 2, 3, 1).reshape(bs * s1 * s1, h)
230
+ # Apply layer norm and final gating
231
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
232
+ gated = normed * og_mh
233
+
234
+ # Final projection
235
+ final_out_flat = gated @ to_out_weight.t()
236
+ final_out = final_out_flat.view(bs, s1, s2, d)
237
+ return final_out
238
+
239
+ @triton.autotune(
240
+ configs=[
241
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
243
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
244
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
245
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
247
+ ],
248
+ key=['H', 'D'],
249
+ )
250
+ @triton.jit
251
+ def fused_final_kernel(
252
+ # Pointers
253
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
254
+ # Metadata
255
+ M, H, D, s1, # M_gate = bs*s1*s2
256
+ # Strides
257
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
258
+ stride_gate_m, stride_gate_h,
259
+ stride_proj_d, stride_proj_h,
260
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
261
+ # Constants
262
+ LN_EPS: tl.constexpr,
263
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
264
+ GROUP_SIZE_M: tl.constexpr,
265
+ ):
266
+ # --- Grid and PID Setup for Matmul ---
267
+ pid = tl.program_id(axis=0)
268
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
269
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
270
+
271
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
272
+ group_id = pid // num_pid_in_group
273
+ first_pid_m = group_id * GROUP_SIZE_M
274
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
275
+ pid_m = first_pid_m + (pid % group_size_m)
276
+ pid_n = (pid % num_pid_in_group) // group_size_m
277
+
278
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
279
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
280
+ m_mask = offs_m < M
281
+
282
+ # Decompose M back to (b, r, c) for reordering lookups
283
+ s1s1 = s1 * s1
284
+ b = offs_m // s1s1
285
+ r = (offs_m % s1s1) // s1
286
+ c = offs_m % s1
287
+
288
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
289
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
291
+
292
+ for k_offset in range(0, H, BLOCK_SIZE_K):
293
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
294
+ k_mask = offs_k < H
295
+
296
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
297
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
298
+
299
+ # Accumulate sum and sum of squares in one pass
300
+ sum_x += tl.sum(in_chunk, axis=1)
301
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
302
+
303
+ # Finalize statistics
304
+ mean = sum_x / H
305
+ var = (sum_x2 / H) - (mean * mean)
306
+ rstd = tl.math.rsqrt(var + LN_EPS)
307
+
308
+ # --- Pass 3: Fused Gating and Matmul ---
309
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
+ for k_offset in range(0, H, BLOCK_SIZE_K):
311
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
312
+ k_mask = offs_k < H
313
+
314
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
315
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
316
+ a_norm = (a - mean[:, None]) * rstd[:, None]
317
+
318
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
319
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
320
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
321
+
322
+ proj_ptrs = ProjW_ptr + \
323
+ offs_n[None, :] * stride_proj_d + \
324
+ offs_k[:, None] * stride_proj_h
325
+
326
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
327
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
328
+ a_gated = a_norm * gate
329
+
330
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
331
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
332
+
333
+ # --- Store Final Output ---
334
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
336
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
337
+
338
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
339
+
340
+ def compiledtrimul_fused_interleaved(
341
+ x: torch.Tensor,
342
+ mask_mh: torch.Tensor,
343
+ norm_weight: torch.Tensor,
344
+ norm_bias: torch.Tensor,
345
+ W_4way: torch.Tensor, # Use the new weight matrices
346
+ W_og: torch.Tensor,
347
+ to_out_norm_weight: torch.Tensor,
348
+ to_out_norm_bias: torch.Tensor,
349
+ to_out_weight: torch.Tensor,
350
+ h: int,
351
+ ):
352
+ bs, s1, s2, d = x.shape
353
+ M, K, H = bs * s1 * s2, x.shape[-1], h
354
+ x_flat = x.view(M, K)
355
+
356
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
357
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
358
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
359
+
360
+ # The grid is launched for the larger 4*H problem
361
+ N_4way = 4 * H
362
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
363
+ fused_ln_dual_matmul_kernel[grid](
364
+ # Pointers (9)
365
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
366
+ left_final, right_final_t, og_mh,
367
+ # Metadata (5) - M, H, K, s1, s2
368
+ M, H, K, s1, s2,
369
+ # Strides (16)
370
+ x_flat.stride(0), x_flat.stride(1),
371
+ W_4way.stride(0), W_4way.stride(1),
372
+ W_og.stride(0), W_og.stride(1),
373
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
374
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
375
+ og_mh.stride(0), og_mh.stride(1),
376
+ mask_mh.stride(0), mask_mh.stride(1),
377
+ # Constexpr (1)
378
+ LN_EPS=1e-5
379
+ )
380
+ return torch_pt2(
381
+ left_final, right_final_t,
382
+ bs=bs,
383
+ s1=s1,
384
+ s2=s2,
385
+ d=d,
386
+ h=h,
387
+ to_out_norm_weight=to_out_norm_weight,
388
+ to_out_norm_bias=to_out_norm_bias,
389
+ og_mh=og_mh,
390
+ to_out_weight=to_out_weight
391
+ )
392
+
393
+ def pack_w_4way_efficient(weights):
394
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
395
+ WL = weights['left_proj.weight']
396
+ WLG = weights['left_gate.weight']
397
+ WR = weights['right_proj.weight']
398
+ WRG = weights['right_gate.weight']
399
+ H, K = WL.shape
400
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
401
+ ws = ws.contiguous().view(4 * H, K)
402
+ return ws.t().to(torch.float16)
403
+
404
+ def get_w_og(weights):
405
+ """ Gets the transposed [K, H] out_gate weight matrix. """
406
+ WOG = weights['out_gate.weight']
407
+ return WOG.t().to(torch.float16)
408
+
409
+
410
+ torch.backends.cuda.matmul.allow_tf32 = True
411
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
412
+
413
+ @torch.compile
414
+ def compiledtrimul(
415
+ x: torch.Tensor,
416
+ mask: torch.Tensor,
417
+ norm_weight: torch.Tensor,
418
+ norm_bias: torch.Tensor,
419
+ w_concat: torch.Tensor,
420
+ to_out_norm_weight: torch.Tensor,
421
+ to_out_norm_bias: torch.Tensor,
422
+ to_out_weight: torch.Tensor,
423
+ h: int
424
+ ) -> torch.Tensor:
425
+ """
426
+ A barebones, compiled PyTorch function for the TriMul logic.
427
+ """
428
+ bs, s1, s2, d = x.shape
429
+
430
+ # Initial LayerNorm
431
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
432
+ # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
433
+ all_projections = torch.mm(x_norm, w_concat)
434
+
435
+ # Split back into individual projections
436
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
437
+
438
+ # Apply mask and gates
439
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
440
+ left = left * mask_expanded * torch.sigmoid(lg)
441
+ right = right * mask_expanded * torch.sigmoid(rg)
442
+ out_gate = torch.sigmoid(og)
443
+
444
+ # Reshape for einsum
445
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
446
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
447
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
448
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
449
+
450
+ # Apply layer norm and final gating
451
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
452
+ gated = normed * out_gate
453
+
454
+ # Final projection
455
+ final_out_flat = gated @ to_out_weight.t()
456
+ final_out = final_out_flat.view(bs, s1, s2, d)
457
+
458
+ return final_out
459
+
460
+ def small_kernel_pt_path(data):
461
+ input_tensor, mask, weights, config = data
462
+ w_concat = torch.cat([
463
+ weights['left_proj.weight'],
464
+ weights['right_proj.weight'],
465
+ weights['left_gate.weight'],
466
+ weights['right_gate.weight'],
467
+ weights['out_gate.weight']
468
+ ], dim=0).t().contiguous().to(torch.float16)
469
+ # Call the compiled function with prepared weights
470
+ output = compiledtrimul(
471
+ x=input_tensor.to(torch.float32),
472
+ mask=mask.unsqueeze(-1),
473
+ norm_weight=weights['norm.weight'].to(torch.float32),
474
+ norm_bias=weights['norm.bias'].to(torch.float32),
475
+ w_concat=w_concat,
476
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float32),
477
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float32),
478
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
479
+ h=config["hidden_dim"]
480
+ )
481
+ return output
482
+
483
+ def kernel_h100(data):
484
+ input_tensor, mask, weights, config = data
485
+ bs, s1, s2, d = input_tensor.shape
486
+
487
+ if s1 <= 512:
488
+ return small_kernel_pt_path(data)
489
+
490
+ H = config["hidden_dim"]
491
+
492
+ W_4way = pack_w_4way_efficient(weights)
493
+ W_og = get_w_og(weights)
494
+
495
+ M = bs * s1 * s2
496
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
497
+
498
+ return compiledtrimul_fused_interleaved(
499
+ x=input_tensor.to(torch.float32),
500
+ mask_mh=mask_mh,
501
+ norm_weight=weights['norm.weight'].to(torch.float32),
502
+ norm_bias=weights['norm.bias'].to(torch.float32),
503
+ W_4way=W_4way, # Pass the new 4-way matrix
504
+ W_og=W_og, # Pass the new out_gate matrix
505
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
506
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
507
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
508
+ h=H,
509
+ )
build/torch-rocm/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .triton_a100 import kernel_a100
2
+ from .triton_h100 import kernel_h100
3
+ from .triton_b200 import kernel_b200
4
+ from .trimul_mi300 import kernel_mi300
5
+ from .trimul_global import kernel_global
6
+
7
+ __all__ = ["kernel_a100", "kernel_h100", "kernel_b200", "kernel_mi300", "kernel_global"]
build/torch-rocm/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._trimul_gpumode_176b4e4
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_trimul_gpumode_176b4e4::{op_name}"
build/torch-rocm/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-rocm/task.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Type definitions for TriMul task.
3
+
4
+ Input: Tuple of (input_tensor, mask, weights, config)
5
+ - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
6
+ - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
7
+ - weights: Dictionary containing model weights
8
+ - config: Dictionary containing model configuration parameters
9
+
10
+ Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
11
+ """
12
+
13
+ import torch
14
+ from typing import Tuple, Dict, Any
15
+
16
+ # Input type: (input_tensor, mask, weights, config)
17
+ input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
18
+
19
+ # Output type: output tensor
20
+ output_t = torch.Tensor
build/torch-rocm/trimul_global.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from utils import make_match_reference, DisableCuDNNTF32
2
+ from .task import input_t, output_t
3
+
4
+ import torch
5
+ from torch import nn, einsum
6
+ import math
7
+ import os
8
+ import requests
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
14
+ # in PyTorch 1.12 and later.
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+
17
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
18
+ torch.backends.cudnn.allow_tf32 = True
19
+
20
+ # Set allocator for TMA descriptors (required for on-device TMA)
21
+ def alloc_fn(size: int, alignment: int, stream=None):
22
+ return torch.empty(size, device="cuda", dtype=torch.int8)
23
+
24
+ triton.set_allocator(alloc_fn)
25
+
26
+ # os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
27
+ # os.environ['MLIR_ENABLE_DIAGNOSTICS'] = 'warnings,remarks'
28
+
29
+ # Reference code in PyTorch
30
+ class TriMul(nn.Module):
31
+ # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
32
+ def __init__(
33
+ self,
34
+ dim: int,
35
+ hidden_dim: int,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm = nn.LayerNorm(dim)
40
+
41
+ self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
42
+ self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
43
+
44
+ self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
45
+ self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
46
+ self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
47
+
48
+ self.to_out_norm = nn.LayerNorm(hidden_dim)
49
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
50
+
51
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ x: [bs, seq_len, seq_len, dim]
54
+ mask: [bs, seq_len, seq_len]
55
+
56
+ Returns:
57
+ output: [bs, seq_len, seq_len, dim]
58
+ """
59
+ batch_size, seq_len, _, dim = x.shape
60
+
61
+ x = self.norm(x)
62
+
63
+ left = self.left_proj(x)
64
+ right = self.right_proj(x)
65
+
66
+ mask = mask.unsqueeze(-1)
67
+ left = left * mask
68
+ right = right * mask
69
+
70
+ left_gate = self.left_gate(x).sigmoid()
71
+ right_gate = self.right_gate(x).sigmoid()
72
+ out_gate = self.out_gate(x).sigmoid()
73
+
74
+ left = left * left_gate
75
+ right = right * right_gate
76
+
77
+ out = einsum('... i k d, ... j k d -> ... i j d', left, right)
78
+ # This einsum is the same as the following:
79
+ # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
80
+
81
+ # # Compute using nested loops
82
+ # for b in range(batch_size):
83
+ # for i in range(seq_len):
84
+ # for j in range(seq_len):
85
+ # # Compute each output element
86
+ # for k in range(seq_len):
87
+ # out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
88
+
89
+ out = self.to_out_norm(out)
90
+ out = out * out_gate
91
+ return self.to_out(out)
92
+
93
+ @triton.jit
94
+ def triton_sigmoid(x):
95
+ """
96
+ Compute sigmoid function: 1 / (1 + exp(-x))
97
+ """
98
+ return 1.0 / (1.0 + tl.exp(-x))
99
+
100
+ def two_mm_kernel_configs_wrapper():
101
+ if torch.cuda.get_device_capability() == (12, 0):
102
+ def two_mm_kernel_configs():
103
+ configs = []
104
+ for BLOCK_M in [16, 32]:
105
+ for BLOCK_N in [16, 32, 64]:
106
+ for BLOCK_K in [16, 32, 64]:
107
+ for num_stages in [2, 3]:
108
+ configs.append(triton.Config({
109
+ 'BLOCK_M': BLOCK_M,
110
+ 'BLOCK_N': BLOCK_N,
111
+ 'BLOCK_K': BLOCK_K,
112
+ 'GROUP_SIZE_M': 8
113
+ }, num_stages=num_stages, num_warps=8))
114
+ return configs
115
+
116
+ elif torch.cuda.get_device_capability()[0] == 9:
117
+ def get_optimal_two_mm_config_h100(B, seq_len, dim):
118
+ configs = {
119
+ (1, 128, 128): (128, 64, 128, 2, 8),
120
+ (1, 128, 256): (128, 64, 128, 2, 8),
121
+ (1, 128, 384): (128, 64, 64, 3, 8),
122
+ (1, 128, 512): (128, 64, 64, 3, 8),
123
+ (1, 128, 768): (128, 64, 64, 3, 8),
124
+ (1, 128, 1024): (128, 64, 64, 3, 8),
125
+ (1, 256, 128): (128, 64, 128, 2, 8),
126
+ (1, 256, 256): (128, 64, 128, 2, 8),
127
+ (1, 256, 384): (128, 64, 64, 3, 8),
128
+ (1, 256, 512): (128, 64, 64, 3, 8),
129
+ (1, 256, 768): (128, 64, 64, 3, 8),
130
+ (1, 256, 1024): (128, 64, 64, 3, 8),
131
+ (1, 512, 128): (128, 64, 128, 2, 8),
132
+ (1, 512, 256): (128, 64, 128, 2, 8),
133
+ (1, 512, 384): (128, 64, 128, 2, 8),
134
+ (1, 512, 512): (128, 64, 128, 2, 8),
135
+ (1, 512, 768): (128, 64, 64, 3, 8),
136
+ (1, 512, 1024): (128, 64, 64, 3, 8),
137
+ (1, 1024, 128): (128, 64, 128, 2, 8),
138
+ (1, 1024, 256): (128, 64, 64, 2, 8),
139
+ (1, 1024, 384): (128, 64, 128, 2, 8),
140
+ (1, 1024, 512): (128, 64, 128, 2, 8),
141
+ (1, 1024, 768): (128, 64, 128, 2, 8),
142
+ (1, 1024, 1024): (128, 64, 128, 2, 8),
143
+ (2, 128, 128): (128, 64, 128, 2, 8),
144
+ (2, 128, 256): (128, 64, 128, 2, 8),
145
+ (2, 128, 384): (128, 64, 64, 3, 8),
146
+ (2, 128, 512): (128, 64, 64, 3, 8),
147
+ (2, 128, 768): (128, 64, 64, 3, 8),
148
+ (2, 128, 1024): (128, 64, 64, 3, 8),
149
+ (2, 256, 128): (128, 64, 128, 2, 8),
150
+ (2, 256, 256): (128, 64, 128, 2, 8),
151
+ (2, 256, 384): (128, 64, 128, 2, 8),
152
+ (2, 256, 512): (128, 64, 128, 2, 8),
153
+ (2, 256, 768): (128, 64, 64, 3, 8),
154
+ (2, 256, 1024): (128, 64, 64, 3, 8),
155
+ (2, 512, 128): (128, 64, 128, 2, 8),
156
+ (2, 512, 256): (128, 64, 128, 2, 8),
157
+ (2, 512, 384): (128, 64, 128, 2, 8),
158
+ (2, 512, 512): (128, 64, 128, 2, 8),
159
+ (2, 512, 768): (128, 64, 128, 2, 8),
160
+ (2, 512, 1024): (128, 64, 128, 2, 8),
161
+ (2, 1024, 128): (128, 64, 128, 2, 8),
162
+ (2, 1024, 256): (128, 64, 128, 2, 8),
163
+ (2, 1024, 384): (128, 64, 128, 2, 8),
164
+ (2, 1024, 512): (128, 64, 128, 2, 8),
165
+ (2, 1024, 768): (128, 64, 128, 2, 8),
166
+ (2, 1024, 1024): (128, 64, 128, 2, 8),
167
+ }
168
+ return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
169
+
170
+ def two_mm_kernel_configs():
171
+ # This function is kept for compatibility but will be overridden for H100
172
+ return [
173
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
174
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
175
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
176
+ ]
177
+
178
+ elif torch.cuda.get_device_capability()[0] == 10 and False:
179
+ def get_optimal_two_mm_config(B, seq_len, dim):
180
+ configs = {
181
+ (1, 128, 128): (64, 128, 64, 2, 8),
182
+ (1, 128, 256): (128, 64, 128, 2, 8),
183
+ (1, 128, 384): (128, 64, 128, 2, 8),
184
+ (1, 128, 512): (128, 64, 128, 2, 8),
185
+ (1, 128, 768): (128, 64, 64, 3, 8),
186
+ (1, 128, 1024): (128, 64, 64, 3, 8),
187
+ (1, 256, 128): (128, 64, 128, 2, 8),
188
+ (1, 256, 256): (128, 64, 128, 2, 8),
189
+ (1, 256, 384): (128, 64, 128, 2, 8),
190
+ (1, 256, 512): (128, 64, 64, 3, 8),
191
+ (1, 256, 768): (128, 64, 64, 3, 8),
192
+ (1, 256, 1024): (128, 64, 64, 3, 8),
193
+ (1, 512, 128): (128, 64, 128, 2, 8),
194
+ (1, 512, 256): (128, 64, 128, 2, 8),
195
+ (1, 512, 384): (128, 64, 128, 2, 8),
196
+ (1, 512, 512): (128, 64, 128, 2, 8),
197
+ (1, 512, 768): (128, 64, 64, 3, 8),
198
+ (1, 512, 1024): (128, 64, 64, 3, 8),
199
+ (1, 1024, 128): (128, 64, 128, 2, 8),
200
+ (1, 1024, 256): (128, 64, 128, 2, 8),
201
+ (1, 1024, 384): (128, 64, 128, 2, 8),
202
+ (1, 1024, 512): (128, 64, 128, 2, 8),
203
+ (1, 1024, 768): (128, 64, 64, 3, 8),
204
+ (1, 1024, 1024): (128, 64, 64, 3, 8),
205
+ (2, 128, 128): (128, 64, 128, 2, 8),
206
+ (2, 128, 256): (128, 64, 128, 2, 8),
207
+ (2, 128, 384): (128, 64, 128, 2, 8),
208
+ (2, 128, 512): (128, 64, 64, 3, 8),
209
+ (2, 128, 768): (128, 64, 64, 3, 8),
210
+ (2, 128, 1024): (128, 64, 64, 3, 8),
211
+ (2, 256, 128): (128, 64, 128, 2, 8),
212
+ (2, 256, 256): (128, 64, 128, 2, 8),
213
+ (2, 256, 384): (128, 64, 128, 2, 8),
214
+ (2, 256, 512): (128, 64, 64, 3, 8),
215
+ (2, 256, 768): (128, 64, 64, 3, 8),
216
+ (2, 256, 1024): (128, 64, 64, 3, 8),
217
+ (2, 512, 128): (128, 64, 128, 2, 8),
218
+ (2, 512, 256): (128, 64, 128, 2, 8),
219
+ (2, 512, 384): (128, 64, 128, 2, 8),
220
+ (2, 512, 512): (128, 64, 128, 2, 8),
221
+ (2, 512, 768): (128, 64, 64, 3, 8),
222
+ (2, 512, 1024): (128, 64, 64, 3, 8),
223
+ (2, 1024, 128): (128, 64, 128, 2, 8),
224
+ (2, 1024, 256): (128, 64, 128, 2, 8),
225
+ (2, 1024, 384): (128, 64, 128, 2, 8),
226
+ (2, 1024, 512): (128, 64, 128, 2, 8),
227
+ (2, 1024, 768): (128, 64, 64, 3, 8),
228
+ (2, 1024, 1024): (128, 64, 64, 3, 8),
229
+ }
230
+ return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
231
+
232
+ def two_mm_kernel_configs():
233
+ # This function is kept for compatibility but will be overridden
234
+ return [
235
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
236
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
237
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
238
+ ]
239
+ elif torch.cuda.get_device_capability()[0] == 8:
240
+ # A100
241
+ def two_mm_kernel_configs():
242
+ configs = []
243
+ for BLOCK_M in [64]:
244
+ for BLOCK_N in [64, 128]:
245
+ for BLOCK_K in [16]:
246
+ for num_stages in [3, 4]:
247
+ for num_warps in [4, 8]:
248
+ configs.append(triton.Config({
249
+ 'BLOCK_M': BLOCK_M,
250
+ 'BLOCK_N': BLOCK_N,
251
+ 'BLOCK_K': BLOCK_K,
252
+ 'GROUP_SIZE_M': 8
253
+ }, num_stages=num_stages, num_warps=num_warps))
254
+ return configs
255
+ else:
256
+ def two_mm_kernel_configs():
257
+ configs = []
258
+ for BLOCK_M in [64, 128]:
259
+ for BLOCK_N in [64, 128]:
260
+ for BLOCK_K in [64, 128]:
261
+ for num_stages in [2, 3]:
262
+ configs.append(triton.Config({
263
+ 'BLOCK_M': BLOCK_M,
264
+ 'BLOCK_N': BLOCK_N,
265
+ 'BLOCK_K': BLOCK_K,
266
+ 'GROUP_SIZE_M': 8
267
+ }, num_stages=num_stages, num_warps=8))
268
+ return configs
269
+
270
+ return two_mm_kernel_configs
271
+
272
+ def two_mm_kernel_wrapper():
273
+ if torch.cuda.get_device_capability()[0] == 8:
274
+ @triton.jit
275
+ def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
276
+ # Persistent kernel using standard tl.load operations
277
+ start_pid = tl.program_id(axis=0)
278
+ num_pid_m = tl.cdiv(M, BLOCK_M)
279
+ num_pid_n = tl.cdiv(N, BLOCK_N)
280
+ k_tiles = tl.cdiv(K, BLOCK_K)
281
+ num_tiles = num_pid_m * num_pid_n
282
+
283
+ # tile_id_c is used in the epilogue to break the dependency between
284
+ # the prologue and the epilogue
285
+ tile_id_c = start_pid - NUM_SMS
286
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
287
+
288
+ # Persistent loop over tiles
289
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
290
+ # Calculate PID for this tile using improved swizzling
291
+ group_id = tile_id // num_pid_in_group
292
+ first_pid_m = group_id * GROUP_SIZE_M
293
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
294
+ pid_m = first_pid_m + (tile_id % group_size_m)
295
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
296
+
297
+ # Calculate block offsets
298
+ offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
299
+ offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
300
+ offs_k = tl.arange(0, BLOCK_K)
301
+
302
+ # Initialize accumulators for all outputs
303
+ accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
304
+ accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
305
+ accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
306
+ accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
307
+ accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
308
+
309
+ # Main computation loop over K dimension
310
+ for ki in range(k_tiles):
311
+ k_start = ki * BLOCK_K
312
+ k_offsets = k_start + offs_k
313
+
314
+ # Create pointers for A matrix (2D flattened view)
315
+ a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3
316
+ a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K)
317
+
318
+ # Create pointers for B matrices [N, K] layout
319
+ b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
320
+ b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
321
+ b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
322
+ b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
323
+ b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
324
+ b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K)
325
+
326
+ # Load blocks from A and all weight matrices using standard tl.load
327
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
328
+ b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0)
329
+ b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0)
330
+ b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0)
331
+ b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0)
332
+ b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0)
333
+
334
+ # Perform matrix multiplications using TF32
335
+ accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
336
+ accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
337
+ accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
338
+ accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
339
+ accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
340
+
341
+ # Store results using separate tile_id_c for epilogue
342
+ tile_id_c += NUM_SMS
343
+ group_id = tile_id_c // num_pid_in_group
344
+ first_pid_m = group_id * GROUP_SIZE_M
345
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346
+ pid_m = first_pid_m + (tile_id_c % group_size_m)
347
+ pid_n = (tile_id_c % num_pid_in_group) // group_size_m
348
+
349
+ # Calculate output offsets and pointers
350
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
351
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
352
+
353
+ # Create masks for bounds checking
354
+ d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
355
+
356
+ # Calculate pointer addresses using 4D strides
357
+ stride_cm = stride_c2 # Stride to next element in flattened M dimension
358
+ stride_cn = stride_c3 # N is the innermost dimension
359
+
360
+ # For D tensor: use separate D strides
361
+ stride_dm = stride_d2 # Stride to next element in flattened M dimension
362
+ stride_dn = stride_d3 # N is the innermost dimension
363
+
364
+ off_c_batch = offs_cm // (seq_len * seq_len)
365
+ off_c_sl1 = (offs_cm // seq_len) % seq_len
366
+ off_c_sl2 = offs_cm % seq_len
367
+ off_c_dim = offs_cn
368
+
369
+ c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
370
+ c_mask = d_mask
371
+
372
+ c1_ptrs = c1_ptr + c_offsets
373
+ c2_ptrs = c2_ptr + c_offsets
374
+ d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
375
+
376
+ mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
377
+
378
+ # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
379
+ mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
380
+ # Apply masking only to left_proj and right_proj results (C1, C2)
381
+ accumulator1 = tl.where(mask_2d, accumulator1, 0)
382
+ accumulator2 = tl.where(mask_2d, accumulator2, 0)
383
+
384
+ # Apply sigmoid to gate values
385
+ left_gate_sigmoid = triton_sigmoid(accumulator3)
386
+ right_gate_sigmoid = triton_sigmoid(accumulator4)
387
+ accumulator_d = triton_sigmoid(accumulator_d)
388
+
389
+ # Apply elementwise multiplication with gated values
390
+ # C1 = left * left_gate, C2 = right * right_gate
391
+ accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
392
+ accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
393
+
394
+ # Convert to appropriate output dtype and store with normal tl.store
395
+ c1 = accumulator1.to(c1_ptr.dtype.element_ty)
396
+ c2 = accumulator2.to(c2_ptr.dtype.element_ty)
397
+ d = accumulator_d.to(d_ptr.dtype.element_ty)
398
+
399
+ tl.store(c1_ptrs, c1, mask=c_mask)
400
+ tl.store(c2_ptrs, c2, mask=c_mask)
401
+ tl.store(d_ptrs, d, mask=d_mask)
402
+ else:
403
+ @triton.jit
404
+ def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
405
+ # Persistent kernel using on-device TMA descriptors
406
+ start_pid = tl.program_id(axis=0)
407
+ num_pid_m = tl.cdiv(M, BLOCK_M)
408
+ num_pid_n = tl.cdiv(N, BLOCK_N)
409
+ k_tiles = tl.cdiv(K, BLOCK_K)
410
+ num_tiles = num_pid_m * num_pid_n
411
+
412
+ # Create on-device TMA descriptors
413
+ a_desc = tl._experimental_make_tensor_descriptor(
414
+ a_ptr,
415
+ shape=[M, K],
416
+ strides=[stride_a2, stride_a3],
417
+ block_shape=[BLOCK_M, BLOCK_K],
418
+ )
419
+ b1_desc = tl._experimental_make_tensor_descriptor(
420
+ b1_ptr,
421
+ shape=[N, K],
422
+ strides=[stride_bn, stride_bk],
423
+ block_shape=[BLOCK_N, BLOCK_K],
424
+ )
425
+ b2_desc = tl._experimental_make_tensor_descriptor(
426
+ b2_ptr,
427
+ shape=[N, K],
428
+ strides=[stride_bn, stride_bk],
429
+ block_shape=[BLOCK_N, BLOCK_K],
430
+ )
431
+ b3_desc = tl._experimental_make_tensor_descriptor(
432
+ b3_ptr,
433
+ shape=[N, K],
434
+ strides=[stride_bn, stride_bk],
435
+ block_shape=[BLOCK_N, BLOCK_K],
436
+ )
437
+ b4_desc = tl._experimental_make_tensor_descriptor(
438
+ b4_ptr,
439
+ shape=[N, K],
440
+ strides=[stride_bn, stride_bk],
441
+ block_shape=[BLOCK_N, BLOCK_K],
442
+ )
443
+ b5_desc = tl._experimental_make_tensor_descriptor(
444
+ b5_ptr,
445
+ shape=[N, K],
446
+ strides=[stride_bn, stride_bk],
447
+ block_shape=[BLOCK_N, BLOCK_K],
448
+ )
449
+
450
+ # tile_id_c is used in the epilogue to break the dependency between
451
+ # the prologue and the epilogue
452
+ tile_id_c = start_pid - NUM_SMS
453
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
454
+
455
+ # Persistent loop over tiles
456
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
457
+ # Calculate PID for this tile using improved swizzling
458
+ group_id = tile_id // num_pid_in_group
459
+ first_pid_m = group_id * GROUP_SIZE_M
460
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
461
+ pid_m = first_pid_m + (tile_id % group_size_m)
462
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
463
+
464
+ # Calculate block offsets
465
+ offs_am = pid_m * BLOCK_M
466
+ offs_bn = pid_n * BLOCK_N
467
+
468
+ # Initialize accumulators for all outputs
469
+ accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
470
+ accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
471
+ accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
472
+ accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
473
+ accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
474
+
475
+ # Main computation loop over K dimension
476
+ for ki in range(k_tiles):
477
+ offs_k = ki * BLOCK_K
478
+ # Load blocks from A and all weight matrices using on-device TMA
479
+ a = a_desc.load([offs_am, offs_k])
480
+ b1 = b1_desc.load([offs_bn, offs_k])
481
+ b2 = b2_desc.load([offs_bn, offs_k])
482
+ b3 = b3_desc.load([offs_bn, offs_k])
483
+ b4 = b4_desc.load([offs_bn, offs_k])
484
+ b5 = b5_desc.load([offs_bn, offs_k])
485
+
486
+ # Perform matrix multiplications using TF32
487
+ accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
488
+ accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
489
+ accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
490
+ accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
491
+ accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
492
+
493
+ # Store results using separate tile_id_c for epilogue
494
+ tile_id_c += NUM_SMS
495
+ group_id = tile_id_c // num_pid_in_group
496
+ first_pid_m = group_id * GROUP_SIZE_M
497
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
498
+ pid_m = first_pid_m + (tile_id_c % group_size_m)
499
+ pid_n = (tile_id_c % num_pid_in_group) // group_size_m
500
+
501
+ # Calculate output offsets and pointers
502
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
503
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
504
+
505
+ # Create masks for bounds checking
506
+ d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
507
+
508
+ # Calculate pointer addresses using 4D strides
509
+ # For C tensors: compute effective 2D strides from 4D strides
510
+ # Output tensor is [B, I, J, N], flattened to [M, N] where M = B*I*J
511
+ stride_cm = stride_c2 # Stride to next element in flattened M dimension
512
+ stride_cn = stride_c3 # N is the innermost dimension
513
+
514
+ # For D tensor: use separate D strides
515
+ stride_dm = stride_d2 # Stride to next element in flattened M dimension
516
+ stride_dn = stride_d3 # N is the innermost dimension
517
+
518
+ off_c_batch = offs_cm // (seq_len * seq_len)
519
+ off_c_sl1 = (offs_cm // seq_len) % seq_len
520
+ off_c_sl2 = offs_cm % seq_len
521
+ off_c_dim = offs_cn
522
+
523
+ # TODO update the mask_c so we don't IMA
524
+ c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
525
+ # c_offsets = offs_cm[:, None] * stride_c2 + offs_cn[None, :] * stride_c3
526
+ c_mask = d_mask
527
+
528
+ c1_ptrs = c1_ptr + c_offsets
529
+ c2_ptrs = c2_ptr + c_offsets
530
+ # c1_ptrs = c1_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
531
+ # c2_ptrs = c2_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
532
+ d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
533
+
534
+ mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
535
+
536
+ # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
537
+ mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
538
+ # Apply masking only to left_proj and right_proj results (C1, C2)
539
+ accumulator1 = tl.where(mask_2d, accumulator1, 0)
540
+ accumulator2 = tl.where(mask_2d, accumulator2, 0)
541
+
542
+ # Apply sigmoid to gate values
543
+ left_gate_sigmoid = triton_sigmoid(accumulator3)
544
+ right_gate_sigmoid = triton_sigmoid(accumulator4)
545
+ accumulator_d = triton_sigmoid(accumulator_d)
546
+
547
+ # Apply elementwise multiplication with gated values
548
+ # C1 = left * left_gate, C2 = right * right_gate
549
+ accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
550
+ accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
551
+
552
+ # Convert to appropriate output dtype and store with normal tl.store
553
+ c1 = accumulator1.to(c1_ptr.dtype.element_ty)
554
+ c2 = accumulator2.to(c2_ptr.dtype.element_ty)
555
+ d = accumulator_d.to(d_ptr.dtype.element_ty)
556
+
557
+ tl.store(c1_ptrs, c1, mask=c_mask)
558
+ tl.store(c2_ptrs, c2, mask=c_mask)
559
+ tl.store(d_ptrs, d, mask=d_mask)
560
+
561
+
562
+ if torch.cuda.get_device_capability()[0] not in [9, 10.2]:
563
+ two_mm_kernel = triton.autotune(
564
+ (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"]
565
+ )(two_mm_kernel)
566
+
567
+ return two_mm_kernel
568
+
569
+
570
+ def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask):
571
+ """
572
+ Persistent matrix multiplication for all weight matrices using on-device TMA descriptors.
573
+
574
+ Args:
575
+ A: [..., K] tensor (arbitrary leading dimensions)
576
+ left_proj: [N, K] matrix (will be transposed)
577
+ right_proj: [N, K] matrix (will be transposed)
578
+ left_gate: [N, K] left gate weight matrix
579
+ right_gate: [N, K] right gate weight matrix
580
+ out_gate: [N, K] output gate weight matrix
581
+ mask: mask tensor
582
+
583
+ Returns:
584
+ (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A
585
+ C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked)
586
+ C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked)
587
+ D = sigmoid(A @ out_gate.T) (unmasked)
588
+ """
589
+ # Check constraints
590
+ assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions"
591
+ assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes"
592
+
593
+ # Assert that all weight matrices have the same strides (same [N, K] shape)
594
+ assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \
595
+ "All weight matrices must have identical strides"
596
+
597
+ # Get dimensions
598
+ original_shape = A.shape[:-1] # All dimensions except the last
599
+ K = A.shape[-1]
600
+ N = left_proj.shape[0]
601
+ B, seq_len, _, _ = A.shape
602
+ dtype = A.dtype
603
+
604
+ # Flatten A to 2D for kernel processing
605
+ A_2d = A.view(-1, K) # [M, K] where M is product of all leading dims
606
+ M = A_2d.shape[0]
607
+
608
+ # Get number of streaming multiprocessors
609
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
610
+
611
+ # Launch persistent kernel with limited number of blocks
612
+ grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
613
+
614
+ # Get original 4D strides for A and output tensors
615
+ A_strides = A.stride() # (stride_0, stride_1, stride_2, stride_3)
616
+
617
+ # Create output tensors with proper 4D shape to get correct strides
618
+ output_shape = original_shape + (N,)
619
+ # C1 = torch.empty(output_shape, device=A.device, dtype=dtype)
620
+ # C2 = torch.empty(output_shape, device=A.device, dtype=dtype)
621
+ C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
622
+ C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
623
+ D = torch.empty(output_shape, device=A.device, dtype=torch.float16)
624
+
625
+ C_strides = C1.stride() # (stride_0, stride_1, stride_2, stride_3)
626
+ D_strides = D.stride() # (stride_0, stride_1, stride_2, stride_3)
627
+
628
+ # Use optimal configuration for B200/H100 or fallback to autotuning for other GPUs
629
+ if torch.cuda.get_device_capability()[0] == 10:
630
+ # Get optimal configuration for B200
631
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
632
+ grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
633
+
634
+ two_mm_kernel_wrapper()[(grid_size,)](
635
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
636
+ C1, C2, D, mask,
637
+ M, N, K,
638
+ *A_strides, # 4D strides for A
639
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
640
+ *C_strides, # 4D strides for C
641
+ seq_len,
642
+ *D_strides, # 4D strides for D
643
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
644
+ num_stages=num_stages, num_warps=num_warps
645
+ )
646
+ elif torch.cuda.get_device_capability()[0] == 9:
647
+ # Get optimal configuration for H100
648
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
649
+ grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
650
+
651
+ two_mm_kernel_wrapper()[(grid_size,)](
652
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
653
+ C1, C2, D, mask,
654
+ M, N, K,
655
+ *A_strides, # 4D strides for A
656
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
657
+ *C_strides, # 4D strides for C
658
+ seq_len,
659
+ *D_strides, # 4D strides for D
660
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
661
+ num_stages=num_stages, num_warps=num_warps
662
+ )
663
+ else:
664
+ # Use autotuning for other GPUs
665
+ two_mm_kernel_wrapper()[grid](
666
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
667
+ C1, C2, D, mask,
668
+ M, N, K,
669
+ *A_strides, # 4D strides for A
670
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
671
+ *C_strides, # 4D strides for C
672
+ seq_len,
673
+ *D_strides, # 4D strides for D
674
+ NUM_SMS=NUM_SMS
675
+ )
676
+
677
+ return C1, C2, D
678
+
679
+
680
+ def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand):
681
+ ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype))
682
+ out = ln * mul_operand
683
+ return out
684
+
685
+ '''
686
+ @triton.autotune(
687
+ [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)],
688
+ key=["R", "C"]
689
+ )
690
+ '''
691
+ @triton.jit
692
+ def layernorm_kernel_first(
693
+ X,
694
+ Y,
695
+ Weight,
696
+ Bias,
697
+ R,
698
+ C, # aka "dim"
699
+ eps,
700
+ ROW_BLOCK_SIZE: tl.constexpr,
701
+ BLOCK_SIZE: tl.constexpr,
702
+ ):
703
+ row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
704
+ cols = tl.arange(0, BLOCK_SIZE)
705
+
706
+ mask_row = row < R
707
+ mask_col = cols < C
708
+
709
+ # Simple indexing for contiguous data
710
+ x = tl.load(
711
+ X + row[:, None] * C + cols[None, :],
712
+ mask=mask_row[:, None] & mask_col[None, :],
713
+ other=0.0
714
+ ).to(tl.float32)
715
+
716
+ weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
717
+ bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
718
+
719
+ mean = tl.sum(x, axis=1) / C
720
+ diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
721
+ var = tl.sum(diff * diff, axis=1) / C
722
+ rstd = 1 / tl.sqrt(var + eps)
723
+
724
+ y_hat = (x - mean[:, None]) * rstd[:, None]
725
+ y = y_hat * weight[None, :] + bias[None, :]
726
+
727
+ tl.store(
728
+ Y + row[:, None] * C + cols[None, :],
729
+ y,
730
+ mask=mask_row[:, None] & mask_col[None, :]
731
+ )
732
+
733
+
734
+ def get_optimal_config_ln(dim):
735
+ config = None
736
+ if torch.cuda.get_device_capability()[0] == 9:
737
+ if (dim <= 256):
738
+ config = (16, 1)
739
+ elif dim <= 512:
740
+ config = (16, 2)
741
+ elif dim <= 1024:
742
+ config = (16, 4)
743
+
744
+ if not config:
745
+ config = (16, 4)
746
+ return config
747
+
748
+
749
+ def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None):
750
+ B, seq_len, seq_len2, dim = x.shape
751
+ assert(seq_len == seq_len2)
752
+
753
+ R = B * seq_len * seq_len
754
+ C = dim
755
+
756
+ out = torch.empty_like(x, dtype=torch.float16)
757
+
758
+ if not num_warps or not ROW_BLOCK_SIZE:
759
+ ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim)
760
+
761
+ BLOCK_SIZE = triton.next_power_of_2(C)
762
+ assert(BLOCK_SIZE <= 1024)
763
+
764
+ def grid(meta):
765
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
766
+
767
+ layernorm_kernel_first[grid](
768
+ x, out, weight, bias,
769
+ R, C, eps,
770
+ ROW_BLOCK_SIZE=ROW_BLOCK_SIZE,
771
+ BLOCK_SIZE=BLOCK_SIZE,
772
+ num_warps=num_warps,
773
+ num_stages=3
774
+ )
775
+
776
+ return out
777
+
778
+ '''
779
+ def triton_layernorm_first(x, weight, bias, eps=1e-5):
780
+ B, seq_len, seq_len2, dim = x.shape
781
+ assert(seq_len == seq_len2)
782
+
783
+ R = B * seq_len * seq_len
784
+ C = dim
785
+
786
+ out = torch.empty_like(x)
787
+
788
+ BLOCK_SIZE = triton.next_power_of_2(C)
789
+ assert(BLOCK_SIZE <= 1024)
790
+
791
+ def grid(meta):
792
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
793
+
794
+ layernorm_kernel_first[grid](
795
+ x, out, weight, bias,
796
+ R, C, eps,
797
+ BLOCK_SIZE=BLOCK_SIZE
798
+ )
799
+
800
+ return out
801
+ '''
802
+
803
+
804
+ @triton.autotune(
805
+ [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)],
806
+ key=[]
807
+ )
808
+ @triton.jit
809
+ def layernorm_kernel_eltwise(
810
+ X,
811
+ Y,
812
+ Weight,
813
+ Bias,
814
+ OutGate,
815
+ seq_len,
816
+ stride_batch,
817
+ stride_dim,
818
+ R,
819
+ C, # aka "dim"
820
+ eps,
821
+ ROW_BLOCK_SIZE: tl.constexpr,
822
+ BLOCK_SIZE: tl.constexpr,
823
+ ):
824
+ row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
825
+ cols = tl.arange(0, BLOCK_SIZE)
826
+
827
+ # Calculate base pointer for this batch of rows
828
+ tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0)
829
+ # batch_offset = (row // (stride_seq1 // stride_dim)) * stride_batch
830
+ batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len)
831
+ seqs_off = row % (seq_len * seq_len) # TODO is this going to prevent vectorization
832
+
833
+ off_r = batch * stride_batch + seqs_off
834
+ off_c = cols * stride_dim
835
+
836
+ mask_row = row < R
837
+ mask_col = cols < C
838
+
839
+ out_gate = tl.load(
840
+ OutGate + row[:, None] * C + cols[None, :],
841
+ mask = mask_row[:, None] & mask_col[None, :],
842
+ )
843
+
844
+ x = tl.load(
845
+ X + off_r[:, None] + off_c[None, :],
846
+ mask=mask_row[:, None] & mask_col[None, :],
847
+ other=0.0
848
+ ).to(tl.float32)
849
+
850
+ weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
851
+ bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
852
+
853
+ mean = tl.sum(x, axis=1) / C
854
+ diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
855
+ var = tl.sum(diff * diff, axis=1) / C
856
+ rstd = 1 / tl.sqrt(var + eps)
857
+
858
+ y_hat = (x - mean[:, None]) * rstd[:, None]
859
+ y = y_hat * weight[None, :] + bias[None, :]
860
+
861
+ tl.store(
862
+ Y + row[:, None] * C + cols[None, :],
863
+ y * out_gate,
864
+ mask=mask_row[:, None] & mask_col[None, :]
865
+ )
866
+
867
+
868
+ def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5):
869
+ B, seq_len, seq_len2, dim = x.shape
870
+ assert(seq_len == seq_len2)
871
+ R = B * seq_len * seq_len
872
+ assert(x.stride(3) == seq_len*seq_len)
873
+ assert(out_gate.is_contiguous())
874
+ C = dim
875
+
876
+ out = torch.empty_like(out_gate, dtype=torch.float32)
877
+
878
+ BLOCK_SIZE = triton.next_power_of_2(C)
879
+ assert(BLOCK_SIZE == 128)
880
+
881
+ def grid(meta):
882
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
883
+
884
+ layernorm_kernel_eltwise[grid](
885
+ x, out, weight, bias, out_gate,
886
+ seq_len,
887
+ x.stride(0), x.stride(3),
888
+ R, C, eps,
889
+ BLOCK_SIZE=BLOCK_SIZE
890
+ )
891
+
892
+ return out
893
+
894
+
895
+ def kernel_global(data: input_t) -> output_t:
896
+ """
897
+ Reference implementation of TriMul using PyTorch.
898
+
899
+ Args:
900
+ data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
901
+ - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
902
+ - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
903
+ - weights: Dictionary containing model weights
904
+ - config: Dictionary containing model configuration parameters
905
+ """
906
+ input_tensor, mask, weights, config = data
907
+
908
+ left_proj_weight = weights["left_proj.weight"].to(torch.float16)
909
+ right_proj_weight = weights["right_proj.weight"].to(torch.float16)
910
+ left_gate_weight = weights["left_gate.weight"].to(torch.float16)
911
+ right_gate_weight = weights["right_gate.weight"].to(torch.float16)
912
+ out_gate_weight = weights["out_gate.weight"].to(torch.float16)
913
+
914
+ hidden_dim = config["hidden_dim"]
915
+ # trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
916
+
917
+ x = input_tensor
918
+
919
+ batch_size, seq_len, _, dim = x.shape
920
+
921
+ x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias'])
922
+ # x = torch.nn.functional.layer_norm(x, (dim,), eps=1e-5, weight=weights['norm.weight'], bias=weights['norm.bias'])
923
+
924
+ left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask)
925
+ # left = torch.nn.functional.linear(x, weights['left_proj.weight'].to(torch.float16))
926
+ # right = torch.nn.functional.linear(x, weights['right_proj.weight'].to(torch.float16))
927
+
928
+ # left = left * mask.unsqueeze(-1)
929
+ # right = right * mask.unsqueeze(-1)
930
+
931
+ '''
932
+ left = left.to(torch.float32)
933
+ right = right.to(torch.float32)
934
+ x = x.to(torch.float32)
935
+
936
+ left_gate = left_gate.sigmoid()
937
+ right_gate = right_gate.sigmoid()
938
+ out_gate = out_gate.sigmoid()
939
+ '''
940
+
941
+ # Elementwise multiplication now handled in kernel
942
+ # left = left * left_gate
943
+ # right = right * right_gate
944
+
945
+ # out = einsum('... i k d, ... j k d -> ... i j d', left, right)
946
+ out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1]))
947
+ out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1)
948
+
949
+ # out = torch.compile(second_layernorm_mul, dynamic=False)(out, hidden_dim, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
950
+ out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
951
+ # out = torch.nn.functional.layer_norm(out, (hidden_dim,), eps=1e-5, weight=weights['to_out_norm.weight'].to(out.dtype), bias=weights['to_out_norm.bias'].to(out.dtype))
952
+ # out = out * out_gate
953
+ return torch.nn.functional.linear(out, weights['to_out.weight'])
954
+
955
+ '''
956
+ # Fill in the given weights of the model
957
+ trimul.norm.weight = nn.Parameter(weights['norm.weight'])
958
+ trimul.norm.bias = nn.Parameter(weights['norm.bias'])
959
+ trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
960
+ trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
961
+ trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
962
+ trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
963
+ trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
964
+ trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
965
+ trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
966
+ trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
967
+
968
+ output = trimul(input_tensor, mask)
969
+
970
+ return output
971
+ '''
build/torch-rocm/trimul_gpumode/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-rocm/trimul_mi300.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
12
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
13
+
14
+ # Configurations with larger block sizes for better data reuse
15
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
16
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=2),
17
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
18
+
19
+ # Configurations with deeper K dimension
20
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
21
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
22
+
23
+ # More extreme configurations to test the limits
24
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
25
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=2),
26
+
27
+ # Configurations with fewer warps
28
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
29
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=2),
30
+
31
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
32
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
33
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
34
+ ],
35
+ key=['M', 'N', 'K'],
36
+ )
37
+ @triton.jit
38
+ def fused_ln_dual_matmul_kernel(
39
+ # Pointers (9)
40
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
41
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
42
+ # Metadata (5)
43
+ M, H, K, s1, s2,
44
+ # Strides (16)
45
+ stride_x_m, stride_x_k,
46
+ stride_w4_k, stride_w4_n,
47
+ stride_wog_k, stride_wog_n,
48
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
49
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
50
+ stride_og_m, stride_og_h,
51
+ stride_mask_m, stride_mask_h,
52
+ # Constexpr (from decorator and kwargs)
53
+ LN_EPS: tl.constexpr,
54
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
55
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
56
+ ):
57
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
58
+ pid = tl.program_id(axis=0)
59
+ N_4way = 4 * H
60
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
62
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
63
+ group_id = pid // num_pid_in_group
64
+ first_pid_m = group_id * GROUP_SIZE_M
65
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
66
+ pid_m = first_pid_m + (pid % group_size_m)
67
+ pid_n = (pid % num_pid_in_group) // group_size_m
68
+
69
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
70
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
+ m_mask = offs_m < M
72
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
73
+
74
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
75
+ for k_offset in range(0, K, BLOCK_SIZE_K):
76
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
77
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
78
+ k_mask = (k_offset + k_chunk_offs) < K
79
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
80
+ mean += tl.sum(x_chunk, axis=1)
81
+ mean /= K
82
+
83
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
84
+ for k_offset in range(0, K, BLOCK_SIZE_K):
85
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
86
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
87
+ k_mask = (k_offset + k_chunk_offs) < K
88
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
89
+ x_centered = x_chunk - mean[:, None]
90
+ var += tl.sum(x_centered * x_centered, axis=1)
91
+ var /= K
92
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
93
+
94
+ # --- Matmul Loop 1: For the 4-Way Projections ---
95
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
96
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
97
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
98
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
99
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
+
101
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
102
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
103
+ k_block_start = k * BLOCK_SIZE_K;
104
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
105
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
106
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
107
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
108
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
109
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
110
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
111
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
112
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
113
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
114
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
115
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
117
+
118
+ #Some threads should calclate out_gate
119
+ if pid_n * BLOCK_SIZE_N < H:
120
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
121
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
122
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
123
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
124
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
125
+
126
+ if pid_n * BLOCK_SIZE_N < H:
127
+ og_out = tl.sigmoid(accumulator_og)
128
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
129
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
130
+ tl.store(outg_ptrs, og_out, mask=og_mask)
131
+
132
+ # --- Fusion Logic for 4-Way Part ---
133
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
134
+ role_idx = tl.arange(0, 4)[None, None, :]
135
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
136
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
137
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
138
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
139
+
140
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
141
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
142
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
143
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
144
+
145
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
146
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
147
+
148
+ s1s2 = s1 * s2
149
+ offs_b = offs_m // s1s2
150
+ offs_s1 = (offs_m % s1s2) // s2
151
+ offs_s2 = offs_m % s2
152
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
153
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
154
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
155
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
156
+
157
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
158
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
159
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
160
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
161
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
162
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
163
+
164
+ @triton.autotune(
165
+ configs=[
166
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
167
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
168
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
169
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
170
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
171
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
172
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
173
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
174
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
175
+ ],
176
+ key=['s1', 's2', 'H'],
177
+ )
178
+ @triton.jit
179
+ def bmm_coalesced_kernel(
180
+ # Pointers
181
+ Left_ptr, Right_ptr, Out_ptr,
182
+ # Dimensions
183
+ bs, s1, s2, H,
184
+ # Strides
185
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
186
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
187
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
188
+ # Kernel parameters
189
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
190
+ GROUP_SIZE_M: tl.constexpr,
191
+ ):
192
+ # Grid and program IDs
193
+ pid = tl.program_id(axis=0)
194
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
195
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
196
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
197
+ group_id = pid // num_pid_in_group
198
+ first_pid_m = group_id * GROUP_SIZE_M
199
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
200
+ pid_m = first_pid_m + (pid % group_size_m)
201
+ pid_n = (pid % num_pid_in_group) // group_size_m
202
+
203
+ pid_bh = tl.program_id(axis=1)
204
+ pid_b = pid_bh // H
205
+ pid_h = pid_bh % H
206
+
207
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
208
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
209
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
210
+
211
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
212
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
213
+
214
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
215
+
216
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
217
+ k_start = k * BLOCK_SIZE_K
218
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
219
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
220
+
221
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
222
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
223
+
224
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
225
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
226
+
227
+ accumulator += tl.dot(a, b)
228
+
229
+ # --- Coalesced Write ---
230
+ # Write to a standard (bs, H, s1, s1) layout
231
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
232
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
233
+
234
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
235
+ tl.store(out_ptrs, accumulator, mask=c_mask)
236
+
237
+ @triton.autotune(
238
+ configs=[
239
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
240
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
241
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
243
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
244
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
245
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
247
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
248
+ ],
249
+ key=['H', 'D'],
250
+ )
251
+ @triton.jit
252
+ def fused_final_kernel(
253
+ # Pointers
254
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
255
+ # Metadata
256
+ M, H, D, s1, # M_gate = bs*s1*s2
257
+ # Strides
258
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
259
+ stride_gate_m, stride_gate_h,
260
+ stride_proj_d, stride_proj_h,
261
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
262
+ # Constants
263
+ LN_EPS: tl.constexpr,
264
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
265
+ GROUP_SIZE_M: tl.constexpr,
266
+ ):
267
+ # --- Grid and PID Setup for Matmul ---
268
+ pid = tl.program_id(axis=0)
269
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
270
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
271
+
272
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
273
+ group_id = pid // num_pid_in_group
274
+ first_pid_m = group_id * GROUP_SIZE_M
275
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
276
+ pid_m = first_pid_m + (pid % group_size_m)
277
+ pid_n = (pid % num_pid_in_group) // group_size_m
278
+
279
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
280
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
281
+ m_mask = offs_m < M
282
+
283
+ # Decompose M back to (b, r, c) for reordering lookups
284
+ s1s1 = s1 * s1
285
+ b = offs_m // s1s1
286
+ r = (offs_m % s1s1) // s1
287
+ c = offs_m % s1
288
+
289
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
291
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
292
+
293
+ for k_offset in range(0, H, BLOCK_SIZE_K):
294
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
295
+ k_mask = offs_k < H
296
+
297
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
298
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
299
+
300
+ # Accumulate sum and sum of squares in one pass
301
+ sum_x += tl.sum(in_chunk, axis=1)
302
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
303
+
304
+ # Finalize statistics
305
+ mean = sum_x / H
306
+ var = (sum_x2 / H) - (mean * mean)
307
+ rstd = tl.math.rsqrt(var + LN_EPS)
308
+
309
+ # --- Pass 3: Fused Gating and Matmul ---
310
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
311
+ for k_offset in range(0, H, BLOCK_SIZE_K):
312
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
313
+ k_mask = offs_k < H
314
+
315
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
316
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
317
+ a_norm = (a - mean[:, None]) * rstd[:, None]
318
+
319
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
320
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
321
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
322
+
323
+ proj_ptrs = ProjW_ptr + \
324
+ offs_n[None, :] * stride_proj_d + \
325
+ offs_k[:, None] * stride_proj_h
326
+
327
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
328
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
329
+ a_gated = a_norm * gate
330
+
331
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
332
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
333
+
334
+ # --- Store Final Output ---
335
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
336
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
337
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
338
+
339
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
340
+
341
+ def compiledtrimul_fused_interleaved(
342
+ x: torch.Tensor,
343
+ mask_mh: torch.Tensor,
344
+ norm_weight: torch.Tensor,
345
+ norm_bias: torch.Tensor,
346
+ W_4way: torch.Tensor, # Use the new weight matrices
347
+ W_og: torch.Tensor,
348
+ to_out_norm_weight: torch.Tensor,
349
+ to_out_norm_bias: torch.Tensor,
350
+ to_out_weight: torch.Tensor,
351
+ h: int,
352
+ ):
353
+ bs, s1, s2, d = x.shape
354
+ M, K, H = bs * s1 * s2, x.shape[-1], h
355
+ x_flat = x.view(M, K)
356
+
357
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
358
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
359
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
360
+
361
+ # The grid is launched for the larger 4*H problem
362
+ N_4way = 4 * H
363
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
364
+ fused_ln_dual_matmul_kernel[grid](
365
+ # Pointers (9)
366
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
367
+ left_final, right_final_t, og_mh,
368
+ # Metadata (5) - M, H, K, s1, s2
369
+ M, H, K, s1, s2,
370
+ # Strides (16)
371
+ x_flat.stride(0), x_flat.stride(1),
372
+ W_4way.stride(0), W_4way.stride(1),
373
+ W_og.stride(0), W_og.stride(1),
374
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
375
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
376
+ og_mh.stride(0), og_mh.stride(1),
377
+ mask_mh.stride(0), mask_mh.stride(1),
378
+ # Constexpr (1)
379
+ LN_EPS=1e-5
380
+ )
381
+
382
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
383
+
384
+ grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
385
+ bmm_coalesced_kernel[grid_bmm](
386
+ left_final, right_final_t, bmm_out_tmp,
387
+ bs, s1, s2, H,
388
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
389
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
390
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
391
+ )
392
+
393
+ # --- Kernel 3: Fully Fused Final Stage ---
394
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
395
+
396
+ grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
397
+ fused_final_kernel[grid_final](
398
+ # Pointers
399
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
400
+ # Metadata
401
+ M, H, d, s1,
402
+ # Strides
403
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
404
+ og_mh.stride(0), og_mh.stride(1),
405
+ to_out_weight.stride(0), to_out_weight.stride(1), # Use strides of the corrected tensor
406
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
407
+ # Constants
408
+ LN_EPS=1e-5,
409
+ )
410
+
411
+ return final_out
412
+
413
+ def pack_w_4way_efficient(weights):
414
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
415
+ WL = weights['left_proj.weight']
416
+ WLG = weights['left_gate.weight']
417
+ WR = weights['right_proj.weight']
418
+ WRG = weights['right_gate.weight']
419
+ H, K = WL.shape
420
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
421
+ ws = ws.contiguous().view(4 * H, K)
422
+ return ws.t().to(torch.float16)
423
+
424
+ def get_w_og(weights):
425
+ """ Gets the transposed [K, H] out_gate weight matrix. """
426
+ WOG = weights['out_gate.weight']
427
+ return WOG.t().to(torch.float16)
428
+
429
+ def compiledtrimul(
430
+ x: torch.Tensor,
431
+ mask: torch.Tensor,
432
+ norm_weight: torch.Tensor,
433
+ norm_bias: torch.Tensor,
434
+ w_concat: torch.Tensor,
435
+ to_out_norm_weight: torch.Tensor,
436
+ to_out_norm_bias: torch.Tensor,
437
+ to_out_weight: torch.Tensor,
438
+ h: int
439
+ ) -> torch.Tensor:
440
+ """
441
+ A barebones, compiled PyTorch function for the TriMul logic.
442
+ """
443
+ bs, s1, s2, d = x.shape
444
+
445
+ # Initial LayerNorm
446
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
447
+ # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
448
+ all_projections = torch.mm(x_norm, w_concat)
449
+
450
+ # Split back into individual projections
451
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
452
+
453
+ # Apply mask and gates
454
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
455
+ left = left * mask_expanded * torch.sigmoid(lg)
456
+ right = right * mask_expanded * torch.sigmoid(rg)
457
+ out_gate = torch.sigmoid(og)
458
+
459
+ # Reshape for einsum
460
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
461
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
462
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
463
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
464
+
465
+ # Apply layer norm and final gating
466
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
467
+ gated = normed * out_gate
468
+
469
+ # Final projection
470
+ final_out_flat = gated @ to_out_weight.t()
471
+ final_out = final_out_flat.view(bs, s1, s2, d)
472
+
473
+ return final_out
474
+
475
+ def small_kernel_pt_path(data):
476
+ input_tensor, mask, weights, config = data
477
+ w_concat = torch.cat([
478
+ weights['left_proj.weight'],
479
+ weights['right_proj.weight'],
480
+ weights['left_gate.weight'],
481
+ weights['right_gate.weight'],
482
+ weights['out_gate.weight']
483
+ ], dim=0).t().contiguous().to(torch.float16)
484
+ # Call the compiled function with prepared weights
485
+ output = compiledtrimul(
486
+ x=input_tensor.to(torch.float32),
487
+ mask=mask.unsqueeze(-1),
488
+ norm_weight=weights['norm.weight'].to(torch.float32),
489
+ norm_bias=weights['norm.bias'].to(torch.float32),
490
+ w_concat=w_concat,
491
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
492
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
493
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
494
+ h=config["hidden_dim"]
495
+ )
496
+ return output
497
+
498
+ def kernel_mi300(data):
499
+ input_tensor, mask, weights, config = data
500
+ bs, s1, s2, d = input_tensor.shape
501
+
502
+ if s1 < 100:
503
+ return small_kernel_pt_path(data)
504
+
505
+ H = config["hidden_dim"]
506
+
507
+ W_4way = pack_w_4way_efficient(weights)
508
+ W_og = get_w_og(weights)
509
+
510
+ M = bs * s1 * s2
511
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
512
+
513
+ return compiledtrimul_fused_interleaved(
514
+ x=input_tensor.to(torch.float32),
515
+ mask_mh=mask_mh,
516
+ norm_weight=weights['norm.weight'].to(torch.float32),
517
+ norm_bias=weights['norm.bias'].to(torch.float32),
518
+ W_4way=W_4way, # Pass the new 4-way matrix
519
+ W_og=W_og, # Pass the new out_gate matrix
520
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
521
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
522
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
523
+ h=H,
524
+ )
build/torch-rocm/triton_a100.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ # Set PyTorch flags for performance
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
9
+
10
+ @triton.jit
11
+ def fused_ln_dual_matmul_kernel(
12
+ # Pointers (9)
13
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
14
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
15
+ # Metadata (5)
16
+ M, H, K, s1, s2,
17
+ # Strides (16)
18
+ stride_x_m, stride_x_k,
19
+ stride_w4_k, stride_w4_n,
20
+ stride_wog_k, stride_wog_n,
21
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
22
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
23
+ stride_og_m, stride_og_h,
24
+ stride_mask_m, stride_mask_h,
25
+ # Constexpr (now passed as arguments from the host)
26
+ LN_EPS: tl.constexpr,
27
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
29
+ ):
30
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
31
+ pid = tl.program_id(axis=0)
32
+ N_4way = 4 * H
33
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
35
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
+ group_id = pid // num_pid_in_group
37
+ first_pid_m = group_id * GROUP_SIZE_M
38
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
+ pid_m = first_pid_m + (pid % group_size_m)
40
+ pid_n = (pid % num_pid_in_group) // group_size_m
41
+
42
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
43
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44
+ m_mask = offs_m < M
45
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
46
+
47
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
48
+ for k_offset in range(0, K, BLOCK_SIZE_K):
49
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
50
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
51
+ k_mask = (k_offset + k_chunk_offs) < K
52
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
53
+ mean += tl.sum(x_chunk, axis=1)
54
+ mean /= K
55
+
56
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
57
+ for k_offset in range(0, K, BLOCK_SIZE_K):
58
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
59
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
60
+ k_mask = (k_offset + k_chunk_offs) < K
61
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
62
+ x_centered = x_chunk - mean[:, None]
63
+ var += tl.sum(x_centered * x_centered, axis=1)
64
+ var /= K
65
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
66
+
67
+ # --- Matmul Loop 1: For the 4-Way Projections ---
68
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
69
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
70
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
71
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+
74
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76
+ k_block_start = k * BLOCK_SIZE_K;
77
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
78
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
79
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
80
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
81
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
82
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
83
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
84
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
86
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
87
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
88
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
89
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
90
+
91
+ if pid_n * BLOCK_SIZE_N < H:
92
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
97
+
98
+ if pid_n * BLOCK_SIZE_N < H:
99
+ og_out = tl.sigmoid(accumulator_og)
100
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
+ tl.store(outg_ptrs, og_out, mask=og_mask)
103
+
104
+ # --- Fusion Logic for 4-Way Part ---
105
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
+ role_idx = tl.arange(0, 4)[None, None, :]
107
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
+
112
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
+
117
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
+
120
+ s1s2 = s1 * s2
121
+ offs_b = offs_m // s1s2
122
+ offs_s1 = (offs_m % s1s2) // s2
123
+ offs_s2 = offs_m % s2
124
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
+
129
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
+
136
+ @triton.jit
137
+ def bmm_coalesced_kernel(
138
+ # Pointers
139
+ Left_ptr, Right_ptr, Out_ptr,
140
+ # Dimensions
141
+ bs, s1, s2, H,
142
+ # Strides
143
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
+ # Kernel parameters
147
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
+ GROUP_SIZE_M: tl.constexpr,
149
+ ):
150
+ pid = tl.program_id(axis=0)
151
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
152
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
153
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
154
+ group_id = pid // num_pid_in_group
155
+ first_pid_m = group_id * GROUP_SIZE_M
156
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
157
+ pid_m = first_pid_m + (pid % group_size_m)
158
+ pid_n = (pid % num_pid_in_group) // group_size_m
159
+
160
+ pid_bh = tl.program_id(axis=1)
161
+ pid_b = pid_bh // H
162
+ pid_h = pid_bh % H
163
+
164
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
165
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
166
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
167
+
168
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
169
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
170
+
171
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
172
+
173
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
174
+ k_start = k * BLOCK_SIZE_K
175
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
176
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
177
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
178
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
179
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
180
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
181
+ accumulator += tl.dot(a, b)
182
+
183
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
184
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
185
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
186
+ tl.store(out_ptrs, accumulator, mask=c_mask)
187
+
188
+ @triton.jit
189
+ def fused_final_kernel(
190
+ # Pointers
191
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
192
+ # Metadata
193
+ M, H, D, s1,
194
+ # Strides
195
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
196
+ stride_gate_m, stride_gate_h,
197
+ stride_proj_d, stride_proj_h,
198
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
199
+ # Constants
200
+ LN_EPS: tl.constexpr,
201
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
202
+ GROUP_SIZE_M: tl.constexpr,
203
+ ):
204
+ pid = tl.program_id(axis=0)
205
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
207
+
208
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
+ group_id = pid // num_pid_in_group
210
+ first_pid_m = group_id * GROUP_SIZE_M
211
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
+ pid_m = first_pid_m + (pid % group_size_m)
213
+ pid_n = (pid % num_pid_in_group) // group_size_m
214
+
215
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
216
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ m_mask = offs_m < M
218
+
219
+ s1s1 = s1 * s1
220
+ b = offs_m // s1s1
221
+ r = (offs_m % s1s1) // s1
222
+ c = offs_m % s1
223
+
224
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
225
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
226
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
227
+
228
+ for k_offset in range(0, H, BLOCK_SIZE_K):
229
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
230
+ k_mask = offs_k < H
231
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
232
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
233
+ sum_x += tl.sum(in_chunk, axis=1)
234
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
235
+
236
+ mean = sum_x / H
237
+ var = (sum_x2 / H) - (mean * mean)
238
+ rstd = tl.math.rsqrt(var + LN_EPS)
239
+
240
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
241
+ for k_offset in range(0, H, BLOCK_SIZE_K):
242
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
243
+ k_mask = offs_k < H
244
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
245
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
246
+ a_norm = (a - mean[:, None]) * rstd[:, None]
247
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
248
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
249
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
250
+ proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
251
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
252
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
253
+ a_gated = a_norm * gate
254
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
255
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
256
+
257
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
258
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
259
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
260
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
261
+
262
+ def compiledtrimul_fused_interleaved_final(
263
+ x: torch.Tensor,
264
+ mask_mh: torch.Tensor,
265
+ norm_weight: torch.Tensor,
266
+ norm_bias: torch.Tensor,
267
+ W_4way: torch.Tensor,
268
+ W_og: torch.Tensor,
269
+ to_out_norm_weight: torch.Tensor,
270
+ to_out_norm_bias: torch.Tensor,
271
+ to_out_weight: torch.Tensor,
272
+ h: int,
273
+ ):
274
+ bs, s1, s2, d = x.shape
275
+ M, K, H = bs * s1 * s2, x.shape[-1], h
276
+ x_flat = x.view(M, K)
277
+
278
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
279
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
280
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
281
+
282
+ # --- Kernel 1: Fused LN + Dual Matmul ---
283
+ N_4way = 4 * H
284
+ # Hardcoded A100 best config: M128-N128-K32-GM8-HC32-W8-S2
285
+ config_k1 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
286
+ grid_k1 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
287
+
288
+ fused_ln_dual_matmul_kernel[grid_k1](
289
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
290
+ left_final, right_final_t, og_mh,
291
+ M, H, K, s1, s2,
292
+ x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
293
+ W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
294
+ left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
295
+ right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
296
+ mask_mh.stride(0), mask_mh.stride(1),
297
+ LN_EPS=1e-5, **config_k1, num_warps=8, num_stages=2
298
+ )
299
+
300
+ # --- Kernel 2: Batched Matrix Multiplication ---
301
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
302
+ # Hardcoded A100 best config: M128-N64-K32-GM8-W4-S3
303
+ config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
304
+ grid_k2 = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
305
+
306
+ bmm_coalesced_kernel[grid_k2](
307
+ left_final, right_final_t, bmm_out_tmp,
308
+ bs, s1, s2, H,
309
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
310
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
311
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
312
+ **config_k2, num_warps=4, num_stages=3
313
+ )
314
+
315
+ # --- Kernel 3: Fully Fused Final Stage ---
316
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
317
+ # Hardcoded A100 best config: M32-N128-K32-GM8-W4-S3
318
+ config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
319
+ grid_k3 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
320
+
321
+ fused_final_kernel[grid_k3](
322
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
323
+ M, H, d, s1,
324
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
325
+ og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
326
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
327
+ LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
328
+ )
329
+ return final_out
330
+
331
+ def pack_w_4way_efficient(weights):
332
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
333
+ WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
334
+ H, K = WL.shape
335
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
336
+ return ws.t().to(torch.float16)
337
+
338
+ def get_w_og(weights):
339
+ """ Gets the transposed [K, H] out_gate weight matrix. """
340
+ return weights['out_gate.weight'].t().to(torch.float16)
341
+
342
+ @torch.compile()
343
+ def compiledtrimul(
344
+ x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
345
+ w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
346
+ to_out_weight: torch.Tensor, h: int
347
+ ) -> torch.Tensor:
348
+ bs, s1, s2, d = x.shape
349
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
350
+ all_projections = torch.mm(x_norm, w_concat)
351
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
352
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
353
+ left = left * mask_expanded * torch.sigmoid(lg)
354
+ right = right * mask_expanded * torch.sigmoid(rg)
355
+ out_gate = torch.sigmoid(og)
356
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
357
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
358
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
359
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
360
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
361
+ gated = normed * out_gate
362
+ final_out_flat = gated @ to_out_weight.t()
363
+ return final_out_flat.view(bs, s1, s1, d)
364
+
365
+ def small_kernel_pt_path(data):
366
+ input_tensor, mask, weights, config = data
367
+ w_concat = torch.cat([
368
+ weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
369
+ weights['right_gate.weight'], weights['out_gate.weight']
370
+ ], dim=0).t().contiguous().to(torch.float16)
371
+ return compiledtrimul(
372
+ x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
373
+ norm_weight=weights['norm.weight'].to(torch.float32),
374
+ norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
375
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
376
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
377
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
378
+ h=config["hidden_dim"]
379
+ )
380
+
381
+ def kernel_a100(data):
382
+ input_tensor, mask, weights, config = data
383
+ bs, s1, s2, d = input_tensor.shape
384
+
385
+ if s1 < 512: # Adjusted threshold based on observed BMM configs
386
+ return small_kernel_pt_path(data)
387
+
388
+ H = config["hidden_dim"]
389
+ W_4way = pack_w_4way_efficient(weights)
390
+ W_og = get_w_og(weights)
391
+ M = bs * s1 * s2
392
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
393
+
394
+ return compiledtrimul_fused_interleaved_final(
395
+ x=input_tensor.to(torch.float32),
396
+ mask_mh=mask_mh,
397
+ norm_weight=weights['norm.weight'].to(torch.float32),
398
+ norm_bias=weights['norm.bias'].to(torch.float32),
399
+ W_4way=W_4way,
400
+ W_og=W_og,
401
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
402
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
403
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
404
+ h=H,
405
+ )
build/torch-rocm/triton_b200.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.jit
10
+ def fused_ln_dual_matmul_kernel(
11
+ # Pointers (9)
12
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
13
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
14
+ # Metadata (5)
15
+ M, H, K, s1, s2,
16
+ # Strides (16)
17
+ stride_x_m, stride_x_k,
18
+ stride_w4_k, stride_w4_n,
19
+ stride_wog_k, stride_wog_n,
20
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
21
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
22
+ stride_og_m, stride_og_h,
23
+ stride_mask_m, stride_mask_h,
24
+ # Constexpr (now passed as arguments from the host)
25
+ LN_EPS: tl.constexpr,
26
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
27
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
28
+ ):
29
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
30
+ pid = tl.program_id(axis=0)
31
+ N_4way = 4 * H
32
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
34
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
35
+ group_id = pid // num_pid_in_group
36
+ first_pid_m = group_id * GROUP_SIZE_M
37
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38
+ pid_m = first_pid_m + (pid % group_size_m)
39
+ pid_n = (pid % num_pid_in_group) // group_size_m
40
+
41
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
42
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
43
+ m_mask = offs_m < M
44
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
45
+
46
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
47
+ for k_offset in range(0, K, BLOCK_SIZE_K):
48
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
49
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
50
+ k_mask = (k_offset + k_chunk_offs) < K
51
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
52
+ mean += tl.sum(x_chunk, axis=1)
53
+ mean /= K
54
+
55
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
56
+ for k_offset in range(0, K, BLOCK_SIZE_K):
57
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
58
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
59
+ k_mask = (k_offset + k_chunk_offs) < K
60
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
61
+ x_centered = x_chunk - mean[:, None]
62
+ var += tl.sum(x_centered * x_centered, axis=1)
63
+ var /= K
64
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
65
+
66
+ # --- Matmul Loop 1: For the 4-Way Projections ---
67
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
69
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
70
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
71
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
+
73
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
74
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75
+ k_block_start = k * BLOCK_SIZE_K;
76
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
77
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
78
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
79
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
80
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
81
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
82
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
83
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
84
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
86
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
87
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
88
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
89
+
90
+ #Some threads should calclate out_gate
91
+ if pid_n * BLOCK_SIZE_N < H:
92
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
97
+
98
+ if pid_n * BLOCK_SIZE_N < H:
99
+ og_out = tl.sigmoid(accumulator_og)
100
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
+ tl.store(outg_ptrs, og_out, mask=og_mask)
103
+
104
+ # --- Fusion Logic for 4-Way Part ---
105
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
+ role_idx = tl.arange(0, 4)[None, None, :]
107
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
+
112
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
+
117
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
+
120
+ s1s2 = s1 * s2
121
+ offs_b = offs_m // s1s2
122
+ offs_s1 = (offs_m % s1s2) // s2
123
+ offs_s2 = offs_m % s2
124
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
+
129
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
+
136
+ @triton.jit
137
+ def bmm_coalesced_kernel(
138
+ # Pointers
139
+ Left_ptr, Right_ptr, Out_ptr,
140
+ # Dimensions
141
+ bs, s1, s2, H,
142
+ # Strides
143
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
+ # Kernel parameters
147
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
+ GROUP_SIZE_M: tl.constexpr,
149
+ ):
150
+ # Grid and program IDs
151
+ pid = tl.program_id(axis=0)
152
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
153
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
154
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
155
+ group_id = pid // num_pid_in_group
156
+ first_pid_m = group_id * GROUP_SIZE_M
157
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
158
+ pid_m = first_pid_m + (pid % group_size_m)
159
+ pid_n = (pid % num_pid_in_group) // group_size_m
160
+
161
+ pid_bh = tl.program_id(axis=1)
162
+ pid_b = pid_bh // H
163
+ pid_h = pid_bh % H
164
+
165
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
166
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
167
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
168
+
169
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
170
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
171
+
172
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173
+
174
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
175
+ k_start = k * BLOCK_SIZE_K
176
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
177
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
178
+
179
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
180
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
181
+
182
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
183
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
184
+
185
+ accumulator += tl.dot(a, b)
186
+
187
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
188
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
189
+
190
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
191
+ tl.store(out_ptrs, accumulator, mask=c_mask)
192
+
193
+ @triton.jit
194
+ def fused_final_kernel(
195
+ # Pointers
196
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
197
+ # Metadata
198
+ M, H, D, s1,
199
+ # Strides
200
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
201
+ stride_gate_m, stride_gate_h,
202
+ stride_proj_d, stride_proj_h,
203
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
204
+ # Constants
205
+ LN_EPS: tl.constexpr,
206
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
207
+ GROUP_SIZE_M: tl.constexpr,
208
+ ):
209
+ pid = tl.program_id(axis=0)
210
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
211
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
212
+
213
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
214
+ group_id = pid // num_pid_in_group
215
+ first_pid_m = group_id * GROUP_SIZE_M
216
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217
+ pid_m = first_pid_m + (pid % group_size_m)
218
+ pid_n = (pid % num_pid_in_group) // group_size_m
219
+
220
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
221
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
222
+ m_mask = offs_m < M
223
+
224
+ s1s1 = s1 * s1
225
+ b = offs_m // s1s1
226
+ r = (offs_m % s1s1) // s1
227
+ c = offs_m % s1
228
+
229
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
230
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
231
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
232
+
233
+ for k_offset in range(0, H, BLOCK_SIZE_K):
234
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
235
+ k_mask = offs_k < H
236
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
237
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
238
+ sum_x += tl.sum(in_chunk, axis=1)
239
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
240
+
241
+ mean = sum_x / H
242
+ var = (sum_x2 / H) - (mean * mean)
243
+ rstd = tl.math.rsqrt(var + LN_EPS)
244
+
245
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
246
+ for k_offset in range(0, H, BLOCK_SIZE_K):
247
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
248
+ k_mask = offs_k < H
249
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
250
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
251
+ a_norm = (a - mean[:, None]) * rstd[:, None]
252
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
253
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
254
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
255
+ proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
256
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
257
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
258
+ a_gated = a_norm * gate
259
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
260
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
261
+
262
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
263
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
264
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
265
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
266
+
267
+ def compiledtrimul_fused_interleaved_final(
268
+ x: torch.Tensor,
269
+ mask_mh: torch.Tensor,
270
+ norm_weight: torch.Tensor,
271
+ norm_bias: torch.Tensor,
272
+ W_4way: torch.Tensor,
273
+ W_og: torch.Tensor,
274
+ to_out_norm_weight: torch.Tensor,
275
+ to_out_norm_bias: torch.Tensor,
276
+ to_out_weight: torch.Tensor,
277
+ h: int,
278
+ ):
279
+ bs, s1, s2, d = x.shape
280
+ M, K, H = bs * s1 * s2, x.shape[-1], h
281
+ x_flat = x.view(M, K)
282
+
283
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
284
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
285
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
286
+
287
+ # --- Kernel 1: Fused LN + Dual Matmul ---
288
+ # The grid is launched for the larger 4*H problem
289
+ N_4way = 4 * H
290
+ # Hardcoded best config from logs: M64-N128-K64-GM8-HC32-W4-S2
291
+ config_k1 = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
292
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
293
+
294
+ fused_ln_dual_matmul_kernel[grid](
295
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
296
+ left_final, right_final_t, og_mh,
297
+ M, H, K, s1, s2,
298
+ x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
299
+ W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
300
+ left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
301
+ right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
302
+ mask_mh.stride(0), mask_mh.stride(1),
303
+ LN_EPS=1e-5, **config_k1, num_warps=4, num_stages=2
304
+ )
305
+
306
+ # --- Kernel 2: Batched Matrix Multiplication ---
307
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
308
+ # Hardcoded best config from logs: M128-N128-K32-GM8-W8-S3
309
+ config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
310
+ grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
311
+
312
+ bmm_coalesced_kernel[grid_bmm](
313
+ left_final, right_final_t, bmm_out_tmp,
314
+ bs, s1, s2, H,
315
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
316
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
317
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
318
+ **config_k2, num_warps=8, num_stages=3
319
+ )
320
+
321
+ # --- Kernel 3: Fully Fused Final Stage ---
322
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
323
+ # Hardcoded best config from logs: M32-N128-K32-GM8-W4-S3
324
+ config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
325
+ grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
326
+
327
+ fused_final_kernel[grid_final](
328
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
329
+ M, H, d, s1,
330
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
331
+ og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
332
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
333
+ LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
334
+ )
335
+ return final_out
336
+
337
+ def pack_w_4way_efficient(weights):
338
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
339
+ WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
340
+ H, K = WL.shape
341
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
342
+ return ws.t().to(torch.float16)
343
+
344
+ def get_w_og(weights):
345
+ """ Gets the transposed [K, H] out_gate weight matrix. """
346
+ return weights['out_gate.weight'].t().to(torch.float16)
347
+
348
+ @torch.compile()
349
+ def compiledtrimul(
350
+ x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
351
+ w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
352
+ to_out_weight: torch.Tensor, h: int
353
+ ) -> torch.Tensor:
354
+ bs, s1, s2, d = x.shape
355
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
356
+ all_projections = torch.mm(x_norm, w_concat)
357
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
358
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
359
+ left = left * mask_expanded * torch.sigmoid(lg)
360
+ right = right * mask_expanded * torch.sigmoid(rg)
361
+ out_gate = torch.sigmoid(og)
362
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
363
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
364
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
365
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
366
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
367
+ gated = normed * out_gate
368
+ final_out_flat = gated @ to_out_weight.t()
369
+ return final_out_flat.view(bs, s1, s1, d)
370
+
371
+ def small_kernel_pt_path(data):
372
+ input_tensor, mask, weights, config = data
373
+ w_concat = torch.cat([
374
+ weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
375
+ weights['right_gate.weight'], weights['out_gate.weight']
376
+ ], dim=0).t().contiguous().to(torch.float16)
377
+ return compiledtrimul(
378
+ x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
379
+ norm_weight=weights['norm.weight'].to(torch.float32),
380
+ norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
381
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
382
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
383
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
384
+ h=config["hidden_dim"]
385
+ )
386
+
387
+ def kernel_b200(data):
388
+ input_tensor, mask, weights, config = data
389
+ bs, s1, s2, d = input_tensor.shape
390
+
391
+ if s1 < 800:
392
+ return small_kernel_pt_path(data)
393
+
394
+ H = config["hidden_dim"]
395
+ W_4way = pack_w_4way_efficient(weights)
396
+ W_og = get_w_og(weights)
397
+ M = bs * s1 * s2
398
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
399
+
400
+ return compiledtrimul_fused_interleaved_final(
401
+ x=input_tensor.to(torch.float32),
402
+ mask_mh=mask_mh,
403
+ norm_weight=weights['norm.weight'].to(torch.float32),
404
+ norm_bias=weights['norm.bias'].to(torch.float32),
405
+ W_4way=W_4way,
406
+ W_og=W_og,
407
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
408
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
409
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
410
+ h=H,
411
+ )
build/torch-rocm/triton_h100.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
12
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
13
+
14
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
15
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
16
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
17
+
18
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=4),
19
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
20
+
21
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=5),
22
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=5),
23
+
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
25
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=4),
26
+ ],
27
+ key=['M', 'N', 'K'],
28
+ )
29
+ @triton.jit
30
+ def fused_ln_dual_matmul_kernel(
31
+ # Pointers (9)
32
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
33
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
34
+ # Metadata (5)
35
+ M, H, K, s1, s2,
36
+ # Strides (16)
37
+ stride_x_m, stride_x_k,
38
+ stride_w4_k, stride_w4_n,
39
+ stride_wog_k, stride_wog_n,
40
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
41
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
42
+ stride_og_m, stride_og_h,
43
+ stride_mask_m, stride_mask_h,
44
+ # Constexpr (from decorator and kwargs)
45
+ LN_EPS: tl.constexpr,
46
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
47
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
48
+ ):
49
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
50
+ pid = tl.program_id(axis=0)
51
+ N_4way = 4 * H
52
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
54
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
55
+ group_id = pid // num_pid_in_group
56
+ first_pid_m = group_id * GROUP_SIZE_M
57
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
+ pid_m = first_pid_m + (pid % group_size_m)
59
+ pid_n = (pid % num_pid_in_group) // group_size_m
60
+
61
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
62
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63
+ m_mask = offs_m < M
64
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
65
+
66
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
67
+ for k_offset in range(0, K, BLOCK_SIZE_K):
68
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
69
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
70
+ k_mask = (k_offset + k_chunk_offs) < K
71
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
72
+ mean += tl.sum(x_chunk, axis=1)
73
+ mean /= K
74
+
75
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
76
+ for k_offset in range(0, K, BLOCK_SIZE_K):
77
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
78
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
79
+ k_mask = (k_offset + k_chunk_offs) < K
80
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
81
+ x_centered = x_chunk - mean[:, None]
82
+ var += tl.sum(x_centered * x_centered, axis=1)
83
+ var /= K
84
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
85
+
86
+ # --- Matmul Loop 1: For the 4-Way Projections ---
87
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
89
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
90
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
91
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92
+
93
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
94
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
95
+ k_block_start = k * BLOCK_SIZE_K;
96
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
97
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
98
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
99
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
100
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
101
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
102
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
103
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
104
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
105
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
106
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
107
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
108
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
109
+
110
+ #Some threads should calclate out_gate
111
+ if pid_n * BLOCK_SIZE_N < H:
112
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
113
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
114
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
115
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
117
+
118
+ if pid_n * BLOCK_SIZE_N < H:
119
+ og_out = tl.sigmoid(accumulator_og)
120
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
121
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
122
+ tl.store(outg_ptrs, og_out, mask=og_mask)
123
+
124
+ # --- Fusion Logic for 4-Way Part ---
125
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
126
+ role_idx = tl.arange(0, 4)[None, None, :]
127
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
128
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
129
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
130
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
131
+
132
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
133
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
134
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
135
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
136
+
137
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
138
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
139
+
140
+ s1s2 = s1 * s2
141
+ offs_b = offs_m // s1s2
142
+ offs_s1 = (offs_m % s1s2) // s2
143
+ offs_s2 = offs_m % s2
144
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
145
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
146
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
147
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
148
+
149
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
150
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
151
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
152
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
153
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
154
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
155
+
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
159
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
160
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
161
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
162
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
163
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
164
+ ],
165
+ key=['s1', 's2', 'H'],
166
+ )
167
+ @triton.jit
168
+ def bmm_coalesced_kernel(
169
+ # Pointers
170
+ Left_ptr, Right_ptr, Out_ptr,
171
+ # Dimensions
172
+ bs, s1, s2, H,
173
+ # Strides
174
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
175
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
176
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
177
+ # Kernel parameters
178
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
179
+ GROUP_SIZE_M: tl.constexpr,
180
+ ):
181
+ # Grid and program IDs
182
+ pid = tl.program_id(axis=0)
183
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
184
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
185
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
186
+ group_id = pid // num_pid_in_group
187
+ first_pid_m = group_id * GROUP_SIZE_M
188
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
189
+ pid_m = first_pid_m + (pid % group_size_m)
190
+ pid_n = (pid % num_pid_in_group) // group_size_m
191
+
192
+ pid_bh = tl.program_id(axis=1)
193
+ pid_b = pid_bh // H
194
+ pid_h = pid_bh % H
195
+
196
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
199
+
200
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
201
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
202
+
203
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
204
+
205
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
206
+ k_start = k * BLOCK_SIZE_K
207
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
208
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
209
+
210
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
211
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
212
+
213
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
214
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
215
+
216
+ accumulator += tl.dot(a, b)
217
+
218
+ # --- Coalesced Write ---
219
+ # Write to a standard (bs, H, s1, s1) layout
220
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
221
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
222
+
223
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
224
+ tl.store(out_ptrs, accumulator, mask=c_mask)
225
+
226
+ @torch.compile
227
+ def torch_pt2(left_final, right_final_t, bs, s1, s2, d, h, to_out_norm_weight, to_out_norm_bias, og_mh, to_out_weight):
228
+ bmm_out = torch.matmul(left_final, right_final_t)
229
+ out_einsum_flat = bmm_out.permute(0, 2, 3, 1).reshape(bs * s1 * s1, h)
230
+ # Apply layer norm and final gating
231
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
232
+ gated = normed * og_mh
233
+
234
+ # Final projection
235
+ final_out_flat = gated @ to_out_weight.t()
236
+ final_out = final_out_flat.view(bs, s1, s2, d)
237
+ return final_out
238
+
239
+ @triton.autotune(
240
+ configs=[
241
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
243
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
244
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
245
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
247
+ ],
248
+ key=['H', 'D'],
249
+ )
250
+ @triton.jit
251
+ def fused_final_kernel(
252
+ # Pointers
253
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
254
+ # Metadata
255
+ M, H, D, s1, # M_gate = bs*s1*s2
256
+ # Strides
257
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
258
+ stride_gate_m, stride_gate_h,
259
+ stride_proj_d, stride_proj_h,
260
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
261
+ # Constants
262
+ LN_EPS: tl.constexpr,
263
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
264
+ GROUP_SIZE_M: tl.constexpr,
265
+ ):
266
+ # --- Grid and PID Setup for Matmul ---
267
+ pid = tl.program_id(axis=0)
268
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
269
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
270
+
271
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
272
+ group_id = pid // num_pid_in_group
273
+ first_pid_m = group_id * GROUP_SIZE_M
274
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
275
+ pid_m = first_pid_m + (pid % group_size_m)
276
+ pid_n = (pid % num_pid_in_group) // group_size_m
277
+
278
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
279
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
280
+ m_mask = offs_m < M
281
+
282
+ # Decompose M back to (b, r, c) for reordering lookups
283
+ s1s1 = s1 * s1
284
+ b = offs_m // s1s1
285
+ r = (offs_m % s1s1) // s1
286
+ c = offs_m % s1
287
+
288
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
289
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
291
+
292
+ for k_offset in range(0, H, BLOCK_SIZE_K):
293
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
294
+ k_mask = offs_k < H
295
+
296
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
297
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
298
+
299
+ # Accumulate sum and sum of squares in one pass
300
+ sum_x += tl.sum(in_chunk, axis=1)
301
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
302
+
303
+ # Finalize statistics
304
+ mean = sum_x / H
305
+ var = (sum_x2 / H) - (mean * mean)
306
+ rstd = tl.math.rsqrt(var + LN_EPS)
307
+
308
+ # --- Pass 3: Fused Gating and Matmul ---
309
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
+ for k_offset in range(0, H, BLOCK_SIZE_K):
311
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
312
+ k_mask = offs_k < H
313
+
314
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
315
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
316
+ a_norm = (a - mean[:, None]) * rstd[:, None]
317
+
318
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
319
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
320
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
321
+
322
+ proj_ptrs = ProjW_ptr + \
323
+ offs_n[None, :] * stride_proj_d + \
324
+ offs_k[:, None] * stride_proj_h
325
+
326
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
327
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
328
+ a_gated = a_norm * gate
329
+
330
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
331
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
332
+
333
+ # --- Store Final Output ---
334
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
336
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
337
+
338
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
339
+
340
+ def compiledtrimul_fused_interleaved(
341
+ x: torch.Tensor,
342
+ mask_mh: torch.Tensor,
343
+ norm_weight: torch.Tensor,
344
+ norm_bias: torch.Tensor,
345
+ W_4way: torch.Tensor, # Use the new weight matrices
346
+ W_og: torch.Tensor,
347
+ to_out_norm_weight: torch.Tensor,
348
+ to_out_norm_bias: torch.Tensor,
349
+ to_out_weight: torch.Tensor,
350
+ h: int,
351
+ ):
352
+ bs, s1, s2, d = x.shape
353
+ M, K, H = bs * s1 * s2, x.shape[-1], h
354
+ x_flat = x.view(M, K)
355
+
356
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
357
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
358
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
359
+
360
+ # The grid is launched for the larger 4*H problem
361
+ N_4way = 4 * H
362
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
363
+ fused_ln_dual_matmul_kernel[grid](
364
+ # Pointers (9)
365
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
366
+ left_final, right_final_t, og_mh,
367
+ # Metadata (5) - M, H, K, s1, s2
368
+ M, H, K, s1, s2,
369
+ # Strides (16)
370
+ x_flat.stride(0), x_flat.stride(1),
371
+ W_4way.stride(0), W_4way.stride(1),
372
+ W_og.stride(0), W_og.stride(1),
373
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
374
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
375
+ og_mh.stride(0), og_mh.stride(1),
376
+ mask_mh.stride(0), mask_mh.stride(1),
377
+ # Constexpr (1)
378
+ LN_EPS=1e-5
379
+ )
380
+ return torch_pt2(
381
+ left_final, right_final_t,
382
+ bs=bs,
383
+ s1=s1,
384
+ s2=s2,
385
+ d=d,
386
+ h=h,
387
+ to_out_norm_weight=to_out_norm_weight,
388
+ to_out_norm_bias=to_out_norm_bias,
389
+ og_mh=og_mh,
390
+ to_out_weight=to_out_weight
391
+ )
392
+
393
+ def pack_w_4way_efficient(weights):
394
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
395
+ WL = weights['left_proj.weight']
396
+ WLG = weights['left_gate.weight']
397
+ WR = weights['right_proj.weight']
398
+ WRG = weights['right_gate.weight']
399
+ H, K = WL.shape
400
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
401
+ ws = ws.contiguous().view(4 * H, K)
402
+ return ws.t().to(torch.float16)
403
+
404
+ def get_w_og(weights):
405
+ """ Gets the transposed [K, H] out_gate weight matrix. """
406
+ WOG = weights['out_gate.weight']
407
+ return WOG.t().to(torch.float16)
408
+
409
+
410
+ torch.backends.cuda.matmul.allow_tf32 = True
411
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
412
+
413
+ @torch.compile
414
+ def compiledtrimul(
415
+ x: torch.Tensor,
416
+ mask: torch.Tensor,
417
+ norm_weight: torch.Tensor,
418
+ norm_bias: torch.Tensor,
419
+ w_concat: torch.Tensor,
420
+ to_out_norm_weight: torch.Tensor,
421
+ to_out_norm_bias: torch.Tensor,
422
+ to_out_weight: torch.Tensor,
423
+ h: int
424
+ ) -> torch.Tensor:
425
+ """
426
+ A barebones, compiled PyTorch function for the TriMul logic.
427
+ """
428
+ bs, s1, s2, d = x.shape
429
+
430
+ # Initial LayerNorm
431
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
432
+ # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
433
+ all_projections = torch.mm(x_norm, w_concat)
434
+
435
+ # Split back into individual projections
436
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
437
+
438
+ # Apply mask and gates
439
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
440
+ left = left * mask_expanded * torch.sigmoid(lg)
441
+ right = right * mask_expanded * torch.sigmoid(rg)
442
+ out_gate = torch.sigmoid(og)
443
+
444
+ # Reshape for einsum
445
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
446
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
447
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
448
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
449
+
450
+ # Apply layer norm and final gating
451
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
452
+ gated = normed * out_gate
453
+
454
+ # Final projection
455
+ final_out_flat = gated @ to_out_weight.t()
456
+ final_out = final_out_flat.view(bs, s1, s2, d)
457
+
458
+ return final_out
459
+
460
+ def small_kernel_pt_path(data):
461
+ input_tensor, mask, weights, config = data
462
+ w_concat = torch.cat([
463
+ weights['left_proj.weight'],
464
+ weights['right_proj.weight'],
465
+ weights['left_gate.weight'],
466
+ weights['right_gate.weight'],
467
+ weights['out_gate.weight']
468
+ ], dim=0).t().contiguous().to(torch.float16)
469
+ # Call the compiled function with prepared weights
470
+ output = compiledtrimul(
471
+ x=input_tensor.to(torch.float32),
472
+ mask=mask.unsqueeze(-1),
473
+ norm_weight=weights['norm.weight'].to(torch.float32),
474
+ norm_bias=weights['norm.bias'].to(torch.float32),
475
+ w_concat=w_concat,
476
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float32),
477
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float32),
478
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
479
+ h=config["hidden_dim"]
480
+ )
481
+ return output
482
+
483
+ def kernel_h100(data):
484
+ input_tensor, mask, weights, config = data
485
+ bs, s1, s2, d = input_tensor.shape
486
+
487
+ if s1 <= 512:
488
+ return small_kernel_pt_path(data)
489
+
490
+ H = config["hidden_dim"]
491
+
492
+ W_4way = pack_w_4way_efficient(weights)
493
+ W_og = get_w_og(weights)
494
+
495
+ M = bs * s1 * s2
496
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
497
+
498
+ return compiledtrimul_fused_interleaved(
499
+ x=input_tensor.to(torch.float32),
500
+ mask_mh=mask_mh,
501
+ norm_weight=weights['norm.weight'].to(torch.float32),
502
+ norm_bias=weights['norm.bias'].to(torch.float32),
503
+ W_4way=W_4way, # Pass the new 4-way matrix
504
+ W_og=W_og, # Pass the new out_gate matrix
505
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
506
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
507
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
508
+ h=H,
509
+ )
build/torch-xpu/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .triton_a100 import kernel_a100
2
+ from .triton_h100 import kernel_h100
3
+ from .triton_b200 import kernel_b200
4
+ from .trimul_mi300 import kernel_mi300
5
+ from .trimul_global import kernel_global
6
+
7
+ __all__ = ["kernel_a100", "kernel_h100", "kernel_b200", "kernel_mi300", "kernel_global"]
build/torch-xpu/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._trimul_gpumode_176b4e4
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_trimul_gpumode_176b4e4::{op_name}"
build/torch-xpu/metadata.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"python-depends":[]}
build/torch-xpu/task.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Type definitions for TriMul task.
3
+
4
+ Input: Tuple of (input_tensor, mask, weights, config)
5
+ - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
6
+ - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
7
+ - weights: Dictionary containing model weights
8
+ - config: Dictionary containing model configuration parameters
9
+
10
+ Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
11
+ """
12
+
13
+ import torch
14
+ from typing import Tuple, Dict, Any
15
+
16
+ # Input type: (input_tensor, mask, weights, config)
17
+ input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
18
+
19
+ # Output type: output tensor
20
+ output_t = torch.Tensor
build/torch-xpu/trimul_global.py ADDED
@@ -0,0 +1,971 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from utils import make_match_reference, DisableCuDNNTF32
2
+ from .task import input_t, output_t
3
+
4
+ import torch
5
+ from torch import nn, einsum
6
+ import math
7
+ import os
8
+ import requests
9
+
10
+ import triton
11
+ import triton.language as tl
12
+
13
+ # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
14
+ # in PyTorch 1.12 and later.
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+
17
+ # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
18
+ torch.backends.cudnn.allow_tf32 = True
19
+
20
+ # Set allocator for TMA descriptors (required for on-device TMA)
21
+ def alloc_fn(size: int, alignment: int, stream=None):
22
+ return torch.empty(size, device="cuda", dtype=torch.int8)
23
+
24
+ triton.set_allocator(alloc_fn)
25
+
26
+ # os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
27
+ # os.environ['MLIR_ENABLE_DIAGNOSTICS'] = 'warnings,remarks'
28
+
29
+ # Reference code in PyTorch
30
+ class TriMul(nn.Module):
31
+ # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
32
+ def __init__(
33
+ self,
34
+ dim: int,
35
+ hidden_dim: int,
36
+ ):
37
+ super().__init__()
38
+
39
+ self.norm = nn.LayerNorm(dim)
40
+
41
+ self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
42
+ self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
43
+
44
+ self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
45
+ self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
46
+ self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
47
+
48
+ self.to_out_norm = nn.LayerNorm(hidden_dim)
49
+ self.to_out = nn.Linear(hidden_dim, dim, bias=False)
50
+
51
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ x: [bs, seq_len, seq_len, dim]
54
+ mask: [bs, seq_len, seq_len]
55
+
56
+ Returns:
57
+ output: [bs, seq_len, seq_len, dim]
58
+ """
59
+ batch_size, seq_len, _, dim = x.shape
60
+
61
+ x = self.norm(x)
62
+
63
+ left = self.left_proj(x)
64
+ right = self.right_proj(x)
65
+
66
+ mask = mask.unsqueeze(-1)
67
+ left = left * mask
68
+ right = right * mask
69
+
70
+ left_gate = self.left_gate(x).sigmoid()
71
+ right_gate = self.right_gate(x).sigmoid()
72
+ out_gate = self.out_gate(x).sigmoid()
73
+
74
+ left = left * left_gate
75
+ right = right * right_gate
76
+
77
+ out = einsum('... i k d, ... j k d -> ... i j d', left, right)
78
+ # This einsum is the same as the following:
79
+ # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
80
+
81
+ # # Compute using nested loops
82
+ # for b in range(batch_size):
83
+ # for i in range(seq_len):
84
+ # for j in range(seq_len):
85
+ # # Compute each output element
86
+ # for k in range(seq_len):
87
+ # out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
88
+
89
+ out = self.to_out_norm(out)
90
+ out = out * out_gate
91
+ return self.to_out(out)
92
+
93
+ @triton.jit
94
+ def triton_sigmoid(x):
95
+ """
96
+ Compute sigmoid function: 1 / (1 + exp(-x))
97
+ """
98
+ return 1.0 / (1.0 + tl.exp(-x))
99
+
100
+ def two_mm_kernel_configs_wrapper():
101
+ if torch.cuda.get_device_capability() == (12, 0):
102
+ def two_mm_kernel_configs():
103
+ configs = []
104
+ for BLOCK_M in [16, 32]:
105
+ for BLOCK_N in [16, 32, 64]:
106
+ for BLOCK_K in [16, 32, 64]:
107
+ for num_stages in [2, 3]:
108
+ configs.append(triton.Config({
109
+ 'BLOCK_M': BLOCK_M,
110
+ 'BLOCK_N': BLOCK_N,
111
+ 'BLOCK_K': BLOCK_K,
112
+ 'GROUP_SIZE_M': 8
113
+ }, num_stages=num_stages, num_warps=8))
114
+ return configs
115
+
116
+ elif torch.cuda.get_device_capability()[0] == 9:
117
+ def get_optimal_two_mm_config_h100(B, seq_len, dim):
118
+ configs = {
119
+ (1, 128, 128): (128, 64, 128, 2, 8),
120
+ (1, 128, 256): (128, 64, 128, 2, 8),
121
+ (1, 128, 384): (128, 64, 64, 3, 8),
122
+ (1, 128, 512): (128, 64, 64, 3, 8),
123
+ (1, 128, 768): (128, 64, 64, 3, 8),
124
+ (1, 128, 1024): (128, 64, 64, 3, 8),
125
+ (1, 256, 128): (128, 64, 128, 2, 8),
126
+ (1, 256, 256): (128, 64, 128, 2, 8),
127
+ (1, 256, 384): (128, 64, 64, 3, 8),
128
+ (1, 256, 512): (128, 64, 64, 3, 8),
129
+ (1, 256, 768): (128, 64, 64, 3, 8),
130
+ (1, 256, 1024): (128, 64, 64, 3, 8),
131
+ (1, 512, 128): (128, 64, 128, 2, 8),
132
+ (1, 512, 256): (128, 64, 128, 2, 8),
133
+ (1, 512, 384): (128, 64, 128, 2, 8),
134
+ (1, 512, 512): (128, 64, 128, 2, 8),
135
+ (1, 512, 768): (128, 64, 64, 3, 8),
136
+ (1, 512, 1024): (128, 64, 64, 3, 8),
137
+ (1, 1024, 128): (128, 64, 128, 2, 8),
138
+ (1, 1024, 256): (128, 64, 64, 2, 8),
139
+ (1, 1024, 384): (128, 64, 128, 2, 8),
140
+ (1, 1024, 512): (128, 64, 128, 2, 8),
141
+ (1, 1024, 768): (128, 64, 128, 2, 8),
142
+ (1, 1024, 1024): (128, 64, 128, 2, 8),
143
+ (2, 128, 128): (128, 64, 128, 2, 8),
144
+ (2, 128, 256): (128, 64, 128, 2, 8),
145
+ (2, 128, 384): (128, 64, 64, 3, 8),
146
+ (2, 128, 512): (128, 64, 64, 3, 8),
147
+ (2, 128, 768): (128, 64, 64, 3, 8),
148
+ (2, 128, 1024): (128, 64, 64, 3, 8),
149
+ (2, 256, 128): (128, 64, 128, 2, 8),
150
+ (2, 256, 256): (128, 64, 128, 2, 8),
151
+ (2, 256, 384): (128, 64, 128, 2, 8),
152
+ (2, 256, 512): (128, 64, 128, 2, 8),
153
+ (2, 256, 768): (128, 64, 64, 3, 8),
154
+ (2, 256, 1024): (128, 64, 64, 3, 8),
155
+ (2, 512, 128): (128, 64, 128, 2, 8),
156
+ (2, 512, 256): (128, 64, 128, 2, 8),
157
+ (2, 512, 384): (128, 64, 128, 2, 8),
158
+ (2, 512, 512): (128, 64, 128, 2, 8),
159
+ (2, 512, 768): (128, 64, 128, 2, 8),
160
+ (2, 512, 1024): (128, 64, 128, 2, 8),
161
+ (2, 1024, 128): (128, 64, 128, 2, 8),
162
+ (2, 1024, 256): (128, 64, 128, 2, 8),
163
+ (2, 1024, 384): (128, 64, 128, 2, 8),
164
+ (2, 1024, 512): (128, 64, 128, 2, 8),
165
+ (2, 1024, 768): (128, 64, 128, 2, 8),
166
+ (2, 1024, 1024): (128, 64, 128, 2, 8),
167
+ }
168
+ return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
169
+
170
+ def two_mm_kernel_configs():
171
+ # This function is kept for compatibility but will be overridden for H100
172
+ return [
173
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
174
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
175
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
176
+ ]
177
+
178
+ elif torch.cuda.get_device_capability()[0] == 10 and False:
179
+ def get_optimal_two_mm_config(B, seq_len, dim):
180
+ configs = {
181
+ (1, 128, 128): (64, 128, 64, 2, 8),
182
+ (1, 128, 256): (128, 64, 128, 2, 8),
183
+ (1, 128, 384): (128, 64, 128, 2, 8),
184
+ (1, 128, 512): (128, 64, 128, 2, 8),
185
+ (1, 128, 768): (128, 64, 64, 3, 8),
186
+ (1, 128, 1024): (128, 64, 64, 3, 8),
187
+ (1, 256, 128): (128, 64, 128, 2, 8),
188
+ (1, 256, 256): (128, 64, 128, 2, 8),
189
+ (1, 256, 384): (128, 64, 128, 2, 8),
190
+ (1, 256, 512): (128, 64, 64, 3, 8),
191
+ (1, 256, 768): (128, 64, 64, 3, 8),
192
+ (1, 256, 1024): (128, 64, 64, 3, 8),
193
+ (1, 512, 128): (128, 64, 128, 2, 8),
194
+ (1, 512, 256): (128, 64, 128, 2, 8),
195
+ (1, 512, 384): (128, 64, 128, 2, 8),
196
+ (1, 512, 512): (128, 64, 128, 2, 8),
197
+ (1, 512, 768): (128, 64, 64, 3, 8),
198
+ (1, 512, 1024): (128, 64, 64, 3, 8),
199
+ (1, 1024, 128): (128, 64, 128, 2, 8),
200
+ (1, 1024, 256): (128, 64, 128, 2, 8),
201
+ (1, 1024, 384): (128, 64, 128, 2, 8),
202
+ (1, 1024, 512): (128, 64, 128, 2, 8),
203
+ (1, 1024, 768): (128, 64, 64, 3, 8),
204
+ (1, 1024, 1024): (128, 64, 64, 3, 8),
205
+ (2, 128, 128): (128, 64, 128, 2, 8),
206
+ (2, 128, 256): (128, 64, 128, 2, 8),
207
+ (2, 128, 384): (128, 64, 128, 2, 8),
208
+ (2, 128, 512): (128, 64, 64, 3, 8),
209
+ (2, 128, 768): (128, 64, 64, 3, 8),
210
+ (2, 128, 1024): (128, 64, 64, 3, 8),
211
+ (2, 256, 128): (128, 64, 128, 2, 8),
212
+ (2, 256, 256): (128, 64, 128, 2, 8),
213
+ (2, 256, 384): (128, 64, 128, 2, 8),
214
+ (2, 256, 512): (128, 64, 64, 3, 8),
215
+ (2, 256, 768): (128, 64, 64, 3, 8),
216
+ (2, 256, 1024): (128, 64, 64, 3, 8),
217
+ (2, 512, 128): (128, 64, 128, 2, 8),
218
+ (2, 512, 256): (128, 64, 128, 2, 8),
219
+ (2, 512, 384): (128, 64, 128, 2, 8),
220
+ (2, 512, 512): (128, 64, 128, 2, 8),
221
+ (2, 512, 768): (128, 64, 64, 3, 8),
222
+ (2, 512, 1024): (128, 64, 64, 3, 8),
223
+ (2, 1024, 128): (128, 64, 128, 2, 8),
224
+ (2, 1024, 256): (128, 64, 128, 2, 8),
225
+ (2, 1024, 384): (128, 64, 128, 2, 8),
226
+ (2, 1024, 512): (128, 64, 128, 2, 8),
227
+ (2, 1024, 768): (128, 64, 64, 3, 8),
228
+ (2, 1024, 1024): (128, 64, 64, 3, 8),
229
+ }
230
+ return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
231
+
232
+ def two_mm_kernel_configs():
233
+ # This function is kept for compatibility but will be overridden
234
+ return [
235
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
236
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
237
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
238
+ ]
239
+ elif torch.cuda.get_device_capability()[0] == 8:
240
+ # A100
241
+ def two_mm_kernel_configs():
242
+ configs = []
243
+ for BLOCK_M in [64]:
244
+ for BLOCK_N in [64, 128]:
245
+ for BLOCK_K in [16]:
246
+ for num_stages in [3, 4]:
247
+ for num_warps in [4, 8]:
248
+ configs.append(triton.Config({
249
+ 'BLOCK_M': BLOCK_M,
250
+ 'BLOCK_N': BLOCK_N,
251
+ 'BLOCK_K': BLOCK_K,
252
+ 'GROUP_SIZE_M': 8
253
+ }, num_stages=num_stages, num_warps=num_warps))
254
+ return configs
255
+ else:
256
+ def two_mm_kernel_configs():
257
+ configs = []
258
+ for BLOCK_M in [64, 128]:
259
+ for BLOCK_N in [64, 128]:
260
+ for BLOCK_K in [64, 128]:
261
+ for num_stages in [2, 3]:
262
+ configs.append(triton.Config({
263
+ 'BLOCK_M': BLOCK_M,
264
+ 'BLOCK_N': BLOCK_N,
265
+ 'BLOCK_K': BLOCK_K,
266
+ 'GROUP_SIZE_M': 8
267
+ }, num_stages=num_stages, num_warps=8))
268
+ return configs
269
+
270
+ return two_mm_kernel_configs
271
+
272
+ def two_mm_kernel_wrapper():
273
+ if torch.cuda.get_device_capability()[0] == 8:
274
+ @triton.jit
275
+ def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
276
+ # Persistent kernel using standard tl.load operations
277
+ start_pid = tl.program_id(axis=0)
278
+ num_pid_m = tl.cdiv(M, BLOCK_M)
279
+ num_pid_n = tl.cdiv(N, BLOCK_N)
280
+ k_tiles = tl.cdiv(K, BLOCK_K)
281
+ num_tiles = num_pid_m * num_pid_n
282
+
283
+ # tile_id_c is used in the epilogue to break the dependency between
284
+ # the prologue and the epilogue
285
+ tile_id_c = start_pid - NUM_SMS
286
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
287
+
288
+ # Persistent loop over tiles
289
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
290
+ # Calculate PID for this tile using improved swizzling
291
+ group_id = tile_id // num_pid_in_group
292
+ first_pid_m = group_id * GROUP_SIZE_M
293
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
294
+ pid_m = first_pid_m + (tile_id % group_size_m)
295
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
296
+
297
+ # Calculate block offsets
298
+ offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
299
+ offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
300
+ offs_k = tl.arange(0, BLOCK_K)
301
+
302
+ # Initialize accumulators for all outputs
303
+ accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
304
+ accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
305
+ accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
306
+ accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
307
+ accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
308
+
309
+ # Main computation loop over K dimension
310
+ for ki in range(k_tiles):
311
+ k_start = ki * BLOCK_K
312
+ k_offsets = k_start + offs_k
313
+
314
+ # Create pointers for A matrix (2D flattened view)
315
+ a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3
316
+ a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K)
317
+
318
+ # Create pointers for B matrices [N, K] layout
319
+ b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
320
+ b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
321
+ b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
322
+ b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
323
+ b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
324
+ b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K)
325
+
326
+ # Load blocks from A and all weight matrices using standard tl.load
327
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
328
+ b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0)
329
+ b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0)
330
+ b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0)
331
+ b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0)
332
+ b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0)
333
+
334
+ # Perform matrix multiplications using TF32
335
+ accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
336
+ accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
337
+ accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
338
+ accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
339
+ accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
340
+
341
+ # Store results using separate tile_id_c for epilogue
342
+ tile_id_c += NUM_SMS
343
+ group_id = tile_id_c // num_pid_in_group
344
+ first_pid_m = group_id * GROUP_SIZE_M
345
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346
+ pid_m = first_pid_m + (tile_id_c % group_size_m)
347
+ pid_n = (tile_id_c % num_pid_in_group) // group_size_m
348
+
349
+ # Calculate output offsets and pointers
350
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
351
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
352
+
353
+ # Create masks for bounds checking
354
+ d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
355
+
356
+ # Calculate pointer addresses using 4D strides
357
+ stride_cm = stride_c2 # Stride to next element in flattened M dimension
358
+ stride_cn = stride_c3 # N is the innermost dimension
359
+
360
+ # For D tensor: use separate D strides
361
+ stride_dm = stride_d2 # Stride to next element in flattened M dimension
362
+ stride_dn = stride_d3 # N is the innermost dimension
363
+
364
+ off_c_batch = offs_cm // (seq_len * seq_len)
365
+ off_c_sl1 = (offs_cm // seq_len) % seq_len
366
+ off_c_sl2 = offs_cm % seq_len
367
+ off_c_dim = offs_cn
368
+
369
+ c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
370
+ c_mask = d_mask
371
+
372
+ c1_ptrs = c1_ptr + c_offsets
373
+ c2_ptrs = c2_ptr + c_offsets
374
+ d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
375
+
376
+ mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
377
+
378
+ # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
379
+ mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
380
+ # Apply masking only to left_proj and right_proj results (C1, C2)
381
+ accumulator1 = tl.where(mask_2d, accumulator1, 0)
382
+ accumulator2 = tl.where(mask_2d, accumulator2, 0)
383
+
384
+ # Apply sigmoid to gate values
385
+ left_gate_sigmoid = triton_sigmoid(accumulator3)
386
+ right_gate_sigmoid = triton_sigmoid(accumulator4)
387
+ accumulator_d = triton_sigmoid(accumulator_d)
388
+
389
+ # Apply elementwise multiplication with gated values
390
+ # C1 = left * left_gate, C2 = right * right_gate
391
+ accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
392
+ accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
393
+
394
+ # Convert to appropriate output dtype and store with normal tl.store
395
+ c1 = accumulator1.to(c1_ptr.dtype.element_ty)
396
+ c2 = accumulator2.to(c2_ptr.dtype.element_ty)
397
+ d = accumulator_d.to(d_ptr.dtype.element_ty)
398
+
399
+ tl.store(c1_ptrs, c1, mask=c_mask)
400
+ tl.store(c2_ptrs, c2, mask=c_mask)
401
+ tl.store(d_ptrs, d, mask=d_mask)
402
+ else:
403
+ @triton.jit
404
+ def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
405
+ # Persistent kernel using on-device TMA descriptors
406
+ start_pid = tl.program_id(axis=0)
407
+ num_pid_m = tl.cdiv(M, BLOCK_M)
408
+ num_pid_n = tl.cdiv(N, BLOCK_N)
409
+ k_tiles = tl.cdiv(K, BLOCK_K)
410
+ num_tiles = num_pid_m * num_pid_n
411
+
412
+ # Create on-device TMA descriptors
413
+ a_desc = tl._experimental_make_tensor_descriptor(
414
+ a_ptr,
415
+ shape=[M, K],
416
+ strides=[stride_a2, stride_a3],
417
+ block_shape=[BLOCK_M, BLOCK_K],
418
+ )
419
+ b1_desc = tl._experimental_make_tensor_descriptor(
420
+ b1_ptr,
421
+ shape=[N, K],
422
+ strides=[stride_bn, stride_bk],
423
+ block_shape=[BLOCK_N, BLOCK_K],
424
+ )
425
+ b2_desc = tl._experimental_make_tensor_descriptor(
426
+ b2_ptr,
427
+ shape=[N, K],
428
+ strides=[stride_bn, stride_bk],
429
+ block_shape=[BLOCK_N, BLOCK_K],
430
+ )
431
+ b3_desc = tl._experimental_make_tensor_descriptor(
432
+ b3_ptr,
433
+ shape=[N, K],
434
+ strides=[stride_bn, stride_bk],
435
+ block_shape=[BLOCK_N, BLOCK_K],
436
+ )
437
+ b4_desc = tl._experimental_make_tensor_descriptor(
438
+ b4_ptr,
439
+ shape=[N, K],
440
+ strides=[stride_bn, stride_bk],
441
+ block_shape=[BLOCK_N, BLOCK_K],
442
+ )
443
+ b5_desc = tl._experimental_make_tensor_descriptor(
444
+ b5_ptr,
445
+ shape=[N, K],
446
+ strides=[stride_bn, stride_bk],
447
+ block_shape=[BLOCK_N, BLOCK_K],
448
+ )
449
+
450
+ # tile_id_c is used in the epilogue to break the dependency between
451
+ # the prologue and the epilogue
452
+ tile_id_c = start_pid - NUM_SMS
453
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
454
+
455
+ # Persistent loop over tiles
456
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
457
+ # Calculate PID for this tile using improved swizzling
458
+ group_id = tile_id // num_pid_in_group
459
+ first_pid_m = group_id * GROUP_SIZE_M
460
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
461
+ pid_m = first_pid_m + (tile_id % group_size_m)
462
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
463
+
464
+ # Calculate block offsets
465
+ offs_am = pid_m * BLOCK_M
466
+ offs_bn = pid_n * BLOCK_N
467
+
468
+ # Initialize accumulators for all outputs
469
+ accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
470
+ accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
471
+ accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
472
+ accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
473
+ accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
474
+
475
+ # Main computation loop over K dimension
476
+ for ki in range(k_tiles):
477
+ offs_k = ki * BLOCK_K
478
+ # Load blocks from A and all weight matrices using on-device TMA
479
+ a = a_desc.load([offs_am, offs_k])
480
+ b1 = b1_desc.load([offs_bn, offs_k])
481
+ b2 = b2_desc.load([offs_bn, offs_k])
482
+ b3 = b3_desc.load([offs_bn, offs_k])
483
+ b4 = b4_desc.load([offs_bn, offs_k])
484
+ b5 = b5_desc.load([offs_bn, offs_k])
485
+
486
+ # Perform matrix multiplications using TF32
487
+ accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
488
+ accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
489
+ accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
490
+ accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
491
+ accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
492
+
493
+ # Store results using separate tile_id_c for epilogue
494
+ tile_id_c += NUM_SMS
495
+ group_id = tile_id_c // num_pid_in_group
496
+ first_pid_m = group_id * GROUP_SIZE_M
497
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
498
+ pid_m = first_pid_m + (tile_id_c % group_size_m)
499
+ pid_n = (tile_id_c % num_pid_in_group) // group_size_m
500
+
501
+ # Calculate output offsets and pointers
502
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
503
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
504
+
505
+ # Create masks for bounds checking
506
+ d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
507
+
508
+ # Calculate pointer addresses using 4D strides
509
+ # For C tensors: compute effective 2D strides from 4D strides
510
+ # Output tensor is [B, I, J, N], flattened to [M, N] where M = B*I*J
511
+ stride_cm = stride_c2 # Stride to next element in flattened M dimension
512
+ stride_cn = stride_c3 # N is the innermost dimension
513
+
514
+ # For D tensor: use separate D strides
515
+ stride_dm = stride_d2 # Stride to next element in flattened M dimension
516
+ stride_dn = stride_d3 # N is the innermost dimension
517
+
518
+ off_c_batch = offs_cm // (seq_len * seq_len)
519
+ off_c_sl1 = (offs_cm // seq_len) % seq_len
520
+ off_c_sl2 = offs_cm % seq_len
521
+ off_c_dim = offs_cn
522
+
523
+ # TODO update the mask_c so we don't IMA
524
+ c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
525
+ # c_offsets = offs_cm[:, None] * stride_c2 + offs_cn[None, :] * stride_c3
526
+ c_mask = d_mask
527
+
528
+ c1_ptrs = c1_ptr + c_offsets
529
+ c2_ptrs = c2_ptr + c_offsets
530
+ # c1_ptrs = c1_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
531
+ # c2_ptrs = c2_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
532
+ d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
533
+
534
+ mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
535
+
536
+ # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
537
+ mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
538
+ # Apply masking only to left_proj and right_proj results (C1, C2)
539
+ accumulator1 = tl.where(mask_2d, accumulator1, 0)
540
+ accumulator2 = tl.where(mask_2d, accumulator2, 0)
541
+
542
+ # Apply sigmoid to gate values
543
+ left_gate_sigmoid = triton_sigmoid(accumulator3)
544
+ right_gate_sigmoid = triton_sigmoid(accumulator4)
545
+ accumulator_d = triton_sigmoid(accumulator_d)
546
+
547
+ # Apply elementwise multiplication with gated values
548
+ # C1 = left * left_gate, C2 = right * right_gate
549
+ accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
550
+ accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
551
+
552
+ # Convert to appropriate output dtype and store with normal tl.store
553
+ c1 = accumulator1.to(c1_ptr.dtype.element_ty)
554
+ c2 = accumulator2.to(c2_ptr.dtype.element_ty)
555
+ d = accumulator_d.to(d_ptr.dtype.element_ty)
556
+
557
+ tl.store(c1_ptrs, c1, mask=c_mask)
558
+ tl.store(c2_ptrs, c2, mask=c_mask)
559
+ tl.store(d_ptrs, d, mask=d_mask)
560
+
561
+
562
+ if torch.cuda.get_device_capability()[0] not in [9, 10.2]:
563
+ two_mm_kernel = triton.autotune(
564
+ (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"]
565
+ )(two_mm_kernel)
566
+
567
+ return two_mm_kernel
568
+
569
+
570
+ def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask):
571
+ """
572
+ Persistent matrix multiplication for all weight matrices using on-device TMA descriptors.
573
+
574
+ Args:
575
+ A: [..., K] tensor (arbitrary leading dimensions)
576
+ left_proj: [N, K] matrix (will be transposed)
577
+ right_proj: [N, K] matrix (will be transposed)
578
+ left_gate: [N, K] left gate weight matrix
579
+ right_gate: [N, K] right gate weight matrix
580
+ out_gate: [N, K] output gate weight matrix
581
+ mask: mask tensor
582
+
583
+ Returns:
584
+ (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A
585
+ C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked)
586
+ C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked)
587
+ D = sigmoid(A @ out_gate.T) (unmasked)
588
+ """
589
+ # Check constraints
590
+ assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions"
591
+ assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes"
592
+
593
+ # Assert that all weight matrices have the same strides (same [N, K] shape)
594
+ assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \
595
+ "All weight matrices must have identical strides"
596
+
597
+ # Get dimensions
598
+ original_shape = A.shape[:-1] # All dimensions except the last
599
+ K = A.shape[-1]
600
+ N = left_proj.shape[0]
601
+ B, seq_len, _, _ = A.shape
602
+ dtype = A.dtype
603
+
604
+ # Flatten A to 2D for kernel processing
605
+ A_2d = A.view(-1, K) # [M, K] where M is product of all leading dims
606
+ M = A_2d.shape[0]
607
+
608
+ # Get number of streaming multiprocessors
609
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
610
+
611
+ # Launch persistent kernel with limited number of blocks
612
+ grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
613
+
614
+ # Get original 4D strides for A and output tensors
615
+ A_strides = A.stride() # (stride_0, stride_1, stride_2, stride_3)
616
+
617
+ # Create output tensors with proper 4D shape to get correct strides
618
+ output_shape = original_shape + (N,)
619
+ # C1 = torch.empty(output_shape, device=A.device, dtype=dtype)
620
+ # C2 = torch.empty(output_shape, device=A.device, dtype=dtype)
621
+ C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
622
+ C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
623
+ D = torch.empty(output_shape, device=A.device, dtype=torch.float16)
624
+
625
+ C_strides = C1.stride() # (stride_0, stride_1, stride_2, stride_3)
626
+ D_strides = D.stride() # (stride_0, stride_1, stride_2, stride_3)
627
+
628
+ # Use optimal configuration for B200/H100 or fallback to autotuning for other GPUs
629
+ if torch.cuda.get_device_capability()[0] == 10:
630
+ # Get optimal configuration for B200
631
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
632
+ grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
633
+
634
+ two_mm_kernel_wrapper()[(grid_size,)](
635
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
636
+ C1, C2, D, mask,
637
+ M, N, K,
638
+ *A_strides, # 4D strides for A
639
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
640
+ *C_strides, # 4D strides for C
641
+ seq_len,
642
+ *D_strides, # 4D strides for D
643
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
644
+ num_stages=num_stages, num_warps=num_warps
645
+ )
646
+ elif torch.cuda.get_device_capability()[0] == 9:
647
+ # Get optimal configuration for H100
648
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
649
+ grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
650
+
651
+ two_mm_kernel_wrapper()[(grid_size,)](
652
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
653
+ C1, C2, D, mask,
654
+ M, N, K,
655
+ *A_strides, # 4D strides for A
656
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
657
+ *C_strides, # 4D strides for C
658
+ seq_len,
659
+ *D_strides, # 4D strides for D
660
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
661
+ num_stages=num_stages, num_warps=num_warps
662
+ )
663
+ else:
664
+ # Use autotuning for other GPUs
665
+ two_mm_kernel_wrapper()[grid](
666
+ A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
667
+ C1, C2, D, mask,
668
+ M, N, K,
669
+ *A_strides, # 4D strides for A
670
+ left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
671
+ *C_strides, # 4D strides for C
672
+ seq_len,
673
+ *D_strides, # 4D strides for D
674
+ NUM_SMS=NUM_SMS
675
+ )
676
+
677
+ return C1, C2, D
678
+
679
+
680
+ def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand):
681
+ ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype))
682
+ out = ln * mul_operand
683
+ return out
684
+
685
+ '''
686
+ @triton.autotune(
687
+ [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)],
688
+ key=["R", "C"]
689
+ )
690
+ '''
691
+ @triton.jit
692
+ def layernorm_kernel_first(
693
+ X,
694
+ Y,
695
+ Weight,
696
+ Bias,
697
+ R,
698
+ C, # aka "dim"
699
+ eps,
700
+ ROW_BLOCK_SIZE: tl.constexpr,
701
+ BLOCK_SIZE: tl.constexpr,
702
+ ):
703
+ row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
704
+ cols = tl.arange(0, BLOCK_SIZE)
705
+
706
+ mask_row = row < R
707
+ mask_col = cols < C
708
+
709
+ # Simple indexing for contiguous data
710
+ x = tl.load(
711
+ X + row[:, None] * C + cols[None, :],
712
+ mask=mask_row[:, None] & mask_col[None, :],
713
+ other=0.0
714
+ ).to(tl.float32)
715
+
716
+ weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
717
+ bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
718
+
719
+ mean = tl.sum(x, axis=1) / C
720
+ diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
721
+ var = tl.sum(diff * diff, axis=1) / C
722
+ rstd = 1 / tl.sqrt(var + eps)
723
+
724
+ y_hat = (x - mean[:, None]) * rstd[:, None]
725
+ y = y_hat * weight[None, :] + bias[None, :]
726
+
727
+ tl.store(
728
+ Y + row[:, None] * C + cols[None, :],
729
+ y,
730
+ mask=mask_row[:, None] & mask_col[None, :]
731
+ )
732
+
733
+
734
+ def get_optimal_config_ln(dim):
735
+ config = None
736
+ if torch.cuda.get_device_capability()[0] == 9:
737
+ if (dim <= 256):
738
+ config = (16, 1)
739
+ elif dim <= 512:
740
+ config = (16, 2)
741
+ elif dim <= 1024:
742
+ config = (16, 4)
743
+
744
+ if not config:
745
+ config = (16, 4)
746
+ return config
747
+
748
+
749
+ def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None):
750
+ B, seq_len, seq_len2, dim = x.shape
751
+ assert(seq_len == seq_len2)
752
+
753
+ R = B * seq_len * seq_len
754
+ C = dim
755
+
756
+ out = torch.empty_like(x, dtype=torch.float16)
757
+
758
+ if not num_warps or not ROW_BLOCK_SIZE:
759
+ ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim)
760
+
761
+ BLOCK_SIZE = triton.next_power_of_2(C)
762
+ assert(BLOCK_SIZE <= 1024)
763
+
764
+ def grid(meta):
765
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
766
+
767
+ layernorm_kernel_first[grid](
768
+ x, out, weight, bias,
769
+ R, C, eps,
770
+ ROW_BLOCK_SIZE=ROW_BLOCK_SIZE,
771
+ BLOCK_SIZE=BLOCK_SIZE,
772
+ num_warps=num_warps,
773
+ num_stages=3
774
+ )
775
+
776
+ return out
777
+
778
+ '''
779
+ def triton_layernorm_first(x, weight, bias, eps=1e-5):
780
+ B, seq_len, seq_len2, dim = x.shape
781
+ assert(seq_len == seq_len2)
782
+
783
+ R = B * seq_len * seq_len
784
+ C = dim
785
+
786
+ out = torch.empty_like(x)
787
+
788
+ BLOCK_SIZE = triton.next_power_of_2(C)
789
+ assert(BLOCK_SIZE <= 1024)
790
+
791
+ def grid(meta):
792
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
793
+
794
+ layernorm_kernel_first[grid](
795
+ x, out, weight, bias,
796
+ R, C, eps,
797
+ BLOCK_SIZE=BLOCK_SIZE
798
+ )
799
+
800
+ return out
801
+ '''
802
+
803
+
804
+ @triton.autotune(
805
+ [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)],
806
+ key=[]
807
+ )
808
+ @triton.jit
809
+ def layernorm_kernel_eltwise(
810
+ X,
811
+ Y,
812
+ Weight,
813
+ Bias,
814
+ OutGate,
815
+ seq_len,
816
+ stride_batch,
817
+ stride_dim,
818
+ R,
819
+ C, # aka "dim"
820
+ eps,
821
+ ROW_BLOCK_SIZE: tl.constexpr,
822
+ BLOCK_SIZE: tl.constexpr,
823
+ ):
824
+ row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
825
+ cols = tl.arange(0, BLOCK_SIZE)
826
+
827
+ # Calculate base pointer for this batch of rows
828
+ tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0)
829
+ # batch_offset = (row // (stride_seq1 // stride_dim)) * stride_batch
830
+ batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len)
831
+ seqs_off = row % (seq_len * seq_len) # TODO is this going to prevent vectorization
832
+
833
+ off_r = batch * stride_batch + seqs_off
834
+ off_c = cols * stride_dim
835
+
836
+ mask_row = row < R
837
+ mask_col = cols < C
838
+
839
+ out_gate = tl.load(
840
+ OutGate + row[:, None] * C + cols[None, :],
841
+ mask = mask_row[:, None] & mask_col[None, :],
842
+ )
843
+
844
+ x = tl.load(
845
+ X + off_r[:, None] + off_c[None, :],
846
+ mask=mask_row[:, None] & mask_col[None, :],
847
+ other=0.0
848
+ ).to(tl.float32)
849
+
850
+ weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
851
+ bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
852
+
853
+ mean = tl.sum(x, axis=1) / C
854
+ diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
855
+ var = tl.sum(diff * diff, axis=1) / C
856
+ rstd = 1 / tl.sqrt(var + eps)
857
+
858
+ y_hat = (x - mean[:, None]) * rstd[:, None]
859
+ y = y_hat * weight[None, :] + bias[None, :]
860
+
861
+ tl.store(
862
+ Y + row[:, None] * C + cols[None, :],
863
+ y * out_gate,
864
+ mask=mask_row[:, None] & mask_col[None, :]
865
+ )
866
+
867
+
868
+ def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5):
869
+ B, seq_len, seq_len2, dim = x.shape
870
+ assert(seq_len == seq_len2)
871
+ R = B * seq_len * seq_len
872
+ assert(x.stride(3) == seq_len*seq_len)
873
+ assert(out_gate.is_contiguous())
874
+ C = dim
875
+
876
+ out = torch.empty_like(out_gate, dtype=torch.float32)
877
+
878
+ BLOCK_SIZE = triton.next_power_of_2(C)
879
+ assert(BLOCK_SIZE == 128)
880
+
881
+ def grid(meta):
882
+ return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
883
+
884
+ layernorm_kernel_eltwise[grid](
885
+ x, out, weight, bias, out_gate,
886
+ seq_len,
887
+ x.stride(0), x.stride(3),
888
+ R, C, eps,
889
+ BLOCK_SIZE=BLOCK_SIZE
890
+ )
891
+
892
+ return out
893
+
894
+
895
+ def kernel_global(data: input_t) -> output_t:
896
+ """
897
+ Reference implementation of TriMul using PyTorch.
898
+
899
+ Args:
900
+ data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
901
+ - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
902
+ - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
903
+ - weights: Dictionary containing model weights
904
+ - config: Dictionary containing model configuration parameters
905
+ """
906
+ input_tensor, mask, weights, config = data
907
+
908
+ left_proj_weight = weights["left_proj.weight"].to(torch.float16)
909
+ right_proj_weight = weights["right_proj.weight"].to(torch.float16)
910
+ left_gate_weight = weights["left_gate.weight"].to(torch.float16)
911
+ right_gate_weight = weights["right_gate.weight"].to(torch.float16)
912
+ out_gate_weight = weights["out_gate.weight"].to(torch.float16)
913
+
914
+ hidden_dim = config["hidden_dim"]
915
+ # trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
916
+
917
+ x = input_tensor
918
+
919
+ batch_size, seq_len, _, dim = x.shape
920
+
921
+ x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias'])
922
+ # x = torch.nn.functional.layer_norm(x, (dim,), eps=1e-5, weight=weights['norm.weight'], bias=weights['norm.bias'])
923
+
924
+ left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask)
925
+ # left = torch.nn.functional.linear(x, weights['left_proj.weight'].to(torch.float16))
926
+ # right = torch.nn.functional.linear(x, weights['right_proj.weight'].to(torch.float16))
927
+
928
+ # left = left * mask.unsqueeze(-1)
929
+ # right = right * mask.unsqueeze(-1)
930
+
931
+ '''
932
+ left = left.to(torch.float32)
933
+ right = right.to(torch.float32)
934
+ x = x.to(torch.float32)
935
+
936
+ left_gate = left_gate.sigmoid()
937
+ right_gate = right_gate.sigmoid()
938
+ out_gate = out_gate.sigmoid()
939
+ '''
940
+
941
+ # Elementwise multiplication now handled in kernel
942
+ # left = left * left_gate
943
+ # right = right * right_gate
944
+
945
+ # out = einsum('... i k d, ... j k d -> ... i j d', left, right)
946
+ out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1]))
947
+ out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1)
948
+
949
+ # out = torch.compile(second_layernorm_mul, dynamic=False)(out, hidden_dim, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
950
+ out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
951
+ # out = torch.nn.functional.layer_norm(out, (hidden_dim,), eps=1e-5, weight=weights['to_out_norm.weight'].to(out.dtype), bias=weights['to_out_norm.bias'].to(out.dtype))
952
+ # out = out * out_gate
953
+ return torch.nn.functional.linear(out, weights['to_out.weight'])
954
+
955
+ '''
956
+ # Fill in the given weights of the model
957
+ trimul.norm.weight = nn.Parameter(weights['norm.weight'])
958
+ trimul.norm.bias = nn.Parameter(weights['norm.bias'])
959
+ trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
960
+ trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
961
+ trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
962
+ trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
963
+ trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
964
+ trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
965
+ trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
966
+ trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
967
+
968
+ output = trimul(input_tensor, mask)
969
+
970
+ return output
971
+ '''
build/torch-xpu/trimul_gpumode/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import sys
3
+
4
+ import importlib
5
+ from pathlib import Path
6
+ from types import ModuleType
7
+
8
+ def _import_from_path(file_path: Path) -> ModuleType:
9
+ # We cannot use the module name as-is, after adding it to `sys.modules`,
10
+ # it would also be used for other imports. So, we make a module name that
11
+ # depends on the path for it to be unique using the hex-encoded hash of
12
+ # the path.
13
+ path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
+ module_name = path_hash
15
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
16
+ if spec is None:
17
+ raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
+ module = importlib.util.module_from_spec(spec)
19
+ if module is None:
20
+ raise ImportError(f"Cannot load module {module_name} from spec")
21
+ sys.modules[module_name] = module
22
+ spec.loader.exec_module(module) # type: ignore
23
+ return module
24
+
25
+
26
+ globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
build/torch-xpu/trimul_mi300.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
12
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
13
+
14
+ # Configurations with larger block sizes for better data reuse
15
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
16
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=2),
17
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
18
+
19
+ # Configurations with deeper K dimension
20
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
21
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
22
+
23
+ # More extreme configurations to test the limits
24
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
25
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=2),
26
+
27
+ # Configurations with fewer warps
28
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
29
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=2),
30
+
31
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
32
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
33
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
34
+ ],
35
+ key=['M', 'N', 'K'],
36
+ )
37
+ @triton.jit
38
+ def fused_ln_dual_matmul_kernel(
39
+ # Pointers (9)
40
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
41
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
42
+ # Metadata (5)
43
+ M, H, K, s1, s2,
44
+ # Strides (16)
45
+ stride_x_m, stride_x_k,
46
+ stride_w4_k, stride_w4_n,
47
+ stride_wog_k, stride_wog_n,
48
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
49
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
50
+ stride_og_m, stride_og_h,
51
+ stride_mask_m, stride_mask_h,
52
+ # Constexpr (from decorator and kwargs)
53
+ LN_EPS: tl.constexpr,
54
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
55
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
56
+ ):
57
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
58
+ pid = tl.program_id(axis=0)
59
+ N_4way = 4 * H
60
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
62
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
63
+ group_id = pid // num_pid_in_group
64
+ first_pid_m = group_id * GROUP_SIZE_M
65
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
66
+ pid_m = first_pid_m + (pid % group_size_m)
67
+ pid_n = (pid % num_pid_in_group) // group_size_m
68
+
69
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
70
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
+ m_mask = offs_m < M
72
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
73
+
74
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
75
+ for k_offset in range(0, K, BLOCK_SIZE_K):
76
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
77
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
78
+ k_mask = (k_offset + k_chunk_offs) < K
79
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
80
+ mean += tl.sum(x_chunk, axis=1)
81
+ mean /= K
82
+
83
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
84
+ for k_offset in range(0, K, BLOCK_SIZE_K):
85
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
86
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
87
+ k_mask = (k_offset + k_chunk_offs) < K
88
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
89
+ x_centered = x_chunk - mean[:, None]
90
+ var += tl.sum(x_centered * x_centered, axis=1)
91
+ var /= K
92
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
93
+
94
+ # --- Matmul Loop 1: For the 4-Way Projections ---
95
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
96
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
97
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
98
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
99
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
+
101
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
102
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
103
+ k_block_start = k * BLOCK_SIZE_K;
104
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
105
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
106
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
107
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
108
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
109
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
110
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
111
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
112
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
113
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
114
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
115
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
117
+
118
+ #Some threads should calclate out_gate
119
+ if pid_n * BLOCK_SIZE_N < H:
120
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
121
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
122
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
123
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
124
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
125
+
126
+ if pid_n * BLOCK_SIZE_N < H:
127
+ og_out = tl.sigmoid(accumulator_og)
128
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
129
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
130
+ tl.store(outg_ptrs, og_out, mask=og_mask)
131
+
132
+ # --- Fusion Logic for 4-Way Part ---
133
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
134
+ role_idx = tl.arange(0, 4)[None, None, :]
135
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
136
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
137
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
138
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
139
+
140
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
141
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
142
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
143
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
144
+
145
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
146
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
147
+
148
+ s1s2 = s1 * s2
149
+ offs_b = offs_m // s1s2
150
+ offs_s1 = (offs_m % s1s2) // s2
151
+ offs_s2 = offs_m % s2
152
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
153
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
154
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
155
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
156
+
157
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
158
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
159
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
160
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
161
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
162
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
163
+
164
+ @triton.autotune(
165
+ configs=[
166
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
167
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
168
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
169
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
170
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
171
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
172
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
173
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
174
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
175
+ ],
176
+ key=['s1', 's2', 'H'],
177
+ )
178
+ @triton.jit
179
+ def bmm_coalesced_kernel(
180
+ # Pointers
181
+ Left_ptr, Right_ptr, Out_ptr,
182
+ # Dimensions
183
+ bs, s1, s2, H,
184
+ # Strides
185
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
186
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
187
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
188
+ # Kernel parameters
189
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
190
+ GROUP_SIZE_M: tl.constexpr,
191
+ ):
192
+ # Grid and program IDs
193
+ pid = tl.program_id(axis=0)
194
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
195
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
196
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
197
+ group_id = pid // num_pid_in_group
198
+ first_pid_m = group_id * GROUP_SIZE_M
199
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
200
+ pid_m = first_pid_m + (pid % group_size_m)
201
+ pid_n = (pid % num_pid_in_group) // group_size_m
202
+
203
+ pid_bh = tl.program_id(axis=1)
204
+ pid_b = pid_bh // H
205
+ pid_h = pid_bh % H
206
+
207
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
208
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
209
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
210
+
211
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
212
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
213
+
214
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
215
+
216
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
217
+ k_start = k * BLOCK_SIZE_K
218
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
219
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
220
+
221
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
222
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
223
+
224
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
225
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
226
+
227
+ accumulator += tl.dot(a, b)
228
+
229
+ # --- Coalesced Write ---
230
+ # Write to a standard (bs, H, s1, s1) layout
231
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
232
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
233
+
234
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
235
+ tl.store(out_ptrs, accumulator, mask=c_mask)
236
+
237
+ @triton.autotune(
238
+ configs=[
239
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
240
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
241
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
243
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
244
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
245
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
247
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
248
+ ],
249
+ key=['H', 'D'],
250
+ )
251
+ @triton.jit
252
+ def fused_final_kernel(
253
+ # Pointers
254
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
255
+ # Metadata
256
+ M, H, D, s1, # M_gate = bs*s1*s2
257
+ # Strides
258
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
259
+ stride_gate_m, stride_gate_h,
260
+ stride_proj_d, stride_proj_h,
261
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
262
+ # Constants
263
+ LN_EPS: tl.constexpr,
264
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
265
+ GROUP_SIZE_M: tl.constexpr,
266
+ ):
267
+ # --- Grid and PID Setup for Matmul ---
268
+ pid = tl.program_id(axis=0)
269
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
270
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
271
+
272
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
273
+ group_id = pid // num_pid_in_group
274
+ first_pid_m = group_id * GROUP_SIZE_M
275
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
276
+ pid_m = first_pid_m + (pid % group_size_m)
277
+ pid_n = (pid % num_pid_in_group) // group_size_m
278
+
279
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
280
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
281
+ m_mask = offs_m < M
282
+
283
+ # Decompose M back to (b, r, c) for reordering lookups
284
+ s1s1 = s1 * s1
285
+ b = offs_m // s1s1
286
+ r = (offs_m % s1s1) // s1
287
+ c = offs_m % s1
288
+
289
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
291
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
292
+
293
+ for k_offset in range(0, H, BLOCK_SIZE_K):
294
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
295
+ k_mask = offs_k < H
296
+
297
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
298
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
299
+
300
+ # Accumulate sum and sum of squares in one pass
301
+ sum_x += tl.sum(in_chunk, axis=1)
302
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
303
+
304
+ # Finalize statistics
305
+ mean = sum_x / H
306
+ var = (sum_x2 / H) - (mean * mean)
307
+ rstd = tl.math.rsqrt(var + LN_EPS)
308
+
309
+ # --- Pass 3: Fused Gating and Matmul ---
310
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
311
+ for k_offset in range(0, H, BLOCK_SIZE_K):
312
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
313
+ k_mask = offs_k < H
314
+
315
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
316
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
317
+ a_norm = (a - mean[:, None]) * rstd[:, None]
318
+
319
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
320
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
321
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
322
+
323
+ proj_ptrs = ProjW_ptr + \
324
+ offs_n[None, :] * stride_proj_d + \
325
+ offs_k[:, None] * stride_proj_h
326
+
327
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
328
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
329
+ a_gated = a_norm * gate
330
+
331
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
332
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
333
+
334
+ # --- Store Final Output ---
335
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
336
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
337
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
338
+
339
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
340
+
341
+ def compiledtrimul_fused_interleaved(
342
+ x: torch.Tensor,
343
+ mask_mh: torch.Tensor,
344
+ norm_weight: torch.Tensor,
345
+ norm_bias: torch.Tensor,
346
+ W_4way: torch.Tensor, # Use the new weight matrices
347
+ W_og: torch.Tensor,
348
+ to_out_norm_weight: torch.Tensor,
349
+ to_out_norm_bias: torch.Tensor,
350
+ to_out_weight: torch.Tensor,
351
+ h: int,
352
+ ):
353
+ bs, s1, s2, d = x.shape
354
+ M, K, H = bs * s1 * s2, x.shape[-1], h
355
+ x_flat = x.view(M, K)
356
+
357
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
358
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
359
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
360
+
361
+ # The grid is launched for the larger 4*H problem
362
+ N_4way = 4 * H
363
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
364
+ fused_ln_dual_matmul_kernel[grid](
365
+ # Pointers (9)
366
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
367
+ left_final, right_final_t, og_mh,
368
+ # Metadata (5) - M, H, K, s1, s2
369
+ M, H, K, s1, s2,
370
+ # Strides (16)
371
+ x_flat.stride(0), x_flat.stride(1),
372
+ W_4way.stride(0), W_4way.stride(1),
373
+ W_og.stride(0), W_og.stride(1),
374
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
375
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
376
+ og_mh.stride(0), og_mh.stride(1),
377
+ mask_mh.stride(0), mask_mh.stride(1),
378
+ # Constexpr (1)
379
+ LN_EPS=1e-5
380
+ )
381
+
382
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
383
+
384
+ grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
385
+ bmm_coalesced_kernel[grid_bmm](
386
+ left_final, right_final_t, bmm_out_tmp,
387
+ bs, s1, s2, H,
388
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
389
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
390
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
391
+ )
392
+
393
+ # --- Kernel 3: Fully Fused Final Stage ---
394
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
395
+
396
+ grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
397
+ fused_final_kernel[grid_final](
398
+ # Pointers
399
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
400
+ # Metadata
401
+ M, H, d, s1,
402
+ # Strides
403
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
404
+ og_mh.stride(0), og_mh.stride(1),
405
+ to_out_weight.stride(0), to_out_weight.stride(1), # Use strides of the corrected tensor
406
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
407
+ # Constants
408
+ LN_EPS=1e-5,
409
+ )
410
+
411
+ return final_out
412
+
413
+ def pack_w_4way_efficient(weights):
414
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
415
+ WL = weights['left_proj.weight']
416
+ WLG = weights['left_gate.weight']
417
+ WR = weights['right_proj.weight']
418
+ WRG = weights['right_gate.weight']
419
+ H, K = WL.shape
420
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
421
+ ws = ws.contiguous().view(4 * H, K)
422
+ return ws.t().to(torch.float16)
423
+
424
+ def get_w_og(weights):
425
+ """ Gets the transposed [K, H] out_gate weight matrix. """
426
+ WOG = weights['out_gate.weight']
427
+ return WOG.t().to(torch.float16)
428
+
429
+ def compiledtrimul(
430
+ x: torch.Tensor,
431
+ mask: torch.Tensor,
432
+ norm_weight: torch.Tensor,
433
+ norm_bias: torch.Tensor,
434
+ w_concat: torch.Tensor,
435
+ to_out_norm_weight: torch.Tensor,
436
+ to_out_norm_bias: torch.Tensor,
437
+ to_out_weight: torch.Tensor,
438
+ h: int
439
+ ) -> torch.Tensor:
440
+ """
441
+ A barebones, compiled PyTorch function for the TriMul logic.
442
+ """
443
+ bs, s1, s2, d = x.shape
444
+
445
+ # Initial LayerNorm
446
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
447
+ # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
448
+ all_projections = torch.mm(x_norm, w_concat)
449
+
450
+ # Split back into individual projections
451
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
452
+
453
+ # Apply mask and gates
454
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
455
+ left = left * mask_expanded * torch.sigmoid(lg)
456
+ right = right * mask_expanded * torch.sigmoid(rg)
457
+ out_gate = torch.sigmoid(og)
458
+
459
+ # Reshape for einsum
460
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
461
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
462
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
463
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
464
+
465
+ # Apply layer norm and final gating
466
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
467
+ gated = normed * out_gate
468
+
469
+ # Final projection
470
+ final_out_flat = gated @ to_out_weight.t()
471
+ final_out = final_out_flat.view(bs, s1, s2, d)
472
+
473
+ return final_out
474
+
475
+ def small_kernel_pt_path(data):
476
+ input_tensor, mask, weights, config = data
477
+ w_concat = torch.cat([
478
+ weights['left_proj.weight'],
479
+ weights['right_proj.weight'],
480
+ weights['left_gate.weight'],
481
+ weights['right_gate.weight'],
482
+ weights['out_gate.weight']
483
+ ], dim=0).t().contiguous().to(torch.float16)
484
+ # Call the compiled function with prepared weights
485
+ output = compiledtrimul(
486
+ x=input_tensor.to(torch.float32),
487
+ mask=mask.unsqueeze(-1),
488
+ norm_weight=weights['norm.weight'].to(torch.float32),
489
+ norm_bias=weights['norm.bias'].to(torch.float32),
490
+ w_concat=w_concat,
491
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
492
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
493
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
494
+ h=config["hidden_dim"]
495
+ )
496
+ return output
497
+
498
+ def kernel_mi300(data):
499
+ input_tensor, mask, weights, config = data
500
+ bs, s1, s2, d = input_tensor.shape
501
+
502
+ if s1 < 100:
503
+ return small_kernel_pt_path(data)
504
+
505
+ H = config["hidden_dim"]
506
+
507
+ W_4way = pack_w_4way_efficient(weights)
508
+ W_og = get_w_og(weights)
509
+
510
+ M = bs * s1 * s2
511
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
512
+
513
+ return compiledtrimul_fused_interleaved(
514
+ x=input_tensor.to(torch.float32),
515
+ mask_mh=mask_mh,
516
+ norm_weight=weights['norm.weight'].to(torch.float32),
517
+ norm_bias=weights['norm.bias'].to(torch.float32),
518
+ W_4way=W_4way, # Pass the new 4-way matrix
519
+ W_og=W_og, # Pass the new out_gate matrix
520
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
521
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
522
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
523
+ h=H,
524
+ )
build/torch-xpu/triton_a100.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ # Set PyTorch flags for performance
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
9
+
10
+ @triton.jit
11
+ def fused_ln_dual_matmul_kernel(
12
+ # Pointers (9)
13
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
14
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
15
+ # Metadata (5)
16
+ M, H, K, s1, s2,
17
+ # Strides (16)
18
+ stride_x_m, stride_x_k,
19
+ stride_w4_k, stride_w4_n,
20
+ stride_wog_k, stride_wog_n,
21
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
22
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
23
+ stride_og_m, stride_og_h,
24
+ stride_mask_m, stride_mask_h,
25
+ # Constexpr (now passed as arguments from the host)
26
+ LN_EPS: tl.constexpr,
27
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
29
+ ):
30
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
31
+ pid = tl.program_id(axis=0)
32
+ N_4way = 4 * H
33
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
35
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
+ group_id = pid // num_pid_in_group
37
+ first_pid_m = group_id * GROUP_SIZE_M
38
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
+ pid_m = first_pid_m + (pid % group_size_m)
40
+ pid_n = (pid % num_pid_in_group) // group_size_m
41
+
42
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
43
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44
+ m_mask = offs_m < M
45
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
46
+
47
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
48
+ for k_offset in range(0, K, BLOCK_SIZE_K):
49
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
50
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
51
+ k_mask = (k_offset + k_chunk_offs) < K
52
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
53
+ mean += tl.sum(x_chunk, axis=1)
54
+ mean /= K
55
+
56
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
57
+ for k_offset in range(0, K, BLOCK_SIZE_K):
58
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
59
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
60
+ k_mask = (k_offset + k_chunk_offs) < K
61
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
62
+ x_centered = x_chunk - mean[:, None]
63
+ var += tl.sum(x_centered * x_centered, axis=1)
64
+ var /= K
65
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
66
+
67
+ # --- Matmul Loop 1: For the 4-Way Projections ---
68
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
69
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
70
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
71
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
+
74
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76
+ k_block_start = k * BLOCK_SIZE_K;
77
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
78
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
79
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
80
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
81
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
82
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
83
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
84
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
86
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
87
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
88
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
89
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
90
+
91
+ if pid_n * BLOCK_SIZE_N < H:
92
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
97
+
98
+ if pid_n * BLOCK_SIZE_N < H:
99
+ og_out = tl.sigmoid(accumulator_og)
100
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
+ tl.store(outg_ptrs, og_out, mask=og_mask)
103
+
104
+ # --- Fusion Logic for 4-Way Part ---
105
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
+ role_idx = tl.arange(0, 4)[None, None, :]
107
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
+
112
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
+
117
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
+
120
+ s1s2 = s1 * s2
121
+ offs_b = offs_m // s1s2
122
+ offs_s1 = (offs_m % s1s2) // s2
123
+ offs_s2 = offs_m % s2
124
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
+
129
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
+
136
+ @triton.jit
137
+ def bmm_coalesced_kernel(
138
+ # Pointers
139
+ Left_ptr, Right_ptr, Out_ptr,
140
+ # Dimensions
141
+ bs, s1, s2, H,
142
+ # Strides
143
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
+ # Kernel parameters
147
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
+ GROUP_SIZE_M: tl.constexpr,
149
+ ):
150
+ pid = tl.program_id(axis=0)
151
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
152
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
153
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
154
+ group_id = pid // num_pid_in_group
155
+ first_pid_m = group_id * GROUP_SIZE_M
156
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
157
+ pid_m = first_pid_m + (pid % group_size_m)
158
+ pid_n = (pid % num_pid_in_group) // group_size_m
159
+
160
+ pid_bh = tl.program_id(axis=1)
161
+ pid_b = pid_bh // H
162
+ pid_h = pid_bh % H
163
+
164
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
165
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
166
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
167
+
168
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
169
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
170
+
171
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
172
+
173
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
174
+ k_start = k * BLOCK_SIZE_K
175
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
176
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
177
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
178
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
179
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
180
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
181
+ accumulator += tl.dot(a, b)
182
+
183
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
184
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
185
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
186
+ tl.store(out_ptrs, accumulator, mask=c_mask)
187
+
188
+ @triton.jit
189
+ def fused_final_kernel(
190
+ # Pointers
191
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
192
+ # Metadata
193
+ M, H, D, s1,
194
+ # Strides
195
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
196
+ stride_gate_m, stride_gate_h,
197
+ stride_proj_d, stride_proj_h,
198
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
199
+ # Constants
200
+ LN_EPS: tl.constexpr,
201
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
202
+ GROUP_SIZE_M: tl.constexpr,
203
+ ):
204
+ pid = tl.program_id(axis=0)
205
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
207
+
208
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
+ group_id = pid // num_pid_in_group
210
+ first_pid_m = group_id * GROUP_SIZE_M
211
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
+ pid_m = first_pid_m + (pid % group_size_m)
213
+ pid_n = (pid % num_pid_in_group) // group_size_m
214
+
215
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
216
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
+ m_mask = offs_m < M
218
+
219
+ s1s1 = s1 * s1
220
+ b = offs_m // s1s1
221
+ r = (offs_m % s1s1) // s1
222
+ c = offs_m % s1
223
+
224
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
225
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
226
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
227
+
228
+ for k_offset in range(0, H, BLOCK_SIZE_K):
229
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
230
+ k_mask = offs_k < H
231
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
232
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
233
+ sum_x += tl.sum(in_chunk, axis=1)
234
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
235
+
236
+ mean = sum_x / H
237
+ var = (sum_x2 / H) - (mean * mean)
238
+ rstd = tl.math.rsqrt(var + LN_EPS)
239
+
240
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
241
+ for k_offset in range(0, H, BLOCK_SIZE_K):
242
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
243
+ k_mask = offs_k < H
244
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
245
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
246
+ a_norm = (a - mean[:, None]) * rstd[:, None]
247
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
248
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
249
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
250
+ proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
251
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
252
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
253
+ a_gated = a_norm * gate
254
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
255
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
256
+
257
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
258
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
259
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
260
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
261
+
262
+ def compiledtrimul_fused_interleaved_final(
263
+ x: torch.Tensor,
264
+ mask_mh: torch.Tensor,
265
+ norm_weight: torch.Tensor,
266
+ norm_bias: torch.Tensor,
267
+ W_4way: torch.Tensor,
268
+ W_og: torch.Tensor,
269
+ to_out_norm_weight: torch.Tensor,
270
+ to_out_norm_bias: torch.Tensor,
271
+ to_out_weight: torch.Tensor,
272
+ h: int,
273
+ ):
274
+ bs, s1, s2, d = x.shape
275
+ M, K, H = bs * s1 * s2, x.shape[-1], h
276
+ x_flat = x.view(M, K)
277
+
278
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
279
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
280
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
281
+
282
+ # --- Kernel 1: Fused LN + Dual Matmul ---
283
+ N_4way = 4 * H
284
+ # Hardcoded A100 best config: M128-N128-K32-GM8-HC32-W8-S2
285
+ config_k1 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
286
+ grid_k1 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
287
+
288
+ fused_ln_dual_matmul_kernel[grid_k1](
289
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
290
+ left_final, right_final_t, og_mh,
291
+ M, H, K, s1, s2,
292
+ x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
293
+ W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
294
+ left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
295
+ right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
296
+ mask_mh.stride(0), mask_mh.stride(1),
297
+ LN_EPS=1e-5, **config_k1, num_warps=8, num_stages=2
298
+ )
299
+
300
+ # --- Kernel 2: Batched Matrix Multiplication ---
301
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
302
+ # Hardcoded A100 best config: M128-N64-K32-GM8-W4-S3
303
+ config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
304
+ grid_k2 = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
305
+
306
+ bmm_coalesced_kernel[grid_k2](
307
+ left_final, right_final_t, bmm_out_tmp,
308
+ bs, s1, s2, H,
309
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
310
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
311
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
312
+ **config_k2, num_warps=4, num_stages=3
313
+ )
314
+
315
+ # --- Kernel 3: Fully Fused Final Stage ---
316
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
317
+ # Hardcoded A100 best config: M32-N128-K32-GM8-W4-S3
318
+ config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
319
+ grid_k3 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
320
+
321
+ fused_final_kernel[grid_k3](
322
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
323
+ M, H, d, s1,
324
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
325
+ og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
326
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
327
+ LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
328
+ )
329
+ return final_out
330
+
331
+ def pack_w_4way_efficient(weights):
332
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
333
+ WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
334
+ H, K = WL.shape
335
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
336
+ return ws.t().to(torch.float16)
337
+
338
+ def get_w_og(weights):
339
+ """ Gets the transposed [K, H] out_gate weight matrix. """
340
+ return weights['out_gate.weight'].t().to(torch.float16)
341
+
342
+ @torch.compile()
343
+ def compiledtrimul(
344
+ x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
345
+ w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
346
+ to_out_weight: torch.Tensor, h: int
347
+ ) -> torch.Tensor:
348
+ bs, s1, s2, d = x.shape
349
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
350
+ all_projections = torch.mm(x_norm, w_concat)
351
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
352
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
353
+ left = left * mask_expanded * torch.sigmoid(lg)
354
+ right = right * mask_expanded * torch.sigmoid(rg)
355
+ out_gate = torch.sigmoid(og)
356
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
357
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
358
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
359
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
360
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
361
+ gated = normed * out_gate
362
+ final_out_flat = gated @ to_out_weight.t()
363
+ return final_out_flat.view(bs, s1, s1, d)
364
+
365
+ def small_kernel_pt_path(data):
366
+ input_tensor, mask, weights, config = data
367
+ w_concat = torch.cat([
368
+ weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
369
+ weights['right_gate.weight'], weights['out_gate.weight']
370
+ ], dim=0).t().contiguous().to(torch.float16)
371
+ return compiledtrimul(
372
+ x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
373
+ norm_weight=weights['norm.weight'].to(torch.float32),
374
+ norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
375
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
376
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
377
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
378
+ h=config["hidden_dim"]
379
+ )
380
+
381
+ def kernel_a100(data):
382
+ input_tensor, mask, weights, config = data
383
+ bs, s1, s2, d = input_tensor.shape
384
+
385
+ if s1 < 512: # Adjusted threshold based on observed BMM configs
386
+ return small_kernel_pt_path(data)
387
+
388
+ H = config["hidden_dim"]
389
+ W_4way = pack_w_4way_efficient(weights)
390
+ W_og = get_w_og(weights)
391
+ M = bs * s1 * s2
392
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
393
+
394
+ return compiledtrimul_fused_interleaved_final(
395
+ x=input_tensor.to(torch.float32),
396
+ mask_mh=mask_mh,
397
+ norm_weight=weights['norm.weight'].to(torch.float32),
398
+ norm_bias=weights['norm.bias'].to(torch.float32),
399
+ W_4way=W_4way,
400
+ W_og=W_og,
401
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
402
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
403
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
404
+ h=H,
405
+ )
build/torch-xpu/triton_b200.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.jit
10
+ def fused_ln_dual_matmul_kernel(
11
+ # Pointers (9)
12
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
13
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
14
+ # Metadata (5)
15
+ M, H, K, s1, s2,
16
+ # Strides (16)
17
+ stride_x_m, stride_x_k,
18
+ stride_w4_k, stride_w4_n,
19
+ stride_wog_k, stride_wog_n,
20
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
21
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
22
+ stride_og_m, stride_og_h,
23
+ stride_mask_m, stride_mask_h,
24
+ # Constexpr (now passed as arguments from the host)
25
+ LN_EPS: tl.constexpr,
26
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
27
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
28
+ ):
29
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
30
+ pid = tl.program_id(axis=0)
31
+ N_4way = 4 * H
32
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
34
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
35
+ group_id = pid // num_pid_in_group
36
+ first_pid_m = group_id * GROUP_SIZE_M
37
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38
+ pid_m = first_pid_m + (pid % group_size_m)
39
+ pid_n = (pid % num_pid_in_group) // group_size_m
40
+
41
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
42
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
43
+ m_mask = offs_m < M
44
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
45
+
46
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
47
+ for k_offset in range(0, K, BLOCK_SIZE_K):
48
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
49
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
50
+ k_mask = (k_offset + k_chunk_offs) < K
51
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
52
+ mean += tl.sum(x_chunk, axis=1)
53
+ mean /= K
54
+
55
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
56
+ for k_offset in range(0, K, BLOCK_SIZE_K):
57
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
58
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
59
+ k_mask = (k_offset + k_chunk_offs) < K
60
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
61
+ x_centered = x_chunk - mean[:, None]
62
+ var += tl.sum(x_centered * x_centered, axis=1)
63
+ var /= K
64
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
65
+
66
+ # --- Matmul Loop 1: For the 4-Way Projections ---
67
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
69
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
70
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
71
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
+
73
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
74
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75
+ k_block_start = k * BLOCK_SIZE_K;
76
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
77
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
78
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
79
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
80
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
81
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
82
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
83
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
84
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
86
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
87
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
88
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
89
+
90
+ #Some threads should calclate out_gate
91
+ if pid_n * BLOCK_SIZE_N < H:
92
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
97
+
98
+ if pid_n * BLOCK_SIZE_N < H:
99
+ og_out = tl.sigmoid(accumulator_og)
100
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
+ tl.store(outg_ptrs, og_out, mask=og_mask)
103
+
104
+ # --- Fusion Logic for 4-Way Part ---
105
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
+ role_idx = tl.arange(0, 4)[None, None, :]
107
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
+
112
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
+
117
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
+
120
+ s1s2 = s1 * s2
121
+ offs_b = offs_m // s1s2
122
+ offs_s1 = (offs_m % s1s2) // s2
123
+ offs_s2 = offs_m % s2
124
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
+
129
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
+
136
+ @triton.jit
137
+ def bmm_coalesced_kernel(
138
+ # Pointers
139
+ Left_ptr, Right_ptr, Out_ptr,
140
+ # Dimensions
141
+ bs, s1, s2, H,
142
+ # Strides
143
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
+ # Kernel parameters
147
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
+ GROUP_SIZE_M: tl.constexpr,
149
+ ):
150
+ # Grid and program IDs
151
+ pid = tl.program_id(axis=0)
152
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
153
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
154
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
155
+ group_id = pid // num_pid_in_group
156
+ first_pid_m = group_id * GROUP_SIZE_M
157
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
158
+ pid_m = first_pid_m + (pid % group_size_m)
159
+ pid_n = (pid % num_pid_in_group) // group_size_m
160
+
161
+ pid_bh = tl.program_id(axis=1)
162
+ pid_b = pid_bh // H
163
+ pid_h = pid_bh % H
164
+
165
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
166
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
167
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
168
+
169
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
170
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
171
+
172
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173
+
174
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
175
+ k_start = k * BLOCK_SIZE_K
176
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
177
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
178
+
179
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
180
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
181
+
182
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
183
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
184
+
185
+ accumulator += tl.dot(a, b)
186
+
187
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
188
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
189
+
190
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
191
+ tl.store(out_ptrs, accumulator, mask=c_mask)
192
+
193
+ @triton.jit
194
+ def fused_final_kernel(
195
+ # Pointers
196
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
197
+ # Metadata
198
+ M, H, D, s1,
199
+ # Strides
200
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
201
+ stride_gate_m, stride_gate_h,
202
+ stride_proj_d, stride_proj_h,
203
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
204
+ # Constants
205
+ LN_EPS: tl.constexpr,
206
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
207
+ GROUP_SIZE_M: tl.constexpr,
208
+ ):
209
+ pid = tl.program_id(axis=0)
210
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
211
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
212
+
213
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
214
+ group_id = pid // num_pid_in_group
215
+ first_pid_m = group_id * GROUP_SIZE_M
216
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217
+ pid_m = first_pid_m + (pid % group_size_m)
218
+ pid_n = (pid % num_pid_in_group) // group_size_m
219
+
220
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
221
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
222
+ m_mask = offs_m < M
223
+
224
+ s1s1 = s1 * s1
225
+ b = offs_m // s1s1
226
+ r = (offs_m % s1s1) // s1
227
+ c = offs_m % s1
228
+
229
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
230
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
231
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
232
+
233
+ for k_offset in range(0, H, BLOCK_SIZE_K):
234
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
235
+ k_mask = offs_k < H
236
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
237
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
238
+ sum_x += tl.sum(in_chunk, axis=1)
239
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
240
+
241
+ mean = sum_x / H
242
+ var = (sum_x2 / H) - (mean * mean)
243
+ rstd = tl.math.rsqrt(var + LN_EPS)
244
+
245
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
246
+ for k_offset in range(0, H, BLOCK_SIZE_K):
247
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
248
+ k_mask = offs_k < H
249
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
250
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
251
+ a_norm = (a - mean[:, None]) * rstd[:, None]
252
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
253
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
254
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
255
+ proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
256
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
257
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
258
+ a_gated = a_norm * gate
259
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
260
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
261
+
262
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
263
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
264
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
265
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
266
+
267
+ def compiledtrimul_fused_interleaved_final(
268
+ x: torch.Tensor,
269
+ mask_mh: torch.Tensor,
270
+ norm_weight: torch.Tensor,
271
+ norm_bias: torch.Tensor,
272
+ W_4way: torch.Tensor,
273
+ W_og: torch.Tensor,
274
+ to_out_norm_weight: torch.Tensor,
275
+ to_out_norm_bias: torch.Tensor,
276
+ to_out_weight: torch.Tensor,
277
+ h: int,
278
+ ):
279
+ bs, s1, s2, d = x.shape
280
+ M, K, H = bs * s1 * s2, x.shape[-1], h
281
+ x_flat = x.view(M, K)
282
+
283
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
284
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
285
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
286
+
287
+ # --- Kernel 1: Fused LN + Dual Matmul ---
288
+ # The grid is launched for the larger 4*H problem
289
+ N_4way = 4 * H
290
+ # Hardcoded best config from logs: M64-N128-K64-GM8-HC32-W4-S2
291
+ config_k1 = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
292
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
293
+
294
+ fused_ln_dual_matmul_kernel[grid](
295
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
296
+ left_final, right_final_t, og_mh,
297
+ M, H, K, s1, s2,
298
+ x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
299
+ W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
300
+ left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
301
+ right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
302
+ mask_mh.stride(0), mask_mh.stride(1),
303
+ LN_EPS=1e-5, **config_k1, num_warps=4, num_stages=2
304
+ )
305
+
306
+ # --- Kernel 2: Batched Matrix Multiplication ---
307
+ bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
308
+ # Hardcoded best config from logs: M128-N128-K32-GM8-W8-S3
309
+ config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
310
+ grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
311
+
312
+ bmm_coalesced_kernel[grid_bmm](
313
+ left_final, right_final_t, bmm_out_tmp,
314
+ bs, s1, s2, H,
315
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
316
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
317
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
318
+ **config_k2, num_warps=8, num_stages=3
319
+ )
320
+
321
+ # --- Kernel 3: Fully Fused Final Stage ---
322
+ final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
323
+ # Hardcoded best config from logs: M32-N128-K32-GM8-W4-S3
324
+ config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
325
+ grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
326
+
327
+ fused_final_kernel[grid_final](
328
+ bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
329
+ M, H, d, s1,
330
+ bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
331
+ og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
332
+ final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
333
+ LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
334
+ )
335
+ return final_out
336
+
337
+ def pack_w_4way_efficient(weights):
338
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
339
+ WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
340
+ H, K = WL.shape
341
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
342
+ return ws.t().to(torch.float16)
343
+
344
+ def get_w_og(weights):
345
+ """ Gets the transposed [K, H] out_gate weight matrix. """
346
+ return weights['out_gate.weight'].t().to(torch.float16)
347
+
348
+ @torch.compile()
349
+ def compiledtrimul(
350
+ x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
351
+ w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
352
+ to_out_weight: torch.Tensor, h: int
353
+ ) -> torch.Tensor:
354
+ bs, s1, s2, d = x.shape
355
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
356
+ all_projections = torch.mm(x_norm, w_concat)
357
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
358
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
359
+ left = left * mask_expanded * torch.sigmoid(lg)
360
+ right = right * mask_expanded * torch.sigmoid(rg)
361
+ out_gate = torch.sigmoid(og)
362
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
363
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
364
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
365
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
366
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
367
+ gated = normed * out_gate
368
+ final_out_flat = gated @ to_out_weight.t()
369
+ return final_out_flat.view(bs, s1, s1, d)
370
+
371
+ def small_kernel_pt_path(data):
372
+ input_tensor, mask, weights, config = data
373
+ w_concat = torch.cat([
374
+ weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
375
+ weights['right_gate.weight'], weights['out_gate.weight']
376
+ ], dim=0).t().contiguous().to(torch.float16)
377
+ return compiledtrimul(
378
+ x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
379
+ norm_weight=weights['norm.weight'].to(torch.float32),
380
+ norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
381
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
382
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
383
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
384
+ h=config["hidden_dim"]
385
+ )
386
+
387
+ def kernel_b200(data):
388
+ input_tensor, mask, weights, config = data
389
+ bs, s1, s2, d = input_tensor.shape
390
+
391
+ if s1 < 800:
392
+ return small_kernel_pt_path(data)
393
+
394
+ H = config["hidden_dim"]
395
+ W_4way = pack_w_4way_efficient(weights)
396
+ W_og = get_w_og(weights)
397
+ M = bs * s1 * s2
398
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
399
+
400
+ return compiledtrimul_fused_interleaved_final(
401
+ x=input_tensor.to(torch.float32),
402
+ mask_mh=mask_mh,
403
+ norm_weight=weights['norm.weight'].to(torch.float32),
404
+ norm_bias=weights['norm.bias'].to(torch.float32),
405
+ W_4way=W_4way,
406
+ W_og=W_og,
407
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
408
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
409
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
410
+ h=H,
411
+ )
build/torch-xpu/triton_h100.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
+
9
+ @triton.autotune(
10
+ configs=[
11
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
12
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
13
+
14
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
15
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
16
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
17
+
18
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=4),
19
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
20
+
21
+ triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=5),
22
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=5),
23
+
24
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
25
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=4),
26
+ ],
27
+ key=['M', 'N', 'K'],
28
+ )
29
+ @triton.jit
30
+ def fused_ln_dual_matmul_kernel(
31
+ # Pointers (9)
32
+ X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
33
+ OutLeft_ptr, OutRight_ptr, OutOG_ptr,
34
+ # Metadata (5)
35
+ M, H, K, s1, s2,
36
+ # Strides (16)
37
+ stride_x_m, stride_x_k,
38
+ stride_w4_k, stride_w4_n,
39
+ stride_wog_k, stride_wog_n,
40
+ stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
41
+ stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
42
+ stride_og_m, stride_og_h,
43
+ stride_mask_m, stride_mask_h,
44
+ # Constexpr (from decorator and kwargs)
45
+ LN_EPS: tl.constexpr,
46
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
47
+ GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
48
+ ):
49
+ # --- PID Mapping: Based on the LARGER 4*H problem ---
50
+ pid = tl.program_id(axis=0)
51
+ N_4way = 4 * H
52
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53
+ num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
54
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
55
+ group_id = pid // num_pid_in_group
56
+ first_pid_m = group_id * GROUP_SIZE_M
57
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
+ pid_m = first_pid_m + (pid % group_size_m)
59
+ pid_n = (pid % num_pid_in_group) // group_size_m
60
+
61
+ # --- SHARED LayerNorm calculation (done only ONCE) ---
62
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63
+ m_mask = offs_m < M
64
+ x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
65
+
66
+ mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
67
+ for k_offset in range(0, K, BLOCK_SIZE_K):
68
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
69
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
70
+ k_mask = (k_offset + k_chunk_offs) < K
71
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
72
+ mean += tl.sum(x_chunk, axis=1)
73
+ mean /= K
74
+
75
+ var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
76
+ for k_offset in range(0, K, BLOCK_SIZE_K):
77
+ k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
78
+ x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
79
+ k_mask = (k_offset + k_chunk_offs) < K
80
+ x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
81
+ x_centered = x_chunk - mean[:, None]
82
+ var += tl.sum(x_centered * x_centered, axis=1)
83
+ var /= K
84
+ rstd = 1.0 / tl.sqrt(var + LN_EPS)
85
+
86
+ # --- Matmul Loop 1: For the 4-Way Projections ---
87
+ offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
89
+ w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
90
+ accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
91
+ accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92
+
93
+ offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
94
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
95
+ k_block_start = k * BLOCK_SIZE_K;
96
+ x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
97
+ w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
98
+ x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
99
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
100
+ x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
101
+ norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
102
+ norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
103
+ nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
104
+ nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
105
+ x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
106
+ x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
107
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
108
+ accumulator_4way += tl.dot(x_norm_tile, w_tile)
109
+
110
+ #Some threads should calclate out_gate
111
+ if pid_n * BLOCK_SIZE_N < H:
112
+ w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
113
+ w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
114
+ w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
115
+ w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
+ accumulator_og += tl.dot(x_norm_tile, w_tile)
117
+
118
+ if pid_n * BLOCK_SIZE_N < H:
119
+ og_out = tl.sigmoid(accumulator_og)
120
+ outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
121
+ og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
122
+ tl.store(outg_ptrs, og_out, mask=og_mask)
123
+
124
+ # --- Fusion Logic for 4-Way Part ---
125
+ acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
126
+ role_idx = tl.arange(0, 4)[None, None, :]
127
+ left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
128
+ left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
129
+ right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
130
+ right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
131
+
132
+ offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
133
+ mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
134
+ m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
135
+ mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
136
+
137
+ left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
138
+ right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
139
+
140
+ s1s2 = s1 * s2
141
+ offs_b = offs_m // s1s2
142
+ offs_s1 = (offs_m % s1s2) // s2
143
+ offs_s2 = offs_m % s2
144
+ offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
145
+ offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
146
+ offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
147
+ offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
148
+
149
+ outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
150
+ offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
151
+ outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
152
+ offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
153
+ tl.store(outl_ptrs, left_out, mask=m_mask_h)
154
+ tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
155
+
156
+ @triton.autotune(
157
+ configs=[
158
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
159
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
160
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
161
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
162
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
163
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
164
+ ],
165
+ key=['s1', 's2', 'H'],
166
+ )
167
+ @triton.jit
168
+ def bmm_coalesced_kernel(
169
+ # Pointers
170
+ Left_ptr, Right_ptr, Out_ptr,
171
+ # Dimensions
172
+ bs, s1, s2, H,
173
+ # Strides
174
+ stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
175
+ stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
176
+ stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
177
+ # Kernel parameters
178
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
179
+ GROUP_SIZE_M: tl.constexpr,
180
+ ):
181
+ # Grid and program IDs
182
+ pid = tl.program_id(axis=0)
183
+ num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
184
+ num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
185
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
186
+ group_id = pid // num_pid_in_group
187
+ first_pid_m = group_id * GROUP_SIZE_M
188
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
189
+ pid_m = first_pid_m + (pid % group_size_m)
190
+ pid_n = (pid % num_pid_in_group) // group_size_m
191
+
192
+ pid_bh = tl.program_id(axis=1)
193
+ pid_b = pid_bh // H
194
+ pid_h = pid_bh % H
195
+
196
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
199
+
200
+ left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
201
+ right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
202
+
203
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
204
+
205
+ for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
206
+ k_start = k * BLOCK_SIZE_K
207
+ a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
208
+ b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
209
+
210
+ a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
211
+ b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
212
+
213
+ a = tl.load(a_ptrs, mask=a_mask, other=0.0)
214
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
215
+
216
+ accumulator += tl.dot(a, b)
217
+
218
+ # --- Coalesced Write ---
219
+ # Write to a standard (bs, H, s1, s1) layout
220
+ out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
221
+ offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
222
+
223
+ c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
224
+ tl.store(out_ptrs, accumulator, mask=c_mask)
225
+
226
+ @torch.compile
227
+ def torch_pt2(left_final, right_final_t, bs, s1, s2, d, h, to_out_norm_weight, to_out_norm_bias, og_mh, to_out_weight):
228
+ bmm_out = torch.matmul(left_final, right_final_t)
229
+ out_einsum_flat = bmm_out.permute(0, 2, 3, 1).reshape(bs * s1 * s1, h)
230
+ # Apply layer norm and final gating
231
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
232
+ gated = normed * og_mh
233
+
234
+ # Final projection
235
+ final_out_flat = gated @ to_out_weight.t()
236
+ final_out = final_out_flat.view(bs, s1, s2, d)
237
+ return final_out
238
+
239
+ @triton.autotune(
240
+ configs=[
241
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
243
+ triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
244
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
245
+ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
+ triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
247
+ ],
248
+ key=['H', 'D'],
249
+ )
250
+ @triton.jit
251
+ def fused_final_kernel(
252
+ # Pointers
253
+ In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
254
+ # Metadata
255
+ M, H, D, s1, # M_gate = bs*s1*s2
256
+ # Strides
257
+ stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
258
+ stride_gate_m, stride_gate_h,
259
+ stride_proj_d, stride_proj_h,
260
+ stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
261
+ # Constants
262
+ LN_EPS: tl.constexpr,
263
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
264
+ GROUP_SIZE_M: tl.constexpr,
265
+ ):
266
+ # --- Grid and PID Setup for Matmul ---
267
+ pid = tl.program_id(axis=0)
268
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
269
+ num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
270
+
271
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
272
+ group_id = pid // num_pid_in_group
273
+ first_pid_m = group_id * GROUP_SIZE_M
274
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
275
+ pid_m = first_pid_m + (pid % group_size_m)
276
+ pid_n = (pid % num_pid_in_group) // group_size_m
277
+
278
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
279
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
280
+ m_mask = offs_m < M
281
+
282
+ # Decompose M back to (b, r, c) for reordering lookups
283
+ s1s1 = s1 * s1
284
+ b = offs_m // s1s1
285
+ r = (offs_m % s1s1) // s1
286
+ c = offs_m % s1
287
+
288
+ sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
289
+ sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
+ in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
291
+
292
+ for k_offset in range(0, H, BLOCK_SIZE_K):
293
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
294
+ k_mask = offs_k < H
295
+
296
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
297
+ in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
298
+
299
+ # Accumulate sum and sum of squares in one pass
300
+ sum_x += tl.sum(in_chunk, axis=1)
301
+ sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
302
+
303
+ # Finalize statistics
304
+ mean = sum_x / H
305
+ var = (sum_x2 / H) - (mean * mean)
306
+ rstd = tl.math.rsqrt(var + LN_EPS)
307
+
308
+ # --- Pass 3: Fused Gating and Matmul ---
309
+ acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
+ for k_offset in range(0, H, BLOCK_SIZE_K):
311
+ offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
312
+ k_mask = offs_k < H
313
+
314
+ in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
315
+ a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
316
+ a_norm = (a - mean[:, None]) * rstd[:, None]
317
+
318
+ norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
319
+ norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
320
+ a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
321
+
322
+ proj_ptrs = ProjW_ptr + \
323
+ offs_n[None, :] * stride_proj_d + \
324
+ offs_k[:, None] * stride_proj_h
325
+
326
+ gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
327
+ gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
328
+ a_gated = a_norm * gate
329
+
330
+ b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
331
+ acc += tl.dot(a_gated.to(b_w.dtype), b_w)
332
+
333
+ # --- Store Final Output ---
334
+ offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
+ out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
336
+ out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
337
+
338
+ tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
339
+
340
+ def compiledtrimul_fused_interleaved(
341
+ x: torch.Tensor,
342
+ mask_mh: torch.Tensor,
343
+ norm_weight: torch.Tensor,
344
+ norm_bias: torch.Tensor,
345
+ W_4way: torch.Tensor, # Use the new weight matrices
346
+ W_og: torch.Tensor,
347
+ to_out_norm_weight: torch.Tensor,
348
+ to_out_norm_bias: torch.Tensor,
349
+ to_out_weight: torch.Tensor,
350
+ h: int,
351
+ ):
352
+ bs, s1, s2, d = x.shape
353
+ M, K, H = bs * s1 * s2, x.shape[-1], h
354
+ x_flat = x.view(M, K)
355
+
356
+ left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
357
+ right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
358
+ og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
359
+
360
+ # The grid is launched for the larger 4*H problem
361
+ N_4way = 4 * H
362
+ grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
363
+ fused_ln_dual_matmul_kernel[grid](
364
+ # Pointers (9)
365
+ x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
366
+ left_final, right_final_t, og_mh,
367
+ # Metadata (5) - M, H, K, s1, s2
368
+ M, H, K, s1, s2,
369
+ # Strides (16)
370
+ x_flat.stride(0), x_flat.stride(1),
371
+ W_4way.stride(0), W_4way.stride(1),
372
+ W_og.stride(0), W_og.stride(1),
373
+ left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
374
+ right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
375
+ og_mh.stride(0), og_mh.stride(1),
376
+ mask_mh.stride(0), mask_mh.stride(1),
377
+ # Constexpr (1)
378
+ LN_EPS=1e-5
379
+ )
380
+ return torch_pt2(
381
+ left_final, right_final_t,
382
+ bs=bs,
383
+ s1=s1,
384
+ s2=s2,
385
+ d=d,
386
+ h=h,
387
+ to_out_norm_weight=to_out_norm_weight,
388
+ to_out_norm_bias=to_out_norm_bias,
389
+ og_mh=og_mh,
390
+ to_out_weight=to_out_weight
391
+ )
392
+
393
+ def pack_w_4way_efficient(weights):
394
+ """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
395
+ WL = weights['left_proj.weight']
396
+ WLG = weights['left_gate.weight']
397
+ WR = weights['right_proj.weight']
398
+ WRG = weights['right_gate.weight']
399
+ H, K = WL.shape
400
+ ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
401
+ ws = ws.contiguous().view(4 * H, K)
402
+ return ws.t().to(torch.float16)
403
+
404
+ def get_w_og(weights):
405
+ """ Gets the transposed [K, H] out_gate weight matrix. """
406
+ WOG = weights['out_gate.weight']
407
+ return WOG.t().to(torch.float16)
408
+
409
+
410
+ torch.backends.cuda.matmul.allow_tf32 = True
411
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
412
+
413
+ @torch.compile
414
+ def compiledtrimul(
415
+ x: torch.Tensor,
416
+ mask: torch.Tensor,
417
+ norm_weight: torch.Tensor,
418
+ norm_bias: torch.Tensor,
419
+ w_concat: torch.Tensor,
420
+ to_out_norm_weight: torch.Tensor,
421
+ to_out_norm_bias: torch.Tensor,
422
+ to_out_weight: torch.Tensor,
423
+ h: int
424
+ ) -> torch.Tensor:
425
+ """
426
+ A barebones, compiled PyTorch function for the TriMul logic.
427
+ """
428
+ bs, s1, s2, d = x.shape
429
+
430
+ # Initial LayerNorm
431
+ x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
432
+ # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
433
+ all_projections = torch.mm(x_norm, w_concat)
434
+
435
+ # Split back into individual projections
436
+ left, right, lg, rg, og = all_projections.chunk(5, dim=1)
437
+
438
+ # Apply mask and gates
439
+ mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
440
+ left = left * mask_expanded * torch.sigmoid(lg)
441
+ right = right * mask_expanded * torch.sigmoid(rg)
442
+ out_gate = torch.sigmoid(og)
443
+
444
+ # Reshape for einsum
445
+ left = left.view(bs, s1, s2, h).permute(0,3,1,2)
446
+ right = right.view(bs, s1, s2, h).permute(0,3,1,2)
447
+ out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
448
+ out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
449
+
450
+ # Apply layer norm and final gating
451
+ normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
452
+ gated = normed * out_gate
453
+
454
+ # Final projection
455
+ final_out_flat = gated @ to_out_weight.t()
456
+ final_out = final_out_flat.view(bs, s1, s2, d)
457
+
458
+ return final_out
459
+
460
+ def small_kernel_pt_path(data):
461
+ input_tensor, mask, weights, config = data
462
+ w_concat = torch.cat([
463
+ weights['left_proj.weight'],
464
+ weights['right_proj.weight'],
465
+ weights['left_gate.weight'],
466
+ weights['right_gate.weight'],
467
+ weights['out_gate.weight']
468
+ ], dim=0).t().contiguous().to(torch.float16)
469
+ # Call the compiled function with prepared weights
470
+ output = compiledtrimul(
471
+ x=input_tensor.to(torch.float32),
472
+ mask=mask.unsqueeze(-1),
473
+ norm_weight=weights['norm.weight'].to(torch.float32),
474
+ norm_bias=weights['norm.bias'].to(torch.float32),
475
+ w_concat=w_concat,
476
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float32),
477
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float32),
478
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
479
+ h=config["hidden_dim"]
480
+ )
481
+ return output
482
+
483
+ def kernel_h100(data):
484
+ input_tensor, mask, weights, config = data
485
+ bs, s1, s2, d = input_tensor.shape
486
+
487
+ if s1 <= 512:
488
+ return small_kernel_pt_path(data)
489
+
490
+ H = config["hidden_dim"]
491
+
492
+ W_4way = pack_w_4way_efficient(weights)
493
+ W_og = get_w_og(weights)
494
+
495
+ M = bs * s1 * s2
496
+ mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
497
+
498
+ return compiledtrimul_fused_interleaved(
499
+ x=input_tensor.to(torch.float32),
500
+ mask_mh=mask_mh,
501
+ norm_weight=weights['norm.weight'].to(torch.float32),
502
+ norm_bias=weights['norm.bias'].to(torch.float32),
503
+ W_4way=W_4way, # Pass the new 4-way matrix
504
+ W_og=W_og, # Pass the new out_gate matrix
505
+ to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
506
+ to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
507
+ to_out_weight=weights['to_out.weight'].to(torch.float16),
508
+ h=H,
509
+ )