danieldk HF Staff commited on
Commit
d2d3257
·
1 Parent(s): faed52f

Revert "Build uploaded using `kernels`."

Browse files

This reverts commit faed52fbc37782e891f85512bf2c486b4e6ff17a.

build/torch-cuda/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from .triton_a100 import kernel_a100
2
- from .triton_h100 import kernel_h100
3
- from .triton_b200 import kernel_b200
4
- from .trimul_mi300 import kernel_mi300
5
- from .trimul_global import kernel_global
6
-
7
- __all__ = ["kernel_a100", "kernel_h100", "kernel_b200", "kernel_mi300", "kernel_global"]
 
 
 
 
 
 
 
 
build/torch-cuda/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._trimul_gpumode_8e6e60d
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_trimul_gpumode_8e6e60d::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-cuda/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch-cuda/task.py DELETED
@@ -1,20 +0,0 @@
1
- """
2
- Type definitions for TriMul task.
3
-
4
- Input: Tuple of (input_tensor, mask, weights, config)
5
- - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
6
- - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
7
- - weights: Dictionary containing model weights
8
- - config: Dictionary containing model configuration parameters
9
-
10
- Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
11
- """
12
-
13
- import torch
14
- from typing import Tuple, Dict, Any
15
-
16
- # Input type: (input_tensor, mask, weights, config)
17
- input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
18
-
19
- # Output type: output tensor
20
- output_t = torch.Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/trimul_global.py DELETED
@@ -1,971 +0,0 @@
1
- # from utils import make_match_reference, DisableCuDNNTF32
2
- from .task import input_t, output_t
3
-
4
- import torch
5
- from torch import nn, einsum
6
- import math
7
- import os
8
- import requests
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
14
- # in PyTorch 1.12 and later.
15
- torch.backends.cuda.matmul.allow_tf32 = True
16
-
17
- # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
18
- torch.backends.cudnn.allow_tf32 = True
19
-
20
- # Set allocator for TMA descriptors (required for on-device TMA)
21
- def alloc_fn(size: int, alignment: int, stream=None):
22
- return torch.empty(size, device="cuda", dtype=torch.int8)
23
-
24
- triton.set_allocator(alloc_fn)
25
-
26
- # os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
27
- # os.environ['MLIR_ENABLE_DIAGNOSTICS'] = 'warnings,remarks'
28
-
29
- # Reference code in PyTorch
30
- class TriMul(nn.Module):
31
- # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
32
- def __init__(
33
- self,
34
- dim: int,
35
- hidden_dim: int,
36
- ):
37
- super().__init__()
38
-
39
- self.norm = nn.LayerNorm(dim)
40
-
41
- self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
42
- self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
43
-
44
- self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
45
- self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
46
- self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
47
-
48
- self.to_out_norm = nn.LayerNorm(hidden_dim)
49
- self.to_out = nn.Linear(hidden_dim, dim, bias=False)
50
-
51
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
52
- """
53
- x: [bs, seq_len, seq_len, dim]
54
- mask: [bs, seq_len, seq_len]
55
-
56
- Returns:
57
- output: [bs, seq_len, seq_len, dim]
58
- """
59
- batch_size, seq_len, _, dim = x.shape
60
-
61
- x = self.norm(x)
62
-
63
- left = self.left_proj(x)
64
- right = self.right_proj(x)
65
-
66
- mask = mask.unsqueeze(-1)
67
- left = left * mask
68
- right = right * mask
69
-
70
- left_gate = self.left_gate(x).sigmoid()
71
- right_gate = self.right_gate(x).sigmoid()
72
- out_gate = self.out_gate(x).sigmoid()
73
-
74
- left = left * left_gate
75
- right = right * right_gate
76
-
77
- out = einsum('... i k d, ... j k d -> ... i j d', left, right)
78
- # This einsum is the same as the following:
79
- # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
80
-
81
- # # Compute using nested loops
82
- # for b in range(batch_size):
83
- # for i in range(seq_len):
84
- # for j in range(seq_len):
85
- # # Compute each output element
86
- # for k in range(seq_len):
87
- # out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
88
-
89
- out = self.to_out_norm(out)
90
- out = out * out_gate
91
- return self.to_out(out)
92
-
93
- @triton.jit
94
- def triton_sigmoid(x):
95
- """
96
- Compute sigmoid function: 1 / (1 + exp(-x))
97
- """
98
- return 1.0 / (1.0 + tl.exp(-x))
99
-
100
- def two_mm_kernel_configs_wrapper():
101
- if torch.cuda.get_device_capability() == (12, 0):
102
- def two_mm_kernel_configs():
103
- configs = []
104
- for BLOCK_M in [16, 32]:
105
- for BLOCK_N in [16, 32, 64]:
106
- for BLOCK_K in [16, 32, 64]:
107
- for num_stages in [2, 3]:
108
- configs.append(triton.Config({
109
- 'BLOCK_M': BLOCK_M,
110
- 'BLOCK_N': BLOCK_N,
111
- 'BLOCK_K': BLOCK_K,
112
- 'GROUP_SIZE_M': 8
113
- }, num_stages=num_stages, num_warps=8))
114
- return configs
115
-
116
- elif torch.cuda.get_device_capability()[0] == 9:
117
- def get_optimal_two_mm_config_h100(B, seq_len, dim):
118
- configs = {
119
- (1, 128, 128): (128, 64, 128, 2, 8),
120
- (1, 128, 256): (128, 64, 128, 2, 8),
121
- (1, 128, 384): (128, 64, 64, 3, 8),
122
- (1, 128, 512): (128, 64, 64, 3, 8),
123
- (1, 128, 768): (128, 64, 64, 3, 8),
124
- (1, 128, 1024): (128, 64, 64, 3, 8),
125
- (1, 256, 128): (128, 64, 128, 2, 8),
126
- (1, 256, 256): (128, 64, 128, 2, 8),
127
- (1, 256, 384): (128, 64, 64, 3, 8),
128
- (1, 256, 512): (128, 64, 64, 3, 8),
129
- (1, 256, 768): (128, 64, 64, 3, 8),
130
- (1, 256, 1024): (128, 64, 64, 3, 8),
131
- (1, 512, 128): (128, 64, 128, 2, 8),
132
- (1, 512, 256): (128, 64, 128, 2, 8),
133
- (1, 512, 384): (128, 64, 128, 2, 8),
134
- (1, 512, 512): (128, 64, 128, 2, 8),
135
- (1, 512, 768): (128, 64, 64, 3, 8),
136
- (1, 512, 1024): (128, 64, 64, 3, 8),
137
- (1, 1024, 128): (128, 64, 128, 2, 8),
138
- (1, 1024, 256): (128, 64, 64, 2, 8),
139
- (1, 1024, 384): (128, 64, 128, 2, 8),
140
- (1, 1024, 512): (128, 64, 128, 2, 8),
141
- (1, 1024, 768): (128, 64, 128, 2, 8),
142
- (1, 1024, 1024): (128, 64, 128, 2, 8),
143
- (2, 128, 128): (128, 64, 128, 2, 8),
144
- (2, 128, 256): (128, 64, 128, 2, 8),
145
- (2, 128, 384): (128, 64, 64, 3, 8),
146
- (2, 128, 512): (128, 64, 64, 3, 8),
147
- (2, 128, 768): (128, 64, 64, 3, 8),
148
- (2, 128, 1024): (128, 64, 64, 3, 8),
149
- (2, 256, 128): (128, 64, 128, 2, 8),
150
- (2, 256, 256): (128, 64, 128, 2, 8),
151
- (2, 256, 384): (128, 64, 128, 2, 8),
152
- (2, 256, 512): (128, 64, 128, 2, 8),
153
- (2, 256, 768): (128, 64, 64, 3, 8),
154
- (2, 256, 1024): (128, 64, 64, 3, 8),
155
- (2, 512, 128): (128, 64, 128, 2, 8),
156
- (2, 512, 256): (128, 64, 128, 2, 8),
157
- (2, 512, 384): (128, 64, 128, 2, 8),
158
- (2, 512, 512): (128, 64, 128, 2, 8),
159
- (2, 512, 768): (128, 64, 128, 2, 8),
160
- (2, 512, 1024): (128, 64, 128, 2, 8),
161
- (2, 1024, 128): (128, 64, 128, 2, 8),
162
- (2, 1024, 256): (128, 64, 128, 2, 8),
163
- (2, 1024, 384): (128, 64, 128, 2, 8),
164
- (2, 1024, 512): (128, 64, 128, 2, 8),
165
- (2, 1024, 768): (128, 64, 128, 2, 8),
166
- (2, 1024, 1024): (128, 64, 128, 2, 8),
167
- }
168
- return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
169
-
170
- def two_mm_kernel_configs():
171
- # This function is kept for compatibility but will be overridden for H100
172
- return [
173
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
174
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
175
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
176
- ]
177
-
178
- elif torch.cuda.get_device_capability()[0] == 10 and False:
179
- def get_optimal_two_mm_config(B, seq_len, dim):
180
- configs = {
181
- (1, 128, 128): (64, 128, 64, 2, 8),
182
- (1, 128, 256): (128, 64, 128, 2, 8),
183
- (1, 128, 384): (128, 64, 128, 2, 8),
184
- (1, 128, 512): (128, 64, 128, 2, 8),
185
- (1, 128, 768): (128, 64, 64, 3, 8),
186
- (1, 128, 1024): (128, 64, 64, 3, 8),
187
- (1, 256, 128): (128, 64, 128, 2, 8),
188
- (1, 256, 256): (128, 64, 128, 2, 8),
189
- (1, 256, 384): (128, 64, 128, 2, 8),
190
- (1, 256, 512): (128, 64, 64, 3, 8),
191
- (1, 256, 768): (128, 64, 64, 3, 8),
192
- (1, 256, 1024): (128, 64, 64, 3, 8),
193
- (1, 512, 128): (128, 64, 128, 2, 8),
194
- (1, 512, 256): (128, 64, 128, 2, 8),
195
- (1, 512, 384): (128, 64, 128, 2, 8),
196
- (1, 512, 512): (128, 64, 128, 2, 8),
197
- (1, 512, 768): (128, 64, 64, 3, 8),
198
- (1, 512, 1024): (128, 64, 64, 3, 8),
199
- (1, 1024, 128): (128, 64, 128, 2, 8),
200
- (1, 1024, 256): (128, 64, 128, 2, 8),
201
- (1, 1024, 384): (128, 64, 128, 2, 8),
202
- (1, 1024, 512): (128, 64, 128, 2, 8),
203
- (1, 1024, 768): (128, 64, 64, 3, 8),
204
- (1, 1024, 1024): (128, 64, 64, 3, 8),
205
- (2, 128, 128): (128, 64, 128, 2, 8),
206
- (2, 128, 256): (128, 64, 128, 2, 8),
207
- (2, 128, 384): (128, 64, 128, 2, 8),
208
- (2, 128, 512): (128, 64, 64, 3, 8),
209
- (2, 128, 768): (128, 64, 64, 3, 8),
210
- (2, 128, 1024): (128, 64, 64, 3, 8),
211
- (2, 256, 128): (128, 64, 128, 2, 8),
212
- (2, 256, 256): (128, 64, 128, 2, 8),
213
- (2, 256, 384): (128, 64, 128, 2, 8),
214
- (2, 256, 512): (128, 64, 64, 3, 8),
215
- (2, 256, 768): (128, 64, 64, 3, 8),
216
- (2, 256, 1024): (128, 64, 64, 3, 8),
217
- (2, 512, 128): (128, 64, 128, 2, 8),
218
- (2, 512, 256): (128, 64, 128, 2, 8),
219
- (2, 512, 384): (128, 64, 128, 2, 8),
220
- (2, 512, 512): (128, 64, 128, 2, 8),
221
- (2, 512, 768): (128, 64, 64, 3, 8),
222
- (2, 512, 1024): (128, 64, 64, 3, 8),
223
- (2, 1024, 128): (128, 64, 128, 2, 8),
224
- (2, 1024, 256): (128, 64, 128, 2, 8),
225
- (2, 1024, 384): (128, 64, 128, 2, 8),
226
- (2, 1024, 512): (128, 64, 128, 2, 8),
227
- (2, 1024, 768): (128, 64, 64, 3, 8),
228
- (2, 1024, 1024): (128, 64, 64, 3, 8),
229
- }
230
- return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
231
-
232
- def two_mm_kernel_configs():
233
- # This function is kept for compatibility but will be overridden
234
- return [
235
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
236
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
237
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
238
- ]
239
- elif torch.cuda.get_device_capability()[0] == 8:
240
- # A100
241
- def two_mm_kernel_configs():
242
- configs = []
243
- for BLOCK_M in [64]:
244
- for BLOCK_N in [64, 128]:
245
- for BLOCK_K in [16]:
246
- for num_stages in [3, 4]:
247
- for num_warps in [4, 8]:
248
- configs.append(triton.Config({
249
- 'BLOCK_M': BLOCK_M,
250
- 'BLOCK_N': BLOCK_N,
251
- 'BLOCK_K': BLOCK_K,
252
- 'GROUP_SIZE_M': 8
253
- }, num_stages=num_stages, num_warps=num_warps))
254
- return configs
255
- else:
256
- def two_mm_kernel_configs():
257
- configs = []
258
- for BLOCK_M in [64, 128]:
259
- for BLOCK_N in [64, 128]:
260
- for BLOCK_K in [64, 128]:
261
- for num_stages in [2, 3]:
262
- configs.append(triton.Config({
263
- 'BLOCK_M': BLOCK_M,
264
- 'BLOCK_N': BLOCK_N,
265
- 'BLOCK_K': BLOCK_K,
266
- 'GROUP_SIZE_M': 8
267
- }, num_stages=num_stages, num_warps=8))
268
- return configs
269
-
270
- return two_mm_kernel_configs
271
-
272
- def two_mm_kernel_wrapper():
273
- if torch.cuda.get_device_capability()[0] == 8:
274
- @triton.jit
275
- def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
276
- # Persistent kernel using standard tl.load operations
277
- start_pid = tl.program_id(axis=0)
278
- num_pid_m = tl.cdiv(M, BLOCK_M)
279
- num_pid_n = tl.cdiv(N, BLOCK_N)
280
- k_tiles = tl.cdiv(K, BLOCK_K)
281
- num_tiles = num_pid_m * num_pid_n
282
-
283
- # tile_id_c is used in the epilogue to break the dependency between
284
- # the prologue and the epilogue
285
- tile_id_c = start_pid - NUM_SMS
286
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
287
-
288
- # Persistent loop over tiles
289
- for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
290
- # Calculate PID for this tile using improved swizzling
291
- group_id = tile_id // num_pid_in_group
292
- first_pid_m = group_id * GROUP_SIZE_M
293
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
294
- pid_m = first_pid_m + (tile_id % group_size_m)
295
- pid_n = (tile_id % num_pid_in_group) // group_size_m
296
-
297
- # Calculate block offsets
298
- offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
299
- offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
300
- offs_k = tl.arange(0, BLOCK_K)
301
-
302
- # Initialize accumulators for all outputs
303
- accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
304
- accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
305
- accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
306
- accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
307
- accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
308
-
309
- # Main computation loop over K dimension
310
- for ki in range(k_tiles):
311
- k_start = ki * BLOCK_K
312
- k_offsets = k_start + offs_k
313
-
314
- # Create pointers for A matrix (2D flattened view)
315
- a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3
316
- a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K)
317
-
318
- # Create pointers for B matrices [N, K] layout
319
- b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
320
- b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
321
- b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
322
- b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
323
- b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
324
- b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K)
325
-
326
- # Load blocks from A and all weight matrices using standard tl.load
327
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
328
- b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0)
329
- b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0)
330
- b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0)
331
- b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0)
332
- b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0)
333
-
334
- # Perform matrix multiplications using TF32
335
- accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
336
- accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
337
- accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
338
- accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
339
- accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
340
-
341
- # Store results using separate tile_id_c for epilogue
342
- tile_id_c += NUM_SMS
343
- group_id = tile_id_c // num_pid_in_group
344
- first_pid_m = group_id * GROUP_SIZE_M
345
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346
- pid_m = first_pid_m + (tile_id_c % group_size_m)
347
- pid_n = (tile_id_c % num_pid_in_group) // group_size_m
348
-
349
- # Calculate output offsets and pointers
350
- offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
351
- offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
352
-
353
- # Create masks for bounds checking
354
- d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
355
-
356
- # Calculate pointer addresses using 4D strides
357
- stride_cm = stride_c2 # Stride to next element in flattened M dimension
358
- stride_cn = stride_c3 # N is the innermost dimension
359
-
360
- # For D tensor: use separate D strides
361
- stride_dm = stride_d2 # Stride to next element in flattened M dimension
362
- stride_dn = stride_d3 # N is the innermost dimension
363
-
364
- off_c_batch = offs_cm // (seq_len * seq_len)
365
- off_c_sl1 = (offs_cm // seq_len) % seq_len
366
- off_c_sl2 = offs_cm % seq_len
367
- off_c_dim = offs_cn
368
-
369
- c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
370
- c_mask = d_mask
371
-
372
- c1_ptrs = c1_ptr + c_offsets
373
- c2_ptrs = c2_ptr + c_offsets
374
- d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
375
-
376
- mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
377
-
378
- # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
379
- mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
380
- # Apply masking only to left_proj and right_proj results (C1, C2)
381
- accumulator1 = tl.where(mask_2d, accumulator1, 0)
382
- accumulator2 = tl.where(mask_2d, accumulator2, 0)
383
-
384
- # Apply sigmoid to gate values
385
- left_gate_sigmoid = triton_sigmoid(accumulator3)
386
- right_gate_sigmoid = triton_sigmoid(accumulator4)
387
- accumulator_d = triton_sigmoid(accumulator_d)
388
-
389
- # Apply elementwise multiplication with gated values
390
- # C1 = left * left_gate, C2 = right * right_gate
391
- accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
392
- accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
393
-
394
- # Convert to appropriate output dtype and store with normal tl.store
395
- c1 = accumulator1.to(c1_ptr.dtype.element_ty)
396
- c2 = accumulator2.to(c2_ptr.dtype.element_ty)
397
- d = accumulator_d.to(d_ptr.dtype.element_ty)
398
-
399
- tl.store(c1_ptrs, c1, mask=c_mask)
400
- tl.store(c2_ptrs, c2, mask=c_mask)
401
- tl.store(d_ptrs, d, mask=d_mask)
402
- else:
403
- @triton.jit
404
- def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
405
- # Persistent kernel using on-device TMA descriptors
406
- start_pid = tl.program_id(axis=0)
407
- num_pid_m = tl.cdiv(M, BLOCK_M)
408
- num_pid_n = tl.cdiv(N, BLOCK_N)
409
- k_tiles = tl.cdiv(K, BLOCK_K)
410
- num_tiles = num_pid_m * num_pid_n
411
-
412
- # Create on-device TMA descriptors
413
- a_desc = tl._experimental_make_tensor_descriptor(
414
- a_ptr,
415
- shape=[M, K],
416
- strides=[stride_a2, stride_a3],
417
- block_shape=[BLOCK_M, BLOCK_K],
418
- )
419
- b1_desc = tl._experimental_make_tensor_descriptor(
420
- b1_ptr,
421
- shape=[N, K],
422
- strides=[stride_bn, stride_bk],
423
- block_shape=[BLOCK_N, BLOCK_K],
424
- )
425
- b2_desc = tl._experimental_make_tensor_descriptor(
426
- b2_ptr,
427
- shape=[N, K],
428
- strides=[stride_bn, stride_bk],
429
- block_shape=[BLOCK_N, BLOCK_K],
430
- )
431
- b3_desc = tl._experimental_make_tensor_descriptor(
432
- b3_ptr,
433
- shape=[N, K],
434
- strides=[stride_bn, stride_bk],
435
- block_shape=[BLOCK_N, BLOCK_K],
436
- )
437
- b4_desc = tl._experimental_make_tensor_descriptor(
438
- b4_ptr,
439
- shape=[N, K],
440
- strides=[stride_bn, stride_bk],
441
- block_shape=[BLOCK_N, BLOCK_K],
442
- )
443
- b5_desc = tl._experimental_make_tensor_descriptor(
444
- b5_ptr,
445
- shape=[N, K],
446
- strides=[stride_bn, stride_bk],
447
- block_shape=[BLOCK_N, BLOCK_K],
448
- )
449
-
450
- # tile_id_c is used in the epilogue to break the dependency between
451
- # the prologue and the epilogue
452
- tile_id_c = start_pid - NUM_SMS
453
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
454
-
455
- # Persistent loop over tiles
456
- for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
457
- # Calculate PID for this tile using improved swizzling
458
- group_id = tile_id // num_pid_in_group
459
- first_pid_m = group_id * GROUP_SIZE_M
460
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
461
- pid_m = first_pid_m + (tile_id % group_size_m)
462
- pid_n = (tile_id % num_pid_in_group) // group_size_m
463
-
464
- # Calculate block offsets
465
- offs_am = pid_m * BLOCK_M
466
- offs_bn = pid_n * BLOCK_N
467
-
468
- # Initialize accumulators for all outputs
469
- accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
470
- accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
471
- accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
472
- accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
473
- accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
474
-
475
- # Main computation loop over K dimension
476
- for ki in range(k_tiles):
477
- offs_k = ki * BLOCK_K
478
- # Load blocks from A and all weight matrices using on-device TMA
479
- a = a_desc.load([offs_am, offs_k])
480
- b1 = b1_desc.load([offs_bn, offs_k])
481
- b2 = b2_desc.load([offs_bn, offs_k])
482
- b3 = b3_desc.load([offs_bn, offs_k])
483
- b4 = b4_desc.load([offs_bn, offs_k])
484
- b5 = b5_desc.load([offs_bn, offs_k])
485
-
486
- # Perform matrix multiplications using TF32
487
- accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
488
- accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
489
- accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
490
- accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
491
- accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
492
-
493
- # Store results using separate tile_id_c for epilogue
494
- tile_id_c += NUM_SMS
495
- group_id = tile_id_c // num_pid_in_group
496
- first_pid_m = group_id * GROUP_SIZE_M
497
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
498
- pid_m = first_pid_m + (tile_id_c % group_size_m)
499
- pid_n = (tile_id_c % num_pid_in_group) // group_size_m
500
-
501
- # Calculate output offsets and pointers
502
- offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
503
- offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
504
-
505
- # Create masks for bounds checking
506
- d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
507
-
508
- # Calculate pointer addresses using 4D strides
509
- # For C tensors: compute effective 2D strides from 4D strides
510
- # Output tensor is [B, I, J, N], flattened to [M, N] where M = B*I*J
511
- stride_cm = stride_c2 # Stride to next element in flattened M dimension
512
- stride_cn = stride_c3 # N is the innermost dimension
513
-
514
- # For D tensor: use separate D strides
515
- stride_dm = stride_d2 # Stride to next element in flattened M dimension
516
- stride_dn = stride_d3 # N is the innermost dimension
517
-
518
- off_c_batch = offs_cm // (seq_len * seq_len)
519
- off_c_sl1 = (offs_cm // seq_len) % seq_len
520
- off_c_sl2 = offs_cm % seq_len
521
- off_c_dim = offs_cn
522
-
523
- # TODO update the mask_c so we don't IMA
524
- c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
525
- # c_offsets = offs_cm[:, None] * stride_c2 + offs_cn[None, :] * stride_c3
526
- c_mask = d_mask
527
-
528
- c1_ptrs = c1_ptr + c_offsets
529
- c2_ptrs = c2_ptr + c_offsets
530
- # c1_ptrs = c1_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
531
- # c2_ptrs = c2_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
532
- d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
533
-
534
- mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
535
-
536
- # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
537
- mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
538
- # Apply masking only to left_proj and right_proj results (C1, C2)
539
- accumulator1 = tl.where(mask_2d, accumulator1, 0)
540
- accumulator2 = tl.where(mask_2d, accumulator2, 0)
541
-
542
- # Apply sigmoid to gate values
543
- left_gate_sigmoid = triton_sigmoid(accumulator3)
544
- right_gate_sigmoid = triton_sigmoid(accumulator4)
545
- accumulator_d = triton_sigmoid(accumulator_d)
546
-
547
- # Apply elementwise multiplication with gated values
548
- # C1 = left * left_gate, C2 = right * right_gate
549
- accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
550
- accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
551
-
552
- # Convert to appropriate output dtype and store with normal tl.store
553
- c1 = accumulator1.to(c1_ptr.dtype.element_ty)
554
- c2 = accumulator2.to(c2_ptr.dtype.element_ty)
555
- d = accumulator_d.to(d_ptr.dtype.element_ty)
556
-
557
- tl.store(c1_ptrs, c1, mask=c_mask)
558
- tl.store(c2_ptrs, c2, mask=c_mask)
559
- tl.store(d_ptrs, d, mask=d_mask)
560
-
561
-
562
- if torch.cuda.get_device_capability()[0] not in [9, 10.2]:
563
- two_mm_kernel = triton.autotune(
564
- (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"]
565
- )(two_mm_kernel)
566
-
567
- return two_mm_kernel
568
-
569
-
570
- def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask):
571
- """
572
- Persistent matrix multiplication for all weight matrices using on-device TMA descriptors.
573
-
574
- Args:
575
- A: [..., K] tensor (arbitrary leading dimensions)
576
- left_proj: [N, K] matrix (will be transposed)
577
- right_proj: [N, K] matrix (will be transposed)
578
- left_gate: [N, K] left gate weight matrix
579
- right_gate: [N, K] right gate weight matrix
580
- out_gate: [N, K] output gate weight matrix
581
- mask: mask tensor
582
-
583
- Returns:
584
- (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A
585
- C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked)
586
- C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked)
587
- D = sigmoid(A @ out_gate.T) (unmasked)
588
- """
589
- # Check constraints
590
- assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions"
591
- assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes"
592
-
593
- # Assert that all weight matrices have the same strides (same [N, K] shape)
594
- assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \
595
- "All weight matrices must have identical strides"
596
-
597
- # Get dimensions
598
- original_shape = A.shape[:-1] # All dimensions except the last
599
- K = A.shape[-1]
600
- N = left_proj.shape[0]
601
- B, seq_len, _, _ = A.shape
602
- dtype = A.dtype
603
-
604
- # Flatten A to 2D for kernel processing
605
- A_2d = A.view(-1, K) # [M, K] where M is product of all leading dims
606
- M = A_2d.shape[0]
607
-
608
- # Get number of streaming multiprocessors
609
- NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
610
-
611
- # Launch persistent kernel with limited number of blocks
612
- grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
613
-
614
- # Get original 4D strides for A and output tensors
615
- A_strides = A.stride() # (stride_0, stride_1, stride_2, stride_3)
616
-
617
- # Create output tensors with proper 4D shape to get correct strides
618
- output_shape = original_shape + (N,)
619
- # C1 = torch.empty(output_shape, device=A.device, dtype=dtype)
620
- # C2 = torch.empty(output_shape, device=A.device, dtype=dtype)
621
- C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
622
- C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
623
- D = torch.empty(output_shape, device=A.device, dtype=torch.float16)
624
-
625
- C_strides = C1.stride() # (stride_0, stride_1, stride_2, stride_3)
626
- D_strides = D.stride() # (stride_0, stride_1, stride_2, stride_3)
627
-
628
- # Use optimal configuration for B200/H100 or fallback to autotuning for other GPUs
629
- if torch.cuda.get_device_capability()[0] == 10:
630
- # Get optimal configuration for B200
631
- BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
632
- grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
633
-
634
- two_mm_kernel_wrapper()[(grid_size,)](
635
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
636
- C1, C2, D, mask,
637
- M, N, K,
638
- *A_strides, # 4D strides for A
639
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
640
- *C_strides, # 4D strides for C
641
- seq_len,
642
- *D_strides, # 4D strides for D
643
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
644
- num_stages=num_stages, num_warps=num_warps
645
- )
646
- elif torch.cuda.get_device_capability()[0] == 9:
647
- # Get optimal configuration for H100
648
- BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
649
- grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
650
-
651
- two_mm_kernel_wrapper()[(grid_size,)](
652
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
653
- C1, C2, D, mask,
654
- M, N, K,
655
- *A_strides, # 4D strides for A
656
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
657
- *C_strides, # 4D strides for C
658
- seq_len,
659
- *D_strides, # 4D strides for D
660
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
661
- num_stages=num_stages, num_warps=num_warps
662
- )
663
- else:
664
- # Use autotuning for other GPUs
665
- two_mm_kernel_wrapper()[grid](
666
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
667
- C1, C2, D, mask,
668
- M, N, K,
669
- *A_strides, # 4D strides for A
670
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
671
- *C_strides, # 4D strides for C
672
- seq_len,
673
- *D_strides, # 4D strides for D
674
- NUM_SMS=NUM_SMS
675
- )
676
-
677
- return C1, C2, D
678
-
679
-
680
- def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand):
681
- ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype))
682
- out = ln * mul_operand
683
- return out
684
-
685
- '''
686
- @triton.autotune(
687
- [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)],
688
- key=["R", "C"]
689
- )
690
- '''
691
- @triton.jit
692
- def layernorm_kernel_first(
693
- X,
694
- Y,
695
- Weight,
696
- Bias,
697
- R,
698
- C, # aka "dim"
699
- eps,
700
- ROW_BLOCK_SIZE: tl.constexpr,
701
- BLOCK_SIZE: tl.constexpr,
702
- ):
703
- row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
704
- cols = tl.arange(0, BLOCK_SIZE)
705
-
706
- mask_row = row < R
707
- mask_col = cols < C
708
-
709
- # Simple indexing for contiguous data
710
- x = tl.load(
711
- X + row[:, None] * C + cols[None, :],
712
- mask=mask_row[:, None] & mask_col[None, :],
713
- other=0.0
714
- ).to(tl.float32)
715
-
716
- weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
717
- bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
718
-
719
- mean = tl.sum(x, axis=1) / C
720
- diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
721
- var = tl.sum(diff * diff, axis=1) / C
722
- rstd = 1 / tl.sqrt(var + eps)
723
-
724
- y_hat = (x - mean[:, None]) * rstd[:, None]
725
- y = y_hat * weight[None, :] + bias[None, :]
726
-
727
- tl.store(
728
- Y + row[:, None] * C + cols[None, :],
729
- y,
730
- mask=mask_row[:, None] & mask_col[None, :]
731
- )
732
-
733
-
734
- def get_optimal_config_ln(dim):
735
- config = None
736
- if torch.cuda.get_device_capability()[0] == 9:
737
- if (dim <= 256):
738
- config = (16, 1)
739
- elif dim <= 512:
740
- config = (16, 2)
741
- elif dim <= 1024:
742
- config = (16, 4)
743
-
744
- if not config:
745
- config = (16, 4)
746
- return config
747
-
748
-
749
- def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None):
750
- B, seq_len, seq_len2, dim = x.shape
751
- assert(seq_len == seq_len2)
752
-
753
- R = B * seq_len * seq_len
754
- C = dim
755
-
756
- out = torch.empty_like(x, dtype=torch.float16)
757
-
758
- if not num_warps or not ROW_BLOCK_SIZE:
759
- ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim)
760
-
761
- BLOCK_SIZE = triton.next_power_of_2(C)
762
- assert(BLOCK_SIZE <= 1024)
763
-
764
- def grid(meta):
765
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
766
-
767
- layernorm_kernel_first[grid](
768
- x, out, weight, bias,
769
- R, C, eps,
770
- ROW_BLOCK_SIZE=ROW_BLOCK_SIZE,
771
- BLOCK_SIZE=BLOCK_SIZE,
772
- num_warps=num_warps,
773
- num_stages=3
774
- )
775
-
776
- return out
777
-
778
- '''
779
- def triton_layernorm_first(x, weight, bias, eps=1e-5):
780
- B, seq_len, seq_len2, dim = x.shape
781
- assert(seq_len == seq_len2)
782
-
783
- R = B * seq_len * seq_len
784
- C = dim
785
-
786
- out = torch.empty_like(x)
787
-
788
- BLOCK_SIZE = triton.next_power_of_2(C)
789
- assert(BLOCK_SIZE <= 1024)
790
-
791
- def grid(meta):
792
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
793
-
794
- layernorm_kernel_first[grid](
795
- x, out, weight, bias,
796
- R, C, eps,
797
- BLOCK_SIZE=BLOCK_SIZE
798
- )
799
-
800
- return out
801
- '''
802
-
803
-
804
- @triton.autotune(
805
- [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)],
806
- key=[]
807
- )
808
- @triton.jit
809
- def layernorm_kernel_eltwise(
810
- X,
811
- Y,
812
- Weight,
813
- Bias,
814
- OutGate,
815
- seq_len,
816
- stride_batch,
817
- stride_dim,
818
- R,
819
- C, # aka "dim"
820
- eps,
821
- ROW_BLOCK_SIZE: tl.constexpr,
822
- BLOCK_SIZE: tl.constexpr,
823
- ):
824
- row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
825
- cols = tl.arange(0, BLOCK_SIZE)
826
-
827
- # Calculate base pointer for this batch of rows
828
- tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0)
829
- # batch_offset = (row // (stride_seq1 // stride_dim)) * stride_batch
830
- batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len)
831
- seqs_off = row % (seq_len * seq_len) # TODO is this going to prevent vectorization
832
-
833
- off_r = batch * stride_batch + seqs_off
834
- off_c = cols * stride_dim
835
-
836
- mask_row = row < R
837
- mask_col = cols < C
838
-
839
- out_gate = tl.load(
840
- OutGate + row[:, None] * C + cols[None, :],
841
- mask = mask_row[:, None] & mask_col[None, :],
842
- )
843
-
844
- x = tl.load(
845
- X + off_r[:, None] + off_c[None, :],
846
- mask=mask_row[:, None] & mask_col[None, :],
847
- other=0.0
848
- ).to(tl.float32)
849
-
850
- weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
851
- bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
852
-
853
- mean = tl.sum(x, axis=1) / C
854
- diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
855
- var = tl.sum(diff * diff, axis=1) / C
856
- rstd = 1 / tl.sqrt(var + eps)
857
-
858
- y_hat = (x - mean[:, None]) * rstd[:, None]
859
- y = y_hat * weight[None, :] + bias[None, :]
860
-
861
- tl.store(
862
- Y + row[:, None] * C + cols[None, :],
863
- y * out_gate,
864
- mask=mask_row[:, None] & mask_col[None, :]
865
- )
866
-
867
-
868
- def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5):
869
- B, seq_len, seq_len2, dim = x.shape
870
- assert(seq_len == seq_len2)
871
- R = B * seq_len * seq_len
872
- assert(x.stride(3) == seq_len*seq_len)
873
- assert(out_gate.is_contiguous())
874
- C = dim
875
-
876
- out = torch.empty_like(out_gate, dtype=torch.float32)
877
-
878
- BLOCK_SIZE = triton.next_power_of_2(C)
879
- assert(BLOCK_SIZE == 128)
880
-
881
- def grid(meta):
882
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
883
-
884
- layernorm_kernel_eltwise[grid](
885
- x, out, weight, bias, out_gate,
886
- seq_len,
887
- x.stride(0), x.stride(3),
888
- R, C, eps,
889
- BLOCK_SIZE=BLOCK_SIZE
890
- )
891
-
892
- return out
893
-
894
-
895
- def kernel_global(data: input_t) -> output_t:
896
- """
897
- Reference implementation of TriMul using PyTorch.
898
-
899
- Args:
900
- data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
901
- - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
902
- - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
903
- - weights: Dictionary containing model weights
904
- - config: Dictionary containing model configuration parameters
905
- """
906
- input_tensor, mask, weights, config = data
907
-
908
- left_proj_weight = weights["left_proj.weight"].to(torch.float16)
909
- right_proj_weight = weights["right_proj.weight"].to(torch.float16)
910
- left_gate_weight = weights["left_gate.weight"].to(torch.float16)
911
- right_gate_weight = weights["right_gate.weight"].to(torch.float16)
912
- out_gate_weight = weights["out_gate.weight"].to(torch.float16)
913
-
914
- hidden_dim = config["hidden_dim"]
915
- # trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
916
-
917
- x = input_tensor
918
-
919
- batch_size, seq_len, _, dim = x.shape
920
-
921
- x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias'])
922
- # x = torch.nn.functional.layer_norm(x, (dim,), eps=1e-5, weight=weights['norm.weight'], bias=weights['norm.bias'])
923
-
924
- left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask)
925
- # left = torch.nn.functional.linear(x, weights['left_proj.weight'].to(torch.float16))
926
- # right = torch.nn.functional.linear(x, weights['right_proj.weight'].to(torch.float16))
927
-
928
- # left = left * mask.unsqueeze(-1)
929
- # right = right * mask.unsqueeze(-1)
930
-
931
- '''
932
- left = left.to(torch.float32)
933
- right = right.to(torch.float32)
934
- x = x.to(torch.float32)
935
-
936
- left_gate = left_gate.sigmoid()
937
- right_gate = right_gate.sigmoid()
938
- out_gate = out_gate.sigmoid()
939
- '''
940
-
941
- # Elementwise multiplication now handled in kernel
942
- # left = left * left_gate
943
- # right = right * right_gate
944
-
945
- # out = einsum('... i k d, ... j k d -> ... i j d', left, right)
946
- out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1]))
947
- out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1)
948
-
949
- # out = torch.compile(second_layernorm_mul, dynamic=False)(out, hidden_dim, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
950
- out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
951
- # out = torch.nn.functional.layer_norm(out, (hidden_dim,), eps=1e-5, weight=weights['to_out_norm.weight'].to(out.dtype), bias=weights['to_out_norm.bias'].to(out.dtype))
952
- # out = out * out_gate
953
- return torch.nn.functional.linear(out, weights['to_out.weight'])
954
-
955
- '''
956
- # Fill in the given weights of the model
957
- trimul.norm.weight = nn.Parameter(weights['norm.weight'])
958
- trimul.norm.bias = nn.Parameter(weights['norm.bias'])
959
- trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
960
- trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
961
- trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
962
- trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
963
- trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
964
- trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
965
- trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
966
- trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
967
-
968
- output = trimul(input_tensor, mask)
969
-
970
- return output
971
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/trimul_gpumode/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/trimul_mi300.py DELETED
@@ -1,524 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
12
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
13
-
14
- # Configurations with larger block sizes for better data reuse
15
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
16
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=2),
17
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
18
-
19
- # Configurations with deeper K dimension
20
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
21
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
22
-
23
- # More extreme configurations to test the limits
24
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
25
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=2),
26
-
27
- # Configurations with fewer warps
28
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
29
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=2),
30
-
31
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
32
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
33
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
34
- ],
35
- key=['M', 'N', 'K'],
36
- )
37
- @triton.jit
38
- def fused_ln_dual_matmul_kernel(
39
- # Pointers (9)
40
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
41
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
42
- # Metadata (5)
43
- M, H, K, s1, s2,
44
- # Strides (16)
45
- stride_x_m, stride_x_k,
46
- stride_w4_k, stride_w4_n,
47
- stride_wog_k, stride_wog_n,
48
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
49
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
50
- stride_og_m, stride_og_h,
51
- stride_mask_m, stride_mask_h,
52
- # Constexpr (from decorator and kwargs)
53
- LN_EPS: tl.constexpr,
54
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
55
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
56
- ):
57
- # --- PID Mapping: Based on the LARGER 4*H problem ---
58
- pid = tl.program_id(axis=0)
59
- N_4way = 4 * H
60
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
62
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
63
- group_id = pid // num_pid_in_group
64
- first_pid_m = group_id * GROUP_SIZE_M
65
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
66
- pid_m = first_pid_m + (pid % group_size_m)
67
- pid_n = (pid % num_pid_in_group) // group_size_m
68
-
69
- # --- SHARED LayerNorm calculation (done only ONCE) ---
70
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
- m_mask = offs_m < M
72
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
73
-
74
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
75
- for k_offset in range(0, K, BLOCK_SIZE_K):
76
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
77
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
78
- k_mask = (k_offset + k_chunk_offs) < K
79
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
80
- mean += tl.sum(x_chunk, axis=1)
81
- mean /= K
82
-
83
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
84
- for k_offset in range(0, K, BLOCK_SIZE_K):
85
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
86
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
87
- k_mask = (k_offset + k_chunk_offs) < K
88
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
89
- x_centered = x_chunk - mean[:, None]
90
- var += tl.sum(x_centered * x_centered, axis=1)
91
- var /= K
92
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
93
-
94
- # --- Matmul Loop 1: For the 4-Way Projections ---
95
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
96
- offs_k = tl.arange(0, BLOCK_SIZE_K)
97
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
98
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
99
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
-
101
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
102
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
103
- k_block_start = k * BLOCK_SIZE_K;
104
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
105
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
106
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
107
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
108
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
109
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
110
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
111
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
112
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
113
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
114
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
115
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
117
-
118
- #Some threads should calclate out_gate
119
- if pid_n * BLOCK_SIZE_N < H:
120
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
121
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
122
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
123
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
124
- accumulator_og += tl.dot(x_norm_tile, w_tile)
125
-
126
- if pid_n * BLOCK_SIZE_N < H:
127
- og_out = tl.sigmoid(accumulator_og)
128
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
129
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
130
- tl.store(outg_ptrs, og_out, mask=og_mask)
131
-
132
- # --- Fusion Logic for 4-Way Part ---
133
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
134
- role_idx = tl.arange(0, 4)[None, None, :]
135
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
136
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
137
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
138
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
139
-
140
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
141
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
142
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
143
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
144
-
145
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
146
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
147
-
148
- s1s2 = s1 * s2
149
- offs_b = offs_m // s1s2
150
- offs_s1 = (offs_m % s1s2) // s2
151
- offs_s2 = offs_m % s2
152
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
153
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
154
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
155
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
156
-
157
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
158
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
159
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
160
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
161
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
162
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
163
-
164
- @triton.autotune(
165
- configs=[
166
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
167
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
168
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
169
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
170
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
171
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
172
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
173
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
174
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
175
- ],
176
- key=['s1', 's2', 'H'],
177
- )
178
- @triton.jit
179
- def bmm_coalesced_kernel(
180
- # Pointers
181
- Left_ptr, Right_ptr, Out_ptr,
182
- # Dimensions
183
- bs, s1, s2, H,
184
- # Strides
185
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
186
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
187
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
188
- # Kernel parameters
189
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
190
- GROUP_SIZE_M: tl.constexpr,
191
- ):
192
- # Grid and program IDs
193
- pid = tl.program_id(axis=0)
194
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
195
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
196
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
197
- group_id = pid // num_pid_in_group
198
- first_pid_m = group_id * GROUP_SIZE_M
199
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
200
- pid_m = first_pid_m + (pid % group_size_m)
201
- pid_n = (pid % num_pid_in_group) // group_size_m
202
-
203
- pid_bh = tl.program_id(axis=1)
204
- pid_b = pid_bh // H
205
- pid_h = pid_bh % H
206
-
207
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
208
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
209
- offs_k = tl.arange(0, BLOCK_SIZE_K)
210
-
211
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
212
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
213
-
214
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
215
-
216
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
217
- k_start = k * BLOCK_SIZE_K
218
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
219
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
220
-
221
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
222
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
223
-
224
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
225
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
226
-
227
- accumulator += tl.dot(a, b)
228
-
229
- # --- Coalesced Write ---
230
- # Write to a standard (bs, H, s1, s1) layout
231
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
232
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
233
-
234
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
235
- tl.store(out_ptrs, accumulator, mask=c_mask)
236
-
237
- @triton.autotune(
238
- configs=[
239
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
240
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
241
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
243
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
244
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
245
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
247
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
248
- ],
249
- key=['H', 'D'],
250
- )
251
- @triton.jit
252
- def fused_final_kernel(
253
- # Pointers
254
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
255
- # Metadata
256
- M, H, D, s1, # M_gate = bs*s1*s2
257
- # Strides
258
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
259
- stride_gate_m, stride_gate_h,
260
- stride_proj_d, stride_proj_h,
261
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
262
- # Constants
263
- LN_EPS: tl.constexpr,
264
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
265
- GROUP_SIZE_M: tl.constexpr,
266
- ):
267
- # --- Grid and PID Setup for Matmul ---
268
- pid = tl.program_id(axis=0)
269
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
270
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
271
-
272
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
273
- group_id = pid // num_pid_in_group
274
- first_pid_m = group_id * GROUP_SIZE_M
275
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
276
- pid_m = first_pid_m + (pid % group_size_m)
277
- pid_n = (pid % num_pid_in_group) // group_size_m
278
-
279
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
280
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
281
- m_mask = offs_m < M
282
-
283
- # Decompose M back to (b, r, c) for reordering lookups
284
- s1s1 = s1 * s1
285
- b = offs_m // s1s1
286
- r = (offs_m % s1s1) // s1
287
- c = offs_m % s1
288
-
289
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
291
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
292
-
293
- for k_offset in range(0, H, BLOCK_SIZE_K):
294
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
295
- k_mask = offs_k < H
296
-
297
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
298
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
299
-
300
- # Accumulate sum and sum of squares in one pass
301
- sum_x += tl.sum(in_chunk, axis=1)
302
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
303
-
304
- # Finalize statistics
305
- mean = sum_x / H
306
- var = (sum_x2 / H) - (mean * mean)
307
- rstd = tl.math.rsqrt(var + LN_EPS)
308
-
309
- # --- Pass 3: Fused Gating and Matmul ---
310
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
311
- for k_offset in range(0, H, BLOCK_SIZE_K):
312
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
313
- k_mask = offs_k < H
314
-
315
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
316
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
317
- a_norm = (a - mean[:, None]) * rstd[:, None]
318
-
319
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
320
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
321
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
322
-
323
- proj_ptrs = ProjW_ptr + \
324
- offs_n[None, :] * stride_proj_d + \
325
- offs_k[:, None] * stride_proj_h
326
-
327
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
328
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
329
- a_gated = a_norm * gate
330
-
331
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
332
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
333
-
334
- # --- Store Final Output ---
335
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
336
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
337
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
338
-
339
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
340
-
341
- def compiledtrimul_fused_interleaved(
342
- x: torch.Tensor,
343
- mask_mh: torch.Tensor,
344
- norm_weight: torch.Tensor,
345
- norm_bias: torch.Tensor,
346
- W_4way: torch.Tensor, # Use the new weight matrices
347
- W_og: torch.Tensor,
348
- to_out_norm_weight: torch.Tensor,
349
- to_out_norm_bias: torch.Tensor,
350
- to_out_weight: torch.Tensor,
351
- h: int,
352
- ):
353
- bs, s1, s2, d = x.shape
354
- M, K, H = bs * s1 * s2, x.shape[-1], h
355
- x_flat = x.view(M, K)
356
-
357
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
358
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
359
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
360
-
361
- # The grid is launched for the larger 4*H problem
362
- N_4way = 4 * H
363
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
364
- fused_ln_dual_matmul_kernel[grid](
365
- # Pointers (9)
366
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
367
- left_final, right_final_t, og_mh,
368
- # Metadata (5) - M, H, K, s1, s2
369
- M, H, K, s1, s2,
370
- # Strides (16)
371
- x_flat.stride(0), x_flat.stride(1),
372
- W_4way.stride(0), W_4way.stride(1),
373
- W_og.stride(0), W_og.stride(1),
374
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
375
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
376
- og_mh.stride(0), og_mh.stride(1),
377
- mask_mh.stride(0), mask_mh.stride(1),
378
- # Constexpr (1)
379
- LN_EPS=1e-5
380
- )
381
-
382
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
383
-
384
- grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
385
- bmm_coalesced_kernel[grid_bmm](
386
- left_final, right_final_t, bmm_out_tmp,
387
- bs, s1, s2, H,
388
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
389
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
390
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
391
- )
392
-
393
- # --- Kernel 3: Fully Fused Final Stage ---
394
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
395
-
396
- grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
397
- fused_final_kernel[grid_final](
398
- # Pointers
399
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
400
- # Metadata
401
- M, H, d, s1,
402
- # Strides
403
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
404
- og_mh.stride(0), og_mh.stride(1),
405
- to_out_weight.stride(0), to_out_weight.stride(1), # Use strides of the corrected tensor
406
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
407
- # Constants
408
- LN_EPS=1e-5,
409
- )
410
-
411
- return final_out
412
-
413
- def pack_w_4way_efficient(weights):
414
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
415
- WL = weights['left_proj.weight']
416
- WLG = weights['left_gate.weight']
417
- WR = weights['right_proj.weight']
418
- WRG = weights['right_gate.weight']
419
- H, K = WL.shape
420
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
421
- ws = ws.contiguous().view(4 * H, K)
422
- return ws.t().to(torch.float16)
423
-
424
- def get_w_og(weights):
425
- """ Gets the transposed [K, H] out_gate weight matrix. """
426
- WOG = weights['out_gate.weight']
427
- return WOG.t().to(torch.float16)
428
-
429
- def compiledtrimul(
430
- x: torch.Tensor,
431
- mask: torch.Tensor,
432
- norm_weight: torch.Tensor,
433
- norm_bias: torch.Tensor,
434
- w_concat: torch.Tensor,
435
- to_out_norm_weight: torch.Tensor,
436
- to_out_norm_bias: torch.Tensor,
437
- to_out_weight: torch.Tensor,
438
- h: int
439
- ) -> torch.Tensor:
440
- """
441
- A barebones, compiled PyTorch function for the TriMul logic.
442
- """
443
- bs, s1, s2, d = x.shape
444
-
445
- # Initial LayerNorm
446
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
447
- # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
448
- all_projections = torch.mm(x_norm, w_concat)
449
-
450
- # Split back into individual projections
451
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
452
-
453
- # Apply mask and gates
454
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
455
- left = left * mask_expanded * torch.sigmoid(lg)
456
- right = right * mask_expanded * torch.sigmoid(rg)
457
- out_gate = torch.sigmoid(og)
458
-
459
- # Reshape for einsum
460
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
461
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
462
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
463
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
464
-
465
- # Apply layer norm and final gating
466
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
467
- gated = normed * out_gate
468
-
469
- # Final projection
470
- final_out_flat = gated @ to_out_weight.t()
471
- final_out = final_out_flat.view(bs, s1, s2, d)
472
-
473
- return final_out
474
-
475
- def small_kernel_pt_path(data):
476
- input_tensor, mask, weights, config = data
477
- w_concat = torch.cat([
478
- weights['left_proj.weight'],
479
- weights['right_proj.weight'],
480
- weights['left_gate.weight'],
481
- weights['right_gate.weight'],
482
- weights['out_gate.weight']
483
- ], dim=0).t().contiguous().to(torch.float16)
484
- # Call the compiled function with prepared weights
485
- output = compiledtrimul(
486
- x=input_tensor.to(torch.float32),
487
- mask=mask.unsqueeze(-1),
488
- norm_weight=weights['norm.weight'].to(torch.float32),
489
- norm_bias=weights['norm.bias'].to(torch.float32),
490
- w_concat=w_concat,
491
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
492
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
493
- to_out_weight=weights['to_out.weight'].to(torch.float16),
494
- h=config["hidden_dim"]
495
- )
496
- return output
497
-
498
- def kernel_mi300(data):
499
- input_tensor, mask, weights, config = data
500
- bs, s1, s2, d = input_tensor.shape
501
-
502
- if s1 < 100:
503
- return small_kernel_pt_path(data)
504
-
505
- H = config["hidden_dim"]
506
-
507
- W_4way = pack_w_4way_efficient(weights)
508
- W_og = get_w_og(weights)
509
-
510
- M = bs * s1 * s2
511
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
512
-
513
- return compiledtrimul_fused_interleaved(
514
- x=input_tensor.to(torch.float32),
515
- mask_mh=mask_mh,
516
- norm_weight=weights['norm.weight'].to(torch.float32),
517
- norm_bias=weights['norm.bias'].to(torch.float32),
518
- W_4way=W_4way, # Pass the new 4-way matrix
519
- W_og=W_og, # Pass the new out_gate matrix
520
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
521
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
522
- to_out_weight=weights['to_out.weight'].to(torch.float16),
523
- h=H,
524
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/triton_a100.py DELETED
@@ -1,405 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- # Set PyTorch flags for performance
7
- torch.backends.cuda.matmul.allow_tf32 = True
8
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
9
-
10
- @triton.jit
11
- def fused_ln_dual_matmul_kernel(
12
- # Pointers (9)
13
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
14
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
15
- # Metadata (5)
16
- M, H, K, s1, s2,
17
- # Strides (16)
18
- stride_x_m, stride_x_k,
19
- stride_w4_k, stride_w4_n,
20
- stride_wog_k, stride_wog_n,
21
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
22
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
23
- stride_og_m, stride_og_h,
24
- stride_mask_m, stride_mask_h,
25
- # Constexpr (now passed as arguments from the host)
26
- LN_EPS: tl.constexpr,
27
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
29
- ):
30
- # --- PID Mapping: Based on the LARGER 4*H problem ---
31
- pid = tl.program_id(axis=0)
32
- N_4way = 4 * H
33
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
35
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
- group_id = pid // num_pid_in_group
37
- first_pid_m = group_id * GROUP_SIZE_M
38
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
- pid_m = first_pid_m + (pid % group_size_m)
40
- pid_n = (pid % num_pid_in_group) // group_size_m
41
-
42
- # --- SHARED LayerNorm calculation (done only ONCE) ---
43
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44
- m_mask = offs_m < M
45
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
46
-
47
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
48
- for k_offset in range(0, K, BLOCK_SIZE_K):
49
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
50
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
51
- k_mask = (k_offset + k_chunk_offs) < K
52
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
53
- mean += tl.sum(x_chunk, axis=1)
54
- mean /= K
55
-
56
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
57
- for k_offset in range(0, K, BLOCK_SIZE_K):
58
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
59
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
60
- k_mask = (k_offset + k_chunk_offs) < K
61
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
62
- x_centered = x_chunk - mean[:, None]
63
- var += tl.sum(x_centered * x_centered, axis=1)
64
- var /= K
65
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
66
-
67
- # --- Matmul Loop 1: For the 4-Way Projections ---
68
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
69
- offs_k = tl.arange(0, BLOCK_SIZE_K)
70
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
71
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
-
74
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76
- k_block_start = k * BLOCK_SIZE_K;
77
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
78
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
79
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
80
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
81
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
82
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
83
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
84
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
86
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
87
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
88
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
89
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
90
-
91
- if pid_n * BLOCK_SIZE_N < H:
92
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
- accumulator_og += tl.dot(x_norm_tile, w_tile)
97
-
98
- if pid_n * BLOCK_SIZE_N < H:
99
- og_out = tl.sigmoid(accumulator_og)
100
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
- tl.store(outg_ptrs, og_out, mask=og_mask)
103
-
104
- # --- Fusion Logic for 4-Way Part ---
105
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
- role_idx = tl.arange(0, 4)[None, None, :]
107
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
-
112
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
-
117
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
-
120
- s1s2 = s1 * s2
121
- offs_b = offs_m // s1s2
122
- offs_s1 = (offs_m % s1s2) // s2
123
- offs_s2 = offs_m % s2
124
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
-
129
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
-
136
- @triton.jit
137
- def bmm_coalesced_kernel(
138
- # Pointers
139
- Left_ptr, Right_ptr, Out_ptr,
140
- # Dimensions
141
- bs, s1, s2, H,
142
- # Strides
143
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
- # Kernel parameters
147
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
- GROUP_SIZE_M: tl.constexpr,
149
- ):
150
- pid = tl.program_id(axis=0)
151
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
152
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
153
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
154
- group_id = pid // num_pid_in_group
155
- first_pid_m = group_id * GROUP_SIZE_M
156
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
157
- pid_m = first_pid_m + (pid % group_size_m)
158
- pid_n = (pid % num_pid_in_group) // group_size_m
159
-
160
- pid_bh = tl.program_id(axis=1)
161
- pid_b = pid_bh // H
162
- pid_h = pid_bh % H
163
-
164
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
165
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
166
- offs_k = tl.arange(0, BLOCK_SIZE_K)
167
-
168
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
169
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
170
-
171
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
172
-
173
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
174
- k_start = k * BLOCK_SIZE_K
175
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
176
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
177
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
178
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
179
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
180
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
181
- accumulator += tl.dot(a, b)
182
-
183
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
184
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
185
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
186
- tl.store(out_ptrs, accumulator, mask=c_mask)
187
-
188
- @triton.jit
189
- def fused_final_kernel(
190
- # Pointers
191
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
192
- # Metadata
193
- M, H, D, s1,
194
- # Strides
195
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
196
- stride_gate_m, stride_gate_h,
197
- stride_proj_d, stride_proj_h,
198
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
199
- # Constants
200
- LN_EPS: tl.constexpr,
201
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
202
- GROUP_SIZE_M: tl.constexpr,
203
- ):
204
- pid = tl.program_id(axis=0)
205
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
207
-
208
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
- group_id = pid // num_pid_in_group
210
- first_pid_m = group_id * GROUP_SIZE_M
211
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
- pid_m = first_pid_m + (pid % group_size_m)
213
- pid_n = (pid % num_pid_in_group) // group_size_m
214
-
215
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
216
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
- m_mask = offs_m < M
218
-
219
- s1s1 = s1 * s1
220
- b = offs_m // s1s1
221
- r = (offs_m % s1s1) // s1
222
- c = offs_m % s1
223
-
224
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
225
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
226
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
227
-
228
- for k_offset in range(0, H, BLOCK_SIZE_K):
229
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
230
- k_mask = offs_k < H
231
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
232
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
233
- sum_x += tl.sum(in_chunk, axis=1)
234
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
235
-
236
- mean = sum_x / H
237
- var = (sum_x2 / H) - (mean * mean)
238
- rstd = tl.math.rsqrt(var + LN_EPS)
239
-
240
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
241
- for k_offset in range(0, H, BLOCK_SIZE_K):
242
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
243
- k_mask = offs_k < H
244
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
245
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
246
- a_norm = (a - mean[:, None]) * rstd[:, None]
247
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
248
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
249
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
250
- proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
251
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
252
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
253
- a_gated = a_norm * gate
254
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
255
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
256
-
257
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
258
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
259
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
260
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
261
-
262
- def compiledtrimul_fused_interleaved_final(
263
- x: torch.Tensor,
264
- mask_mh: torch.Tensor,
265
- norm_weight: torch.Tensor,
266
- norm_bias: torch.Tensor,
267
- W_4way: torch.Tensor,
268
- W_og: torch.Tensor,
269
- to_out_norm_weight: torch.Tensor,
270
- to_out_norm_bias: torch.Tensor,
271
- to_out_weight: torch.Tensor,
272
- h: int,
273
- ):
274
- bs, s1, s2, d = x.shape
275
- M, K, H = bs * s1 * s2, x.shape[-1], h
276
- x_flat = x.view(M, K)
277
-
278
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
279
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
280
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
281
-
282
- # --- Kernel 1: Fused LN + Dual Matmul ---
283
- N_4way = 4 * H
284
- # Hardcoded A100 best config: M128-N128-K32-GM8-HC32-W8-S2
285
- config_k1 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
286
- grid_k1 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
287
-
288
- fused_ln_dual_matmul_kernel[grid_k1](
289
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
290
- left_final, right_final_t, og_mh,
291
- M, H, K, s1, s2,
292
- x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
293
- W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
294
- left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
295
- right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
296
- mask_mh.stride(0), mask_mh.stride(1),
297
- LN_EPS=1e-5, **config_k1, num_warps=8, num_stages=2
298
- )
299
-
300
- # --- Kernel 2: Batched Matrix Multiplication ---
301
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
302
- # Hardcoded A100 best config: M128-N64-K32-GM8-W4-S3
303
- config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
304
- grid_k2 = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
305
-
306
- bmm_coalesced_kernel[grid_k2](
307
- left_final, right_final_t, bmm_out_tmp,
308
- bs, s1, s2, H,
309
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
310
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
311
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
312
- **config_k2, num_warps=4, num_stages=3
313
- )
314
-
315
- # --- Kernel 3: Fully Fused Final Stage ---
316
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
317
- # Hardcoded A100 best config: M32-N128-K32-GM8-W4-S3
318
- config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
319
- grid_k3 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
320
-
321
- fused_final_kernel[grid_k3](
322
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
323
- M, H, d, s1,
324
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
325
- og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
326
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
327
- LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
328
- )
329
- return final_out
330
-
331
- def pack_w_4way_efficient(weights):
332
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
333
- WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
334
- H, K = WL.shape
335
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
336
- return ws.t().to(torch.float16)
337
-
338
- def get_w_og(weights):
339
- """ Gets the transposed [K, H] out_gate weight matrix. """
340
- return weights['out_gate.weight'].t().to(torch.float16)
341
-
342
- @torch.compile()
343
- def compiledtrimul(
344
- x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
345
- w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
346
- to_out_weight: torch.Tensor, h: int
347
- ) -> torch.Tensor:
348
- bs, s1, s2, d = x.shape
349
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
350
- all_projections = torch.mm(x_norm, w_concat)
351
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
352
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
353
- left = left * mask_expanded * torch.sigmoid(lg)
354
- right = right * mask_expanded * torch.sigmoid(rg)
355
- out_gate = torch.sigmoid(og)
356
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
357
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
358
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
359
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
360
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
361
- gated = normed * out_gate
362
- final_out_flat = gated @ to_out_weight.t()
363
- return final_out_flat.view(bs, s1, s1, d)
364
-
365
- def small_kernel_pt_path(data):
366
- input_tensor, mask, weights, config = data
367
- w_concat = torch.cat([
368
- weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
369
- weights['right_gate.weight'], weights['out_gate.weight']
370
- ], dim=0).t().contiguous().to(torch.float16)
371
- return compiledtrimul(
372
- x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
373
- norm_weight=weights['norm.weight'].to(torch.float32),
374
- norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
375
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
376
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
377
- to_out_weight=weights['to_out.weight'].to(torch.float16),
378
- h=config["hidden_dim"]
379
- )
380
-
381
- def kernel_a100(data):
382
- input_tensor, mask, weights, config = data
383
- bs, s1, s2, d = input_tensor.shape
384
-
385
- if s1 < 512: # Adjusted threshold based on observed BMM configs
386
- return small_kernel_pt_path(data)
387
-
388
- H = config["hidden_dim"]
389
- W_4way = pack_w_4way_efficient(weights)
390
- W_og = get_w_og(weights)
391
- M = bs * s1 * s2
392
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
393
-
394
- return compiledtrimul_fused_interleaved_final(
395
- x=input_tensor.to(torch.float32),
396
- mask_mh=mask_mh,
397
- norm_weight=weights['norm.weight'].to(torch.float32),
398
- norm_bias=weights['norm.bias'].to(torch.float32),
399
- W_4way=W_4way,
400
- W_og=W_og,
401
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
402
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
403
- to_out_weight=weights['to_out.weight'].to(torch.float16),
404
- h=H,
405
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/triton_b200.py DELETED
@@ -1,411 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.jit
10
- def fused_ln_dual_matmul_kernel(
11
- # Pointers (9)
12
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
13
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
14
- # Metadata (5)
15
- M, H, K, s1, s2,
16
- # Strides (16)
17
- stride_x_m, stride_x_k,
18
- stride_w4_k, stride_w4_n,
19
- stride_wog_k, stride_wog_n,
20
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
21
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
22
- stride_og_m, stride_og_h,
23
- stride_mask_m, stride_mask_h,
24
- # Constexpr (now passed as arguments from the host)
25
- LN_EPS: tl.constexpr,
26
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
27
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
28
- ):
29
- # --- PID Mapping: Based on the LARGER 4*H problem ---
30
- pid = tl.program_id(axis=0)
31
- N_4way = 4 * H
32
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
34
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
35
- group_id = pid // num_pid_in_group
36
- first_pid_m = group_id * GROUP_SIZE_M
37
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38
- pid_m = first_pid_m + (pid % group_size_m)
39
- pid_n = (pid % num_pid_in_group) // group_size_m
40
-
41
- # --- SHARED LayerNorm calculation (done only ONCE) ---
42
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
43
- m_mask = offs_m < M
44
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
45
-
46
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
47
- for k_offset in range(0, K, BLOCK_SIZE_K):
48
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
49
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
50
- k_mask = (k_offset + k_chunk_offs) < K
51
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
52
- mean += tl.sum(x_chunk, axis=1)
53
- mean /= K
54
-
55
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
56
- for k_offset in range(0, K, BLOCK_SIZE_K):
57
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
58
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
59
- k_mask = (k_offset + k_chunk_offs) < K
60
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
61
- x_centered = x_chunk - mean[:, None]
62
- var += tl.sum(x_centered * x_centered, axis=1)
63
- var /= K
64
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
65
-
66
- # --- Matmul Loop 1: For the 4-Way Projections ---
67
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
- offs_k = tl.arange(0, BLOCK_SIZE_K)
69
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
70
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
71
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
-
73
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
74
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75
- k_block_start = k * BLOCK_SIZE_K;
76
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
77
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
78
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
79
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
80
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
81
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
82
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
83
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
84
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
86
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
87
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
88
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
89
-
90
- #Some threads should calclate out_gate
91
- if pid_n * BLOCK_SIZE_N < H:
92
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
- accumulator_og += tl.dot(x_norm_tile, w_tile)
97
-
98
- if pid_n * BLOCK_SIZE_N < H:
99
- og_out = tl.sigmoid(accumulator_og)
100
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
- tl.store(outg_ptrs, og_out, mask=og_mask)
103
-
104
- # --- Fusion Logic for 4-Way Part ---
105
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
- role_idx = tl.arange(0, 4)[None, None, :]
107
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
-
112
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
-
117
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
-
120
- s1s2 = s1 * s2
121
- offs_b = offs_m // s1s2
122
- offs_s1 = (offs_m % s1s2) // s2
123
- offs_s2 = offs_m % s2
124
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
-
129
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
-
136
- @triton.jit
137
- def bmm_coalesced_kernel(
138
- # Pointers
139
- Left_ptr, Right_ptr, Out_ptr,
140
- # Dimensions
141
- bs, s1, s2, H,
142
- # Strides
143
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
- # Kernel parameters
147
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
- GROUP_SIZE_M: tl.constexpr,
149
- ):
150
- # Grid and program IDs
151
- pid = tl.program_id(axis=0)
152
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
153
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
154
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
155
- group_id = pid // num_pid_in_group
156
- first_pid_m = group_id * GROUP_SIZE_M
157
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
158
- pid_m = first_pid_m + (pid % group_size_m)
159
- pid_n = (pid % num_pid_in_group) // group_size_m
160
-
161
- pid_bh = tl.program_id(axis=1)
162
- pid_b = pid_bh // H
163
- pid_h = pid_bh % H
164
-
165
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
166
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
167
- offs_k = tl.arange(0, BLOCK_SIZE_K)
168
-
169
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
170
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
171
-
172
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173
-
174
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
175
- k_start = k * BLOCK_SIZE_K
176
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
177
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
178
-
179
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
180
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
181
-
182
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
183
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
184
-
185
- accumulator += tl.dot(a, b)
186
-
187
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
188
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
189
-
190
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
191
- tl.store(out_ptrs, accumulator, mask=c_mask)
192
-
193
- @triton.jit
194
- def fused_final_kernel(
195
- # Pointers
196
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
197
- # Metadata
198
- M, H, D, s1,
199
- # Strides
200
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
201
- stride_gate_m, stride_gate_h,
202
- stride_proj_d, stride_proj_h,
203
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
204
- # Constants
205
- LN_EPS: tl.constexpr,
206
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
207
- GROUP_SIZE_M: tl.constexpr,
208
- ):
209
- pid = tl.program_id(axis=0)
210
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
211
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
212
-
213
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
214
- group_id = pid // num_pid_in_group
215
- first_pid_m = group_id * GROUP_SIZE_M
216
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217
- pid_m = first_pid_m + (pid % group_size_m)
218
- pid_n = (pid % num_pid_in_group) // group_size_m
219
-
220
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
221
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
222
- m_mask = offs_m < M
223
-
224
- s1s1 = s1 * s1
225
- b = offs_m // s1s1
226
- r = (offs_m % s1s1) // s1
227
- c = offs_m % s1
228
-
229
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
230
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
231
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
232
-
233
- for k_offset in range(0, H, BLOCK_SIZE_K):
234
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
235
- k_mask = offs_k < H
236
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
237
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
238
- sum_x += tl.sum(in_chunk, axis=1)
239
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
240
-
241
- mean = sum_x / H
242
- var = (sum_x2 / H) - (mean * mean)
243
- rstd = tl.math.rsqrt(var + LN_EPS)
244
-
245
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
246
- for k_offset in range(0, H, BLOCK_SIZE_K):
247
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
248
- k_mask = offs_k < H
249
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
250
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
251
- a_norm = (a - mean[:, None]) * rstd[:, None]
252
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
253
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
254
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
255
- proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
256
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
257
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
258
- a_gated = a_norm * gate
259
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
260
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
261
-
262
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
263
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
264
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
265
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
266
-
267
- def compiledtrimul_fused_interleaved_final(
268
- x: torch.Tensor,
269
- mask_mh: torch.Tensor,
270
- norm_weight: torch.Tensor,
271
- norm_bias: torch.Tensor,
272
- W_4way: torch.Tensor,
273
- W_og: torch.Tensor,
274
- to_out_norm_weight: torch.Tensor,
275
- to_out_norm_bias: torch.Tensor,
276
- to_out_weight: torch.Tensor,
277
- h: int,
278
- ):
279
- bs, s1, s2, d = x.shape
280
- M, K, H = bs * s1 * s2, x.shape[-1], h
281
- x_flat = x.view(M, K)
282
-
283
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
284
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
285
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
286
-
287
- # --- Kernel 1: Fused LN + Dual Matmul ---
288
- # The grid is launched for the larger 4*H problem
289
- N_4way = 4 * H
290
- # Hardcoded best config from logs: M64-N128-K64-GM8-HC32-W4-S2
291
- config_k1 = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
292
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
293
-
294
- fused_ln_dual_matmul_kernel[grid](
295
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
296
- left_final, right_final_t, og_mh,
297
- M, H, K, s1, s2,
298
- x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
299
- W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
300
- left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
301
- right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
302
- mask_mh.stride(0), mask_mh.stride(1),
303
- LN_EPS=1e-5, **config_k1, num_warps=4, num_stages=2
304
- )
305
-
306
- # --- Kernel 2: Batched Matrix Multiplication ---
307
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
308
- # Hardcoded best config from logs: M128-N128-K32-GM8-W8-S3
309
- config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
310
- grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
311
-
312
- bmm_coalesced_kernel[grid_bmm](
313
- left_final, right_final_t, bmm_out_tmp,
314
- bs, s1, s2, H,
315
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
316
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
317
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
318
- **config_k2, num_warps=8, num_stages=3
319
- )
320
-
321
- # --- Kernel 3: Fully Fused Final Stage ---
322
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
323
- # Hardcoded best config from logs: M32-N128-K32-GM8-W4-S3
324
- config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
325
- grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
326
-
327
- fused_final_kernel[grid_final](
328
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
329
- M, H, d, s1,
330
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
331
- og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
332
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
333
- LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
334
- )
335
- return final_out
336
-
337
- def pack_w_4way_efficient(weights):
338
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
339
- WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
340
- H, K = WL.shape
341
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
342
- return ws.t().to(torch.float16)
343
-
344
- def get_w_og(weights):
345
- """ Gets the transposed [K, H] out_gate weight matrix. """
346
- return weights['out_gate.weight'].t().to(torch.float16)
347
-
348
- @torch.compile()
349
- def compiledtrimul(
350
- x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
351
- w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
352
- to_out_weight: torch.Tensor, h: int
353
- ) -> torch.Tensor:
354
- bs, s1, s2, d = x.shape
355
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
356
- all_projections = torch.mm(x_norm, w_concat)
357
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
358
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
359
- left = left * mask_expanded * torch.sigmoid(lg)
360
- right = right * mask_expanded * torch.sigmoid(rg)
361
- out_gate = torch.sigmoid(og)
362
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
363
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
364
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
365
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
366
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
367
- gated = normed * out_gate
368
- final_out_flat = gated @ to_out_weight.t()
369
- return final_out_flat.view(bs, s1, s1, d)
370
-
371
- def small_kernel_pt_path(data):
372
- input_tensor, mask, weights, config = data
373
- w_concat = torch.cat([
374
- weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
375
- weights['right_gate.weight'], weights['out_gate.weight']
376
- ], dim=0).t().contiguous().to(torch.float16)
377
- return compiledtrimul(
378
- x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
379
- norm_weight=weights['norm.weight'].to(torch.float32),
380
- norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
381
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
382
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
383
- to_out_weight=weights['to_out.weight'].to(torch.float16),
384
- h=config["hidden_dim"]
385
- )
386
-
387
- def kernel_b200(data):
388
- input_tensor, mask, weights, config = data
389
- bs, s1, s2, d = input_tensor.shape
390
-
391
- if s1 < 800:
392
- return small_kernel_pt_path(data)
393
-
394
- H = config["hidden_dim"]
395
- W_4way = pack_w_4way_efficient(weights)
396
- W_og = get_w_og(weights)
397
- M = bs * s1 * s2
398
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
399
-
400
- return compiledtrimul_fused_interleaved_final(
401
- x=input_tensor.to(torch.float32),
402
- mask_mh=mask_mh,
403
- norm_weight=weights['norm.weight'].to(torch.float32),
404
- norm_bias=weights['norm.bias'].to(torch.float32),
405
- W_4way=W_4way,
406
- W_og=W_og,
407
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
408
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
409
- to_out_weight=weights['to_out.weight'].to(torch.float16),
410
- h=H,
411
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-cuda/triton_h100.py DELETED
@@ -1,509 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
12
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
13
-
14
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
15
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
16
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
17
-
18
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=4),
19
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
20
-
21
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=5),
22
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=5),
23
-
24
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
25
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=4),
26
- ],
27
- key=['M', 'N', 'K'],
28
- )
29
- @triton.jit
30
- def fused_ln_dual_matmul_kernel(
31
- # Pointers (9)
32
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
33
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
34
- # Metadata (5)
35
- M, H, K, s1, s2,
36
- # Strides (16)
37
- stride_x_m, stride_x_k,
38
- stride_w4_k, stride_w4_n,
39
- stride_wog_k, stride_wog_n,
40
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
41
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
42
- stride_og_m, stride_og_h,
43
- stride_mask_m, stride_mask_h,
44
- # Constexpr (from decorator and kwargs)
45
- LN_EPS: tl.constexpr,
46
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
47
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
48
- ):
49
- # --- PID Mapping: Based on the LARGER 4*H problem ---
50
- pid = tl.program_id(axis=0)
51
- N_4way = 4 * H
52
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
54
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
55
- group_id = pid // num_pid_in_group
56
- first_pid_m = group_id * GROUP_SIZE_M
57
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
- pid_m = first_pid_m + (pid % group_size_m)
59
- pid_n = (pid % num_pid_in_group) // group_size_m
60
-
61
- # --- SHARED LayerNorm calculation (done only ONCE) ---
62
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63
- m_mask = offs_m < M
64
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
65
-
66
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
67
- for k_offset in range(0, K, BLOCK_SIZE_K):
68
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
69
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
70
- k_mask = (k_offset + k_chunk_offs) < K
71
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
72
- mean += tl.sum(x_chunk, axis=1)
73
- mean /= K
74
-
75
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
76
- for k_offset in range(0, K, BLOCK_SIZE_K):
77
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
78
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
79
- k_mask = (k_offset + k_chunk_offs) < K
80
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
81
- x_centered = x_chunk - mean[:, None]
82
- var += tl.sum(x_centered * x_centered, axis=1)
83
- var /= K
84
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
85
-
86
- # --- Matmul Loop 1: For the 4-Way Projections ---
87
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
- offs_k = tl.arange(0, BLOCK_SIZE_K)
89
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
90
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
91
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92
-
93
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
94
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
95
- k_block_start = k * BLOCK_SIZE_K;
96
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
97
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
98
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
99
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
100
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
101
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
102
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
103
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
104
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
105
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
106
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
107
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
108
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
109
-
110
- #Some threads should calclate out_gate
111
- if pid_n * BLOCK_SIZE_N < H:
112
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
113
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
114
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
115
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
- accumulator_og += tl.dot(x_norm_tile, w_tile)
117
-
118
- if pid_n * BLOCK_SIZE_N < H:
119
- og_out = tl.sigmoid(accumulator_og)
120
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
121
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
122
- tl.store(outg_ptrs, og_out, mask=og_mask)
123
-
124
- # --- Fusion Logic for 4-Way Part ---
125
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
126
- role_idx = tl.arange(0, 4)[None, None, :]
127
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
128
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
129
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
130
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
131
-
132
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
133
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
134
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
135
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
136
-
137
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
138
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
139
-
140
- s1s2 = s1 * s2
141
- offs_b = offs_m // s1s2
142
- offs_s1 = (offs_m % s1s2) // s2
143
- offs_s2 = offs_m % s2
144
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
145
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
146
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
147
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
148
-
149
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
150
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
151
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
152
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
153
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
154
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
155
-
156
- @triton.autotune(
157
- configs=[
158
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
159
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
160
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
161
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
162
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
163
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
164
- ],
165
- key=['s1', 's2', 'H'],
166
- )
167
- @triton.jit
168
- def bmm_coalesced_kernel(
169
- # Pointers
170
- Left_ptr, Right_ptr, Out_ptr,
171
- # Dimensions
172
- bs, s1, s2, H,
173
- # Strides
174
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
175
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
176
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
177
- # Kernel parameters
178
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
179
- GROUP_SIZE_M: tl.constexpr,
180
- ):
181
- # Grid and program IDs
182
- pid = tl.program_id(axis=0)
183
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
184
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
185
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
186
- group_id = pid // num_pid_in_group
187
- first_pid_m = group_id * GROUP_SIZE_M
188
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
189
- pid_m = first_pid_m + (pid % group_size_m)
190
- pid_n = (pid % num_pid_in_group) // group_size_m
191
-
192
- pid_bh = tl.program_id(axis=1)
193
- pid_b = pid_bh // H
194
- pid_h = pid_bh % H
195
-
196
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
- offs_k = tl.arange(0, BLOCK_SIZE_K)
199
-
200
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
201
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
202
-
203
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
204
-
205
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
206
- k_start = k * BLOCK_SIZE_K
207
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
208
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
209
-
210
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
211
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
212
-
213
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
214
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
215
-
216
- accumulator += tl.dot(a, b)
217
-
218
- # --- Coalesced Write ---
219
- # Write to a standard (bs, H, s1, s1) layout
220
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
221
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
222
-
223
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
224
- tl.store(out_ptrs, accumulator, mask=c_mask)
225
-
226
- @torch.compile
227
- def torch_pt2(left_final, right_final_t, bs, s1, s2, d, h, to_out_norm_weight, to_out_norm_bias, og_mh, to_out_weight):
228
- bmm_out = torch.matmul(left_final, right_final_t)
229
- out_einsum_flat = bmm_out.permute(0, 2, 3, 1).reshape(bs * s1 * s1, h)
230
- # Apply layer norm and final gating
231
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
232
- gated = normed * og_mh
233
-
234
- # Final projection
235
- final_out_flat = gated @ to_out_weight.t()
236
- final_out = final_out_flat.view(bs, s1, s2, d)
237
- return final_out
238
-
239
- @triton.autotune(
240
- configs=[
241
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
243
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
244
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
245
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
247
- ],
248
- key=['H', 'D'],
249
- )
250
- @triton.jit
251
- def fused_final_kernel(
252
- # Pointers
253
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
254
- # Metadata
255
- M, H, D, s1, # M_gate = bs*s1*s2
256
- # Strides
257
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
258
- stride_gate_m, stride_gate_h,
259
- stride_proj_d, stride_proj_h,
260
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
261
- # Constants
262
- LN_EPS: tl.constexpr,
263
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
264
- GROUP_SIZE_M: tl.constexpr,
265
- ):
266
- # --- Grid and PID Setup for Matmul ---
267
- pid = tl.program_id(axis=0)
268
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
269
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
270
-
271
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
272
- group_id = pid // num_pid_in_group
273
- first_pid_m = group_id * GROUP_SIZE_M
274
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
275
- pid_m = first_pid_m + (pid % group_size_m)
276
- pid_n = (pid % num_pid_in_group) // group_size_m
277
-
278
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
279
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
280
- m_mask = offs_m < M
281
-
282
- # Decompose M back to (b, r, c) for reordering lookups
283
- s1s1 = s1 * s1
284
- b = offs_m // s1s1
285
- r = (offs_m % s1s1) // s1
286
- c = offs_m % s1
287
-
288
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
289
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
291
-
292
- for k_offset in range(0, H, BLOCK_SIZE_K):
293
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
294
- k_mask = offs_k < H
295
-
296
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
297
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
298
-
299
- # Accumulate sum and sum of squares in one pass
300
- sum_x += tl.sum(in_chunk, axis=1)
301
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
302
-
303
- # Finalize statistics
304
- mean = sum_x / H
305
- var = (sum_x2 / H) - (mean * mean)
306
- rstd = tl.math.rsqrt(var + LN_EPS)
307
-
308
- # --- Pass 3: Fused Gating and Matmul ---
309
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
- for k_offset in range(0, H, BLOCK_SIZE_K):
311
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
312
- k_mask = offs_k < H
313
-
314
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
315
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
316
- a_norm = (a - mean[:, None]) * rstd[:, None]
317
-
318
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
319
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
320
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
321
-
322
- proj_ptrs = ProjW_ptr + \
323
- offs_n[None, :] * stride_proj_d + \
324
- offs_k[:, None] * stride_proj_h
325
-
326
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
327
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
328
- a_gated = a_norm * gate
329
-
330
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
331
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
332
-
333
- # --- Store Final Output ---
334
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
336
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
337
-
338
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
339
-
340
- def compiledtrimul_fused_interleaved(
341
- x: torch.Tensor,
342
- mask_mh: torch.Tensor,
343
- norm_weight: torch.Tensor,
344
- norm_bias: torch.Tensor,
345
- W_4way: torch.Tensor, # Use the new weight matrices
346
- W_og: torch.Tensor,
347
- to_out_norm_weight: torch.Tensor,
348
- to_out_norm_bias: torch.Tensor,
349
- to_out_weight: torch.Tensor,
350
- h: int,
351
- ):
352
- bs, s1, s2, d = x.shape
353
- M, K, H = bs * s1 * s2, x.shape[-1], h
354
- x_flat = x.view(M, K)
355
-
356
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
357
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
358
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
359
-
360
- # The grid is launched for the larger 4*H problem
361
- N_4way = 4 * H
362
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
363
- fused_ln_dual_matmul_kernel[grid](
364
- # Pointers (9)
365
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
366
- left_final, right_final_t, og_mh,
367
- # Metadata (5) - M, H, K, s1, s2
368
- M, H, K, s1, s2,
369
- # Strides (16)
370
- x_flat.stride(0), x_flat.stride(1),
371
- W_4way.stride(0), W_4way.stride(1),
372
- W_og.stride(0), W_og.stride(1),
373
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
374
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
375
- og_mh.stride(0), og_mh.stride(1),
376
- mask_mh.stride(0), mask_mh.stride(1),
377
- # Constexpr (1)
378
- LN_EPS=1e-5
379
- )
380
- return torch_pt2(
381
- left_final, right_final_t,
382
- bs=bs,
383
- s1=s1,
384
- s2=s2,
385
- d=d,
386
- h=h,
387
- to_out_norm_weight=to_out_norm_weight,
388
- to_out_norm_bias=to_out_norm_bias,
389
- og_mh=og_mh,
390
- to_out_weight=to_out_weight
391
- )
392
-
393
- def pack_w_4way_efficient(weights):
394
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
395
- WL = weights['left_proj.weight']
396
- WLG = weights['left_gate.weight']
397
- WR = weights['right_proj.weight']
398
- WRG = weights['right_gate.weight']
399
- H, K = WL.shape
400
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
401
- ws = ws.contiguous().view(4 * H, K)
402
- return ws.t().to(torch.float16)
403
-
404
- def get_w_og(weights):
405
- """ Gets the transposed [K, H] out_gate weight matrix. """
406
- WOG = weights['out_gate.weight']
407
- return WOG.t().to(torch.float16)
408
-
409
-
410
- torch.backends.cuda.matmul.allow_tf32 = True
411
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
412
-
413
- @torch.compile
414
- def compiledtrimul(
415
- x: torch.Tensor,
416
- mask: torch.Tensor,
417
- norm_weight: torch.Tensor,
418
- norm_bias: torch.Tensor,
419
- w_concat: torch.Tensor,
420
- to_out_norm_weight: torch.Tensor,
421
- to_out_norm_bias: torch.Tensor,
422
- to_out_weight: torch.Tensor,
423
- h: int
424
- ) -> torch.Tensor:
425
- """
426
- A barebones, compiled PyTorch function for the TriMul logic.
427
- """
428
- bs, s1, s2, d = x.shape
429
-
430
- # Initial LayerNorm
431
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
432
- # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
433
- all_projections = torch.mm(x_norm, w_concat)
434
-
435
- # Split back into individual projections
436
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
437
-
438
- # Apply mask and gates
439
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
440
- left = left * mask_expanded * torch.sigmoid(lg)
441
- right = right * mask_expanded * torch.sigmoid(rg)
442
- out_gate = torch.sigmoid(og)
443
-
444
- # Reshape for einsum
445
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
446
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
447
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
448
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
449
-
450
- # Apply layer norm and final gating
451
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
452
- gated = normed * out_gate
453
-
454
- # Final projection
455
- final_out_flat = gated @ to_out_weight.t()
456
- final_out = final_out_flat.view(bs, s1, s2, d)
457
-
458
- return final_out
459
-
460
- def small_kernel_pt_path(data):
461
- input_tensor, mask, weights, config = data
462
- w_concat = torch.cat([
463
- weights['left_proj.weight'],
464
- weights['right_proj.weight'],
465
- weights['left_gate.weight'],
466
- weights['right_gate.weight'],
467
- weights['out_gate.weight']
468
- ], dim=0).t().contiguous().to(torch.float16)
469
- # Call the compiled function with prepared weights
470
- output = compiledtrimul(
471
- x=input_tensor.to(torch.float32),
472
- mask=mask.unsqueeze(-1),
473
- norm_weight=weights['norm.weight'].to(torch.float32),
474
- norm_bias=weights['norm.bias'].to(torch.float32),
475
- w_concat=w_concat,
476
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float32),
477
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float32),
478
- to_out_weight=weights['to_out.weight'].to(torch.float16),
479
- h=config["hidden_dim"]
480
- )
481
- return output
482
-
483
- def kernel_h100(data):
484
- input_tensor, mask, weights, config = data
485
- bs, s1, s2, d = input_tensor.shape
486
-
487
- if s1 <= 512:
488
- return small_kernel_pt_path(data)
489
-
490
- H = config["hidden_dim"]
491
-
492
- W_4way = pack_w_4way_efficient(weights)
493
- W_og = get_w_og(weights)
494
-
495
- M = bs * s1 * s2
496
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
497
-
498
- return compiledtrimul_fused_interleaved(
499
- x=input_tensor.to(torch.float32),
500
- mask_mh=mask_mh,
501
- norm_weight=weights['norm.weight'].to(torch.float32),
502
- norm_bias=weights['norm.bias'].to(torch.float32),
503
- W_4way=W_4way, # Pass the new 4-way matrix
504
- W_og=W_og, # Pass the new out_gate matrix
505
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
506
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
507
- to_out_weight=weights['to_out.weight'].to(torch.float16),
508
- h=H,
509
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from .triton_a100 import kernel_a100
2
- from .triton_h100 import kernel_h100
3
- from .triton_b200 import kernel_b200
4
- from .trimul_mi300 import kernel_mi300
5
- from .trimul_global import kernel_global
6
-
7
- __all__ = ["kernel_a100", "kernel_h100", "kernel_b200", "kernel_mi300", "kernel_global"]
 
 
 
 
 
 
 
 
build/torch-rocm/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._trimul_gpumode_8e6e60d
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_trimul_gpumode_8e6e60d::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-rocm/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch-rocm/task.py DELETED
@@ -1,20 +0,0 @@
1
- """
2
- Type definitions for TriMul task.
3
-
4
- Input: Tuple of (input_tensor, mask, weights, config)
5
- - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
6
- - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
7
- - weights: Dictionary containing model weights
8
- - config: Dictionary containing model configuration parameters
9
-
10
- Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
11
- """
12
-
13
- import torch
14
- from typing import Tuple, Dict, Any
15
-
16
- # Input type: (input_tensor, mask, weights, config)
17
- input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
18
-
19
- # Output type: output tensor
20
- output_t = torch.Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/trimul_global.py DELETED
@@ -1,971 +0,0 @@
1
- # from utils import make_match_reference, DisableCuDNNTF32
2
- from .task import input_t, output_t
3
-
4
- import torch
5
- from torch import nn, einsum
6
- import math
7
- import os
8
- import requests
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
14
- # in PyTorch 1.12 and later.
15
- torch.backends.cuda.matmul.allow_tf32 = True
16
-
17
- # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
18
- torch.backends.cudnn.allow_tf32 = True
19
-
20
- # Set allocator for TMA descriptors (required for on-device TMA)
21
- def alloc_fn(size: int, alignment: int, stream=None):
22
- return torch.empty(size, device="cuda", dtype=torch.int8)
23
-
24
- triton.set_allocator(alloc_fn)
25
-
26
- # os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
27
- # os.environ['MLIR_ENABLE_DIAGNOSTICS'] = 'warnings,remarks'
28
-
29
- # Reference code in PyTorch
30
- class TriMul(nn.Module):
31
- # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
32
- def __init__(
33
- self,
34
- dim: int,
35
- hidden_dim: int,
36
- ):
37
- super().__init__()
38
-
39
- self.norm = nn.LayerNorm(dim)
40
-
41
- self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
42
- self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
43
-
44
- self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
45
- self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
46
- self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
47
-
48
- self.to_out_norm = nn.LayerNorm(hidden_dim)
49
- self.to_out = nn.Linear(hidden_dim, dim, bias=False)
50
-
51
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
52
- """
53
- x: [bs, seq_len, seq_len, dim]
54
- mask: [bs, seq_len, seq_len]
55
-
56
- Returns:
57
- output: [bs, seq_len, seq_len, dim]
58
- """
59
- batch_size, seq_len, _, dim = x.shape
60
-
61
- x = self.norm(x)
62
-
63
- left = self.left_proj(x)
64
- right = self.right_proj(x)
65
-
66
- mask = mask.unsqueeze(-1)
67
- left = left * mask
68
- right = right * mask
69
-
70
- left_gate = self.left_gate(x).sigmoid()
71
- right_gate = self.right_gate(x).sigmoid()
72
- out_gate = self.out_gate(x).sigmoid()
73
-
74
- left = left * left_gate
75
- right = right * right_gate
76
-
77
- out = einsum('... i k d, ... j k d -> ... i j d', left, right)
78
- # This einsum is the same as the following:
79
- # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
80
-
81
- # # Compute using nested loops
82
- # for b in range(batch_size):
83
- # for i in range(seq_len):
84
- # for j in range(seq_len):
85
- # # Compute each output element
86
- # for k in range(seq_len):
87
- # out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
88
-
89
- out = self.to_out_norm(out)
90
- out = out * out_gate
91
- return self.to_out(out)
92
-
93
- @triton.jit
94
- def triton_sigmoid(x):
95
- """
96
- Compute sigmoid function: 1 / (1 + exp(-x))
97
- """
98
- return 1.0 / (1.0 + tl.exp(-x))
99
-
100
- def two_mm_kernel_configs_wrapper():
101
- if torch.cuda.get_device_capability() == (12, 0):
102
- def two_mm_kernel_configs():
103
- configs = []
104
- for BLOCK_M in [16, 32]:
105
- for BLOCK_N in [16, 32, 64]:
106
- for BLOCK_K in [16, 32, 64]:
107
- for num_stages in [2, 3]:
108
- configs.append(triton.Config({
109
- 'BLOCK_M': BLOCK_M,
110
- 'BLOCK_N': BLOCK_N,
111
- 'BLOCK_K': BLOCK_K,
112
- 'GROUP_SIZE_M': 8
113
- }, num_stages=num_stages, num_warps=8))
114
- return configs
115
-
116
- elif torch.cuda.get_device_capability()[0] == 9:
117
- def get_optimal_two_mm_config_h100(B, seq_len, dim):
118
- configs = {
119
- (1, 128, 128): (128, 64, 128, 2, 8),
120
- (1, 128, 256): (128, 64, 128, 2, 8),
121
- (1, 128, 384): (128, 64, 64, 3, 8),
122
- (1, 128, 512): (128, 64, 64, 3, 8),
123
- (1, 128, 768): (128, 64, 64, 3, 8),
124
- (1, 128, 1024): (128, 64, 64, 3, 8),
125
- (1, 256, 128): (128, 64, 128, 2, 8),
126
- (1, 256, 256): (128, 64, 128, 2, 8),
127
- (1, 256, 384): (128, 64, 64, 3, 8),
128
- (1, 256, 512): (128, 64, 64, 3, 8),
129
- (1, 256, 768): (128, 64, 64, 3, 8),
130
- (1, 256, 1024): (128, 64, 64, 3, 8),
131
- (1, 512, 128): (128, 64, 128, 2, 8),
132
- (1, 512, 256): (128, 64, 128, 2, 8),
133
- (1, 512, 384): (128, 64, 128, 2, 8),
134
- (1, 512, 512): (128, 64, 128, 2, 8),
135
- (1, 512, 768): (128, 64, 64, 3, 8),
136
- (1, 512, 1024): (128, 64, 64, 3, 8),
137
- (1, 1024, 128): (128, 64, 128, 2, 8),
138
- (1, 1024, 256): (128, 64, 64, 2, 8),
139
- (1, 1024, 384): (128, 64, 128, 2, 8),
140
- (1, 1024, 512): (128, 64, 128, 2, 8),
141
- (1, 1024, 768): (128, 64, 128, 2, 8),
142
- (1, 1024, 1024): (128, 64, 128, 2, 8),
143
- (2, 128, 128): (128, 64, 128, 2, 8),
144
- (2, 128, 256): (128, 64, 128, 2, 8),
145
- (2, 128, 384): (128, 64, 64, 3, 8),
146
- (2, 128, 512): (128, 64, 64, 3, 8),
147
- (2, 128, 768): (128, 64, 64, 3, 8),
148
- (2, 128, 1024): (128, 64, 64, 3, 8),
149
- (2, 256, 128): (128, 64, 128, 2, 8),
150
- (2, 256, 256): (128, 64, 128, 2, 8),
151
- (2, 256, 384): (128, 64, 128, 2, 8),
152
- (2, 256, 512): (128, 64, 128, 2, 8),
153
- (2, 256, 768): (128, 64, 64, 3, 8),
154
- (2, 256, 1024): (128, 64, 64, 3, 8),
155
- (2, 512, 128): (128, 64, 128, 2, 8),
156
- (2, 512, 256): (128, 64, 128, 2, 8),
157
- (2, 512, 384): (128, 64, 128, 2, 8),
158
- (2, 512, 512): (128, 64, 128, 2, 8),
159
- (2, 512, 768): (128, 64, 128, 2, 8),
160
- (2, 512, 1024): (128, 64, 128, 2, 8),
161
- (2, 1024, 128): (128, 64, 128, 2, 8),
162
- (2, 1024, 256): (128, 64, 128, 2, 8),
163
- (2, 1024, 384): (128, 64, 128, 2, 8),
164
- (2, 1024, 512): (128, 64, 128, 2, 8),
165
- (2, 1024, 768): (128, 64, 128, 2, 8),
166
- (2, 1024, 1024): (128, 64, 128, 2, 8),
167
- }
168
- return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
169
-
170
- def two_mm_kernel_configs():
171
- # This function is kept for compatibility but will be overridden for H100
172
- return [
173
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
174
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
175
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
176
- ]
177
-
178
- elif torch.cuda.get_device_capability()[0] == 10 and False:
179
- def get_optimal_two_mm_config(B, seq_len, dim):
180
- configs = {
181
- (1, 128, 128): (64, 128, 64, 2, 8),
182
- (1, 128, 256): (128, 64, 128, 2, 8),
183
- (1, 128, 384): (128, 64, 128, 2, 8),
184
- (1, 128, 512): (128, 64, 128, 2, 8),
185
- (1, 128, 768): (128, 64, 64, 3, 8),
186
- (1, 128, 1024): (128, 64, 64, 3, 8),
187
- (1, 256, 128): (128, 64, 128, 2, 8),
188
- (1, 256, 256): (128, 64, 128, 2, 8),
189
- (1, 256, 384): (128, 64, 128, 2, 8),
190
- (1, 256, 512): (128, 64, 64, 3, 8),
191
- (1, 256, 768): (128, 64, 64, 3, 8),
192
- (1, 256, 1024): (128, 64, 64, 3, 8),
193
- (1, 512, 128): (128, 64, 128, 2, 8),
194
- (1, 512, 256): (128, 64, 128, 2, 8),
195
- (1, 512, 384): (128, 64, 128, 2, 8),
196
- (1, 512, 512): (128, 64, 128, 2, 8),
197
- (1, 512, 768): (128, 64, 64, 3, 8),
198
- (1, 512, 1024): (128, 64, 64, 3, 8),
199
- (1, 1024, 128): (128, 64, 128, 2, 8),
200
- (1, 1024, 256): (128, 64, 128, 2, 8),
201
- (1, 1024, 384): (128, 64, 128, 2, 8),
202
- (1, 1024, 512): (128, 64, 128, 2, 8),
203
- (1, 1024, 768): (128, 64, 64, 3, 8),
204
- (1, 1024, 1024): (128, 64, 64, 3, 8),
205
- (2, 128, 128): (128, 64, 128, 2, 8),
206
- (2, 128, 256): (128, 64, 128, 2, 8),
207
- (2, 128, 384): (128, 64, 128, 2, 8),
208
- (2, 128, 512): (128, 64, 64, 3, 8),
209
- (2, 128, 768): (128, 64, 64, 3, 8),
210
- (2, 128, 1024): (128, 64, 64, 3, 8),
211
- (2, 256, 128): (128, 64, 128, 2, 8),
212
- (2, 256, 256): (128, 64, 128, 2, 8),
213
- (2, 256, 384): (128, 64, 128, 2, 8),
214
- (2, 256, 512): (128, 64, 64, 3, 8),
215
- (2, 256, 768): (128, 64, 64, 3, 8),
216
- (2, 256, 1024): (128, 64, 64, 3, 8),
217
- (2, 512, 128): (128, 64, 128, 2, 8),
218
- (2, 512, 256): (128, 64, 128, 2, 8),
219
- (2, 512, 384): (128, 64, 128, 2, 8),
220
- (2, 512, 512): (128, 64, 128, 2, 8),
221
- (2, 512, 768): (128, 64, 64, 3, 8),
222
- (2, 512, 1024): (128, 64, 64, 3, 8),
223
- (2, 1024, 128): (128, 64, 128, 2, 8),
224
- (2, 1024, 256): (128, 64, 128, 2, 8),
225
- (2, 1024, 384): (128, 64, 128, 2, 8),
226
- (2, 1024, 512): (128, 64, 128, 2, 8),
227
- (2, 1024, 768): (128, 64, 64, 3, 8),
228
- (2, 1024, 1024): (128, 64, 64, 3, 8),
229
- }
230
- return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
231
-
232
- def two_mm_kernel_configs():
233
- # This function is kept for compatibility but will be overridden
234
- return [
235
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
236
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
237
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
238
- ]
239
- elif torch.cuda.get_device_capability()[0] == 8:
240
- # A100
241
- def two_mm_kernel_configs():
242
- configs = []
243
- for BLOCK_M in [64]:
244
- for BLOCK_N in [64, 128]:
245
- for BLOCK_K in [16]:
246
- for num_stages in [3, 4]:
247
- for num_warps in [4, 8]:
248
- configs.append(triton.Config({
249
- 'BLOCK_M': BLOCK_M,
250
- 'BLOCK_N': BLOCK_N,
251
- 'BLOCK_K': BLOCK_K,
252
- 'GROUP_SIZE_M': 8
253
- }, num_stages=num_stages, num_warps=num_warps))
254
- return configs
255
- else:
256
- def two_mm_kernel_configs():
257
- configs = []
258
- for BLOCK_M in [64, 128]:
259
- for BLOCK_N in [64, 128]:
260
- for BLOCK_K in [64, 128]:
261
- for num_stages in [2, 3]:
262
- configs.append(triton.Config({
263
- 'BLOCK_M': BLOCK_M,
264
- 'BLOCK_N': BLOCK_N,
265
- 'BLOCK_K': BLOCK_K,
266
- 'GROUP_SIZE_M': 8
267
- }, num_stages=num_stages, num_warps=8))
268
- return configs
269
-
270
- return two_mm_kernel_configs
271
-
272
- def two_mm_kernel_wrapper():
273
- if torch.cuda.get_device_capability()[0] == 8:
274
- @triton.jit
275
- def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
276
- # Persistent kernel using standard tl.load operations
277
- start_pid = tl.program_id(axis=0)
278
- num_pid_m = tl.cdiv(M, BLOCK_M)
279
- num_pid_n = tl.cdiv(N, BLOCK_N)
280
- k_tiles = tl.cdiv(K, BLOCK_K)
281
- num_tiles = num_pid_m * num_pid_n
282
-
283
- # tile_id_c is used in the epilogue to break the dependency between
284
- # the prologue and the epilogue
285
- tile_id_c = start_pid - NUM_SMS
286
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
287
-
288
- # Persistent loop over tiles
289
- for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
290
- # Calculate PID for this tile using improved swizzling
291
- group_id = tile_id // num_pid_in_group
292
- first_pid_m = group_id * GROUP_SIZE_M
293
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
294
- pid_m = first_pid_m + (tile_id % group_size_m)
295
- pid_n = (tile_id % num_pid_in_group) // group_size_m
296
-
297
- # Calculate block offsets
298
- offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
299
- offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
300
- offs_k = tl.arange(0, BLOCK_K)
301
-
302
- # Initialize accumulators for all outputs
303
- accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
304
- accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
305
- accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
306
- accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
307
- accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
308
-
309
- # Main computation loop over K dimension
310
- for ki in range(k_tiles):
311
- k_start = ki * BLOCK_K
312
- k_offsets = k_start + offs_k
313
-
314
- # Create pointers for A matrix (2D flattened view)
315
- a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3
316
- a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K)
317
-
318
- # Create pointers for B matrices [N, K] layout
319
- b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
320
- b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
321
- b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
322
- b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
323
- b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
324
- b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K)
325
-
326
- # Load blocks from A and all weight matrices using standard tl.load
327
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
328
- b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0)
329
- b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0)
330
- b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0)
331
- b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0)
332
- b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0)
333
-
334
- # Perform matrix multiplications using TF32
335
- accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
336
- accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
337
- accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
338
- accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
339
- accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
340
-
341
- # Store results using separate tile_id_c for epilogue
342
- tile_id_c += NUM_SMS
343
- group_id = tile_id_c // num_pid_in_group
344
- first_pid_m = group_id * GROUP_SIZE_M
345
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346
- pid_m = first_pid_m + (tile_id_c % group_size_m)
347
- pid_n = (tile_id_c % num_pid_in_group) // group_size_m
348
-
349
- # Calculate output offsets and pointers
350
- offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
351
- offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
352
-
353
- # Create masks for bounds checking
354
- d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
355
-
356
- # Calculate pointer addresses using 4D strides
357
- stride_cm = stride_c2 # Stride to next element in flattened M dimension
358
- stride_cn = stride_c3 # N is the innermost dimension
359
-
360
- # For D tensor: use separate D strides
361
- stride_dm = stride_d2 # Stride to next element in flattened M dimension
362
- stride_dn = stride_d3 # N is the innermost dimension
363
-
364
- off_c_batch = offs_cm // (seq_len * seq_len)
365
- off_c_sl1 = (offs_cm // seq_len) % seq_len
366
- off_c_sl2 = offs_cm % seq_len
367
- off_c_dim = offs_cn
368
-
369
- c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
370
- c_mask = d_mask
371
-
372
- c1_ptrs = c1_ptr + c_offsets
373
- c2_ptrs = c2_ptr + c_offsets
374
- d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
375
-
376
- mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
377
-
378
- # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
379
- mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
380
- # Apply masking only to left_proj and right_proj results (C1, C2)
381
- accumulator1 = tl.where(mask_2d, accumulator1, 0)
382
- accumulator2 = tl.where(mask_2d, accumulator2, 0)
383
-
384
- # Apply sigmoid to gate values
385
- left_gate_sigmoid = triton_sigmoid(accumulator3)
386
- right_gate_sigmoid = triton_sigmoid(accumulator4)
387
- accumulator_d = triton_sigmoid(accumulator_d)
388
-
389
- # Apply elementwise multiplication with gated values
390
- # C1 = left * left_gate, C2 = right * right_gate
391
- accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
392
- accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
393
-
394
- # Convert to appropriate output dtype and store with normal tl.store
395
- c1 = accumulator1.to(c1_ptr.dtype.element_ty)
396
- c2 = accumulator2.to(c2_ptr.dtype.element_ty)
397
- d = accumulator_d.to(d_ptr.dtype.element_ty)
398
-
399
- tl.store(c1_ptrs, c1, mask=c_mask)
400
- tl.store(c2_ptrs, c2, mask=c_mask)
401
- tl.store(d_ptrs, d, mask=d_mask)
402
- else:
403
- @triton.jit
404
- def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
405
- # Persistent kernel using on-device TMA descriptors
406
- start_pid = tl.program_id(axis=0)
407
- num_pid_m = tl.cdiv(M, BLOCK_M)
408
- num_pid_n = tl.cdiv(N, BLOCK_N)
409
- k_tiles = tl.cdiv(K, BLOCK_K)
410
- num_tiles = num_pid_m * num_pid_n
411
-
412
- # Create on-device TMA descriptors
413
- a_desc = tl._experimental_make_tensor_descriptor(
414
- a_ptr,
415
- shape=[M, K],
416
- strides=[stride_a2, stride_a3],
417
- block_shape=[BLOCK_M, BLOCK_K],
418
- )
419
- b1_desc = tl._experimental_make_tensor_descriptor(
420
- b1_ptr,
421
- shape=[N, K],
422
- strides=[stride_bn, stride_bk],
423
- block_shape=[BLOCK_N, BLOCK_K],
424
- )
425
- b2_desc = tl._experimental_make_tensor_descriptor(
426
- b2_ptr,
427
- shape=[N, K],
428
- strides=[stride_bn, stride_bk],
429
- block_shape=[BLOCK_N, BLOCK_K],
430
- )
431
- b3_desc = tl._experimental_make_tensor_descriptor(
432
- b3_ptr,
433
- shape=[N, K],
434
- strides=[stride_bn, stride_bk],
435
- block_shape=[BLOCK_N, BLOCK_K],
436
- )
437
- b4_desc = tl._experimental_make_tensor_descriptor(
438
- b4_ptr,
439
- shape=[N, K],
440
- strides=[stride_bn, stride_bk],
441
- block_shape=[BLOCK_N, BLOCK_K],
442
- )
443
- b5_desc = tl._experimental_make_tensor_descriptor(
444
- b5_ptr,
445
- shape=[N, K],
446
- strides=[stride_bn, stride_bk],
447
- block_shape=[BLOCK_N, BLOCK_K],
448
- )
449
-
450
- # tile_id_c is used in the epilogue to break the dependency between
451
- # the prologue and the epilogue
452
- tile_id_c = start_pid - NUM_SMS
453
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
454
-
455
- # Persistent loop over tiles
456
- for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
457
- # Calculate PID for this tile using improved swizzling
458
- group_id = tile_id // num_pid_in_group
459
- first_pid_m = group_id * GROUP_SIZE_M
460
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
461
- pid_m = first_pid_m + (tile_id % group_size_m)
462
- pid_n = (tile_id % num_pid_in_group) // group_size_m
463
-
464
- # Calculate block offsets
465
- offs_am = pid_m * BLOCK_M
466
- offs_bn = pid_n * BLOCK_N
467
-
468
- # Initialize accumulators for all outputs
469
- accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
470
- accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
471
- accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
472
- accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
473
- accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
474
-
475
- # Main computation loop over K dimension
476
- for ki in range(k_tiles):
477
- offs_k = ki * BLOCK_K
478
- # Load blocks from A and all weight matrices using on-device TMA
479
- a = a_desc.load([offs_am, offs_k])
480
- b1 = b1_desc.load([offs_bn, offs_k])
481
- b2 = b2_desc.load([offs_bn, offs_k])
482
- b3 = b3_desc.load([offs_bn, offs_k])
483
- b4 = b4_desc.load([offs_bn, offs_k])
484
- b5 = b5_desc.load([offs_bn, offs_k])
485
-
486
- # Perform matrix multiplications using TF32
487
- accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
488
- accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
489
- accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
490
- accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
491
- accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
492
-
493
- # Store results using separate tile_id_c for epilogue
494
- tile_id_c += NUM_SMS
495
- group_id = tile_id_c // num_pid_in_group
496
- first_pid_m = group_id * GROUP_SIZE_M
497
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
498
- pid_m = first_pid_m + (tile_id_c % group_size_m)
499
- pid_n = (tile_id_c % num_pid_in_group) // group_size_m
500
-
501
- # Calculate output offsets and pointers
502
- offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
503
- offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
504
-
505
- # Create masks for bounds checking
506
- d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
507
-
508
- # Calculate pointer addresses using 4D strides
509
- # For C tensors: compute effective 2D strides from 4D strides
510
- # Output tensor is [B, I, J, N], flattened to [M, N] where M = B*I*J
511
- stride_cm = stride_c2 # Stride to next element in flattened M dimension
512
- stride_cn = stride_c3 # N is the innermost dimension
513
-
514
- # For D tensor: use separate D strides
515
- stride_dm = stride_d2 # Stride to next element in flattened M dimension
516
- stride_dn = stride_d3 # N is the innermost dimension
517
-
518
- off_c_batch = offs_cm // (seq_len * seq_len)
519
- off_c_sl1 = (offs_cm // seq_len) % seq_len
520
- off_c_sl2 = offs_cm % seq_len
521
- off_c_dim = offs_cn
522
-
523
- # TODO update the mask_c so we don't IMA
524
- c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
525
- # c_offsets = offs_cm[:, None] * stride_c2 + offs_cn[None, :] * stride_c3
526
- c_mask = d_mask
527
-
528
- c1_ptrs = c1_ptr + c_offsets
529
- c2_ptrs = c2_ptr + c_offsets
530
- # c1_ptrs = c1_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
531
- # c2_ptrs = c2_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
532
- d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
533
-
534
- mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
535
-
536
- # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
537
- mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
538
- # Apply masking only to left_proj and right_proj results (C1, C2)
539
- accumulator1 = tl.where(mask_2d, accumulator1, 0)
540
- accumulator2 = tl.where(mask_2d, accumulator2, 0)
541
-
542
- # Apply sigmoid to gate values
543
- left_gate_sigmoid = triton_sigmoid(accumulator3)
544
- right_gate_sigmoid = triton_sigmoid(accumulator4)
545
- accumulator_d = triton_sigmoid(accumulator_d)
546
-
547
- # Apply elementwise multiplication with gated values
548
- # C1 = left * left_gate, C2 = right * right_gate
549
- accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
550
- accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
551
-
552
- # Convert to appropriate output dtype and store with normal tl.store
553
- c1 = accumulator1.to(c1_ptr.dtype.element_ty)
554
- c2 = accumulator2.to(c2_ptr.dtype.element_ty)
555
- d = accumulator_d.to(d_ptr.dtype.element_ty)
556
-
557
- tl.store(c1_ptrs, c1, mask=c_mask)
558
- tl.store(c2_ptrs, c2, mask=c_mask)
559
- tl.store(d_ptrs, d, mask=d_mask)
560
-
561
-
562
- if torch.cuda.get_device_capability()[0] not in [9, 10.2]:
563
- two_mm_kernel = triton.autotune(
564
- (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"]
565
- )(two_mm_kernel)
566
-
567
- return two_mm_kernel
568
-
569
-
570
- def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask):
571
- """
572
- Persistent matrix multiplication for all weight matrices using on-device TMA descriptors.
573
-
574
- Args:
575
- A: [..., K] tensor (arbitrary leading dimensions)
576
- left_proj: [N, K] matrix (will be transposed)
577
- right_proj: [N, K] matrix (will be transposed)
578
- left_gate: [N, K] left gate weight matrix
579
- right_gate: [N, K] right gate weight matrix
580
- out_gate: [N, K] output gate weight matrix
581
- mask: mask tensor
582
-
583
- Returns:
584
- (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A
585
- C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked)
586
- C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked)
587
- D = sigmoid(A @ out_gate.T) (unmasked)
588
- """
589
- # Check constraints
590
- assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions"
591
- assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes"
592
-
593
- # Assert that all weight matrices have the same strides (same [N, K] shape)
594
- assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \
595
- "All weight matrices must have identical strides"
596
-
597
- # Get dimensions
598
- original_shape = A.shape[:-1] # All dimensions except the last
599
- K = A.shape[-1]
600
- N = left_proj.shape[0]
601
- B, seq_len, _, _ = A.shape
602
- dtype = A.dtype
603
-
604
- # Flatten A to 2D for kernel processing
605
- A_2d = A.view(-1, K) # [M, K] where M is product of all leading dims
606
- M = A_2d.shape[0]
607
-
608
- # Get number of streaming multiprocessors
609
- NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
610
-
611
- # Launch persistent kernel with limited number of blocks
612
- grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
613
-
614
- # Get original 4D strides for A and output tensors
615
- A_strides = A.stride() # (stride_0, stride_1, stride_2, stride_3)
616
-
617
- # Create output tensors with proper 4D shape to get correct strides
618
- output_shape = original_shape + (N,)
619
- # C1 = torch.empty(output_shape, device=A.device, dtype=dtype)
620
- # C2 = torch.empty(output_shape, device=A.device, dtype=dtype)
621
- C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
622
- C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
623
- D = torch.empty(output_shape, device=A.device, dtype=torch.float16)
624
-
625
- C_strides = C1.stride() # (stride_0, stride_1, stride_2, stride_3)
626
- D_strides = D.stride() # (stride_0, stride_1, stride_2, stride_3)
627
-
628
- # Use optimal configuration for B200/H100 or fallback to autotuning for other GPUs
629
- if torch.cuda.get_device_capability()[0] == 10:
630
- # Get optimal configuration for B200
631
- BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
632
- grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
633
-
634
- two_mm_kernel_wrapper()[(grid_size,)](
635
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
636
- C1, C2, D, mask,
637
- M, N, K,
638
- *A_strides, # 4D strides for A
639
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
640
- *C_strides, # 4D strides for C
641
- seq_len,
642
- *D_strides, # 4D strides for D
643
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
644
- num_stages=num_stages, num_warps=num_warps
645
- )
646
- elif torch.cuda.get_device_capability()[0] == 9:
647
- # Get optimal configuration for H100
648
- BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
649
- grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
650
-
651
- two_mm_kernel_wrapper()[(grid_size,)](
652
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
653
- C1, C2, D, mask,
654
- M, N, K,
655
- *A_strides, # 4D strides for A
656
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
657
- *C_strides, # 4D strides for C
658
- seq_len,
659
- *D_strides, # 4D strides for D
660
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
661
- num_stages=num_stages, num_warps=num_warps
662
- )
663
- else:
664
- # Use autotuning for other GPUs
665
- two_mm_kernel_wrapper()[grid](
666
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
667
- C1, C2, D, mask,
668
- M, N, K,
669
- *A_strides, # 4D strides for A
670
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
671
- *C_strides, # 4D strides for C
672
- seq_len,
673
- *D_strides, # 4D strides for D
674
- NUM_SMS=NUM_SMS
675
- )
676
-
677
- return C1, C2, D
678
-
679
-
680
- def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand):
681
- ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype))
682
- out = ln * mul_operand
683
- return out
684
-
685
- '''
686
- @triton.autotune(
687
- [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)],
688
- key=["R", "C"]
689
- )
690
- '''
691
- @triton.jit
692
- def layernorm_kernel_first(
693
- X,
694
- Y,
695
- Weight,
696
- Bias,
697
- R,
698
- C, # aka "dim"
699
- eps,
700
- ROW_BLOCK_SIZE: tl.constexpr,
701
- BLOCK_SIZE: tl.constexpr,
702
- ):
703
- row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
704
- cols = tl.arange(0, BLOCK_SIZE)
705
-
706
- mask_row = row < R
707
- mask_col = cols < C
708
-
709
- # Simple indexing for contiguous data
710
- x = tl.load(
711
- X + row[:, None] * C + cols[None, :],
712
- mask=mask_row[:, None] & mask_col[None, :],
713
- other=0.0
714
- ).to(tl.float32)
715
-
716
- weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
717
- bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
718
-
719
- mean = tl.sum(x, axis=1) / C
720
- diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
721
- var = tl.sum(diff * diff, axis=1) / C
722
- rstd = 1 / tl.sqrt(var + eps)
723
-
724
- y_hat = (x - mean[:, None]) * rstd[:, None]
725
- y = y_hat * weight[None, :] + bias[None, :]
726
-
727
- tl.store(
728
- Y + row[:, None] * C + cols[None, :],
729
- y,
730
- mask=mask_row[:, None] & mask_col[None, :]
731
- )
732
-
733
-
734
- def get_optimal_config_ln(dim):
735
- config = None
736
- if torch.cuda.get_device_capability()[0] == 9:
737
- if (dim <= 256):
738
- config = (16, 1)
739
- elif dim <= 512:
740
- config = (16, 2)
741
- elif dim <= 1024:
742
- config = (16, 4)
743
-
744
- if not config:
745
- config = (16, 4)
746
- return config
747
-
748
-
749
- def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None):
750
- B, seq_len, seq_len2, dim = x.shape
751
- assert(seq_len == seq_len2)
752
-
753
- R = B * seq_len * seq_len
754
- C = dim
755
-
756
- out = torch.empty_like(x, dtype=torch.float16)
757
-
758
- if not num_warps or not ROW_BLOCK_SIZE:
759
- ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim)
760
-
761
- BLOCK_SIZE = triton.next_power_of_2(C)
762
- assert(BLOCK_SIZE <= 1024)
763
-
764
- def grid(meta):
765
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
766
-
767
- layernorm_kernel_first[grid](
768
- x, out, weight, bias,
769
- R, C, eps,
770
- ROW_BLOCK_SIZE=ROW_BLOCK_SIZE,
771
- BLOCK_SIZE=BLOCK_SIZE,
772
- num_warps=num_warps,
773
- num_stages=3
774
- )
775
-
776
- return out
777
-
778
- '''
779
- def triton_layernorm_first(x, weight, bias, eps=1e-5):
780
- B, seq_len, seq_len2, dim = x.shape
781
- assert(seq_len == seq_len2)
782
-
783
- R = B * seq_len * seq_len
784
- C = dim
785
-
786
- out = torch.empty_like(x)
787
-
788
- BLOCK_SIZE = triton.next_power_of_2(C)
789
- assert(BLOCK_SIZE <= 1024)
790
-
791
- def grid(meta):
792
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
793
-
794
- layernorm_kernel_first[grid](
795
- x, out, weight, bias,
796
- R, C, eps,
797
- BLOCK_SIZE=BLOCK_SIZE
798
- )
799
-
800
- return out
801
- '''
802
-
803
-
804
- @triton.autotune(
805
- [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)],
806
- key=[]
807
- )
808
- @triton.jit
809
- def layernorm_kernel_eltwise(
810
- X,
811
- Y,
812
- Weight,
813
- Bias,
814
- OutGate,
815
- seq_len,
816
- stride_batch,
817
- stride_dim,
818
- R,
819
- C, # aka "dim"
820
- eps,
821
- ROW_BLOCK_SIZE: tl.constexpr,
822
- BLOCK_SIZE: tl.constexpr,
823
- ):
824
- row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
825
- cols = tl.arange(0, BLOCK_SIZE)
826
-
827
- # Calculate base pointer for this batch of rows
828
- tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0)
829
- # batch_offset = (row // (stride_seq1 // stride_dim)) * stride_batch
830
- batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len)
831
- seqs_off = row % (seq_len * seq_len) # TODO is this going to prevent vectorization
832
-
833
- off_r = batch * stride_batch + seqs_off
834
- off_c = cols * stride_dim
835
-
836
- mask_row = row < R
837
- mask_col = cols < C
838
-
839
- out_gate = tl.load(
840
- OutGate + row[:, None] * C + cols[None, :],
841
- mask = mask_row[:, None] & mask_col[None, :],
842
- )
843
-
844
- x = tl.load(
845
- X + off_r[:, None] + off_c[None, :],
846
- mask=mask_row[:, None] & mask_col[None, :],
847
- other=0.0
848
- ).to(tl.float32)
849
-
850
- weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
851
- bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
852
-
853
- mean = tl.sum(x, axis=1) / C
854
- diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
855
- var = tl.sum(diff * diff, axis=1) / C
856
- rstd = 1 / tl.sqrt(var + eps)
857
-
858
- y_hat = (x - mean[:, None]) * rstd[:, None]
859
- y = y_hat * weight[None, :] + bias[None, :]
860
-
861
- tl.store(
862
- Y + row[:, None] * C + cols[None, :],
863
- y * out_gate,
864
- mask=mask_row[:, None] & mask_col[None, :]
865
- )
866
-
867
-
868
- def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5):
869
- B, seq_len, seq_len2, dim = x.shape
870
- assert(seq_len == seq_len2)
871
- R = B * seq_len * seq_len
872
- assert(x.stride(3) == seq_len*seq_len)
873
- assert(out_gate.is_contiguous())
874
- C = dim
875
-
876
- out = torch.empty_like(out_gate, dtype=torch.float32)
877
-
878
- BLOCK_SIZE = triton.next_power_of_2(C)
879
- assert(BLOCK_SIZE == 128)
880
-
881
- def grid(meta):
882
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
883
-
884
- layernorm_kernel_eltwise[grid](
885
- x, out, weight, bias, out_gate,
886
- seq_len,
887
- x.stride(0), x.stride(3),
888
- R, C, eps,
889
- BLOCK_SIZE=BLOCK_SIZE
890
- )
891
-
892
- return out
893
-
894
-
895
- def kernel_global(data: input_t) -> output_t:
896
- """
897
- Reference implementation of TriMul using PyTorch.
898
-
899
- Args:
900
- data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
901
- - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
902
- - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
903
- - weights: Dictionary containing model weights
904
- - config: Dictionary containing model configuration parameters
905
- """
906
- input_tensor, mask, weights, config = data
907
-
908
- left_proj_weight = weights["left_proj.weight"].to(torch.float16)
909
- right_proj_weight = weights["right_proj.weight"].to(torch.float16)
910
- left_gate_weight = weights["left_gate.weight"].to(torch.float16)
911
- right_gate_weight = weights["right_gate.weight"].to(torch.float16)
912
- out_gate_weight = weights["out_gate.weight"].to(torch.float16)
913
-
914
- hidden_dim = config["hidden_dim"]
915
- # trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
916
-
917
- x = input_tensor
918
-
919
- batch_size, seq_len, _, dim = x.shape
920
-
921
- x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias'])
922
- # x = torch.nn.functional.layer_norm(x, (dim,), eps=1e-5, weight=weights['norm.weight'], bias=weights['norm.bias'])
923
-
924
- left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask)
925
- # left = torch.nn.functional.linear(x, weights['left_proj.weight'].to(torch.float16))
926
- # right = torch.nn.functional.linear(x, weights['right_proj.weight'].to(torch.float16))
927
-
928
- # left = left * mask.unsqueeze(-1)
929
- # right = right * mask.unsqueeze(-1)
930
-
931
- '''
932
- left = left.to(torch.float32)
933
- right = right.to(torch.float32)
934
- x = x.to(torch.float32)
935
-
936
- left_gate = left_gate.sigmoid()
937
- right_gate = right_gate.sigmoid()
938
- out_gate = out_gate.sigmoid()
939
- '''
940
-
941
- # Elementwise multiplication now handled in kernel
942
- # left = left * left_gate
943
- # right = right * right_gate
944
-
945
- # out = einsum('... i k d, ... j k d -> ... i j d', left, right)
946
- out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1]))
947
- out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1)
948
-
949
- # out = torch.compile(second_layernorm_mul, dynamic=False)(out, hidden_dim, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
950
- out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
951
- # out = torch.nn.functional.layer_norm(out, (hidden_dim,), eps=1e-5, weight=weights['to_out_norm.weight'].to(out.dtype), bias=weights['to_out_norm.bias'].to(out.dtype))
952
- # out = out * out_gate
953
- return torch.nn.functional.linear(out, weights['to_out.weight'])
954
-
955
- '''
956
- # Fill in the given weights of the model
957
- trimul.norm.weight = nn.Parameter(weights['norm.weight'])
958
- trimul.norm.bias = nn.Parameter(weights['norm.bias'])
959
- trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
960
- trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
961
- trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
962
- trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
963
- trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
964
- trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
965
- trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
966
- trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
967
-
968
- output = trimul(input_tensor, mask)
969
-
970
- return output
971
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/trimul_gpumode/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/trimul_mi300.py DELETED
@@ -1,524 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
12
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
13
-
14
- # Configurations with larger block sizes for better data reuse
15
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
16
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=2),
17
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
18
-
19
- # Configurations with deeper K dimension
20
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
21
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
22
-
23
- # More extreme configurations to test the limits
24
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
25
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=2),
26
-
27
- # Configurations with fewer warps
28
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
29
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=2),
30
-
31
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
32
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
33
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
34
- ],
35
- key=['M', 'N', 'K'],
36
- )
37
- @triton.jit
38
- def fused_ln_dual_matmul_kernel(
39
- # Pointers (9)
40
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
41
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
42
- # Metadata (5)
43
- M, H, K, s1, s2,
44
- # Strides (16)
45
- stride_x_m, stride_x_k,
46
- stride_w4_k, stride_w4_n,
47
- stride_wog_k, stride_wog_n,
48
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
49
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
50
- stride_og_m, stride_og_h,
51
- stride_mask_m, stride_mask_h,
52
- # Constexpr (from decorator and kwargs)
53
- LN_EPS: tl.constexpr,
54
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
55
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
56
- ):
57
- # --- PID Mapping: Based on the LARGER 4*H problem ---
58
- pid = tl.program_id(axis=0)
59
- N_4way = 4 * H
60
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
62
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
63
- group_id = pid // num_pid_in_group
64
- first_pid_m = group_id * GROUP_SIZE_M
65
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
66
- pid_m = first_pid_m + (pid % group_size_m)
67
- pid_n = (pid % num_pid_in_group) // group_size_m
68
-
69
- # --- SHARED LayerNorm calculation (done only ONCE) ---
70
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
- m_mask = offs_m < M
72
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
73
-
74
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
75
- for k_offset in range(0, K, BLOCK_SIZE_K):
76
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
77
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
78
- k_mask = (k_offset + k_chunk_offs) < K
79
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
80
- mean += tl.sum(x_chunk, axis=1)
81
- mean /= K
82
-
83
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
84
- for k_offset in range(0, K, BLOCK_SIZE_K):
85
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
86
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
87
- k_mask = (k_offset + k_chunk_offs) < K
88
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
89
- x_centered = x_chunk - mean[:, None]
90
- var += tl.sum(x_centered * x_centered, axis=1)
91
- var /= K
92
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
93
-
94
- # --- Matmul Loop 1: For the 4-Way Projections ---
95
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
96
- offs_k = tl.arange(0, BLOCK_SIZE_K)
97
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
98
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
99
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
-
101
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
102
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
103
- k_block_start = k * BLOCK_SIZE_K;
104
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
105
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
106
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
107
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
108
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
109
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
110
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
111
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
112
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
113
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
114
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
115
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
117
-
118
- #Some threads should calclate out_gate
119
- if pid_n * BLOCK_SIZE_N < H:
120
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
121
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
122
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
123
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
124
- accumulator_og += tl.dot(x_norm_tile, w_tile)
125
-
126
- if pid_n * BLOCK_SIZE_N < H:
127
- og_out = tl.sigmoid(accumulator_og)
128
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
129
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
130
- tl.store(outg_ptrs, og_out, mask=og_mask)
131
-
132
- # --- Fusion Logic for 4-Way Part ---
133
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
134
- role_idx = tl.arange(0, 4)[None, None, :]
135
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
136
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
137
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
138
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
139
-
140
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
141
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
142
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
143
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
144
-
145
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
146
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
147
-
148
- s1s2 = s1 * s2
149
- offs_b = offs_m // s1s2
150
- offs_s1 = (offs_m % s1s2) // s2
151
- offs_s2 = offs_m % s2
152
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
153
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
154
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
155
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
156
-
157
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
158
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
159
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
160
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
161
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
162
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
163
-
164
- @triton.autotune(
165
- configs=[
166
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
167
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
168
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
169
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
170
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
171
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
172
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
173
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
174
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
175
- ],
176
- key=['s1', 's2', 'H'],
177
- )
178
- @triton.jit
179
- def bmm_coalesced_kernel(
180
- # Pointers
181
- Left_ptr, Right_ptr, Out_ptr,
182
- # Dimensions
183
- bs, s1, s2, H,
184
- # Strides
185
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
186
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
187
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
188
- # Kernel parameters
189
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
190
- GROUP_SIZE_M: tl.constexpr,
191
- ):
192
- # Grid and program IDs
193
- pid = tl.program_id(axis=0)
194
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
195
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
196
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
197
- group_id = pid // num_pid_in_group
198
- first_pid_m = group_id * GROUP_SIZE_M
199
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
200
- pid_m = first_pid_m + (pid % group_size_m)
201
- pid_n = (pid % num_pid_in_group) // group_size_m
202
-
203
- pid_bh = tl.program_id(axis=1)
204
- pid_b = pid_bh // H
205
- pid_h = pid_bh % H
206
-
207
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
208
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
209
- offs_k = tl.arange(0, BLOCK_SIZE_K)
210
-
211
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
212
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
213
-
214
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
215
-
216
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
217
- k_start = k * BLOCK_SIZE_K
218
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
219
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
220
-
221
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
222
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
223
-
224
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
225
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
226
-
227
- accumulator += tl.dot(a, b)
228
-
229
- # --- Coalesced Write ---
230
- # Write to a standard (bs, H, s1, s1) layout
231
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
232
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
233
-
234
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
235
- tl.store(out_ptrs, accumulator, mask=c_mask)
236
-
237
- @triton.autotune(
238
- configs=[
239
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
240
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
241
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
243
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
244
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
245
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
247
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
248
- ],
249
- key=['H', 'D'],
250
- )
251
- @triton.jit
252
- def fused_final_kernel(
253
- # Pointers
254
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
255
- # Metadata
256
- M, H, D, s1, # M_gate = bs*s1*s2
257
- # Strides
258
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
259
- stride_gate_m, stride_gate_h,
260
- stride_proj_d, stride_proj_h,
261
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
262
- # Constants
263
- LN_EPS: tl.constexpr,
264
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
265
- GROUP_SIZE_M: tl.constexpr,
266
- ):
267
- # --- Grid and PID Setup for Matmul ---
268
- pid = tl.program_id(axis=0)
269
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
270
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
271
-
272
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
273
- group_id = pid // num_pid_in_group
274
- first_pid_m = group_id * GROUP_SIZE_M
275
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
276
- pid_m = first_pid_m + (pid % group_size_m)
277
- pid_n = (pid % num_pid_in_group) // group_size_m
278
-
279
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
280
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
281
- m_mask = offs_m < M
282
-
283
- # Decompose M back to (b, r, c) for reordering lookups
284
- s1s1 = s1 * s1
285
- b = offs_m // s1s1
286
- r = (offs_m % s1s1) // s1
287
- c = offs_m % s1
288
-
289
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
291
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
292
-
293
- for k_offset in range(0, H, BLOCK_SIZE_K):
294
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
295
- k_mask = offs_k < H
296
-
297
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
298
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
299
-
300
- # Accumulate sum and sum of squares in one pass
301
- sum_x += tl.sum(in_chunk, axis=1)
302
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
303
-
304
- # Finalize statistics
305
- mean = sum_x / H
306
- var = (sum_x2 / H) - (mean * mean)
307
- rstd = tl.math.rsqrt(var + LN_EPS)
308
-
309
- # --- Pass 3: Fused Gating and Matmul ---
310
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
311
- for k_offset in range(0, H, BLOCK_SIZE_K):
312
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
313
- k_mask = offs_k < H
314
-
315
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
316
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
317
- a_norm = (a - mean[:, None]) * rstd[:, None]
318
-
319
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
320
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
321
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
322
-
323
- proj_ptrs = ProjW_ptr + \
324
- offs_n[None, :] * stride_proj_d + \
325
- offs_k[:, None] * stride_proj_h
326
-
327
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
328
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
329
- a_gated = a_norm * gate
330
-
331
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
332
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
333
-
334
- # --- Store Final Output ---
335
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
336
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
337
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
338
-
339
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
340
-
341
- def compiledtrimul_fused_interleaved(
342
- x: torch.Tensor,
343
- mask_mh: torch.Tensor,
344
- norm_weight: torch.Tensor,
345
- norm_bias: torch.Tensor,
346
- W_4way: torch.Tensor, # Use the new weight matrices
347
- W_og: torch.Tensor,
348
- to_out_norm_weight: torch.Tensor,
349
- to_out_norm_bias: torch.Tensor,
350
- to_out_weight: torch.Tensor,
351
- h: int,
352
- ):
353
- bs, s1, s2, d = x.shape
354
- M, K, H = bs * s1 * s2, x.shape[-1], h
355
- x_flat = x.view(M, K)
356
-
357
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
358
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
359
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
360
-
361
- # The grid is launched for the larger 4*H problem
362
- N_4way = 4 * H
363
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
364
- fused_ln_dual_matmul_kernel[grid](
365
- # Pointers (9)
366
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
367
- left_final, right_final_t, og_mh,
368
- # Metadata (5) - M, H, K, s1, s2
369
- M, H, K, s1, s2,
370
- # Strides (16)
371
- x_flat.stride(0), x_flat.stride(1),
372
- W_4way.stride(0), W_4way.stride(1),
373
- W_og.stride(0), W_og.stride(1),
374
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
375
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
376
- og_mh.stride(0), og_mh.stride(1),
377
- mask_mh.stride(0), mask_mh.stride(1),
378
- # Constexpr (1)
379
- LN_EPS=1e-5
380
- )
381
-
382
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
383
-
384
- grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
385
- bmm_coalesced_kernel[grid_bmm](
386
- left_final, right_final_t, bmm_out_tmp,
387
- bs, s1, s2, H,
388
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
389
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
390
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
391
- )
392
-
393
- # --- Kernel 3: Fully Fused Final Stage ---
394
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
395
-
396
- grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
397
- fused_final_kernel[grid_final](
398
- # Pointers
399
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
400
- # Metadata
401
- M, H, d, s1,
402
- # Strides
403
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
404
- og_mh.stride(0), og_mh.stride(1),
405
- to_out_weight.stride(0), to_out_weight.stride(1), # Use strides of the corrected tensor
406
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
407
- # Constants
408
- LN_EPS=1e-5,
409
- )
410
-
411
- return final_out
412
-
413
- def pack_w_4way_efficient(weights):
414
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
415
- WL = weights['left_proj.weight']
416
- WLG = weights['left_gate.weight']
417
- WR = weights['right_proj.weight']
418
- WRG = weights['right_gate.weight']
419
- H, K = WL.shape
420
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
421
- ws = ws.contiguous().view(4 * H, K)
422
- return ws.t().to(torch.float16)
423
-
424
- def get_w_og(weights):
425
- """ Gets the transposed [K, H] out_gate weight matrix. """
426
- WOG = weights['out_gate.weight']
427
- return WOG.t().to(torch.float16)
428
-
429
- def compiledtrimul(
430
- x: torch.Tensor,
431
- mask: torch.Tensor,
432
- norm_weight: torch.Tensor,
433
- norm_bias: torch.Tensor,
434
- w_concat: torch.Tensor,
435
- to_out_norm_weight: torch.Tensor,
436
- to_out_norm_bias: torch.Tensor,
437
- to_out_weight: torch.Tensor,
438
- h: int
439
- ) -> torch.Tensor:
440
- """
441
- A barebones, compiled PyTorch function for the TriMul logic.
442
- """
443
- bs, s1, s2, d = x.shape
444
-
445
- # Initial LayerNorm
446
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
447
- # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
448
- all_projections = torch.mm(x_norm, w_concat)
449
-
450
- # Split back into individual projections
451
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
452
-
453
- # Apply mask and gates
454
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
455
- left = left * mask_expanded * torch.sigmoid(lg)
456
- right = right * mask_expanded * torch.sigmoid(rg)
457
- out_gate = torch.sigmoid(og)
458
-
459
- # Reshape for einsum
460
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
461
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
462
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
463
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
464
-
465
- # Apply layer norm and final gating
466
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
467
- gated = normed * out_gate
468
-
469
- # Final projection
470
- final_out_flat = gated @ to_out_weight.t()
471
- final_out = final_out_flat.view(bs, s1, s2, d)
472
-
473
- return final_out
474
-
475
- def small_kernel_pt_path(data):
476
- input_tensor, mask, weights, config = data
477
- w_concat = torch.cat([
478
- weights['left_proj.weight'],
479
- weights['right_proj.weight'],
480
- weights['left_gate.weight'],
481
- weights['right_gate.weight'],
482
- weights['out_gate.weight']
483
- ], dim=0).t().contiguous().to(torch.float16)
484
- # Call the compiled function with prepared weights
485
- output = compiledtrimul(
486
- x=input_tensor.to(torch.float32),
487
- mask=mask.unsqueeze(-1),
488
- norm_weight=weights['norm.weight'].to(torch.float32),
489
- norm_bias=weights['norm.bias'].to(torch.float32),
490
- w_concat=w_concat,
491
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
492
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
493
- to_out_weight=weights['to_out.weight'].to(torch.float16),
494
- h=config["hidden_dim"]
495
- )
496
- return output
497
-
498
- def kernel_mi300(data):
499
- input_tensor, mask, weights, config = data
500
- bs, s1, s2, d = input_tensor.shape
501
-
502
- if s1 < 100:
503
- return small_kernel_pt_path(data)
504
-
505
- H = config["hidden_dim"]
506
-
507
- W_4way = pack_w_4way_efficient(weights)
508
- W_og = get_w_og(weights)
509
-
510
- M = bs * s1 * s2
511
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
512
-
513
- return compiledtrimul_fused_interleaved(
514
- x=input_tensor.to(torch.float32),
515
- mask_mh=mask_mh,
516
- norm_weight=weights['norm.weight'].to(torch.float32),
517
- norm_bias=weights['norm.bias'].to(torch.float32),
518
- W_4way=W_4way, # Pass the new 4-way matrix
519
- W_og=W_og, # Pass the new out_gate matrix
520
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
521
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
522
- to_out_weight=weights['to_out.weight'].to(torch.float16),
523
- h=H,
524
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/triton_a100.py DELETED
@@ -1,405 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- # Set PyTorch flags for performance
7
- torch.backends.cuda.matmul.allow_tf32 = True
8
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
9
-
10
- @triton.jit
11
- def fused_ln_dual_matmul_kernel(
12
- # Pointers (9)
13
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
14
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
15
- # Metadata (5)
16
- M, H, K, s1, s2,
17
- # Strides (16)
18
- stride_x_m, stride_x_k,
19
- stride_w4_k, stride_w4_n,
20
- stride_wog_k, stride_wog_n,
21
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
22
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
23
- stride_og_m, stride_og_h,
24
- stride_mask_m, stride_mask_h,
25
- # Constexpr (now passed as arguments from the host)
26
- LN_EPS: tl.constexpr,
27
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
29
- ):
30
- # --- PID Mapping: Based on the LARGER 4*H problem ---
31
- pid = tl.program_id(axis=0)
32
- N_4way = 4 * H
33
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
35
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
- group_id = pid // num_pid_in_group
37
- first_pid_m = group_id * GROUP_SIZE_M
38
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
- pid_m = first_pid_m + (pid % group_size_m)
40
- pid_n = (pid % num_pid_in_group) // group_size_m
41
-
42
- # --- SHARED LayerNorm calculation (done only ONCE) ---
43
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44
- m_mask = offs_m < M
45
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
46
-
47
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
48
- for k_offset in range(0, K, BLOCK_SIZE_K):
49
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
50
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
51
- k_mask = (k_offset + k_chunk_offs) < K
52
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
53
- mean += tl.sum(x_chunk, axis=1)
54
- mean /= K
55
-
56
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
57
- for k_offset in range(0, K, BLOCK_SIZE_K):
58
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
59
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
60
- k_mask = (k_offset + k_chunk_offs) < K
61
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
62
- x_centered = x_chunk - mean[:, None]
63
- var += tl.sum(x_centered * x_centered, axis=1)
64
- var /= K
65
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
66
-
67
- # --- Matmul Loop 1: For the 4-Way Projections ---
68
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
69
- offs_k = tl.arange(0, BLOCK_SIZE_K)
70
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
71
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
-
74
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76
- k_block_start = k * BLOCK_SIZE_K;
77
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
78
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
79
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
80
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
81
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
82
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
83
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
84
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
86
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
87
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
88
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
89
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
90
-
91
- if pid_n * BLOCK_SIZE_N < H:
92
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
- accumulator_og += tl.dot(x_norm_tile, w_tile)
97
-
98
- if pid_n * BLOCK_SIZE_N < H:
99
- og_out = tl.sigmoid(accumulator_og)
100
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
- tl.store(outg_ptrs, og_out, mask=og_mask)
103
-
104
- # --- Fusion Logic for 4-Way Part ---
105
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
- role_idx = tl.arange(0, 4)[None, None, :]
107
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
-
112
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
-
117
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
-
120
- s1s2 = s1 * s2
121
- offs_b = offs_m // s1s2
122
- offs_s1 = (offs_m % s1s2) // s2
123
- offs_s2 = offs_m % s2
124
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
-
129
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
-
136
- @triton.jit
137
- def bmm_coalesced_kernel(
138
- # Pointers
139
- Left_ptr, Right_ptr, Out_ptr,
140
- # Dimensions
141
- bs, s1, s2, H,
142
- # Strides
143
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
- # Kernel parameters
147
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
- GROUP_SIZE_M: tl.constexpr,
149
- ):
150
- pid = tl.program_id(axis=0)
151
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
152
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
153
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
154
- group_id = pid // num_pid_in_group
155
- first_pid_m = group_id * GROUP_SIZE_M
156
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
157
- pid_m = first_pid_m + (pid % group_size_m)
158
- pid_n = (pid % num_pid_in_group) // group_size_m
159
-
160
- pid_bh = tl.program_id(axis=1)
161
- pid_b = pid_bh // H
162
- pid_h = pid_bh % H
163
-
164
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
165
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
166
- offs_k = tl.arange(0, BLOCK_SIZE_K)
167
-
168
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
169
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
170
-
171
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
172
-
173
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
174
- k_start = k * BLOCK_SIZE_K
175
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
176
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
177
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
178
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
179
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
180
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
181
- accumulator += tl.dot(a, b)
182
-
183
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
184
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
185
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
186
- tl.store(out_ptrs, accumulator, mask=c_mask)
187
-
188
- @triton.jit
189
- def fused_final_kernel(
190
- # Pointers
191
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
192
- # Metadata
193
- M, H, D, s1,
194
- # Strides
195
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
196
- stride_gate_m, stride_gate_h,
197
- stride_proj_d, stride_proj_h,
198
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
199
- # Constants
200
- LN_EPS: tl.constexpr,
201
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
202
- GROUP_SIZE_M: tl.constexpr,
203
- ):
204
- pid = tl.program_id(axis=0)
205
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
207
-
208
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
- group_id = pid // num_pid_in_group
210
- first_pid_m = group_id * GROUP_SIZE_M
211
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
- pid_m = first_pid_m + (pid % group_size_m)
213
- pid_n = (pid % num_pid_in_group) // group_size_m
214
-
215
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
216
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
- m_mask = offs_m < M
218
-
219
- s1s1 = s1 * s1
220
- b = offs_m // s1s1
221
- r = (offs_m % s1s1) // s1
222
- c = offs_m % s1
223
-
224
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
225
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
226
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
227
-
228
- for k_offset in range(0, H, BLOCK_SIZE_K):
229
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
230
- k_mask = offs_k < H
231
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
232
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
233
- sum_x += tl.sum(in_chunk, axis=1)
234
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
235
-
236
- mean = sum_x / H
237
- var = (sum_x2 / H) - (mean * mean)
238
- rstd = tl.math.rsqrt(var + LN_EPS)
239
-
240
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
241
- for k_offset in range(0, H, BLOCK_SIZE_K):
242
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
243
- k_mask = offs_k < H
244
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
245
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
246
- a_norm = (a - mean[:, None]) * rstd[:, None]
247
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
248
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
249
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
250
- proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
251
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
252
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
253
- a_gated = a_norm * gate
254
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
255
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
256
-
257
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
258
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
259
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
260
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
261
-
262
- def compiledtrimul_fused_interleaved_final(
263
- x: torch.Tensor,
264
- mask_mh: torch.Tensor,
265
- norm_weight: torch.Tensor,
266
- norm_bias: torch.Tensor,
267
- W_4way: torch.Tensor,
268
- W_og: torch.Tensor,
269
- to_out_norm_weight: torch.Tensor,
270
- to_out_norm_bias: torch.Tensor,
271
- to_out_weight: torch.Tensor,
272
- h: int,
273
- ):
274
- bs, s1, s2, d = x.shape
275
- M, K, H = bs * s1 * s2, x.shape[-1], h
276
- x_flat = x.view(M, K)
277
-
278
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
279
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
280
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
281
-
282
- # --- Kernel 1: Fused LN + Dual Matmul ---
283
- N_4way = 4 * H
284
- # Hardcoded A100 best config: M128-N128-K32-GM8-HC32-W8-S2
285
- config_k1 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
286
- grid_k1 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
287
-
288
- fused_ln_dual_matmul_kernel[grid_k1](
289
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
290
- left_final, right_final_t, og_mh,
291
- M, H, K, s1, s2,
292
- x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
293
- W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
294
- left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
295
- right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
296
- mask_mh.stride(0), mask_mh.stride(1),
297
- LN_EPS=1e-5, **config_k1, num_warps=8, num_stages=2
298
- )
299
-
300
- # --- Kernel 2: Batched Matrix Multiplication ---
301
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
302
- # Hardcoded A100 best config: M128-N64-K32-GM8-W4-S3
303
- config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
304
- grid_k2 = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
305
-
306
- bmm_coalesced_kernel[grid_k2](
307
- left_final, right_final_t, bmm_out_tmp,
308
- bs, s1, s2, H,
309
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
310
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
311
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
312
- **config_k2, num_warps=4, num_stages=3
313
- )
314
-
315
- # --- Kernel 3: Fully Fused Final Stage ---
316
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
317
- # Hardcoded A100 best config: M32-N128-K32-GM8-W4-S3
318
- config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
319
- grid_k3 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
320
-
321
- fused_final_kernel[grid_k3](
322
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
323
- M, H, d, s1,
324
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
325
- og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
326
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
327
- LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
328
- )
329
- return final_out
330
-
331
- def pack_w_4way_efficient(weights):
332
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
333
- WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
334
- H, K = WL.shape
335
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
336
- return ws.t().to(torch.float16)
337
-
338
- def get_w_og(weights):
339
- """ Gets the transposed [K, H] out_gate weight matrix. """
340
- return weights['out_gate.weight'].t().to(torch.float16)
341
-
342
- @torch.compile()
343
- def compiledtrimul(
344
- x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
345
- w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
346
- to_out_weight: torch.Tensor, h: int
347
- ) -> torch.Tensor:
348
- bs, s1, s2, d = x.shape
349
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
350
- all_projections = torch.mm(x_norm, w_concat)
351
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
352
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
353
- left = left * mask_expanded * torch.sigmoid(lg)
354
- right = right * mask_expanded * torch.sigmoid(rg)
355
- out_gate = torch.sigmoid(og)
356
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
357
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
358
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
359
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
360
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
361
- gated = normed * out_gate
362
- final_out_flat = gated @ to_out_weight.t()
363
- return final_out_flat.view(bs, s1, s1, d)
364
-
365
- def small_kernel_pt_path(data):
366
- input_tensor, mask, weights, config = data
367
- w_concat = torch.cat([
368
- weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
369
- weights['right_gate.weight'], weights['out_gate.weight']
370
- ], dim=0).t().contiguous().to(torch.float16)
371
- return compiledtrimul(
372
- x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
373
- norm_weight=weights['norm.weight'].to(torch.float32),
374
- norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
375
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
376
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
377
- to_out_weight=weights['to_out.weight'].to(torch.float16),
378
- h=config["hidden_dim"]
379
- )
380
-
381
- def kernel_a100(data):
382
- input_tensor, mask, weights, config = data
383
- bs, s1, s2, d = input_tensor.shape
384
-
385
- if s1 < 512: # Adjusted threshold based on observed BMM configs
386
- return small_kernel_pt_path(data)
387
-
388
- H = config["hidden_dim"]
389
- W_4way = pack_w_4way_efficient(weights)
390
- W_og = get_w_og(weights)
391
- M = bs * s1 * s2
392
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
393
-
394
- return compiledtrimul_fused_interleaved_final(
395
- x=input_tensor.to(torch.float32),
396
- mask_mh=mask_mh,
397
- norm_weight=weights['norm.weight'].to(torch.float32),
398
- norm_bias=weights['norm.bias'].to(torch.float32),
399
- W_4way=W_4way,
400
- W_og=W_og,
401
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
402
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
403
- to_out_weight=weights['to_out.weight'].to(torch.float16),
404
- h=H,
405
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/triton_b200.py DELETED
@@ -1,411 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.jit
10
- def fused_ln_dual_matmul_kernel(
11
- # Pointers (9)
12
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
13
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
14
- # Metadata (5)
15
- M, H, K, s1, s2,
16
- # Strides (16)
17
- stride_x_m, stride_x_k,
18
- stride_w4_k, stride_w4_n,
19
- stride_wog_k, stride_wog_n,
20
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
21
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
22
- stride_og_m, stride_og_h,
23
- stride_mask_m, stride_mask_h,
24
- # Constexpr (now passed as arguments from the host)
25
- LN_EPS: tl.constexpr,
26
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
27
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
28
- ):
29
- # --- PID Mapping: Based on the LARGER 4*H problem ---
30
- pid = tl.program_id(axis=0)
31
- N_4way = 4 * H
32
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
34
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
35
- group_id = pid // num_pid_in_group
36
- first_pid_m = group_id * GROUP_SIZE_M
37
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38
- pid_m = first_pid_m + (pid % group_size_m)
39
- pid_n = (pid % num_pid_in_group) // group_size_m
40
-
41
- # --- SHARED LayerNorm calculation (done only ONCE) ---
42
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
43
- m_mask = offs_m < M
44
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
45
-
46
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
47
- for k_offset in range(0, K, BLOCK_SIZE_K):
48
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
49
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
50
- k_mask = (k_offset + k_chunk_offs) < K
51
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
52
- mean += tl.sum(x_chunk, axis=1)
53
- mean /= K
54
-
55
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
56
- for k_offset in range(0, K, BLOCK_SIZE_K):
57
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
58
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
59
- k_mask = (k_offset + k_chunk_offs) < K
60
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
61
- x_centered = x_chunk - mean[:, None]
62
- var += tl.sum(x_centered * x_centered, axis=1)
63
- var /= K
64
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
65
-
66
- # --- Matmul Loop 1: For the 4-Way Projections ---
67
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
- offs_k = tl.arange(0, BLOCK_SIZE_K)
69
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
70
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
71
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
-
73
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
74
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75
- k_block_start = k * BLOCK_SIZE_K;
76
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
77
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
78
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
79
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
80
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
81
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
82
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
83
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
84
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
86
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
87
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
88
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
89
-
90
- #Some threads should calclate out_gate
91
- if pid_n * BLOCK_SIZE_N < H:
92
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
- accumulator_og += tl.dot(x_norm_tile, w_tile)
97
-
98
- if pid_n * BLOCK_SIZE_N < H:
99
- og_out = tl.sigmoid(accumulator_og)
100
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
- tl.store(outg_ptrs, og_out, mask=og_mask)
103
-
104
- # --- Fusion Logic for 4-Way Part ---
105
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
- role_idx = tl.arange(0, 4)[None, None, :]
107
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
-
112
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
-
117
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
-
120
- s1s2 = s1 * s2
121
- offs_b = offs_m // s1s2
122
- offs_s1 = (offs_m % s1s2) // s2
123
- offs_s2 = offs_m % s2
124
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
-
129
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
-
136
- @triton.jit
137
- def bmm_coalesced_kernel(
138
- # Pointers
139
- Left_ptr, Right_ptr, Out_ptr,
140
- # Dimensions
141
- bs, s1, s2, H,
142
- # Strides
143
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
- # Kernel parameters
147
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
- GROUP_SIZE_M: tl.constexpr,
149
- ):
150
- # Grid and program IDs
151
- pid = tl.program_id(axis=0)
152
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
153
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
154
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
155
- group_id = pid // num_pid_in_group
156
- first_pid_m = group_id * GROUP_SIZE_M
157
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
158
- pid_m = first_pid_m + (pid % group_size_m)
159
- pid_n = (pid % num_pid_in_group) // group_size_m
160
-
161
- pid_bh = tl.program_id(axis=1)
162
- pid_b = pid_bh // H
163
- pid_h = pid_bh % H
164
-
165
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
166
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
167
- offs_k = tl.arange(0, BLOCK_SIZE_K)
168
-
169
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
170
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
171
-
172
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173
-
174
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
175
- k_start = k * BLOCK_SIZE_K
176
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
177
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
178
-
179
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
180
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
181
-
182
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
183
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
184
-
185
- accumulator += tl.dot(a, b)
186
-
187
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
188
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
189
-
190
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
191
- tl.store(out_ptrs, accumulator, mask=c_mask)
192
-
193
- @triton.jit
194
- def fused_final_kernel(
195
- # Pointers
196
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
197
- # Metadata
198
- M, H, D, s1,
199
- # Strides
200
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
201
- stride_gate_m, stride_gate_h,
202
- stride_proj_d, stride_proj_h,
203
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
204
- # Constants
205
- LN_EPS: tl.constexpr,
206
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
207
- GROUP_SIZE_M: tl.constexpr,
208
- ):
209
- pid = tl.program_id(axis=0)
210
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
211
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
212
-
213
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
214
- group_id = pid // num_pid_in_group
215
- first_pid_m = group_id * GROUP_SIZE_M
216
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217
- pid_m = first_pid_m + (pid % group_size_m)
218
- pid_n = (pid % num_pid_in_group) // group_size_m
219
-
220
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
221
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
222
- m_mask = offs_m < M
223
-
224
- s1s1 = s1 * s1
225
- b = offs_m // s1s1
226
- r = (offs_m % s1s1) // s1
227
- c = offs_m % s1
228
-
229
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
230
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
231
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
232
-
233
- for k_offset in range(0, H, BLOCK_SIZE_K):
234
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
235
- k_mask = offs_k < H
236
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
237
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
238
- sum_x += tl.sum(in_chunk, axis=1)
239
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
240
-
241
- mean = sum_x / H
242
- var = (sum_x2 / H) - (mean * mean)
243
- rstd = tl.math.rsqrt(var + LN_EPS)
244
-
245
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
246
- for k_offset in range(0, H, BLOCK_SIZE_K):
247
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
248
- k_mask = offs_k < H
249
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
250
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
251
- a_norm = (a - mean[:, None]) * rstd[:, None]
252
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
253
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
254
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
255
- proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
256
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
257
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
258
- a_gated = a_norm * gate
259
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
260
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
261
-
262
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
263
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
264
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
265
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
266
-
267
- def compiledtrimul_fused_interleaved_final(
268
- x: torch.Tensor,
269
- mask_mh: torch.Tensor,
270
- norm_weight: torch.Tensor,
271
- norm_bias: torch.Tensor,
272
- W_4way: torch.Tensor,
273
- W_og: torch.Tensor,
274
- to_out_norm_weight: torch.Tensor,
275
- to_out_norm_bias: torch.Tensor,
276
- to_out_weight: torch.Tensor,
277
- h: int,
278
- ):
279
- bs, s1, s2, d = x.shape
280
- M, K, H = bs * s1 * s2, x.shape[-1], h
281
- x_flat = x.view(M, K)
282
-
283
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
284
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
285
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
286
-
287
- # --- Kernel 1: Fused LN + Dual Matmul ---
288
- # The grid is launched for the larger 4*H problem
289
- N_4way = 4 * H
290
- # Hardcoded best config from logs: M64-N128-K64-GM8-HC32-W4-S2
291
- config_k1 = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
292
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
293
-
294
- fused_ln_dual_matmul_kernel[grid](
295
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
296
- left_final, right_final_t, og_mh,
297
- M, H, K, s1, s2,
298
- x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
299
- W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
300
- left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
301
- right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
302
- mask_mh.stride(0), mask_mh.stride(1),
303
- LN_EPS=1e-5, **config_k1, num_warps=4, num_stages=2
304
- )
305
-
306
- # --- Kernel 2: Batched Matrix Multiplication ---
307
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
308
- # Hardcoded best config from logs: M128-N128-K32-GM8-W8-S3
309
- config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
310
- grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
311
-
312
- bmm_coalesced_kernel[grid_bmm](
313
- left_final, right_final_t, bmm_out_tmp,
314
- bs, s1, s2, H,
315
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
316
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
317
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
318
- **config_k2, num_warps=8, num_stages=3
319
- )
320
-
321
- # --- Kernel 3: Fully Fused Final Stage ---
322
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
323
- # Hardcoded best config from logs: M32-N128-K32-GM8-W4-S3
324
- config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
325
- grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
326
-
327
- fused_final_kernel[grid_final](
328
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
329
- M, H, d, s1,
330
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
331
- og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
332
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
333
- LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
334
- )
335
- return final_out
336
-
337
- def pack_w_4way_efficient(weights):
338
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
339
- WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
340
- H, K = WL.shape
341
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
342
- return ws.t().to(torch.float16)
343
-
344
- def get_w_og(weights):
345
- """ Gets the transposed [K, H] out_gate weight matrix. """
346
- return weights['out_gate.weight'].t().to(torch.float16)
347
-
348
- @torch.compile()
349
- def compiledtrimul(
350
- x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
351
- w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
352
- to_out_weight: torch.Tensor, h: int
353
- ) -> torch.Tensor:
354
- bs, s1, s2, d = x.shape
355
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
356
- all_projections = torch.mm(x_norm, w_concat)
357
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
358
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
359
- left = left * mask_expanded * torch.sigmoid(lg)
360
- right = right * mask_expanded * torch.sigmoid(rg)
361
- out_gate = torch.sigmoid(og)
362
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
363
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
364
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
365
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
366
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
367
- gated = normed * out_gate
368
- final_out_flat = gated @ to_out_weight.t()
369
- return final_out_flat.view(bs, s1, s1, d)
370
-
371
- def small_kernel_pt_path(data):
372
- input_tensor, mask, weights, config = data
373
- w_concat = torch.cat([
374
- weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
375
- weights['right_gate.weight'], weights['out_gate.weight']
376
- ], dim=0).t().contiguous().to(torch.float16)
377
- return compiledtrimul(
378
- x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
379
- norm_weight=weights['norm.weight'].to(torch.float32),
380
- norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
381
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
382
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
383
- to_out_weight=weights['to_out.weight'].to(torch.float16),
384
- h=config["hidden_dim"]
385
- )
386
-
387
- def kernel_b200(data):
388
- input_tensor, mask, weights, config = data
389
- bs, s1, s2, d = input_tensor.shape
390
-
391
- if s1 < 800:
392
- return small_kernel_pt_path(data)
393
-
394
- H = config["hidden_dim"]
395
- W_4way = pack_w_4way_efficient(weights)
396
- W_og = get_w_og(weights)
397
- M = bs * s1 * s2
398
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
399
-
400
- return compiledtrimul_fused_interleaved_final(
401
- x=input_tensor.to(torch.float32),
402
- mask_mh=mask_mh,
403
- norm_weight=weights['norm.weight'].to(torch.float32),
404
- norm_bias=weights['norm.bias'].to(torch.float32),
405
- W_4way=W_4way,
406
- W_og=W_og,
407
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
408
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
409
- to_out_weight=weights['to_out.weight'].to(torch.float16),
410
- h=H,
411
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-rocm/triton_h100.py DELETED
@@ -1,509 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
12
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
13
-
14
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
15
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
16
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
17
-
18
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=4),
19
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
20
-
21
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=5),
22
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=5),
23
-
24
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
25
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=4),
26
- ],
27
- key=['M', 'N', 'K'],
28
- )
29
- @triton.jit
30
- def fused_ln_dual_matmul_kernel(
31
- # Pointers (9)
32
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
33
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
34
- # Metadata (5)
35
- M, H, K, s1, s2,
36
- # Strides (16)
37
- stride_x_m, stride_x_k,
38
- stride_w4_k, stride_w4_n,
39
- stride_wog_k, stride_wog_n,
40
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
41
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
42
- stride_og_m, stride_og_h,
43
- stride_mask_m, stride_mask_h,
44
- # Constexpr (from decorator and kwargs)
45
- LN_EPS: tl.constexpr,
46
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
47
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
48
- ):
49
- # --- PID Mapping: Based on the LARGER 4*H problem ---
50
- pid = tl.program_id(axis=0)
51
- N_4way = 4 * H
52
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
54
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
55
- group_id = pid // num_pid_in_group
56
- first_pid_m = group_id * GROUP_SIZE_M
57
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
- pid_m = first_pid_m + (pid % group_size_m)
59
- pid_n = (pid % num_pid_in_group) // group_size_m
60
-
61
- # --- SHARED LayerNorm calculation (done only ONCE) ---
62
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63
- m_mask = offs_m < M
64
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
65
-
66
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
67
- for k_offset in range(0, K, BLOCK_SIZE_K):
68
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
69
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
70
- k_mask = (k_offset + k_chunk_offs) < K
71
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
72
- mean += tl.sum(x_chunk, axis=1)
73
- mean /= K
74
-
75
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
76
- for k_offset in range(0, K, BLOCK_SIZE_K):
77
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
78
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
79
- k_mask = (k_offset + k_chunk_offs) < K
80
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
81
- x_centered = x_chunk - mean[:, None]
82
- var += tl.sum(x_centered * x_centered, axis=1)
83
- var /= K
84
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
85
-
86
- # --- Matmul Loop 1: For the 4-Way Projections ---
87
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
- offs_k = tl.arange(0, BLOCK_SIZE_K)
89
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
90
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
91
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92
-
93
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
94
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
95
- k_block_start = k * BLOCK_SIZE_K;
96
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
97
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
98
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
99
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
100
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
101
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
102
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
103
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
104
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
105
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
106
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
107
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
108
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
109
-
110
- #Some threads should calclate out_gate
111
- if pid_n * BLOCK_SIZE_N < H:
112
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
113
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
114
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
115
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
- accumulator_og += tl.dot(x_norm_tile, w_tile)
117
-
118
- if pid_n * BLOCK_SIZE_N < H:
119
- og_out = tl.sigmoid(accumulator_og)
120
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
121
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
122
- tl.store(outg_ptrs, og_out, mask=og_mask)
123
-
124
- # --- Fusion Logic for 4-Way Part ---
125
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
126
- role_idx = tl.arange(0, 4)[None, None, :]
127
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
128
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
129
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
130
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
131
-
132
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
133
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
134
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
135
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
136
-
137
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
138
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
139
-
140
- s1s2 = s1 * s2
141
- offs_b = offs_m // s1s2
142
- offs_s1 = (offs_m % s1s2) // s2
143
- offs_s2 = offs_m % s2
144
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
145
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
146
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
147
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
148
-
149
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
150
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
151
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
152
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
153
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
154
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
155
-
156
- @triton.autotune(
157
- configs=[
158
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
159
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
160
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
161
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
162
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
163
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
164
- ],
165
- key=['s1', 's2', 'H'],
166
- )
167
- @triton.jit
168
- def bmm_coalesced_kernel(
169
- # Pointers
170
- Left_ptr, Right_ptr, Out_ptr,
171
- # Dimensions
172
- bs, s1, s2, H,
173
- # Strides
174
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
175
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
176
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
177
- # Kernel parameters
178
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
179
- GROUP_SIZE_M: tl.constexpr,
180
- ):
181
- # Grid and program IDs
182
- pid = tl.program_id(axis=0)
183
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
184
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
185
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
186
- group_id = pid // num_pid_in_group
187
- first_pid_m = group_id * GROUP_SIZE_M
188
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
189
- pid_m = first_pid_m + (pid % group_size_m)
190
- pid_n = (pid % num_pid_in_group) // group_size_m
191
-
192
- pid_bh = tl.program_id(axis=1)
193
- pid_b = pid_bh // H
194
- pid_h = pid_bh % H
195
-
196
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
- offs_k = tl.arange(0, BLOCK_SIZE_K)
199
-
200
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
201
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
202
-
203
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
204
-
205
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
206
- k_start = k * BLOCK_SIZE_K
207
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
208
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
209
-
210
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
211
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
212
-
213
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
214
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
215
-
216
- accumulator += tl.dot(a, b)
217
-
218
- # --- Coalesced Write ---
219
- # Write to a standard (bs, H, s1, s1) layout
220
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
221
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
222
-
223
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
224
- tl.store(out_ptrs, accumulator, mask=c_mask)
225
-
226
- @torch.compile
227
- def torch_pt2(left_final, right_final_t, bs, s1, s2, d, h, to_out_norm_weight, to_out_norm_bias, og_mh, to_out_weight):
228
- bmm_out = torch.matmul(left_final, right_final_t)
229
- out_einsum_flat = bmm_out.permute(0, 2, 3, 1).reshape(bs * s1 * s1, h)
230
- # Apply layer norm and final gating
231
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
232
- gated = normed * og_mh
233
-
234
- # Final projection
235
- final_out_flat = gated @ to_out_weight.t()
236
- final_out = final_out_flat.view(bs, s1, s2, d)
237
- return final_out
238
-
239
- @triton.autotune(
240
- configs=[
241
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
243
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
244
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
245
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
247
- ],
248
- key=['H', 'D'],
249
- )
250
- @triton.jit
251
- def fused_final_kernel(
252
- # Pointers
253
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
254
- # Metadata
255
- M, H, D, s1, # M_gate = bs*s1*s2
256
- # Strides
257
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
258
- stride_gate_m, stride_gate_h,
259
- stride_proj_d, stride_proj_h,
260
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
261
- # Constants
262
- LN_EPS: tl.constexpr,
263
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
264
- GROUP_SIZE_M: tl.constexpr,
265
- ):
266
- # --- Grid and PID Setup for Matmul ---
267
- pid = tl.program_id(axis=0)
268
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
269
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
270
-
271
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
272
- group_id = pid // num_pid_in_group
273
- first_pid_m = group_id * GROUP_SIZE_M
274
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
275
- pid_m = first_pid_m + (pid % group_size_m)
276
- pid_n = (pid % num_pid_in_group) // group_size_m
277
-
278
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
279
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
280
- m_mask = offs_m < M
281
-
282
- # Decompose M back to (b, r, c) for reordering lookups
283
- s1s1 = s1 * s1
284
- b = offs_m // s1s1
285
- r = (offs_m % s1s1) // s1
286
- c = offs_m % s1
287
-
288
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
289
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
291
-
292
- for k_offset in range(0, H, BLOCK_SIZE_K):
293
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
294
- k_mask = offs_k < H
295
-
296
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
297
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
298
-
299
- # Accumulate sum and sum of squares in one pass
300
- sum_x += tl.sum(in_chunk, axis=1)
301
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
302
-
303
- # Finalize statistics
304
- mean = sum_x / H
305
- var = (sum_x2 / H) - (mean * mean)
306
- rstd = tl.math.rsqrt(var + LN_EPS)
307
-
308
- # --- Pass 3: Fused Gating and Matmul ---
309
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
- for k_offset in range(0, H, BLOCK_SIZE_K):
311
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
312
- k_mask = offs_k < H
313
-
314
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
315
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
316
- a_norm = (a - mean[:, None]) * rstd[:, None]
317
-
318
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
319
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
320
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
321
-
322
- proj_ptrs = ProjW_ptr + \
323
- offs_n[None, :] * stride_proj_d + \
324
- offs_k[:, None] * stride_proj_h
325
-
326
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
327
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
328
- a_gated = a_norm * gate
329
-
330
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
331
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
332
-
333
- # --- Store Final Output ---
334
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
336
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
337
-
338
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
339
-
340
- def compiledtrimul_fused_interleaved(
341
- x: torch.Tensor,
342
- mask_mh: torch.Tensor,
343
- norm_weight: torch.Tensor,
344
- norm_bias: torch.Tensor,
345
- W_4way: torch.Tensor, # Use the new weight matrices
346
- W_og: torch.Tensor,
347
- to_out_norm_weight: torch.Tensor,
348
- to_out_norm_bias: torch.Tensor,
349
- to_out_weight: torch.Tensor,
350
- h: int,
351
- ):
352
- bs, s1, s2, d = x.shape
353
- M, K, H = bs * s1 * s2, x.shape[-1], h
354
- x_flat = x.view(M, K)
355
-
356
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
357
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
358
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
359
-
360
- # The grid is launched for the larger 4*H problem
361
- N_4way = 4 * H
362
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
363
- fused_ln_dual_matmul_kernel[grid](
364
- # Pointers (9)
365
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
366
- left_final, right_final_t, og_mh,
367
- # Metadata (5) - M, H, K, s1, s2
368
- M, H, K, s1, s2,
369
- # Strides (16)
370
- x_flat.stride(0), x_flat.stride(1),
371
- W_4way.stride(0), W_4way.stride(1),
372
- W_og.stride(0), W_og.stride(1),
373
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
374
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
375
- og_mh.stride(0), og_mh.stride(1),
376
- mask_mh.stride(0), mask_mh.stride(1),
377
- # Constexpr (1)
378
- LN_EPS=1e-5
379
- )
380
- return torch_pt2(
381
- left_final, right_final_t,
382
- bs=bs,
383
- s1=s1,
384
- s2=s2,
385
- d=d,
386
- h=h,
387
- to_out_norm_weight=to_out_norm_weight,
388
- to_out_norm_bias=to_out_norm_bias,
389
- og_mh=og_mh,
390
- to_out_weight=to_out_weight
391
- )
392
-
393
- def pack_w_4way_efficient(weights):
394
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
395
- WL = weights['left_proj.weight']
396
- WLG = weights['left_gate.weight']
397
- WR = weights['right_proj.weight']
398
- WRG = weights['right_gate.weight']
399
- H, K = WL.shape
400
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
401
- ws = ws.contiguous().view(4 * H, K)
402
- return ws.t().to(torch.float16)
403
-
404
- def get_w_og(weights):
405
- """ Gets the transposed [K, H] out_gate weight matrix. """
406
- WOG = weights['out_gate.weight']
407
- return WOG.t().to(torch.float16)
408
-
409
-
410
- torch.backends.cuda.matmul.allow_tf32 = True
411
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
412
-
413
- @torch.compile
414
- def compiledtrimul(
415
- x: torch.Tensor,
416
- mask: torch.Tensor,
417
- norm_weight: torch.Tensor,
418
- norm_bias: torch.Tensor,
419
- w_concat: torch.Tensor,
420
- to_out_norm_weight: torch.Tensor,
421
- to_out_norm_bias: torch.Tensor,
422
- to_out_weight: torch.Tensor,
423
- h: int
424
- ) -> torch.Tensor:
425
- """
426
- A barebones, compiled PyTorch function for the TriMul logic.
427
- """
428
- bs, s1, s2, d = x.shape
429
-
430
- # Initial LayerNorm
431
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
432
- # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
433
- all_projections = torch.mm(x_norm, w_concat)
434
-
435
- # Split back into individual projections
436
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
437
-
438
- # Apply mask and gates
439
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
440
- left = left * mask_expanded * torch.sigmoid(lg)
441
- right = right * mask_expanded * torch.sigmoid(rg)
442
- out_gate = torch.sigmoid(og)
443
-
444
- # Reshape for einsum
445
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
446
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
447
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
448
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
449
-
450
- # Apply layer norm and final gating
451
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
452
- gated = normed * out_gate
453
-
454
- # Final projection
455
- final_out_flat = gated @ to_out_weight.t()
456
- final_out = final_out_flat.view(bs, s1, s2, d)
457
-
458
- return final_out
459
-
460
- def small_kernel_pt_path(data):
461
- input_tensor, mask, weights, config = data
462
- w_concat = torch.cat([
463
- weights['left_proj.weight'],
464
- weights['right_proj.weight'],
465
- weights['left_gate.weight'],
466
- weights['right_gate.weight'],
467
- weights['out_gate.weight']
468
- ], dim=0).t().contiguous().to(torch.float16)
469
- # Call the compiled function with prepared weights
470
- output = compiledtrimul(
471
- x=input_tensor.to(torch.float32),
472
- mask=mask.unsqueeze(-1),
473
- norm_weight=weights['norm.weight'].to(torch.float32),
474
- norm_bias=weights['norm.bias'].to(torch.float32),
475
- w_concat=w_concat,
476
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float32),
477
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float32),
478
- to_out_weight=weights['to_out.weight'].to(torch.float16),
479
- h=config["hidden_dim"]
480
- )
481
- return output
482
-
483
- def kernel_h100(data):
484
- input_tensor, mask, weights, config = data
485
- bs, s1, s2, d = input_tensor.shape
486
-
487
- if s1 <= 512:
488
- return small_kernel_pt_path(data)
489
-
490
- H = config["hidden_dim"]
491
-
492
- W_4way = pack_w_4way_efficient(weights)
493
- W_og = get_w_og(weights)
494
-
495
- M = bs * s1 * s2
496
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
497
-
498
- return compiledtrimul_fused_interleaved(
499
- x=input_tensor.to(torch.float32),
500
- mask_mh=mask_mh,
501
- norm_weight=weights['norm.weight'].to(torch.float32),
502
- norm_bias=weights['norm.bias'].to(torch.float32),
503
- W_4way=W_4way, # Pass the new 4-way matrix
504
- W_og=W_og, # Pass the new out_gate matrix
505
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
506
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
507
- to_out_weight=weights['to_out.weight'].to(torch.float16),
508
- h=H,
509
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from .triton_a100 import kernel_a100
2
- from .triton_h100 import kernel_h100
3
- from .triton_b200 import kernel_b200
4
- from .trimul_mi300 import kernel_mi300
5
- from .trimul_global import kernel_global
6
-
7
- __all__ = ["kernel_a100", "kernel_h100", "kernel_b200", "kernel_mi300", "kernel_global"]
 
 
 
 
 
 
 
 
build/torch-xpu/_ops.py DELETED
@@ -1,8 +0,0 @@
1
- import torch
2
- ops = torch.ops._trimul_gpumode_8e6e60d
3
-
4
- def add_op_namespace_prefix(op_name: str):
5
- """
6
- Prefix op by namespace.
7
- """
8
- return f"_trimul_gpumode_8e6e60d::{op_name}"
 
 
 
 
 
 
 
 
 
build/torch-xpu/metadata.json DELETED
@@ -1 +0,0 @@
1
- {"python-depends":[]}
 
 
build/torch-xpu/task.py DELETED
@@ -1,20 +0,0 @@
1
- """
2
- Type definitions for TriMul task.
3
-
4
- Input: Tuple of (input_tensor, mask, weights, config)
5
- - input_tensor: Input tensor of shape [batch_size, seq_len, seq_len, dim]
6
- - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
7
- - weights: Dictionary containing model weights
8
- - config: Dictionary containing model configuration parameters
9
-
10
- Output: Output tensor of shape [batch_size, seq_len, seq_len, dim]
11
- """
12
-
13
- import torch
14
- from typing import Tuple, Dict, Any
15
-
16
- # Input type: (input_tensor, mask, weights, config)
17
- input_t = Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, Any]]
18
-
19
- # Output type: output tensor
20
- output_t = torch.Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/trimul_global.py DELETED
@@ -1,971 +0,0 @@
1
- # from utils import make_match_reference, DisableCuDNNTF32
2
- from .task import input_t, output_t
3
-
4
- import torch
5
- from torch import nn, einsum
6
- import math
7
- import os
8
- import requests
9
-
10
- import triton
11
- import triton.language as tl
12
-
13
- # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
14
- # in PyTorch 1.12 and later.
15
- torch.backends.cuda.matmul.allow_tf32 = True
16
-
17
- # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
18
- torch.backends.cudnn.allow_tf32 = True
19
-
20
- # Set allocator for TMA descriptors (required for on-device TMA)
21
- def alloc_fn(size: int, alignment: int, stream=None):
22
- return torch.empty(size, device="cuda", dtype=torch.int8)
23
-
24
- triton.set_allocator(alloc_fn)
25
-
26
- # os.environ['TRITON_PRINT_AUTOTUNING'] = '1'
27
- # os.environ['MLIR_ENABLE_DIAGNOSTICS'] = 'warnings,remarks'
28
-
29
- # Reference code in PyTorch
30
- class TriMul(nn.Module):
31
- # Based on https://github.com/lucidrains/triangle-multiplicative-module/blob/main/triangle_multiplicative_module/triangle_multiplicative_module.py
32
- def __init__(
33
- self,
34
- dim: int,
35
- hidden_dim: int,
36
- ):
37
- super().__init__()
38
-
39
- self.norm = nn.LayerNorm(dim)
40
-
41
- self.left_proj = nn.Linear(dim, hidden_dim, bias=False)
42
- self.right_proj = nn.Linear(dim, hidden_dim, bias=False)
43
-
44
- self.left_gate = nn.Linear(dim, hidden_dim, bias=False)
45
- self.right_gate = nn.Linear(dim, hidden_dim, bias=False)
46
- self.out_gate = nn.Linear(dim, hidden_dim, bias=False)
47
-
48
- self.to_out_norm = nn.LayerNorm(hidden_dim)
49
- self.to_out = nn.Linear(hidden_dim, dim, bias=False)
50
-
51
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
52
- """
53
- x: [bs, seq_len, seq_len, dim]
54
- mask: [bs, seq_len, seq_len]
55
-
56
- Returns:
57
- output: [bs, seq_len, seq_len, dim]
58
- """
59
- batch_size, seq_len, _, dim = x.shape
60
-
61
- x = self.norm(x)
62
-
63
- left = self.left_proj(x)
64
- right = self.right_proj(x)
65
-
66
- mask = mask.unsqueeze(-1)
67
- left = left * mask
68
- right = right * mask
69
-
70
- left_gate = self.left_gate(x).sigmoid()
71
- right_gate = self.right_gate(x).sigmoid()
72
- out_gate = self.out_gate(x).sigmoid()
73
-
74
- left = left * left_gate
75
- right = right * right_gate
76
-
77
- out = einsum('... i k d, ... j k d -> ... i j d', left, right)
78
- # This einsum is the same as the following:
79
- # out = torch.zeros(batch_size, seq_len, seq_len, dim, device=x.device)
80
-
81
- # # Compute using nested loops
82
- # for b in range(batch_size):
83
- # for i in range(seq_len):
84
- # for j in range(seq_len):
85
- # # Compute each output element
86
- # for k in range(seq_len):
87
- # out[b, i, j] += left[b, i, k, :] * right[b, j, k, :]
88
-
89
- out = self.to_out_norm(out)
90
- out = out * out_gate
91
- return self.to_out(out)
92
-
93
- @triton.jit
94
- def triton_sigmoid(x):
95
- """
96
- Compute sigmoid function: 1 / (1 + exp(-x))
97
- """
98
- return 1.0 / (1.0 + tl.exp(-x))
99
-
100
- def two_mm_kernel_configs_wrapper():
101
- if torch.cuda.get_device_capability() == (12, 0):
102
- def two_mm_kernel_configs():
103
- configs = []
104
- for BLOCK_M in [16, 32]:
105
- for BLOCK_N in [16, 32, 64]:
106
- for BLOCK_K in [16, 32, 64]:
107
- for num_stages in [2, 3]:
108
- configs.append(triton.Config({
109
- 'BLOCK_M': BLOCK_M,
110
- 'BLOCK_N': BLOCK_N,
111
- 'BLOCK_K': BLOCK_K,
112
- 'GROUP_SIZE_M': 8
113
- }, num_stages=num_stages, num_warps=8))
114
- return configs
115
-
116
- elif torch.cuda.get_device_capability()[0] == 9:
117
- def get_optimal_two_mm_config_h100(B, seq_len, dim):
118
- configs = {
119
- (1, 128, 128): (128, 64, 128, 2, 8),
120
- (1, 128, 256): (128, 64, 128, 2, 8),
121
- (1, 128, 384): (128, 64, 64, 3, 8),
122
- (1, 128, 512): (128, 64, 64, 3, 8),
123
- (1, 128, 768): (128, 64, 64, 3, 8),
124
- (1, 128, 1024): (128, 64, 64, 3, 8),
125
- (1, 256, 128): (128, 64, 128, 2, 8),
126
- (1, 256, 256): (128, 64, 128, 2, 8),
127
- (1, 256, 384): (128, 64, 64, 3, 8),
128
- (1, 256, 512): (128, 64, 64, 3, 8),
129
- (1, 256, 768): (128, 64, 64, 3, 8),
130
- (1, 256, 1024): (128, 64, 64, 3, 8),
131
- (1, 512, 128): (128, 64, 128, 2, 8),
132
- (1, 512, 256): (128, 64, 128, 2, 8),
133
- (1, 512, 384): (128, 64, 128, 2, 8),
134
- (1, 512, 512): (128, 64, 128, 2, 8),
135
- (1, 512, 768): (128, 64, 64, 3, 8),
136
- (1, 512, 1024): (128, 64, 64, 3, 8),
137
- (1, 1024, 128): (128, 64, 128, 2, 8),
138
- (1, 1024, 256): (128, 64, 64, 2, 8),
139
- (1, 1024, 384): (128, 64, 128, 2, 8),
140
- (1, 1024, 512): (128, 64, 128, 2, 8),
141
- (1, 1024, 768): (128, 64, 128, 2, 8),
142
- (1, 1024, 1024): (128, 64, 128, 2, 8),
143
- (2, 128, 128): (128, 64, 128, 2, 8),
144
- (2, 128, 256): (128, 64, 128, 2, 8),
145
- (2, 128, 384): (128, 64, 64, 3, 8),
146
- (2, 128, 512): (128, 64, 64, 3, 8),
147
- (2, 128, 768): (128, 64, 64, 3, 8),
148
- (2, 128, 1024): (128, 64, 64, 3, 8),
149
- (2, 256, 128): (128, 64, 128, 2, 8),
150
- (2, 256, 256): (128, 64, 128, 2, 8),
151
- (2, 256, 384): (128, 64, 128, 2, 8),
152
- (2, 256, 512): (128, 64, 128, 2, 8),
153
- (2, 256, 768): (128, 64, 64, 3, 8),
154
- (2, 256, 1024): (128, 64, 64, 3, 8),
155
- (2, 512, 128): (128, 64, 128, 2, 8),
156
- (2, 512, 256): (128, 64, 128, 2, 8),
157
- (2, 512, 384): (128, 64, 128, 2, 8),
158
- (2, 512, 512): (128, 64, 128, 2, 8),
159
- (2, 512, 768): (128, 64, 128, 2, 8),
160
- (2, 512, 1024): (128, 64, 128, 2, 8),
161
- (2, 1024, 128): (128, 64, 128, 2, 8),
162
- (2, 1024, 256): (128, 64, 128, 2, 8),
163
- (2, 1024, 384): (128, 64, 128, 2, 8),
164
- (2, 1024, 512): (128, 64, 128, 2, 8),
165
- (2, 1024, 768): (128, 64, 128, 2, 8),
166
- (2, 1024, 1024): (128, 64, 128, 2, 8),
167
- }
168
- return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
169
-
170
- def two_mm_kernel_configs():
171
- # This function is kept for compatibility but will be overridden for H100
172
- return [
173
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
174
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
175
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
176
- ]
177
-
178
- elif torch.cuda.get_device_capability()[0] == 10 and False:
179
- def get_optimal_two_mm_config(B, seq_len, dim):
180
- configs = {
181
- (1, 128, 128): (64, 128, 64, 2, 8),
182
- (1, 128, 256): (128, 64, 128, 2, 8),
183
- (1, 128, 384): (128, 64, 128, 2, 8),
184
- (1, 128, 512): (128, 64, 128, 2, 8),
185
- (1, 128, 768): (128, 64, 64, 3, 8),
186
- (1, 128, 1024): (128, 64, 64, 3, 8),
187
- (1, 256, 128): (128, 64, 128, 2, 8),
188
- (1, 256, 256): (128, 64, 128, 2, 8),
189
- (1, 256, 384): (128, 64, 128, 2, 8),
190
- (1, 256, 512): (128, 64, 64, 3, 8),
191
- (1, 256, 768): (128, 64, 64, 3, 8),
192
- (1, 256, 1024): (128, 64, 64, 3, 8),
193
- (1, 512, 128): (128, 64, 128, 2, 8),
194
- (1, 512, 256): (128, 64, 128, 2, 8),
195
- (1, 512, 384): (128, 64, 128, 2, 8),
196
- (1, 512, 512): (128, 64, 128, 2, 8),
197
- (1, 512, 768): (128, 64, 64, 3, 8),
198
- (1, 512, 1024): (128, 64, 64, 3, 8),
199
- (1, 1024, 128): (128, 64, 128, 2, 8),
200
- (1, 1024, 256): (128, 64, 128, 2, 8),
201
- (1, 1024, 384): (128, 64, 128, 2, 8),
202
- (1, 1024, 512): (128, 64, 128, 2, 8),
203
- (1, 1024, 768): (128, 64, 64, 3, 8),
204
- (1, 1024, 1024): (128, 64, 64, 3, 8),
205
- (2, 128, 128): (128, 64, 128, 2, 8),
206
- (2, 128, 256): (128, 64, 128, 2, 8),
207
- (2, 128, 384): (128, 64, 128, 2, 8),
208
- (2, 128, 512): (128, 64, 64, 3, 8),
209
- (2, 128, 768): (128, 64, 64, 3, 8),
210
- (2, 128, 1024): (128, 64, 64, 3, 8),
211
- (2, 256, 128): (128, 64, 128, 2, 8),
212
- (2, 256, 256): (128, 64, 128, 2, 8),
213
- (2, 256, 384): (128, 64, 128, 2, 8),
214
- (2, 256, 512): (128, 64, 64, 3, 8),
215
- (2, 256, 768): (128, 64, 64, 3, 8),
216
- (2, 256, 1024): (128, 64, 64, 3, 8),
217
- (2, 512, 128): (128, 64, 128, 2, 8),
218
- (2, 512, 256): (128, 64, 128, 2, 8),
219
- (2, 512, 384): (128, 64, 128, 2, 8),
220
- (2, 512, 512): (128, 64, 128, 2, 8),
221
- (2, 512, 768): (128, 64, 64, 3, 8),
222
- (2, 512, 1024): (128, 64, 64, 3, 8),
223
- (2, 1024, 128): (128, 64, 128, 2, 8),
224
- (2, 1024, 256): (128, 64, 128, 2, 8),
225
- (2, 1024, 384): (128, 64, 128, 2, 8),
226
- (2, 1024, 512): (128, 64, 128, 2, 8),
227
- (2, 1024, 768): (128, 64, 64, 3, 8),
228
- (2, 1024, 1024): (128, 64, 64, 3, 8),
229
- }
230
- return configs.get((B, seq_len, dim), (64, 64, 32, 2, 8)) # default fallback
231
-
232
- def two_mm_kernel_configs():
233
- # This function is kept for compatibility but will be overridden
234
- return [
235
- triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
236
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),
237
- triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
238
- ]
239
- elif torch.cuda.get_device_capability()[0] == 8:
240
- # A100
241
- def two_mm_kernel_configs():
242
- configs = []
243
- for BLOCK_M in [64]:
244
- for BLOCK_N in [64, 128]:
245
- for BLOCK_K in [16]:
246
- for num_stages in [3, 4]:
247
- for num_warps in [4, 8]:
248
- configs.append(triton.Config({
249
- 'BLOCK_M': BLOCK_M,
250
- 'BLOCK_N': BLOCK_N,
251
- 'BLOCK_K': BLOCK_K,
252
- 'GROUP_SIZE_M': 8
253
- }, num_stages=num_stages, num_warps=num_warps))
254
- return configs
255
- else:
256
- def two_mm_kernel_configs():
257
- configs = []
258
- for BLOCK_M in [64, 128]:
259
- for BLOCK_N in [64, 128]:
260
- for BLOCK_K in [64, 128]:
261
- for num_stages in [2, 3]:
262
- configs.append(triton.Config({
263
- 'BLOCK_M': BLOCK_M,
264
- 'BLOCK_N': BLOCK_N,
265
- 'BLOCK_K': BLOCK_K,
266
- 'GROUP_SIZE_M': 8
267
- }, num_stages=num_stages, num_warps=8))
268
- return configs
269
-
270
- return two_mm_kernel_configs
271
-
272
- def two_mm_kernel_wrapper():
273
- if torch.cuda.get_device_capability()[0] == 8:
274
- @triton.jit
275
- def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
276
- # Persistent kernel using standard tl.load operations
277
- start_pid = tl.program_id(axis=0)
278
- num_pid_m = tl.cdiv(M, BLOCK_M)
279
- num_pid_n = tl.cdiv(N, BLOCK_N)
280
- k_tiles = tl.cdiv(K, BLOCK_K)
281
- num_tiles = num_pid_m * num_pid_n
282
-
283
- # tile_id_c is used in the epilogue to break the dependency between
284
- # the prologue and the epilogue
285
- tile_id_c = start_pid - NUM_SMS
286
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
287
-
288
- # Persistent loop over tiles
289
- for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
290
- # Calculate PID for this tile using improved swizzling
291
- group_id = tile_id // num_pid_in_group
292
- first_pid_m = group_id * GROUP_SIZE_M
293
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
294
- pid_m = first_pid_m + (tile_id % group_size_m)
295
- pid_n = (tile_id % num_pid_in_group) // group_size_m
296
-
297
- # Calculate block offsets
298
- offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
299
- offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
300
- offs_k = tl.arange(0, BLOCK_K)
301
-
302
- # Initialize accumulators for all outputs
303
- accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
304
- accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
305
- accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
306
- accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
307
- accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
308
-
309
- # Main computation loop over K dimension
310
- for ki in range(k_tiles):
311
- k_start = ki * BLOCK_K
312
- k_offsets = k_start + offs_k
313
-
314
- # Create pointers for A matrix (2D flattened view)
315
- a_ptrs = a_ptr + offs_am[:, None] * stride_a2 + k_offsets[None, :] * stride_a3
316
- a_mask = (offs_am[:, None] < M) & (k_offsets[None, :] < K)
317
-
318
- # Create pointers for B matrices [N, K] layout
319
- b1_ptrs = b1_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
320
- b2_ptrs = b2_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
321
- b3_ptrs = b3_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
322
- b4_ptrs = b4_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
323
- b5_ptrs = b5_ptr + offs_bn[:, None] * stride_bn + k_offsets[None, :] * stride_bk
324
- b_mask = (offs_bn[:, None] < N) & (k_offsets[None, :] < K)
325
-
326
- # Load blocks from A and all weight matrices using standard tl.load
327
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
328
- b1 = tl.load(b1_ptrs, mask=b_mask, other=0.0)
329
- b2 = tl.load(b2_ptrs, mask=b_mask, other=0.0)
330
- b3 = tl.load(b3_ptrs, mask=b_mask, other=0.0)
331
- b4 = tl.load(b4_ptrs, mask=b_mask, other=0.0)
332
- b5 = tl.load(b5_ptrs, mask=b_mask, other=0.0)
333
-
334
- # Perform matrix multiplications using TF32
335
- accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
336
- accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
337
- accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
338
- accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
339
- accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
340
-
341
- # Store results using separate tile_id_c for epilogue
342
- tile_id_c += NUM_SMS
343
- group_id = tile_id_c // num_pid_in_group
344
- first_pid_m = group_id * GROUP_SIZE_M
345
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
346
- pid_m = first_pid_m + (tile_id_c % group_size_m)
347
- pid_n = (tile_id_c % num_pid_in_group) // group_size_m
348
-
349
- # Calculate output offsets and pointers
350
- offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
351
- offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
352
-
353
- # Create masks for bounds checking
354
- d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
355
-
356
- # Calculate pointer addresses using 4D strides
357
- stride_cm = stride_c2 # Stride to next element in flattened M dimension
358
- stride_cn = stride_c3 # N is the innermost dimension
359
-
360
- # For D tensor: use separate D strides
361
- stride_dm = stride_d2 # Stride to next element in flattened M dimension
362
- stride_dn = stride_d3 # N is the innermost dimension
363
-
364
- off_c_batch = offs_cm // (seq_len * seq_len)
365
- off_c_sl1 = (offs_cm // seq_len) % seq_len
366
- off_c_sl2 = offs_cm % seq_len
367
- off_c_dim = offs_cn
368
-
369
- c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
370
- c_mask = d_mask
371
-
372
- c1_ptrs = c1_ptr + c_offsets
373
- c2_ptrs = c2_ptr + c_offsets
374
- d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
375
-
376
- mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
377
-
378
- # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
379
- mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
380
- # Apply masking only to left_proj and right_proj results (C1, C2)
381
- accumulator1 = tl.where(mask_2d, accumulator1, 0)
382
- accumulator2 = tl.where(mask_2d, accumulator2, 0)
383
-
384
- # Apply sigmoid to gate values
385
- left_gate_sigmoid = triton_sigmoid(accumulator3)
386
- right_gate_sigmoid = triton_sigmoid(accumulator4)
387
- accumulator_d = triton_sigmoid(accumulator_d)
388
-
389
- # Apply elementwise multiplication with gated values
390
- # C1 = left * left_gate, C2 = right * right_gate
391
- accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
392
- accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
393
-
394
- # Convert to appropriate output dtype and store with normal tl.store
395
- c1 = accumulator1.to(c1_ptr.dtype.element_ty)
396
- c2 = accumulator2.to(c2_ptr.dtype.element_ty)
397
- d = accumulator_d.to(d_ptr.dtype.element_ty)
398
-
399
- tl.store(c1_ptrs, c1, mask=c_mask)
400
- tl.store(c2_ptrs, c2, mask=c_mask)
401
- tl.store(d_ptrs, d, mask=d_mask)
402
- else:
403
- @triton.jit
404
- def two_mm_kernel(a_ptr, b1_ptr, b2_ptr, b3_ptr, b4_ptr, b5_ptr, c1_ptr, c2_ptr, d_ptr, mask_ptr, M, N, K, stride_a0, stride_a1, stride_a2, stride_a3, stride_bk, stride_bn, stride_c0, stride_c1, stride_c2, stride_c3, seq_len, stride_d0, stride_d1, stride_d2, stride_d3, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, NUM_SMS: tl.constexpr):
405
- # Persistent kernel using on-device TMA descriptors
406
- start_pid = tl.program_id(axis=0)
407
- num_pid_m = tl.cdiv(M, BLOCK_M)
408
- num_pid_n = tl.cdiv(N, BLOCK_N)
409
- k_tiles = tl.cdiv(K, BLOCK_K)
410
- num_tiles = num_pid_m * num_pid_n
411
-
412
- # Create on-device TMA descriptors
413
- a_desc = tl._experimental_make_tensor_descriptor(
414
- a_ptr,
415
- shape=[M, K],
416
- strides=[stride_a2, stride_a3],
417
- block_shape=[BLOCK_M, BLOCK_K],
418
- )
419
- b1_desc = tl._experimental_make_tensor_descriptor(
420
- b1_ptr,
421
- shape=[N, K],
422
- strides=[stride_bn, stride_bk],
423
- block_shape=[BLOCK_N, BLOCK_K],
424
- )
425
- b2_desc = tl._experimental_make_tensor_descriptor(
426
- b2_ptr,
427
- shape=[N, K],
428
- strides=[stride_bn, stride_bk],
429
- block_shape=[BLOCK_N, BLOCK_K],
430
- )
431
- b3_desc = tl._experimental_make_tensor_descriptor(
432
- b3_ptr,
433
- shape=[N, K],
434
- strides=[stride_bn, stride_bk],
435
- block_shape=[BLOCK_N, BLOCK_K],
436
- )
437
- b4_desc = tl._experimental_make_tensor_descriptor(
438
- b4_ptr,
439
- shape=[N, K],
440
- strides=[stride_bn, stride_bk],
441
- block_shape=[BLOCK_N, BLOCK_K],
442
- )
443
- b5_desc = tl._experimental_make_tensor_descriptor(
444
- b5_ptr,
445
- shape=[N, K],
446
- strides=[stride_bn, stride_bk],
447
- block_shape=[BLOCK_N, BLOCK_K],
448
- )
449
-
450
- # tile_id_c is used in the epilogue to break the dependency between
451
- # the prologue and the epilogue
452
- tile_id_c = start_pid - NUM_SMS
453
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
454
-
455
- # Persistent loop over tiles
456
- for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=False):
457
- # Calculate PID for this tile using improved swizzling
458
- group_id = tile_id // num_pid_in_group
459
- first_pid_m = group_id * GROUP_SIZE_M
460
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
461
- pid_m = first_pid_m + (tile_id % group_size_m)
462
- pid_n = (tile_id % num_pid_in_group) // group_size_m
463
-
464
- # Calculate block offsets
465
- offs_am = pid_m * BLOCK_M
466
- offs_bn = pid_n * BLOCK_N
467
-
468
- # Initialize accumulators for all outputs
469
- accumulator1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
470
- accumulator2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
471
- accumulator3 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
472
- accumulator4 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
473
- accumulator_d = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
474
-
475
- # Main computation loop over K dimension
476
- for ki in range(k_tiles):
477
- offs_k = ki * BLOCK_K
478
- # Load blocks from A and all weight matrices using on-device TMA
479
- a = a_desc.load([offs_am, offs_k])
480
- b1 = b1_desc.load([offs_bn, offs_k])
481
- b2 = b2_desc.load([offs_bn, offs_k])
482
- b3 = b3_desc.load([offs_bn, offs_k])
483
- b4 = b4_desc.load([offs_bn, offs_k])
484
- b5 = b5_desc.load([offs_bn, offs_k])
485
-
486
- # Perform matrix multiplications using TF32
487
- accumulator1 = tl.dot(a, b1.T, accumulator1, allow_tf32=True) # A @ B1.T
488
- accumulator2 = tl.dot(a, b2.T, accumulator2, allow_tf32=True) # A @ B2.T
489
- accumulator3 = tl.dot(a, b3.T, accumulator3, allow_tf32=True) # A @ B3.T
490
- accumulator4 = tl.dot(a, b4.T, accumulator4, allow_tf32=True) # A @ B4.T
491
- accumulator_d = tl.dot(a, b5.T, accumulator_d, allow_tf32=True) # A @ B5.T
492
-
493
- # Store results using separate tile_id_c for epilogue
494
- tile_id_c += NUM_SMS
495
- group_id = tile_id_c // num_pid_in_group
496
- first_pid_m = group_id * GROUP_SIZE_M
497
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
498
- pid_m = first_pid_m + (tile_id_c % group_size_m)
499
- pid_n = (tile_id_c % num_pid_in_group) // group_size_m
500
-
501
- # Calculate output offsets and pointers
502
- offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
503
- offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
504
-
505
- # Create masks for bounds checking
506
- d_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
507
-
508
- # Calculate pointer addresses using 4D strides
509
- # For C tensors: compute effective 2D strides from 4D strides
510
- # Output tensor is [B, I, J, N], flattened to [M, N] where M = B*I*J
511
- stride_cm = stride_c2 # Stride to next element in flattened M dimension
512
- stride_cn = stride_c3 # N is the innermost dimension
513
-
514
- # For D tensor: use separate D strides
515
- stride_dm = stride_d2 # Stride to next element in flattened M dimension
516
- stride_dn = stride_d3 # N is the innermost dimension
517
-
518
- off_c_batch = offs_cm // (seq_len * seq_len)
519
- off_c_sl1 = (offs_cm // seq_len) % seq_len
520
- off_c_sl2 = offs_cm % seq_len
521
- off_c_dim = offs_cn
522
-
523
- # TODO update the mask_c so we don't IMA
524
- c_offsets = (off_c_batch * stride_c0 + off_c_sl1 * stride_c1 + off_c_sl2 * stride_c2)[:, None] + off_c_dim[None, :] * stride_c3
525
- # c_offsets = offs_cm[:, None] * stride_c2 + offs_cn[None, :] * stride_c3
526
- c_mask = d_mask
527
-
528
- c1_ptrs = c1_ptr + c_offsets
529
- c2_ptrs = c2_ptr + c_offsets
530
- # c1_ptrs = c1_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
531
- # c2_ptrs = c2_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
532
- d_ptrs = d_ptr + stride_dm * offs_cm[:, None] + stride_dn * offs_cn[None, :]
533
-
534
- mask = tl.load(mask_ptr + offs_cm, mask=(offs_cm < M))
535
-
536
- # Broadcast mask to match accumulator dimensions [BLOCK_M, BLOCK_N]
537
- mask_2d = mask[:, None] # Convert to [BLOCK_M, 1] then broadcast
538
- # Apply masking only to left_proj and right_proj results (C1, C2)
539
- accumulator1 = tl.where(mask_2d, accumulator1, 0)
540
- accumulator2 = tl.where(mask_2d, accumulator2, 0)
541
-
542
- # Apply sigmoid to gate values
543
- left_gate_sigmoid = triton_sigmoid(accumulator3)
544
- right_gate_sigmoid = triton_sigmoid(accumulator4)
545
- accumulator_d = triton_sigmoid(accumulator_d)
546
-
547
- # Apply elementwise multiplication with gated values
548
- # C1 = left * left_gate, C2 = right * right_gate
549
- accumulator1 = accumulator1 * left_gate_sigmoid # left * left_gate
550
- accumulator2 = accumulator2 * right_gate_sigmoid # right * right_gate
551
-
552
- # Convert to appropriate output dtype and store with normal tl.store
553
- c1 = accumulator1.to(c1_ptr.dtype.element_ty)
554
- c2 = accumulator2.to(c2_ptr.dtype.element_ty)
555
- d = accumulator_d.to(d_ptr.dtype.element_ty)
556
-
557
- tl.store(c1_ptrs, c1, mask=c_mask)
558
- tl.store(c2_ptrs, c2, mask=c_mask)
559
- tl.store(d_ptrs, d, mask=d_mask)
560
-
561
-
562
- if torch.cuda.get_device_capability()[0] not in [9, 10.2]:
563
- two_mm_kernel = triton.autotune(
564
- (two_mm_kernel_configs_wrapper())(), key=["M", "N", "K"]
565
- )(two_mm_kernel)
566
-
567
- return two_mm_kernel
568
-
569
-
570
- def two_mm(A, left_proj, right_proj, left_gate, right_gate, out_gate, mask):
571
- """
572
- Persistent matrix multiplication for all weight matrices using on-device TMA descriptors.
573
-
574
- Args:
575
- A: [..., K] tensor (arbitrary leading dimensions)
576
- left_proj: [N, K] matrix (will be transposed)
577
- right_proj: [N, K] matrix (will be transposed)
578
- left_gate: [N, K] left gate weight matrix
579
- right_gate: [N, K] right gate weight matrix
580
- out_gate: [N, K] output gate weight matrix
581
- mask: mask tensor
582
-
583
- Returns:
584
- (C1, C2, D): Tuple of result tensors [..., N] with same leading dims as A
585
- C1 = (A @ left_proj.T) * sigmoid(A @ left_gate.T) (masked)
586
- C2 = (A @ right_proj.T) * sigmoid(A @ right_gate.T) (masked)
587
- D = sigmoid(A @ out_gate.T) (unmasked)
588
- """
589
- # Check constraints
590
- assert A.shape[-1] == left_proj.shape[1] == right_proj.shape[1], "Incompatible K dimensions"
591
- assert A.dtype == left_proj.dtype == right_proj.dtype, "Incompatible dtypes"
592
-
593
- # Assert that all weight matrices have the same strides (same [N, K] shape)
594
- assert left_proj.stride() == right_proj.stride() == left_gate.stride() == right_gate.stride() == out_gate.stride(), \
595
- "All weight matrices must have identical strides"
596
-
597
- # Get dimensions
598
- original_shape = A.shape[:-1] # All dimensions except the last
599
- K = A.shape[-1]
600
- N = left_proj.shape[0]
601
- B, seq_len, _, _ = A.shape
602
- dtype = A.dtype
603
-
604
- # Flatten A to 2D for kernel processing
605
- A_2d = A.view(-1, K) # [M, K] where M is product of all leading dims
606
- M = A_2d.shape[0]
607
-
608
- # Get number of streaming multiprocessors
609
- NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
610
-
611
- # Launch persistent kernel with limited number of blocks
612
- grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"])),)
613
-
614
- # Get original 4D strides for A and output tensors
615
- A_strides = A.stride() # (stride_0, stride_1, stride_2, stride_3)
616
-
617
- # Create output tensors with proper 4D shape to get correct strides
618
- output_shape = original_shape + (N,)
619
- # C1 = torch.empty(output_shape, device=A.device, dtype=dtype)
620
- # C2 = torch.empty(output_shape, device=A.device, dtype=dtype)
621
- C1 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
622
- C2 = torch.empty(B, N, seq_len, seq_len, device=A.device, dtype=torch.float16).permute(0, 2, 3, 1)
623
- D = torch.empty(output_shape, device=A.device, dtype=torch.float16)
624
-
625
- C_strides = C1.stride() # (stride_0, stride_1, stride_2, stride_3)
626
- D_strides = D.stride() # (stride_0, stride_1, stride_2, stride_3)
627
-
628
- # Use optimal configuration for B200/H100 or fallback to autotuning for other GPUs
629
- if torch.cuda.get_device_capability()[0] == 10:
630
- # Get optimal configuration for B200
631
- BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
632
- grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
633
-
634
- two_mm_kernel_wrapper()[(grid_size,)](
635
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
636
- C1, C2, D, mask,
637
- M, N, K,
638
- *A_strides, # 4D strides for A
639
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
640
- *C_strides, # 4D strides for C
641
- seq_len,
642
- *D_strides, # 4D strides for D
643
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
644
- num_stages=num_stages, num_warps=num_warps
645
- )
646
- elif torch.cuda.get_device_capability()[0] == 9:
647
- # Get optimal configuration for H100
648
- BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps = (two_mm_kernel_configs_wrapper())(B, seq_len, K)
649
- grid_size = min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N))
650
-
651
- two_mm_kernel_wrapper()[(grid_size,)](
652
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
653
- C1, C2, D, mask,
654
- M, N, K,
655
- *A_strides, # 4D strides for A
656
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
657
- *C_strides, # 4D strides for C
658
- seq_len,
659
- *D_strides, # 4D strides for D
660
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, NUM_SMS=NUM_SMS,
661
- num_stages=num_stages, num_warps=num_warps
662
- )
663
- else:
664
- # Use autotuning for other GPUs
665
- two_mm_kernel_wrapper()[grid](
666
- A_2d, left_proj, right_proj, left_gate, right_gate, out_gate,
667
- C1, C2, D, mask,
668
- M, N, K,
669
- *A_strides, # 4D strides for A
670
- left_proj.stride(1), left_proj.stride(0), # B matrices [N, K] shape strides
671
- *C_strides, # 4D strides for C
672
- seq_len,
673
- *D_strides, # 4D strides for D
674
- NUM_SMS=NUM_SMS
675
- )
676
-
677
- return C1, C2, D
678
-
679
-
680
- def second_layernorm_mul(inp, hidden_dim, weight, bias, mul_operand):
681
- ln = torch.nn.functional.layer_norm(inp, (hidden_dim,), eps=1e-5, weight=weight.to(inp.dtype), bias=bias.to(inp.dtype))
682
- out = ln * mul_operand
683
- return out
684
-
685
- '''
686
- @triton.autotune(
687
- [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=4, num_stages=3)],
688
- key=["R", "C"]
689
- )
690
- '''
691
- @triton.jit
692
- def layernorm_kernel_first(
693
- X,
694
- Y,
695
- Weight,
696
- Bias,
697
- R,
698
- C, # aka "dim"
699
- eps,
700
- ROW_BLOCK_SIZE: tl.constexpr,
701
- BLOCK_SIZE: tl.constexpr,
702
- ):
703
- row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
704
- cols = tl.arange(0, BLOCK_SIZE)
705
-
706
- mask_row = row < R
707
- mask_col = cols < C
708
-
709
- # Simple indexing for contiguous data
710
- x = tl.load(
711
- X + row[:, None] * C + cols[None, :],
712
- mask=mask_row[:, None] & mask_col[None, :],
713
- other=0.0
714
- ).to(tl.float32)
715
-
716
- weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
717
- bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
718
-
719
- mean = tl.sum(x, axis=1) / C
720
- diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
721
- var = tl.sum(diff * diff, axis=1) / C
722
- rstd = 1 / tl.sqrt(var + eps)
723
-
724
- y_hat = (x - mean[:, None]) * rstd[:, None]
725
- y = y_hat * weight[None, :] + bias[None, :]
726
-
727
- tl.store(
728
- Y + row[:, None] * C + cols[None, :],
729
- y,
730
- mask=mask_row[:, None] & mask_col[None, :]
731
- )
732
-
733
-
734
- def get_optimal_config_ln(dim):
735
- config = None
736
- if torch.cuda.get_device_capability()[0] == 9:
737
- if (dim <= 256):
738
- config = (16, 1)
739
- elif dim <= 512:
740
- config = (16, 2)
741
- elif dim <= 1024:
742
- config = (16, 4)
743
-
744
- if not config:
745
- config = (16, 4)
746
- return config
747
-
748
-
749
- def triton_layernorm_first(x, weight, bias, eps=1e-5, num_warps=None, ROW_BLOCK_SIZE=None):
750
- B, seq_len, seq_len2, dim = x.shape
751
- assert(seq_len == seq_len2)
752
-
753
- R = B * seq_len * seq_len
754
- C = dim
755
-
756
- out = torch.empty_like(x, dtype=torch.float16)
757
-
758
- if not num_warps or not ROW_BLOCK_SIZE:
759
- ROW_BLOCK_SIZE, num_warps = get_optimal_config_ln(dim)
760
-
761
- BLOCK_SIZE = triton.next_power_of_2(C)
762
- assert(BLOCK_SIZE <= 1024)
763
-
764
- def grid(meta):
765
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
766
-
767
- layernorm_kernel_first[grid](
768
- x, out, weight, bias,
769
- R, C, eps,
770
- ROW_BLOCK_SIZE=ROW_BLOCK_SIZE,
771
- BLOCK_SIZE=BLOCK_SIZE,
772
- num_warps=num_warps,
773
- num_stages=3
774
- )
775
-
776
- return out
777
-
778
- '''
779
- def triton_layernorm_first(x, weight, bias, eps=1e-5):
780
- B, seq_len, seq_len2, dim = x.shape
781
- assert(seq_len == seq_len2)
782
-
783
- R = B * seq_len * seq_len
784
- C = dim
785
-
786
- out = torch.empty_like(x)
787
-
788
- BLOCK_SIZE = triton.next_power_of_2(C)
789
- assert(BLOCK_SIZE <= 1024)
790
-
791
- def grid(meta):
792
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
793
-
794
- layernorm_kernel_first[grid](
795
- x, out, weight, bias,
796
- R, C, eps,
797
- BLOCK_SIZE=BLOCK_SIZE
798
- )
799
-
800
- return out
801
- '''
802
-
803
-
804
- @triton.autotune(
805
- [triton.Config({"ROW_BLOCK_SIZE": 16}, num_warps=1, num_stages=3)],
806
- key=[]
807
- )
808
- @triton.jit
809
- def layernorm_kernel_eltwise(
810
- X,
811
- Y,
812
- Weight,
813
- Bias,
814
- OutGate,
815
- seq_len,
816
- stride_batch,
817
- stride_dim,
818
- R,
819
- C, # aka "dim"
820
- eps,
821
- ROW_BLOCK_SIZE: tl.constexpr,
822
- BLOCK_SIZE: tl.constexpr,
823
- ):
824
- row = tl.program_id(0) * ROW_BLOCK_SIZE + tl.arange(0, ROW_BLOCK_SIZE)
825
- cols = tl.arange(0, BLOCK_SIZE)
826
-
827
- # Calculate base pointer for this batch of rows
828
- tl.device_assert(seq_len*seq_len % ROW_BLOCK_SIZE == 0)
829
- # batch_offset = (row // (stride_seq1 // stride_dim)) * stride_batch
830
- batch = tl.program_id(0) * ROW_BLOCK_SIZE // (seq_len * seq_len)
831
- seqs_off = row % (seq_len * seq_len) # TODO is this going to prevent vectorization
832
-
833
- off_r = batch * stride_batch + seqs_off
834
- off_c = cols * stride_dim
835
-
836
- mask_row = row < R
837
- mask_col = cols < C
838
-
839
- out_gate = tl.load(
840
- OutGate + row[:, None] * C + cols[None, :],
841
- mask = mask_row[:, None] & mask_col[None, :],
842
- )
843
-
844
- x = tl.load(
845
- X + off_r[:, None] + off_c[None, :],
846
- mask=mask_row[:, None] & mask_col[None, :],
847
- other=0.0
848
- ).to(tl.float32)
849
-
850
- weight = tl.load(Weight + cols, mask=mask_col, other=0.0).to(tl.float32)
851
- bias = tl.load(Bias + cols, mask=mask_col, other=0.0).to(tl.float32)
852
-
853
- mean = tl.sum(x, axis=1) / C
854
- diff = tl.where(mask_row[:, None] & mask_col[None, :], x - mean[:, None], 0)
855
- var = tl.sum(diff * diff, axis=1) / C
856
- rstd = 1 / tl.sqrt(var + eps)
857
-
858
- y_hat = (x - mean[:, None]) * rstd[:, None]
859
- y = y_hat * weight[None, :] + bias[None, :]
860
-
861
- tl.store(
862
- Y + row[:, None] * C + cols[None, :],
863
- y * out_gate,
864
- mask=mask_row[:, None] & mask_col[None, :]
865
- )
866
-
867
-
868
- def triton_layernorm_eltwise(x, weight, bias, out_gate, eps=1e-5):
869
- B, seq_len, seq_len2, dim = x.shape
870
- assert(seq_len == seq_len2)
871
- R = B * seq_len * seq_len
872
- assert(x.stride(3) == seq_len*seq_len)
873
- assert(out_gate.is_contiguous())
874
- C = dim
875
-
876
- out = torch.empty_like(out_gate, dtype=torch.float32)
877
-
878
- BLOCK_SIZE = triton.next_power_of_2(C)
879
- assert(BLOCK_SIZE == 128)
880
-
881
- def grid(meta):
882
- return (triton.cdiv(R, meta["ROW_BLOCK_SIZE"]),)
883
-
884
- layernorm_kernel_eltwise[grid](
885
- x, out, weight, bias, out_gate,
886
- seq_len,
887
- x.stride(0), x.stride(3),
888
- R, C, eps,
889
- BLOCK_SIZE=BLOCK_SIZE
890
- )
891
-
892
- return out
893
-
894
-
895
- def kernel_global(data: input_t) -> output_t:
896
- """
897
- Reference implementation of TriMul using PyTorch.
898
-
899
- Args:
900
- data: Tuple of (input: torch.Tensor, mask: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
901
- - input: Input tensor of shape [batch_size, seq_len, seq_len, dim]
902
- - mask: Mask tensor of shape [batch_size, seq_len, seq_len]
903
- - weights: Dictionary containing model weights
904
- - config: Dictionary containing model configuration parameters
905
- """
906
- input_tensor, mask, weights, config = data
907
-
908
- left_proj_weight = weights["left_proj.weight"].to(torch.float16)
909
- right_proj_weight = weights["right_proj.weight"].to(torch.float16)
910
- left_gate_weight = weights["left_gate.weight"].to(torch.float16)
911
- right_gate_weight = weights["right_gate.weight"].to(torch.float16)
912
- out_gate_weight = weights["out_gate.weight"].to(torch.float16)
913
-
914
- hidden_dim = config["hidden_dim"]
915
- # trimul = TriMul(dim=config["dim"], hidden_dim=config["hidden_dim"]).to(input_tensor.device)
916
-
917
- x = input_tensor
918
-
919
- batch_size, seq_len, _, dim = x.shape
920
-
921
- x = triton_layernorm_first(x, weights['norm.weight'], weights['norm.bias'])
922
- # x = torch.nn.functional.layer_norm(x, (dim,), eps=1e-5, weight=weights['norm.weight'], bias=weights['norm.bias'])
923
-
924
- left, right, out_gate = two_mm(x, left_proj_weight, right_proj_weight, left_gate_weight, right_gate_weight, out_gate_weight, mask)
925
- # left = torch.nn.functional.linear(x, weights['left_proj.weight'].to(torch.float16))
926
- # right = torch.nn.functional.linear(x, weights['right_proj.weight'].to(torch.float16))
927
-
928
- # left = left * mask.unsqueeze(-1)
929
- # right = right * mask.unsqueeze(-1)
930
-
931
- '''
932
- left = left.to(torch.float32)
933
- right = right.to(torch.float32)
934
- x = x.to(torch.float32)
935
-
936
- left_gate = left_gate.sigmoid()
937
- right_gate = right_gate.sigmoid()
938
- out_gate = out_gate.sigmoid()
939
- '''
940
-
941
- # Elementwise multiplication now handled in kernel
942
- # left = left * left_gate
943
- # right = right * right_gate
944
-
945
- # out = einsum('... i k d, ... j k d -> ... i j d', left, right)
946
- out = torch.bmm(left.permute(0, 3, 1, 2).view(-1, left.shape[1], left.shape[2]), right.permute(0, 3, 2, 1).view(-1, right.shape[2], right.shape[1]))
947
- out = out.view(batch_size, hidden_dim, seq_len, seq_len).permute(0, 2, 3, 1)
948
-
949
- # out = torch.compile(second_layernorm_mul, dynamic=False)(out, hidden_dim, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
950
- out = triton_layernorm_eltwise(out, weights['to_out_norm.weight'], weights['to_out_norm.bias'], out_gate)
951
- # out = torch.nn.functional.layer_norm(out, (hidden_dim,), eps=1e-5, weight=weights['to_out_norm.weight'].to(out.dtype), bias=weights['to_out_norm.bias'].to(out.dtype))
952
- # out = out * out_gate
953
- return torch.nn.functional.linear(out, weights['to_out.weight'])
954
-
955
- '''
956
- # Fill in the given weights of the model
957
- trimul.norm.weight = nn.Parameter(weights['norm.weight'])
958
- trimul.norm.bias = nn.Parameter(weights['norm.bias'])
959
- trimul.left_proj.weight = nn.Parameter(weights['left_proj.weight'])
960
- trimul.right_proj.weight = nn.Parameter(weights['right_proj.weight'])
961
- trimul.left_gate.weight = nn.Parameter(weights['left_gate.weight'])
962
- trimul.right_gate.weight = nn.Parameter(weights['right_gate.weight'])
963
- trimul.out_gate.weight = nn.Parameter(weights['out_gate.weight'])
964
- trimul.to_out_norm.weight = nn.Parameter(weights['to_out_norm.weight'])
965
- trimul.to_out_norm.bias = nn.Parameter(weights['to_out_norm.bias'])
966
- trimul.to_out.weight = nn.Parameter(weights['to_out.weight'])
967
-
968
- output = trimul(input_tensor, mask)
969
-
970
- return output
971
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/trimul_gpumode/__init__.py DELETED
@@ -1,26 +0,0 @@
1
- import ctypes
2
- import sys
3
-
4
- import importlib
5
- from pathlib import Path
6
- from types import ModuleType
7
-
8
- def _import_from_path(file_path: Path) -> ModuleType:
9
- # We cannot use the module name as-is, after adding it to `sys.modules`,
10
- # it would also be used for other imports. So, we make a module name that
11
- # depends on the path for it to be unique using the hex-encoded hash of
12
- # the path.
13
- path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value)
14
- module_name = path_hash
15
- spec = importlib.util.spec_from_file_location(module_name, file_path)
16
- if spec is None:
17
- raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
18
- module = importlib.util.module_from_spec(spec)
19
- if module is None:
20
- raise ImportError(f"Cannot load module {module_name} from spec")
21
- sys.modules[module_name] = module
22
- spec.loader.exec_module(module) # type: ignore
23
- return module
24
-
25
-
26
- globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py")))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/trimul_mi300.py DELETED
@@ -1,524 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
12
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
13
-
14
- # Configurations with larger block sizes for better data reuse
15
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
16
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=2),
17
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=2),
18
-
19
- # Configurations with deeper K dimension
20
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
21
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
22
-
23
- # More extreme configurations to test the limits
24
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=2),
25
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=2),
26
-
27
- # Configurations with fewer warps
28
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=2),
29
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=2),
30
-
31
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
32
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
33
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
34
- ],
35
- key=['M', 'N', 'K'],
36
- )
37
- @triton.jit
38
- def fused_ln_dual_matmul_kernel(
39
- # Pointers (9)
40
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
41
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
42
- # Metadata (5)
43
- M, H, K, s1, s2,
44
- # Strides (16)
45
- stride_x_m, stride_x_k,
46
- stride_w4_k, stride_w4_n,
47
- stride_wog_k, stride_wog_n,
48
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
49
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
50
- stride_og_m, stride_og_h,
51
- stride_mask_m, stride_mask_h,
52
- # Constexpr (from decorator and kwargs)
53
- LN_EPS: tl.constexpr,
54
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
55
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
56
- ):
57
- # --- PID Mapping: Based on the LARGER 4*H problem ---
58
- pid = tl.program_id(axis=0)
59
- N_4way = 4 * H
60
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
61
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
62
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
63
- group_id = pid // num_pid_in_group
64
- first_pid_m = group_id * GROUP_SIZE_M
65
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
66
- pid_m = first_pid_m + (pid % group_size_m)
67
- pid_n = (pid % num_pid_in_group) // group_size_m
68
-
69
- # --- SHARED LayerNorm calculation (done only ONCE) ---
70
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
71
- m_mask = offs_m < M
72
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
73
-
74
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
75
- for k_offset in range(0, K, BLOCK_SIZE_K):
76
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
77
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
78
- k_mask = (k_offset + k_chunk_offs) < K
79
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
80
- mean += tl.sum(x_chunk, axis=1)
81
- mean /= K
82
-
83
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
84
- for k_offset in range(0, K, BLOCK_SIZE_K):
85
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
86
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
87
- k_mask = (k_offset + k_chunk_offs) < K
88
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
89
- x_centered = x_chunk - mean[:, None]
90
- var += tl.sum(x_centered * x_centered, axis=1)
91
- var /= K
92
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
93
-
94
- # --- Matmul Loop 1: For the 4-Way Projections ---
95
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
96
- offs_k = tl.arange(0, BLOCK_SIZE_K)
97
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
98
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
99
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
100
-
101
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
102
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
103
- k_block_start = k * BLOCK_SIZE_K;
104
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
105
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
106
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
107
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
108
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
109
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
110
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
111
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
112
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
113
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
114
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
115
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
117
-
118
- #Some threads should calclate out_gate
119
- if pid_n * BLOCK_SIZE_N < H:
120
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
121
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
122
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
123
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
124
- accumulator_og += tl.dot(x_norm_tile, w_tile)
125
-
126
- if pid_n * BLOCK_SIZE_N < H:
127
- og_out = tl.sigmoid(accumulator_og)
128
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
129
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
130
- tl.store(outg_ptrs, og_out, mask=og_mask)
131
-
132
- # --- Fusion Logic for 4-Way Part ---
133
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
134
- role_idx = tl.arange(0, 4)[None, None, :]
135
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
136
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
137
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
138
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
139
-
140
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
141
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
142
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
143
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
144
-
145
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
146
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
147
-
148
- s1s2 = s1 * s2
149
- offs_b = offs_m // s1s2
150
- offs_s1 = (offs_m % s1s2) // s2
151
- offs_s2 = offs_m % s2
152
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
153
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
154
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
155
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
156
-
157
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
158
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
159
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
160
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
161
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
162
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
163
-
164
- @triton.autotune(
165
- configs=[
166
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
167
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
168
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
169
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
170
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
171
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
172
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
173
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
174
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
175
- ],
176
- key=['s1', 's2', 'H'],
177
- )
178
- @triton.jit
179
- def bmm_coalesced_kernel(
180
- # Pointers
181
- Left_ptr, Right_ptr, Out_ptr,
182
- # Dimensions
183
- bs, s1, s2, H,
184
- # Strides
185
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
186
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
187
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
188
- # Kernel parameters
189
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
190
- GROUP_SIZE_M: tl.constexpr,
191
- ):
192
- # Grid and program IDs
193
- pid = tl.program_id(axis=0)
194
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
195
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
196
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
197
- group_id = pid // num_pid_in_group
198
- first_pid_m = group_id * GROUP_SIZE_M
199
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
200
- pid_m = first_pid_m + (pid % group_size_m)
201
- pid_n = (pid % num_pid_in_group) // group_size_m
202
-
203
- pid_bh = tl.program_id(axis=1)
204
- pid_b = pid_bh // H
205
- pid_h = pid_bh % H
206
-
207
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
208
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
209
- offs_k = tl.arange(0, BLOCK_SIZE_K)
210
-
211
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
212
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
213
-
214
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
215
-
216
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
217
- k_start = k * BLOCK_SIZE_K
218
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
219
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
220
-
221
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
222
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
223
-
224
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
225
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
226
-
227
- accumulator += tl.dot(a, b)
228
-
229
- # --- Coalesced Write ---
230
- # Write to a standard (bs, H, s1, s1) layout
231
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
232
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
233
-
234
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
235
- tl.store(out_ptrs, accumulator, mask=c_mask)
236
-
237
- @triton.autotune(
238
- configs=[
239
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
240
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
241
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
243
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
244
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
245
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
247
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
248
- ],
249
- key=['H', 'D'],
250
- )
251
- @triton.jit
252
- def fused_final_kernel(
253
- # Pointers
254
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
255
- # Metadata
256
- M, H, D, s1, # M_gate = bs*s1*s2
257
- # Strides
258
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
259
- stride_gate_m, stride_gate_h,
260
- stride_proj_d, stride_proj_h,
261
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
262
- # Constants
263
- LN_EPS: tl.constexpr,
264
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
265
- GROUP_SIZE_M: tl.constexpr,
266
- ):
267
- # --- Grid and PID Setup for Matmul ---
268
- pid = tl.program_id(axis=0)
269
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
270
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
271
-
272
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
273
- group_id = pid // num_pid_in_group
274
- first_pid_m = group_id * GROUP_SIZE_M
275
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
276
- pid_m = first_pid_m + (pid % group_size_m)
277
- pid_n = (pid % num_pid_in_group) // group_size_m
278
-
279
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
280
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
281
- m_mask = offs_m < M
282
-
283
- # Decompose M back to (b, r, c) for reordering lookups
284
- s1s1 = s1 * s1
285
- b = offs_m // s1s1
286
- r = (offs_m % s1s1) // s1
287
- c = offs_m % s1
288
-
289
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
291
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
292
-
293
- for k_offset in range(0, H, BLOCK_SIZE_K):
294
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
295
- k_mask = offs_k < H
296
-
297
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
298
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
299
-
300
- # Accumulate sum and sum of squares in one pass
301
- sum_x += tl.sum(in_chunk, axis=1)
302
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
303
-
304
- # Finalize statistics
305
- mean = sum_x / H
306
- var = (sum_x2 / H) - (mean * mean)
307
- rstd = tl.math.rsqrt(var + LN_EPS)
308
-
309
- # --- Pass 3: Fused Gating and Matmul ---
310
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
311
- for k_offset in range(0, H, BLOCK_SIZE_K):
312
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
313
- k_mask = offs_k < H
314
-
315
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
316
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
317
- a_norm = (a - mean[:, None]) * rstd[:, None]
318
-
319
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
320
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
321
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
322
-
323
- proj_ptrs = ProjW_ptr + \
324
- offs_n[None, :] * stride_proj_d + \
325
- offs_k[:, None] * stride_proj_h
326
-
327
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
328
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
329
- a_gated = a_norm * gate
330
-
331
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
332
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
333
-
334
- # --- Store Final Output ---
335
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
336
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
337
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
338
-
339
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
340
-
341
- def compiledtrimul_fused_interleaved(
342
- x: torch.Tensor,
343
- mask_mh: torch.Tensor,
344
- norm_weight: torch.Tensor,
345
- norm_bias: torch.Tensor,
346
- W_4way: torch.Tensor, # Use the new weight matrices
347
- W_og: torch.Tensor,
348
- to_out_norm_weight: torch.Tensor,
349
- to_out_norm_bias: torch.Tensor,
350
- to_out_weight: torch.Tensor,
351
- h: int,
352
- ):
353
- bs, s1, s2, d = x.shape
354
- M, K, H = bs * s1 * s2, x.shape[-1], h
355
- x_flat = x.view(M, K)
356
-
357
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
358
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
359
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
360
-
361
- # The grid is launched for the larger 4*H problem
362
- N_4way = 4 * H
363
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
364
- fused_ln_dual_matmul_kernel[grid](
365
- # Pointers (9)
366
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
367
- left_final, right_final_t, og_mh,
368
- # Metadata (5) - M, H, K, s1, s2
369
- M, H, K, s1, s2,
370
- # Strides (16)
371
- x_flat.stride(0), x_flat.stride(1),
372
- W_4way.stride(0), W_4way.stride(1),
373
- W_og.stride(0), W_og.stride(1),
374
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
375
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
376
- og_mh.stride(0), og_mh.stride(1),
377
- mask_mh.stride(0), mask_mh.stride(1),
378
- # Constexpr (1)
379
- LN_EPS=1e-5
380
- )
381
-
382
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
383
-
384
- grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
385
- bmm_coalesced_kernel[grid_bmm](
386
- left_final, right_final_t, bmm_out_tmp,
387
- bs, s1, s2, H,
388
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
389
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
390
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
391
- )
392
-
393
- # --- Kernel 3: Fully Fused Final Stage ---
394
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
395
-
396
- grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
397
- fused_final_kernel[grid_final](
398
- # Pointers
399
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
400
- # Metadata
401
- M, H, d, s1,
402
- # Strides
403
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
404
- og_mh.stride(0), og_mh.stride(1),
405
- to_out_weight.stride(0), to_out_weight.stride(1), # Use strides of the corrected tensor
406
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
407
- # Constants
408
- LN_EPS=1e-5,
409
- )
410
-
411
- return final_out
412
-
413
- def pack_w_4way_efficient(weights):
414
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
415
- WL = weights['left_proj.weight']
416
- WLG = weights['left_gate.weight']
417
- WR = weights['right_proj.weight']
418
- WRG = weights['right_gate.weight']
419
- H, K = WL.shape
420
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
421
- ws = ws.contiguous().view(4 * H, K)
422
- return ws.t().to(torch.float16)
423
-
424
- def get_w_og(weights):
425
- """ Gets the transposed [K, H] out_gate weight matrix. """
426
- WOG = weights['out_gate.weight']
427
- return WOG.t().to(torch.float16)
428
-
429
- def compiledtrimul(
430
- x: torch.Tensor,
431
- mask: torch.Tensor,
432
- norm_weight: torch.Tensor,
433
- norm_bias: torch.Tensor,
434
- w_concat: torch.Tensor,
435
- to_out_norm_weight: torch.Tensor,
436
- to_out_norm_bias: torch.Tensor,
437
- to_out_weight: torch.Tensor,
438
- h: int
439
- ) -> torch.Tensor:
440
- """
441
- A barebones, compiled PyTorch function for the TriMul logic.
442
- """
443
- bs, s1, s2, d = x.shape
444
-
445
- # Initial LayerNorm
446
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
447
- # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
448
- all_projections = torch.mm(x_norm, w_concat)
449
-
450
- # Split back into individual projections
451
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
452
-
453
- # Apply mask and gates
454
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
455
- left = left * mask_expanded * torch.sigmoid(lg)
456
- right = right * mask_expanded * torch.sigmoid(rg)
457
- out_gate = torch.sigmoid(og)
458
-
459
- # Reshape for einsum
460
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
461
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
462
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
463
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
464
-
465
- # Apply layer norm and final gating
466
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
467
- gated = normed * out_gate
468
-
469
- # Final projection
470
- final_out_flat = gated @ to_out_weight.t()
471
- final_out = final_out_flat.view(bs, s1, s2, d)
472
-
473
- return final_out
474
-
475
- def small_kernel_pt_path(data):
476
- input_tensor, mask, weights, config = data
477
- w_concat = torch.cat([
478
- weights['left_proj.weight'],
479
- weights['right_proj.weight'],
480
- weights['left_gate.weight'],
481
- weights['right_gate.weight'],
482
- weights['out_gate.weight']
483
- ], dim=0).t().contiguous().to(torch.float16)
484
- # Call the compiled function with prepared weights
485
- output = compiledtrimul(
486
- x=input_tensor.to(torch.float32),
487
- mask=mask.unsqueeze(-1),
488
- norm_weight=weights['norm.weight'].to(torch.float32),
489
- norm_bias=weights['norm.bias'].to(torch.float32),
490
- w_concat=w_concat,
491
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
492
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
493
- to_out_weight=weights['to_out.weight'].to(torch.float16),
494
- h=config["hidden_dim"]
495
- )
496
- return output
497
-
498
- def kernel_mi300(data):
499
- input_tensor, mask, weights, config = data
500
- bs, s1, s2, d = input_tensor.shape
501
-
502
- if s1 < 100:
503
- return small_kernel_pt_path(data)
504
-
505
- H = config["hidden_dim"]
506
-
507
- W_4way = pack_w_4way_efficient(weights)
508
- W_og = get_w_og(weights)
509
-
510
- M = bs * s1 * s2
511
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
512
-
513
- return compiledtrimul_fused_interleaved(
514
- x=input_tensor.to(torch.float32),
515
- mask_mh=mask_mh,
516
- norm_weight=weights['norm.weight'].to(torch.float32),
517
- norm_bias=weights['norm.bias'].to(torch.float32),
518
- W_4way=W_4way, # Pass the new 4-way matrix
519
- W_og=W_og, # Pass the new out_gate matrix
520
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
521
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
522
- to_out_weight=weights['to_out.weight'].to(torch.float16),
523
- h=H,
524
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/triton_a100.py DELETED
@@ -1,405 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- # Set PyTorch flags for performance
7
- torch.backends.cuda.matmul.allow_tf32 = True
8
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
9
-
10
- @triton.jit
11
- def fused_ln_dual_matmul_kernel(
12
- # Pointers (9)
13
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
14
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
15
- # Metadata (5)
16
- M, H, K, s1, s2,
17
- # Strides (16)
18
- stride_x_m, stride_x_k,
19
- stride_w4_k, stride_w4_n,
20
- stride_wog_k, stride_wog_n,
21
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
22
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
23
- stride_og_m, stride_og_h,
24
- stride_mask_m, stride_mask_h,
25
- # Constexpr (now passed as arguments from the host)
26
- LN_EPS: tl.constexpr,
27
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
28
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
29
- ):
30
- # --- PID Mapping: Based on the LARGER 4*H problem ---
31
- pid = tl.program_id(axis=0)
32
- N_4way = 4 * H
33
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
34
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
35
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
36
- group_id = pid // num_pid_in_group
37
- first_pid_m = group_id * GROUP_SIZE_M
38
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
39
- pid_m = first_pid_m + (pid % group_size_m)
40
- pid_n = (pid % num_pid_in_group) // group_size_m
41
-
42
- # --- SHARED LayerNorm calculation (done only ONCE) ---
43
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
44
- m_mask = offs_m < M
45
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
46
-
47
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
48
- for k_offset in range(0, K, BLOCK_SIZE_K):
49
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
50
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
51
- k_mask = (k_offset + k_chunk_offs) < K
52
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
53
- mean += tl.sum(x_chunk, axis=1)
54
- mean /= K
55
-
56
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
57
- for k_offset in range(0, K, BLOCK_SIZE_K):
58
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
59
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
60
- k_mask = (k_offset + k_chunk_offs) < K
61
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
62
- x_centered = x_chunk - mean[:, None]
63
- var += tl.sum(x_centered * x_centered, axis=1)
64
- var /= K
65
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
66
-
67
- # --- Matmul Loop 1: For the 4-Way Projections ---
68
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
69
- offs_k = tl.arange(0, BLOCK_SIZE_K)
70
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
71
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
73
-
74
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
75
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
76
- k_block_start = k * BLOCK_SIZE_K;
77
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
78
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
79
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
80
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
81
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
82
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
83
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
84
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
86
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
87
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
88
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
89
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
90
-
91
- if pid_n * BLOCK_SIZE_N < H:
92
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
- accumulator_og += tl.dot(x_norm_tile, w_tile)
97
-
98
- if pid_n * BLOCK_SIZE_N < H:
99
- og_out = tl.sigmoid(accumulator_og)
100
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
- tl.store(outg_ptrs, og_out, mask=og_mask)
103
-
104
- # --- Fusion Logic for 4-Way Part ---
105
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
- role_idx = tl.arange(0, 4)[None, None, :]
107
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
-
112
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
-
117
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
-
120
- s1s2 = s1 * s2
121
- offs_b = offs_m // s1s2
122
- offs_s1 = (offs_m % s1s2) // s2
123
- offs_s2 = offs_m % s2
124
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
-
129
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
-
136
- @triton.jit
137
- def bmm_coalesced_kernel(
138
- # Pointers
139
- Left_ptr, Right_ptr, Out_ptr,
140
- # Dimensions
141
- bs, s1, s2, H,
142
- # Strides
143
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
- # Kernel parameters
147
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
- GROUP_SIZE_M: tl.constexpr,
149
- ):
150
- pid = tl.program_id(axis=0)
151
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
152
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
153
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
154
- group_id = pid // num_pid_in_group
155
- first_pid_m = group_id * GROUP_SIZE_M
156
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
157
- pid_m = first_pid_m + (pid % group_size_m)
158
- pid_n = (pid % num_pid_in_group) // group_size_m
159
-
160
- pid_bh = tl.program_id(axis=1)
161
- pid_b = pid_bh // H
162
- pid_h = pid_bh % H
163
-
164
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
165
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
166
- offs_k = tl.arange(0, BLOCK_SIZE_K)
167
-
168
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
169
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
170
-
171
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
172
-
173
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
174
- k_start = k * BLOCK_SIZE_K
175
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
176
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
177
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
178
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
179
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
180
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
181
- accumulator += tl.dot(a, b)
182
-
183
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
184
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
185
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
186
- tl.store(out_ptrs, accumulator, mask=c_mask)
187
-
188
- @triton.jit
189
- def fused_final_kernel(
190
- # Pointers
191
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
192
- # Metadata
193
- M, H, D, s1,
194
- # Strides
195
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
196
- stride_gate_m, stride_gate_h,
197
- stride_proj_d, stride_proj_h,
198
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
199
- # Constants
200
- LN_EPS: tl.constexpr,
201
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
202
- GROUP_SIZE_M: tl.constexpr,
203
- ):
204
- pid = tl.program_id(axis=0)
205
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
206
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
207
-
208
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
209
- group_id = pid // num_pid_in_group
210
- first_pid_m = group_id * GROUP_SIZE_M
211
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
212
- pid_m = first_pid_m + (pid % group_size_m)
213
- pid_n = (pid % num_pid_in_group) // group_size_m
214
-
215
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
216
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
217
- m_mask = offs_m < M
218
-
219
- s1s1 = s1 * s1
220
- b = offs_m // s1s1
221
- r = (offs_m % s1s1) // s1
222
- c = offs_m % s1
223
-
224
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
225
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
226
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
227
-
228
- for k_offset in range(0, H, BLOCK_SIZE_K):
229
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
230
- k_mask = offs_k < H
231
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
232
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
233
- sum_x += tl.sum(in_chunk, axis=1)
234
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
235
-
236
- mean = sum_x / H
237
- var = (sum_x2 / H) - (mean * mean)
238
- rstd = tl.math.rsqrt(var + LN_EPS)
239
-
240
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
241
- for k_offset in range(0, H, BLOCK_SIZE_K):
242
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
243
- k_mask = offs_k < H
244
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
245
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
246
- a_norm = (a - mean[:, None]) * rstd[:, None]
247
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
248
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
249
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
250
- proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
251
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
252
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
253
- a_gated = a_norm * gate
254
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
255
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
256
-
257
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
258
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
259
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
260
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
261
-
262
- def compiledtrimul_fused_interleaved_final(
263
- x: torch.Tensor,
264
- mask_mh: torch.Tensor,
265
- norm_weight: torch.Tensor,
266
- norm_bias: torch.Tensor,
267
- W_4way: torch.Tensor,
268
- W_og: torch.Tensor,
269
- to_out_norm_weight: torch.Tensor,
270
- to_out_norm_bias: torch.Tensor,
271
- to_out_weight: torch.Tensor,
272
- h: int,
273
- ):
274
- bs, s1, s2, d = x.shape
275
- M, K, H = bs * s1 * s2, x.shape[-1], h
276
- x_flat = x.view(M, K)
277
-
278
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
279
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
280
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
281
-
282
- # --- Kernel 1: Fused LN + Dual Matmul ---
283
- N_4way = 4 * H
284
- # Hardcoded A100 best config: M128-N128-K32-GM8-HC32-W8-S2
285
- config_k1 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
286
- grid_k1 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
287
-
288
- fused_ln_dual_matmul_kernel[grid_k1](
289
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
290
- left_final, right_final_t, og_mh,
291
- M, H, K, s1, s2,
292
- x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
293
- W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
294
- left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
295
- right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
296
- mask_mh.stride(0), mask_mh.stride(1),
297
- LN_EPS=1e-5, **config_k1, num_warps=8, num_stages=2
298
- )
299
-
300
- # --- Kernel 2: Batched Matrix Multiplication ---
301
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
302
- # Hardcoded A100 best config: M128-N64-K32-GM8-W4-S3
303
- config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
304
- grid_k2 = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
305
-
306
- bmm_coalesced_kernel[grid_k2](
307
- left_final, right_final_t, bmm_out_tmp,
308
- bs, s1, s2, H,
309
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
310
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
311
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
312
- **config_k2, num_warps=4, num_stages=3
313
- )
314
-
315
- # --- Kernel 3: Fully Fused Final Stage ---
316
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
317
- # Hardcoded A100 best config: M32-N128-K32-GM8-W4-S3
318
- config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
319
- grid_k3 = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
320
-
321
- fused_final_kernel[grid_k3](
322
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
323
- M, H, d, s1,
324
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
325
- og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
326
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
327
- LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
328
- )
329
- return final_out
330
-
331
- def pack_w_4way_efficient(weights):
332
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
333
- WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
334
- H, K = WL.shape
335
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
336
- return ws.t().to(torch.float16)
337
-
338
- def get_w_og(weights):
339
- """ Gets the transposed [K, H] out_gate weight matrix. """
340
- return weights['out_gate.weight'].t().to(torch.float16)
341
-
342
- @torch.compile()
343
- def compiledtrimul(
344
- x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
345
- w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
346
- to_out_weight: torch.Tensor, h: int
347
- ) -> torch.Tensor:
348
- bs, s1, s2, d = x.shape
349
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
350
- all_projections = torch.mm(x_norm, w_concat)
351
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
352
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
353
- left = left * mask_expanded * torch.sigmoid(lg)
354
- right = right * mask_expanded * torch.sigmoid(rg)
355
- out_gate = torch.sigmoid(og)
356
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
357
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
358
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
359
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
360
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
361
- gated = normed * out_gate
362
- final_out_flat = gated @ to_out_weight.t()
363
- return final_out_flat.view(bs, s1, s1, d)
364
-
365
- def small_kernel_pt_path(data):
366
- input_tensor, mask, weights, config = data
367
- w_concat = torch.cat([
368
- weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
369
- weights['right_gate.weight'], weights['out_gate.weight']
370
- ], dim=0).t().contiguous().to(torch.float16)
371
- return compiledtrimul(
372
- x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
373
- norm_weight=weights['norm.weight'].to(torch.float32),
374
- norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
375
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
376
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
377
- to_out_weight=weights['to_out.weight'].to(torch.float16),
378
- h=config["hidden_dim"]
379
- )
380
-
381
- def kernel_a100(data):
382
- input_tensor, mask, weights, config = data
383
- bs, s1, s2, d = input_tensor.shape
384
-
385
- if s1 < 512: # Adjusted threshold based on observed BMM configs
386
- return small_kernel_pt_path(data)
387
-
388
- H = config["hidden_dim"]
389
- W_4way = pack_w_4way_efficient(weights)
390
- W_og = get_w_og(weights)
391
- M = bs * s1 * s2
392
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
393
-
394
- return compiledtrimul_fused_interleaved_final(
395
- x=input_tensor.to(torch.float32),
396
- mask_mh=mask_mh,
397
- norm_weight=weights['norm.weight'].to(torch.float32),
398
- norm_bias=weights['norm.bias'].to(torch.float32),
399
- W_4way=W_4way,
400
- W_og=W_og,
401
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
402
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
403
- to_out_weight=weights['to_out.weight'].to(torch.float16),
404
- h=H,
405
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/triton_b200.py DELETED
@@ -1,411 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.jit
10
- def fused_ln_dual_matmul_kernel(
11
- # Pointers (9)
12
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
13
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
14
- # Metadata (5)
15
- M, H, K, s1, s2,
16
- # Strides (16)
17
- stride_x_m, stride_x_k,
18
- stride_w4_k, stride_w4_n,
19
- stride_wog_k, stride_wog_n,
20
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
21
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
22
- stride_og_m, stride_og_h,
23
- stride_mask_m, stride_mask_h,
24
- # Constexpr (now passed as arguments from the host)
25
- LN_EPS: tl.constexpr,
26
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
27
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
28
- ):
29
- # --- PID Mapping: Based on the LARGER 4*H problem ---
30
- pid = tl.program_id(axis=0)
31
- N_4way = 4 * H
32
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
33
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
34
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
35
- group_id = pid // num_pid_in_group
36
- first_pid_m = group_id * GROUP_SIZE_M
37
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
38
- pid_m = first_pid_m + (pid % group_size_m)
39
- pid_n = (pid % num_pid_in_group) // group_size_m
40
-
41
- # --- SHARED LayerNorm calculation (done only ONCE) ---
42
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
43
- m_mask = offs_m < M
44
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
45
-
46
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
47
- for k_offset in range(0, K, BLOCK_SIZE_K):
48
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
49
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
50
- k_mask = (k_offset + k_chunk_offs) < K
51
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
52
- mean += tl.sum(x_chunk, axis=1)
53
- mean /= K
54
-
55
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
56
- for k_offset in range(0, K, BLOCK_SIZE_K):
57
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
58
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
59
- k_mask = (k_offset + k_chunk_offs) < K
60
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
61
- x_centered = x_chunk - mean[:, None]
62
- var += tl.sum(x_centered * x_centered, axis=1)
63
- var /= K
64
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
65
-
66
- # --- Matmul Loop 1: For the 4-Way Projections ---
67
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
68
- offs_k = tl.arange(0, BLOCK_SIZE_K)
69
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
70
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
71
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
72
-
73
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
74
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75
- k_block_start = k * BLOCK_SIZE_K;
76
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
77
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
78
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
79
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
80
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
81
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
82
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
83
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
84
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
85
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
86
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
87
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
88
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
89
-
90
- #Some threads should calclate out_gate
91
- if pid_n * BLOCK_SIZE_N < H:
92
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
93
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
94
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
95
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
96
- accumulator_og += tl.dot(x_norm_tile, w_tile)
97
-
98
- if pid_n * BLOCK_SIZE_N < H:
99
- og_out = tl.sigmoid(accumulator_og)
100
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
101
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
102
- tl.store(outg_ptrs, og_out, mask=og_mask)
103
-
104
- # --- Fusion Logic for 4-Way Part ---
105
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
106
- role_idx = tl.arange(0, 4)[None, None, :]
107
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
108
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
109
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
110
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
111
-
112
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
113
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
114
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
115
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
116
-
117
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
118
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
119
-
120
- s1s2 = s1 * s2
121
- offs_b = offs_m // s1s2
122
- offs_s1 = (offs_m % s1s2) // s2
123
- offs_s2 = offs_m % s2
124
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
125
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
126
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
127
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
128
-
129
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
130
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
131
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
132
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1)
133
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
134
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
135
-
136
- @triton.jit
137
- def bmm_coalesced_kernel(
138
- # Pointers
139
- Left_ptr, Right_ptr, Out_ptr,
140
- # Dimensions
141
- bs, s1, s2, H,
142
- # Strides
143
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
144
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
145
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
146
- # Kernel parameters
147
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
148
- GROUP_SIZE_M: tl.constexpr,
149
- ):
150
- # Grid and program IDs
151
- pid = tl.program_id(axis=0)
152
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
153
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
154
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
155
- group_id = pid // num_pid_in_group
156
- first_pid_m = group_id * GROUP_SIZE_M
157
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
158
- pid_m = first_pid_m + (pid % group_size_m)
159
- pid_n = (pid % num_pid_in_group) // group_size_m
160
-
161
- pid_bh = tl.program_id(axis=1)
162
- pid_b = pid_bh // H
163
- pid_h = pid_bh % H
164
-
165
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
166
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
167
- offs_k = tl.arange(0, BLOCK_SIZE_K)
168
-
169
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
170
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
171
-
172
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
173
-
174
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
175
- k_start = k * BLOCK_SIZE_K
176
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
177
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
178
-
179
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
180
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
181
-
182
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
183
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
184
-
185
- accumulator += tl.dot(a, b)
186
-
187
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
188
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
189
-
190
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
191
- tl.store(out_ptrs, accumulator, mask=c_mask)
192
-
193
- @triton.jit
194
- def fused_final_kernel(
195
- # Pointers
196
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
197
- # Metadata
198
- M, H, D, s1,
199
- # Strides
200
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
201
- stride_gate_m, stride_gate_h,
202
- stride_proj_d, stride_proj_h,
203
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
204
- # Constants
205
- LN_EPS: tl.constexpr,
206
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
207
- GROUP_SIZE_M: tl.constexpr,
208
- ):
209
- pid = tl.program_id(axis=0)
210
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
211
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
212
-
213
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
214
- group_id = pid // num_pid_in_group
215
- first_pid_m = group_id * GROUP_SIZE_M
216
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
217
- pid_m = first_pid_m + (pid % group_size_m)
218
- pid_n = (pid % num_pid_in_group) // group_size_m
219
-
220
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
221
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
222
- m_mask = offs_m < M
223
-
224
- s1s1 = s1 * s1
225
- b = offs_m // s1s1
226
- r = (offs_m % s1s1) // s1
227
- c = offs_m % s1
228
-
229
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
230
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
231
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
232
-
233
- for k_offset in range(0, H, BLOCK_SIZE_K):
234
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
235
- k_mask = offs_k < H
236
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
237
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
238
- sum_x += tl.sum(in_chunk, axis=1)
239
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
240
-
241
- mean = sum_x / H
242
- var = (sum_x2 / H) - (mean * mean)
243
- rstd = tl.math.rsqrt(var + LN_EPS)
244
-
245
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
246
- for k_offset in range(0, H, BLOCK_SIZE_K):
247
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
248
- k_mask = offs_k < H
249
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
250
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
251
- a_norm = (a - mean[:, None]) * rstd[:, None]
252
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
253
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
254
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
255
- proj_ptrs = ProjW_ptr + offs_n[None, :] * stride_proj_d + offs_k[:, None] * stride_proj_h
256
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
257
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
258
- a_gated = a_norm * gate
259
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
260
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
261
-
262
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
263
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
264
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
265
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
266
-
267
- def compiledtrimul_fused_interleaved_final(
268
- x: torch.Tensor,
269
- mask_mh: torch.Tensor,
270
- norm_weight: torch.Tensor,
271
- norm_bias: torch.Tensor,
272
- W_4way: torch.Tensor,
273
- W_og: torch.Tensor,
274
- to_out_norm_weight: torch.Tensor,
275
- to_out_norm_bias: torch.Tensor,
276
- to_out_weight: torch.Tensor,
277
- h: int,
278
- ):
279
- bs, s1, s2, d = x.shape
280
- M, K, H = bs * s1 * s2, x.shape[-1], h
281
- x_flat = x.view(M, K)
282
-
283
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
284
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
285
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
286
-
287
- # --- Kernel 1: Fused LN + Dual Matmul ---
288
- # The grid is launched for the larger 4*H problem
289
- N_4way = 4 * H
290
- # Hardcoded best config from logs: M64-N128-K64-GM8-HC32-W4-S2
291
- config_k1 = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}
292
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
293
-
294
- fused_ln_dual_matmul_kernel[grid](
295
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
296
- left_final, right_final_t, og_mh,
297
- M, H, K, s1, s2,
298
- x_flat.stride(0), x_flat.stride(1), W_4way.stride(0), W_4way.stride(1),
299
- W_og.stride(0), W_og.stride(1), left_final.stride(0), left_final.stride(1),
300
- left_final.stride(2), left_final.stride(3), right_final_t.stride(0), right_final_t.stride(1),
301
- right_final_t.stride(2), right_final_t.stride(3), og_mh.stride(0), og_mh.stride(1),
302
- mask_mh.stride(0), mask_mh.stride(1),
303
- LN_EPS=1e-5, **config_k1, num_warps=4, num_stages=2
304
- )
305
-
306
- # --- Kernel 2: Batched Matrix Multiplication ---
307
- bmm_out_tmp = torch.empty((bs, H, s1, s1), device=x.device, dtype=torch.float16)
308
- # Hardcoded best config from logs: M128-N128-K32-GM8-W8-S3
309
- config_k2 = {'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
310
- grid_bmm = lambda meta: (triton.cdiv(s1, meta['BLOCK_SIZE_M']) * triton.cdiv(s1, meta['BLOCK_SIZE_N']), bs * H)
311
-
312
- bmm_coalesced_kernel[grid_bmm](
313
- left_final, right_final_t, bmm_out_tmp,
314
- bs, s1, s2, H,
315
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
316
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
317
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
318
- **config_k2, num_warps=8, num_stages=3
319
- )
320
-
321
- # --- Kernel 3: Fully Fused Final Stage ---
322
- final_out = torch.empty((bs, s1, s1, d), device=x.device, dtype=torch.float16)
323
- # Hardcoded best config from logs: M32-N128-K32-GM8-W4-S3
324
- config_k3 = {'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}
325
- grid_final = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(d, meta['BLOCK_SIZE_N']),)
326
-
327
- fused_final_kernel[grid_final](
328
- bmm_out_tmp, og_mh, to_out_norm_weight, to_out_norm_bias, to_out_weight, final_out,
329
- M, H, d, s1,
330
- bmm_out_tmp.stride(0), bmm_out_tmp.stride(1), bmm_out_tmp.stride(2), bmm_out_tmp.stride(3),
331
- og_mh.stride(0), og_mh.stride(1), to_out_weight.stride(0), to_out_weight.stride(1),
332
- final_out.stride(0), final_out.stride(1), final_out.stride(2), final_out.stride(3),
333
- LN_EPS=1e-5, **config_k3, num_warps=4, num_stages=3
334
- )
335
- return final_out
336
-
337
- def pack_w_4way_efficient(weights):
338
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
339
- WL, WLG, WR, WRG = (weights[k] for k in ['left_proj.weight', 'left_gate.weight', 'right_proj.weight', 'right_gate.weight'])
340
- H, K = WL.shape
341
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2).contiguous().view(4 * H, K)
342
- return ws.t().to(torch.float16)
343
-
344
- def get_w_og(weights):
345
- """ Gets the transposed [K, H] out_gate weight matrix. """
346
- return weights['out_gate.weight'].t().to(torch.float16)
347
-
348
- @torch.compile()
349
- def compiledtrimul(
350
- x: torch.Tensor, mask: torch.Tensor, norm_weight: torch.Tensor, norm_bias: torch.Tensor,
351
- w_concat: torch.Tensor, to_out_norm_weight: torch.Tensor, to_out_norm_bias: torch.Tensor,
352
- to_out_weight: torch.Tensor, h: int
353
- ) -> torch.Tensor:
354
- bs, s1, s2, d = x.shape
355
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
356
- all_projections = torch.mm(x_norm, w_concat)
357
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
358
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
359
- left = left * mask_expanded * torch.sigmoid(lg)
360
- right = right * mask_expanded * torch.sigmoid(rg)
361
- out_gate = torch.sigmoid(og)
362
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
363
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
364
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
365
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
366
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
367
- gated = normed * out_gate
368
- final_out_flat = gated @ to_out_weight.t()
369
- return final_out_flat.view(bs, s1, s1, d)
370
-
371
- def small_kernel_pt_path(data):
372
- input_tensor, mask, weights, config = data
373
- w_concat = torch.cat([
374
- weights['left_proj.weight'], weights['right_proj.weight'], weights['left_gate.weight'],
375
- weights['right_gate.weight'], weights['out_gate.weight']
376
- ], dim=0).t().contiguous().to(torch.float16)
377
- return compiledtrimul(
378
- x=input_tensor.to(torch.float32), mask=mask.unsqueeze(-1),
379
- norm_weight=weights['norm.weight'].to(torch.float32),
380
- norm_bias=weights['norm.bias'].to(torch.float32), w_concat=w_concat,
381
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
382
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
383
- to_out_weight=weights['to_out.weight'].to(torch.float16),
384
- h=config["hidden_dim"]
385
- )
386
-
387
- def kernel_b200(data):
388
- input_tensor, mask, weights, config = data
389
- bs, s1, s2, d = input_tensor.shape
390
-
391
- if s1 < 800:
392
- return small_kernel_pt_path(data)
393
-
394
- H = config["hidden_dim"]
395
- W_4way = pack_w_4way_efficient(weights)
396
- W_og = get_w_og(weights)
397
- M = bs * s1 * s2
398
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16)
399
-
400
- return compiledtrimul_fused_interleaved_final(
401
- x=input_tensor.to(torch.float32),
402
- mask_mh=mask_mh,
403
- norm_weight=weights['norm.weight'].to(torch.float32),
404
- norm_bias=weights['norm.bias'].to(torch.float32),
405
- W_4way=W_4way,
406
- W_og=W_og,
407
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
408
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
409
- to_out_weight=weights['to_out.weight'].to(torch.float16),
410
- h=H,
411
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch-xpu/triton_h100.py DELETED
@@ -1,509 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import triton
4
- import triton.language as tl
5
-
6
- torch.backends.cuda.matmul.allow_tf32 = True
7
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
8
-
9
- @triton.autotune(
10
- configs=[
11
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
12
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
13
-
14
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=3),
15
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=8, num_stages=4),
16
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=8, num_stages=4),
17
-
18
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=4),
19
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=3),
20
-
21
- triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 16}, num_warps=4, num_stages=5),
22
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 64}, num_warps=4, num_stages=5),
23
-
24
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=4, num_stages=3),
25
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8, 'H_CHUNK_SIZE': 32}, num_warps=2, num_stages=4),
26
- ],
27
- key=['M', 'N', 'K'],
28
- )
29
- @triton.jit
30
- def fused_ln_dual_matmul_kernel(
31
- # Pointers (9)
32
- X_ptr, W_4way_ptr, W_og_ptr, Mask_ptr, Norm_Weight_ptr, Norm_Bias_ptr,
33
- OutLeft_ptr, OutRight_ptr, OutOG_ptr,
34
- # Metadata (5)
35
- M, H, K, s1, s2,
36
- # Strides (16)
37
- stride_x_m, stride_x_k,
38
- stride_w4_k, stride_w4_n,
39
- stride_wog_k, stride_wog_n,
40
- stride_ol_bs, stride_ol_h, stride_ol_s1, stride_ol_s2,
41
- stride_or_t_bs, stride_or_t_h, stride_or_t_s2, stride_or_t_s1,
42
- stride_og_m, stride_og_h,
43
- stride_mask_m, stride_mask_h,
44
- # Constexpr (from decorator and kwargs)
45
- LN_EPS: tl.constexpr,
46
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
47
- GROUP_SIZE_M: tl.constexpr, H_CHUNK_SIZE: tl.constexpr,
48
- ):
49
- # --- PID Mapping: Based on the LARGER 4*H problem ---
50
- pid = tl.program_id(axis=0)
51
- N_4way = 4 * H
52
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
53
- num_pid_n = tl.cdiv(N_4way, BLOCK_SIZE_N)
54
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
55
- group_id = pid // num_pid_in_group
56
- first_pid_m = group_id * GROUP_SIZE_M
57
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
58
- pid_m = first_pid_m + (pid % group_size_m)
59
- pid_n = (pid % num_pid_in_group) // group_size_m
60
-
61
- # --- SHARED LayerNorm calculation (done only ONCE) ---
62
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
63
- m_mask = offs_m < M
64
- x_rows_base_ptr = X_ptr + offs_m[:, None] * stride_x_m
65
-
66
- mean = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
67
- for k_offset in range(0, K, BLOCK_SIZE_K):
68
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
69
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
70
- k_mask = (k_offset + k_chunk_offs) < K
71
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
72
- mean += tl.sum(x_chunk, axis=1)
73
- mean /= K
74
-
75
- var = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
76
- for k_offset in range(0, K, BLOCK_SIZE_K):
77
- k_chunk_offs = tl.arange(0, BLOCK_SIZE_K)
78
- x_ptrs = x_rows_base_ptr + (k_offset + k_chunk_offs)[None, :]
79
- k_mask = (k_offset + k_chunk_offs) < K
80
- x_chunk = tl.load(x_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
81
- x_centered = x_chunk - mean[:, None]
82
- var += tl.sum(x_centered * x_centered, axis=1)
83
- var /= K
84
- rstd = 1.0 / tl.sqrt(var + LN_EPS)
85
-
86
- # --- Matmul Loop 1: For the 4-Way Projections ---
87
- offs_n_4way = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
- offs_k = tl.arange(0, BLOCK_SIZE_K)
89
- w_4way_ptrs_base = W_4way_ptr + (offs_n_4way[None, :] * stride_w4_n)
90
- accumulator_4way = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
91
- accumulator_og = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92
-
93
- offs_n_og = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
94
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
95
- k_block_start = k * BLOCK_SIZE_K;
96
- x_ptrs = x_rows_base_ptr + (k_block_start + offs_k)[None, :] * stride_x_k
97
- w_ptrs = w_4way_ptrs_base + (k_block_start + offs_k)[:, None] * stride_w4_k
98
- x_mask = (offs_m[:, None] < M) & ((k_block_start + offs_k)[None, :] < K)
99
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_4way[None, :] < N_4way)
100
- x_tile = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32)
101
- norm_w_ptrs = Norm_Weight_ptr + k_block_start + offs_k
102
- norm_b_ptrs = Norm_Bias_ptr + k_block_start + offs_k
103
- nw = tl.load(norm_w_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
104
- nb = tl.load(norm_b_ptrs, mask=(k_block_start + offs_k) < K, other=0.0)
105
- x_norm_tile = (x_tile - mean[:, None]) * rstd[:, None]
106
- x_norm_tile = (x_norm_tile * nw[None, :] + nb[None, :]).to(tl.float16)
107
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
108
- accumulator_4way += tl.dot(x_norm_tile, w_tile)
109
-
110
- #Some threads should calclate out_gate
111
- if pid_n * BLOCK_SIZE_N < H:
112
- w_og_ptrs_base = W_og_ptr + (offs_n_og[None, :] * stride_wog_n)
113
- w_ptrs = w_og_ptrs_base + (k_block_start + offs_k)[:, None] * stride_wog_k
114
- w_mask = ((k_block_start + offs_k)[:, None] < K) & (offs_n_og[None, :] < H);
115
- w_tile = tl.load(w_ptrs, mask=w_mask, other=0.0)
116
- accumulator_og += tl.dot(x_norm_tile, w_tile)
117
-
118
- if pid_n * BLOCK_SIZE_N < H:
119
- og_out = tl.sigmoid(accumulator_og)
120
- outg_ptrs = OutOG_ptr + offs_m[:, None] * stride_og_m + offs_n_og[None, :] * stride_og_h
121
- og_mask = m_mask[:, None] & (offs_n_og[None, :] < H)
122
- tl.store(outg_ptrs, og_out, mask=og_mask)
123
-
124
- # --- Fusion Logic for 4-Way Part ---
125
- acc_reshaped = tl.reshape(accumulator_4way, (BLOCK_SIZE_M, H_CHUNK_SIZE, 4))
126
- role_idx = tl.arange(0, 4)[None, None, :]
127
- left_proj = tl.sum(tl.where(role_idx == 0, acc_reshaped, 0.0), axis=2)
128
- left_gate = tl.sum(tl.where(role_idx == 1, acc_reshaped, 0.0), axis=2)
129
- right_proj = tl.sum(tl.where(role_idx == 2, acc_reshaped, 0.0), axis=2)
130
- right_gate = tl.sum(tl.where(role_idx == 3, acc_reshaped, 0.0), axis=2)
131
-
132
- offs_h_chunk = (pid_n * H_CHUNK_SIZE) + tl.arange(0, H_CHUNK_SIZE)
133
- mask_ptrs = Mask_ptr + offs_m[:, None] * stride_mask_m + offs_h_chunk[None, :] * stride_mask_h
134
- m_mask_h = m_mask[:, None] & (offs_h_chunk[None, :] < H)
135
- mask_tile = tl.load(mask_ptrs, mask=m_mask_h, other=0.0)
136
-
137
- left_out = left_proj * tl.sigmoid(left_gate) * mask_tile
138
- right_out = right_proj * tl.sigmoid(right_gate) * mask_tile
139
-
140
- s1s2 = s1 * s2
141
- offs_b = offs_m // s1s2
142
- offs_s1 = (offs_m % s1s2) // s2
143
- offs_s2 = offs_m % s2
144
- offs_b_2d = tl.reshape(offs_b, (BLOCK_SIZE_M, 1))
145
- offs_h_2d = tl.reshape(offs_h_chunk, (1, H_CHUNK_SIZE))
146
- offs_s1_2d = tl.reshape(offs_s1, (BLOCK_SIZE_M, 1))
147
- offs_s2_2d = tl.reshape(offs_s2, (BLOCK_SIZE_M, 1))
148
-
149
- outl_ptrs = OutLeft_ptr + (offs_b_2d * stride_ol_bs + offs_h_2d * stride_ol_h +
150
- offs_s1_2d * stride_ol_s1 + offs_s2_2d * stride_ol_s2)
151
- outr_ptrs_t = OutRight_ptr + (offs_b_2d * stride_or_t_bs + offs_h_2d * stride_or_t_h +
152
- offs_s2_2d * stride_or_t_s2 + offs_s1_2d * stride_or_t_s1) # s2 offset uses s2 stride, s1 offset uses s1 stride
153
- tl.store(outl_ptrs, left_out, mask=m_mask_h)
154
- tl.store(outr_ptrs_t, right_out, mask=m_mask_h)
155
-
156
- @triton.autotune(
157
- configs=[
158
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
159
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
160
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
161
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=3),
162
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
163
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
164
- ],
165
- key=['s1', 's2', 'H'],
166
- )
167
- @triton.jit
168
- def bmm_coalesced_kernel(
169
- # Pointers
170
- Left_ptr, Right_ptr, Out_ptr,
171
- # Dimensions
172
- bs, s1, s2, H,
173
- # Strides
174
- stride_l_bs, stride_l_h, stride_l_s1, stride_l_s2,
175
- stride_r_bs, stride_r_h, stride_r_s2, stride_r_s1,
176
- stride_o_bs, stride_o_h, stride_o_s1, stride_o_s2,
177
- # Kernel parameters
178
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
179
- GROUP_SIZE_M: tl.constexpr,
180
- ):
181
- # Grid and program IDs
182
- pid = tl.program_id(axis=0)
183
- num_pid_m = tl.cdiv(s1, BLOCK_SIZE_M)
184
- num_pid_n = tl.cdiv(s1, BLOCK_SIZE_N)
185
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
186
- group_id = pid // num_pid_in_group
187
- first_pid_m = group_id * GROUP_SIZE_M
188
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
189
- pid_m = first_pid_m + (pid % group_size_m)
190
- pid_n = (pid % num_pid_in_group) // group_size_m
191
-
192
- pid_bh = tl.program_id(axis=1)
193
- pid_b = pid_bh // H
194
- pid_h = pid_bh % H
195
-
196
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
197
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
198
- offs_k = tl.arange(0, BLOCK_SIZE_K)
199
-
200
- left_ptrs_base = Left_ptr + pid_b * stride_l_bs + pid_h * stride_l_h
201
- right_ptrs_base = Right_ptr + pid_b * stride_r_bs + pid_h * stride_r_h
202
-
203
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
204
-
205
- for k in range(0, tl.cdiv(s2, BLOCK_SIZE_K)):
206
- k_start = k * BLOCK_SIZE_K
207
- a_ptrs = left_ptrs_base + (offs_m[:, None] * stride_l_s1 + (k_start + offs_k[None, :]) * stride_l_s2)
208
- b_ptrs = right_ptrs_base + ((k_start + offs_k[:, None]) * stride_r_s2 + offs_n[None, :] * stride_r_s1)
209
-
210
- a_mask = (offs_m[:, None] < s1) & ((k_start + offs_k[None, :]) < s2)
211
- b_mask = ((k_start + offs_k[:, None]) < s2) & (offs_n[None, :] < s1)
212
-
213
- a = tl.load(a_ptrs, mask=a_mask, other=0.0)
214
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
215
-
216
- accumulator += tl.dot(a, b)
217
-
218
- # --- Coalesced Write ---
219
- # Write to a standard (bs, H, s1, s1) layout
220
- out_ptrs = Out_ptr + pid_b * stride_o_bs + pid_h * stride_o_h + \
221
- offs_m[:, None] * stride_o_s1 + offs_n[None, :] * stride_o_s2
222
-
223
- c_mask = (offs_m[:, None] < s1) & (offs_n[None, :] < s1)
224
- tl.store(out_ptrs, accumulator, mask=c_mask)
225
-
226
- @torch.compile
227
- def torch_pt2(left_final, right_final_t, bs, s1, s2, d, h, to_out_norm_weight, to_out_norm_bias, og_mh, to_out_weight):
228
- bmm_out = torch.matmul(left_final, right_final_t)
229
- out_einsum_flat = bmm_out.permute(0, 2, 3, 1).reshape(bs * s1 * s1, h)
230
- # Apply layer norm and final gating
231
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
232
- gated = normed * og_mh
233
-
234
- # Final projection
235
- final_out_flat = gated @ to_out_weight.t()
236
- final_out = final_out_flat.view(bs, s1, s2, d)
237
- return final_out
238
-
239
- @triton.autotune(
240
- configs=[
241
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
242
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
243
- triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=3),
244
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
245
- triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_warps=8, num_stages=4),
246
- triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_warps=4, num_stages=4),
247
- ],
248
- key=['H', 'D'],
249
- )
250
- @triton.jit
251
- def fused_final_kernel(
252
- # Pointers
253
- In_ptr, Gate_ptr, NormW_ptr, NormB_ptr, ProjW_ptr, Out_ptr,
254
- # Metadata
255
- M, H, D, s1, # M_gate = bs*s1*s2
256
- # Strides
257
- stride_in_bs, stride_in_h, stride_in_s1_row, stride_in_s1_col,
258
- stride_gate_m, stride_gate_h,
259
- stride_proj_d, stride_proj_h,
260
- stride_out_bs, stride_out_s1_row, stride_out_s1_col, stride_out_d,
261
- # Constants
262
- LN_EPS: tl.constexpr,
263
- BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
264
- GROUP_SIZE_M: tl.constexpr,
265
- ):
266
- # --- Grid and PID Setup for Matmul ---
267
- pid = tl.program_id(axis=0)
268
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
269
- num_pid_n = tl.cdiv(D, BLOCK_SIZE_N)
270
-
271
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
272
- group_id = pid // num_pid_in_group
273
- first_pid_m = group_id * GROUP_SIZE_M
274
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
275
- pid_m = first_pid_m + (pid % group_size_m)
276
- pid_n = (pid % num_pid_in_group) // group_size_m
277
-
278
- offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
279
- offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
280
- m_mask = offs_m < M
281
-
282
- # Decompose M back to (b, r, c) for reordering lookups
283
- s1s1 = s1 * s1
284
- b = offs_m // s1s1
285
- r = (offs_m % s1s1) // s1
286
- c = offs_m % s1
287
-
288
- sum_x = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
289
- sum_x2 = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
290
- in_ptr_base = In_ptr + b * stride_in_bs + r * stride_in_s1_row + c * stride_in_s1_col
291
-
292
- for k_offset in range(0, H, BLOCK_SIZE_K):
293
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
294
- k_mask = offs_k < H
295
-
296
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
297
- in_chunk = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0).to(tl.float32)
298
-
299
- # Accumulate sum and sum of squares in one pass
300
- sum_x += tl.sum(in_chunk, axis=1)
301
- sum_x2 += tl.sum(in_chunk * in_chunk, axis=1)
302
-
303
- # Finalize statistics
304
- mean = sum_x / H
305
- var = (sum_x2 / H) - (mean * mean)
306
- rstd = tl.math.rsqrt(var + LN_EPS)
307
-
308
- # --- Pass 3: Fused Gating and Matmul ---
309
- acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
310
- for k_offset in range(0, H, BLOCK_SIZE_K):
311
- offs_k = k_offset + tl.arange(0, BLOCK_SIZE_K)
312
- k_mask = offs_k < H
313
-
314
- in_ptrs = in_ptr_base[:, None] + offs_k[None, :] * stride_in_h
315
- a = tl.load(in_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
316
- a_norm = (a - mean[:, None]) * rstd[:, None]
317
-
318
- norm_w = tl.load(NormW_ptr + offs_k, mask=k_mask, other=0.0)
319
- norm_b = tl.load(NormB_ptr + offs_k, mask=k_mask, other=0.0)
320
- a_norm = a_norm * norm_w[None, :] + norm_b[None, :]
321
-
322
- proj_ptrs = ProjW_ptr + \
323
- offs_n[None, :] * stride_proj_d + \
324
- offs_k[:, None] * stride_proj_h
325
-
326
- gate_ptrs = Gate_ptr + offs_m[:, None] * stride_gate_m + offs_k[None, :] * stride_gate_h
327
- gate = tl.load(gate_ptrs, mask=m_mask[:, None] & k_mask[None, :], other=0.0)
328
- a_gated = a_norm * gate
329
-
330
- b_w = tl.load(proj_ptrs, mask=k_mask[:, None] & (offs_n[None, :] < D), other=0.0)
331
- acc += tl.dot(a_gated.to(b_w.dtype), b_w)
332
-
333
- # --- Store Final Output ---
334
- offs_d = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
335
- out_ptr_base = Out_ptr + b*stride_out_bs + r*stride_out_s1_row + c*stride_out_s1_col
336
- out_ptrs = out_ptr_base[:, None] + offs_d[None, :] * stride_out_d
337
-
338
- tl.store(out_ptrs, acc, mask=m_mask[:, None] & (offs_d[None, :] < D))
339
-
340
- def compiledtrimul_fused_interleaved(
341
- x: torch.Tensor,
342
- mask_mh: torch.Tensor,
343
- norm_weight: torch.Tensor,
344
- norm_bias: torch.Tensor,
345
- W_4way: torch.Tensor, # Use the new weight matrices
346
- W_og: torch.Tensor,
347
- to_out_norm_weight: torch.Tensor,
348
- to_out_norm_bias: torch.Tensor,
349
- to_out_weight: torch.Tensor,
350
- h: int,
351
- ):
352
- bs, s1, s2, d = x.shape
353
- M, K, H = bs * s1 * s2, x.shape[-1], h
354
- x_flat = x.view(M, K)
355
-
356
- left_final = torch.empty((bs, H, s1, s2), device=x.device, dtype=torch.float16)
357
- right_final_t = torch.empty((bs, H, s2, s1), device=x.device, dtype=torch.float16)
358
- og_mh = torch.empty((M, H), device=x.device, dtype=torch.float16)
359
-
360
- # The grid is launched for the larger 4*H problem
361
- N_4way = 4 * H
362
- grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N_4way, meta['BLOCK_SIZE_N']),)
363
- fused_ln_dual_matmul_kernel[grid](
364
- # Pointers (9)
365
- x_flat, W_4way, W_og, mask_mh, norm_weight, norm_bias,
366
- left_final, right_final_t, og_mh,
367
- # Metadata (5) - M, H, K, s1, s2
368
- M, H, K, s1, s2,
369
- # Strides (16)
370
- x_flat.stride(0), x_flat.stride(1),
371
- W_4way.stride(0), W_4way.stride(1),
372
- W_og.stride(0), W_og.stride(1),
373
- left_final.stride(0), left_final.stride(1), left_final.stride(2), left_final.stride(3),
374
- right_final_t.stride(0), right_final_t.stride(1), right_final_t.stride(2), right_final_t.stride(3),
375
- og_mh.stride(0), og_mh.stride(1),
376
- mask_mh.stride(0), mask_mh.stride(1),
377
- # Constexpr (1)
378
- LN_EPS=1e-5
379
- )
380
- return torch_pt2(
381
- left_final, right_final_t,
382
- bs=bs,
383
- s1=s1,
384
- s2=s2,
385
- d=d,
386
- h=h,
387
- to_out_norm_weight=to_out_norm_weight,
388
- to_out_norm_bias=to_out_norm_bias,
389
- og_mh=og_mh,
390
- to_out_weight=to_out_weight
391
- )
392
-
393
- def pack_w_4way_efficient(weights):
394
- """ Packs L, LG, R, RG into a tight [K, 4*H] matrix. """
395
- WL = weights['left_proj.weight']
396
- WLG = weights['left_gate.weight']
397
- WR = weights['right_proj.weight']
398
- WRG = weights['right_gate.weight']
399
- H, K = WL.shape
400
- ws = torch.stack([WL, WLG, WR, WRG], dim=0).permute(1, 0, 2)
401
- ws = ws.contiguous().view(4 * H, K)
402
- return ws.t().to(torch.float16)
403
-
404
- def get_w_og(weights):
405
- """ Gets the transposed [K, H] out_gate weight matrix. """
406
- WOG = weights['out_gate.weight']
407
- return WOG.t().to(torch.float16)
408
-
409
-
410
- torch.backends.cuda.matmul.allow_tf32 = True
411
- torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
412
-
413
- @torch.compile
414
- def compiledtrimul(
415
- x: torch.Tensor,
416
- mask: torch.Tensor,
417
- norm_weight: torch.Tensor,
418
- norm_bias: torch.Tensor,
419
- w_concat: torch.Tensor,
420
- to_out_norm_weight: torch.Tensor,
421
- to_out_norm_bias: torch.Tensor,
422
- to_out_weight: torch.Tensor,
423
- h: int
424
- ) -> torch.Tensor:
425
- """
426
- A barebones, compiled PyTorch function for the TriMul logic.
427
- """
428
- bs, s1, s2, d = x.shape
429
-
430
- # Initial LayerNorm
431
- x_norm = F.layer_norm(x, (d,), norm_weight, norm_bias).view((bs * s1 * s2, d)).to(torch.float16)
432
- # Single large matmul: [M, d] @ [d, 5h] = [M, 5h]
433
- all_projections = torch.mm(x_norm, w_concat)
434
-
435
- # Split back into individual projections
436
- left, right, lg, rg, og = all_projections.chunk(5, dim=1)
437
-
438
- # Apply mask and gates
439
- mask_expanded = mask.expand(-1, -1, -1, h).reshape(-1, h)
440
- left = left * mask_expanded * torch.sigmoid(lg)
441
- right = right * mask_expanded * torch.sigmoid(rg)
442
- out_gate = torch.sigmoid(og)
443
-
444
- # Reshape for einsum
445
- left = left.view(bs, s1, s2, h).permute(0,3,1,2)
446
- right = right.view(bs, s1, s2, h).permute(0,3,1,2)
447
- out_p = torch.matmul(left.to(torch.float16), right.to(torch.float16).transpose(-1, -2))
448
- out_einsum_flat = out_p.permute(0,2,3,1).reshape(bs * s1 * s1, h)
449
-
450
- # Apply layer norm and final gating
451
- normed = F.layer_norm(out_einsum_flat, (h,), to_out_norm_weight, to_out_norm_bias).to(torch.float16)
452
- gated = normed * out_gate
453
-
454
- # Final projection
455
- final_out_flat = gated @ to_out_weight.t()
456
- final_out = final_out_flat.view(bs, s1, s2, d)
457
-
458
- return final_out
459
-
460
- def small_kernel_pt_path(data):
461
- input_tensor, mask, weights, config = data
462
- w_concat = torch.cat([
463
- weights['left_proj.weight'],
464
- weights['right_proj.weight'],
465
- weights['left_gate.weight'],
466
- weights['right_gate.weight'],
467
- weights['out_gate.weight']
468
- ], dim=0).t().contiguous().to(torch.float16)
469
- # Call the compiled function with prepared weights
470
- output = compiledtrimul(
471
- x=input_tensor.to(torch.float32),
472
- mask=mask.unsqueeze(-1),
473
- norm_weight=weights['norm.weight'].to(torch.float32),
474
- norm_bias=weights['norm.bias'].to(torch.float32),
475
- w_concat=w_concat,
476
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float32),
477
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float32),
478
- to_out_weight=weights['to_out.weight'].to(torch.float16),
479
- h=config["hidden_dim"]
480
- )
481
- return output
482
-
483
- def kernel_h100(data):
484
- input_tensor, mask, weights, config = data
485
- bs, s1, s2, d = input_tensor.shape
486
-
487
- if s1 <= 512:
488
- return small_kernel_pt_path(data)
489
-
490
- H = config["hidden_dim"]
491
-
492
- W_4way = pack_w_4way_efficient(weights)
493
- W_og = get_w_og(weights)
494
-
495
- M = bs * s1 * s2
496
- mask_mh = mask.unsqueeze(-1).expand(-1, -1, -1, H).reshape(M, H).to(torch.float16) #move into kernel possibly
497
-
498
- return compiledtrimul_fused_interleaved(
499
- x=input_tensor.to(torch.float32),
500
- mask_mh=mask_mh,
501
- norm_weight=weights['norm.weight'].to(torch.float32),
502
- norm_bias=weights['norm.bias'].to(torch.float32),
503
- W_4way=W_4way, # Pass the new 4-way matrix
504
- W_og=W_og, # Pass the new out_gate matrix
505
- to_out_norm_weight=weights['to_out_norm.weight'].to(torch.float16),
506
- to_out_norm_bias=weights['to_out_norm.bias'].to(torch.float16),
507
- to_out_weight=weights['to_out.weight'].to(torch.float16),
508
- h=H,
509
- )