danieldk HF Staff commited on
Commit
a36a00c
·
verified ·
1 Parent(s): fd7de3f

Uploaded using `kernel-builder`.

Browse files
build/torch-cuda/_ops.py CHANGED
@@ -1,8 +1,8 @@
1
  import torch
2
- ops = torch.ops._sonic_moe_57a1b31
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
- return f"_sonic_moe_57a1b31::{op_name}"
 
1
  import torch
2
+ ops = torch.ops._sonic_moe_75daa46
3
 
4
  def add_op_namespace_prefix(op_name: str):
5
  """
6
  Prefix op by namespace.
7
  """
8
+ return f"_sonic_moe_75daa46::{op_name}"
build/torch-cuda/functional/__init__.py CHANGED
@@ -70,6 +70,7 @@ class _UpProjection(torch.autograd.Function):
70
  is_varlen_K: bool,
71
  activation_type: ActivationType,
72
  is_inference_mode_enabled: bool,
 
73
  ) -> torch.Tensor:
74
  T, H = x.shape
75
  I, H, E = w1.shape
@@ -105,6 +106,7 @@ class _UpProjection(torch.autograd.Function):
105
  activation_type=activation_type.value,
106
  is_glu_activation=is_glu_activation,
107
  is_inference_mode_enabled=is_inference_mode_enabled,
 
108
  )
109
 
110
  ctx.T = T
@@ -115,6 +117,7 @@ class _UpProjection(torch.autograd.Function):
115
  ctx.I = I
116
  ctx.is_varlen_K = is_varlen_K
117
  ctx.is_glu_activation = is_glu_activation
 
118
  ctx.stream_id = stream_id
119
 
120
  ctx.save_for_backward(
@@ -146,6 +149,7 @@ class _UpProjection(torch.autograd.Function):
146
  K = ctx.K
147
  H = ctx.H
148
  is_glu_activation = ctx.is_glu_activation
 
149
  is_varlen_K = ctx.is_varlen_K
150
  stream_id = ctx.stream_id
151
 
@@ -190,6 +194,7 @@ class _UpProjection(torch.autograd.Function):
190
  s_scatter_idx=s_scatter_idx,
191
  is_glu_activation=is_glu_activation,
192
  stream_id=stream_id,
 
193
  )
194
 
195
  _up_projection_backward_weight(
@@ -201,6 +206,7 @@ class _UpProjection(torch.autograd.Function):
201
  x_gather_idx=x_gather_idx,
202
  is_glu_activation=is_glu_activation,
203
  stream_id=stream_id,
 
204
  )
205
 
206
  dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
@@ -215,7 +221,7 @@ class _UpProjection(torch.autograd.Function):
215
  is_varlen_K=is_varlen_K,
216
  )
217
 
218
- return dx_reduced, dw1, db1, *[None] * 12
219
 
220
 
221
  class _DownProjection(torch.autograd.Function):
@@ -486,6 +492,7 @@ def moe_general_routing_inputs(
486
  stream_id: int,
487
  activation_type: ActivationType,
488
  is_inference_mode_enabled: bool = False,
 
489
  ) -> tuple[torch.Tensor, torch.Tensor]:
490
  assert ((b1 is None) and (b2 is None)) or (
491
  (b1 is not None) and (b2 is not None)
@@ -531,6 +538,7 @@ def moe_general_routing_inputs(
531
  True, # is_varlen_K
532
  activation_type,
533
  is_inference_mode_enabled,
 
534
  )
535
 
536
  o = _DownProjection.apply(
 
70
  is_varlen_K: bool,
71
  activation_type: ActivationType,
72
  is_inference_mode_enabled: bool,
73
+ is_concatenated_gate_up: bool = False,
74
  ) -> torch.Tensor:
75
  T, H = x.shape
76
  I, H, E = w1.shape
 
106
  activation_type=activation_type.value,
107
  is_glu_activation=is_glu_activation,
108
  is_inference_mode_enabled=is_inference_mode_enabled,
109
+ is_concatenated_gate_up=is_concatenated_gate_up,
110
  )
111
 
112
  ctx.T = T
 
117
  ctx.I = I
118
  ctx.is_varlen_K = is_varlen_K
119
  ctx.is_glu_activation = is_glu_activation
120
+ ctx.is_concatenated_gate_up = is_concatenated_gate_up
121
  ctx.stream_id = stream_id
122
 
123
  ctx.save_for_backward(
 
149
  K = ctx.K
150
  H = ctx.H
151
  is_glu_activation = ctx.is_glu_activation
152
+ is_concatenated_gate_up = ctx.is_concatenated_gate_up
153
  is_varlen_K = ctx.is_varlen_K
154
  stream_id = ctx.stream_id
155
 
 
194
  s_scatter_idx=s_scatter_idx,
195
  is_glu_activation=is_glu_activation,
196
  stream_id=stream_id,
197
+ is_concatenated_gate_up=is_concatenated_gate_up,
198
  )
199
 
200
  _up_projection_backward_weight(
 
206
  x_gather_idx=x_gather_idx,
207
  is_glu_activation=is_glu_activation,
208
  stream_id=stream_id,
209
+ is_concatenated_gate_up=is_concatenated_gate_up,
210
  )
211
 
212
  dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device)
 
221
  is_varlen_K=is_varlen_K,
222
  )
223
 
224
+ return dx_reduced, dw1, db1, *[None] * 13
225
 
226
 
227
  class _DownProjection(torch.autograd.Function):
 
492
  stream_id: int,
493
  activation_type: ActivationType,
494
  is_inference_mode_enabled: bool = False,
495
+ is_concatenated_gate_up: bool = False,
496
  ) -> tuple[torch.Tensor, torch.Tensor]:
497
  assert ((b1 is None) and (b2 is None)) or (
498
  (b1 is not None) and (b2 is not None)
 
538
  True, # is_varlen_K
539
  activation_type,
540
  is_inference_mode_enabled,
541
+ is_concatenated_gate_up,
542
  )
543
 
544
  o = _DownProjection.apply(
build/torch-cuda/functional/backward.py CHANGED
@@ -206,6 +206,7 @@ def _up_projection_backward_act(
206
  s_scatter_idx: torch.Tensor,
207
  is_glu_activation: bool,
208
  stream_id: int,
 
209
  ) -> None:
210
  I, H, E = w1.size()
211
  if is_glu_activation:
@@ -228,9 +229,9 @@ def _up_projection_backward_act(
228
  mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
229
  current_stream = cuda.CUstream(stream_id)
230
 
231
- compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype)
232
  if compile_dx_key not in _up_projection_backward_act.compile_cache:
233
- dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation)
234
  tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
235
  _up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
236
  dx_module,
@@ -244,9 +245,9 @@ def _up_projection_backward_act(
244
  mE_permute_order,
245
  current_stream,
246
  )
247
- _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps
248
 
249
- dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"]
250
  _up_projection_backward_act.compile_cache[compile_dx_key](
251
  mDz,
252
  mW1_trans,
@@ -273,6 +274,7 @@ def _up_projection_backward_weight(
273
  x_gather_idx: torch.Tensor,
274
  is_glu_activation: bool,
275
  stream_id: int,
 
276
  ) -> None:
277
  I, H, E = dw1.size()
278
  if is_glu_activation:
@@ -293,9 +295,9 @@ def _up_projection_backward_weight(
293
  mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
294
  current_stream = cuda.CUstream(stream_id)
295
 
296
- compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype)
297
  if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
298
- dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation)
299
  tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
300
  _up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
301
  dw1_module,
@@ -308,9 +310,9 @@ def _up_projection_backward_weight(
308
  mE_permute_order,
309
  current_stream,
310
  )
311
- _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps
312
 
313
- dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"]
314
  _up_projection_backward_weight.compile_cache[compile_dw1_key](
315
  mX_trans,
316
  mDz_trans,
@@ -406,14 +408,14 @@ def _down_projection_backward_act(
406
  mE_permute_order,
407
  current_stream,
408
  )
409
- _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps
410
 
411
  if ds_partial is None:
412
  ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"]
413
  ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
414
  mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
415
 
416
- dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"]
417
  _down_projection_backward_act.compile_cache[compile_dz_key](
418
  mDout,
419
  mW2_trans,
@@ -520,9 +522,9 @@ def _down_projection_backward_weight(
520
  mE_permute_order,
521
  current_stream,
522
  )
523
- _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps
524
 
525
- dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"]
526
  _down_projection_backward_weight.compile_cache[compile_dw2_key](
527
  mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream
528
  )
 
206
  s_scatter_idx: torch.Tensor,
207
  is_glu_activation: bool,
208
  stream_id: int,
209
+ is_concatenated_gate_up: bool = False,
210
  ) -> None:
211
  I, H, E = w1.size()
212
  if is_glu_activation:
 
229
  mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
230
  current_stream = cuda.CUstream(stream_id)
231
 
232
+ compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype, is_concatenated_gate_up)
233
  if compile_dx_key not in _up_projection_backward_act.compile_cache:
234
+ dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation, is_concatenated_gate_up=is_concatenated_gate_up)
235
  tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)]
236
  _up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile(
237
  dx_module,
 
245
  mE_permute_order,
246
  current_stream,
247
  )
248
+ _up_projection_backward_act.compile_cache[(TENSORMAP, compile_dx_key)] = tensormaps
249
 
250
+ dx_tensormaps = _up_projection_backward_act.compile_cache[(TENSORMAP, compile_dx_key)]
251
  _up_projection_backward_act.compile_cache[compile_dx_key](
252
  mDz,
253
  mW1_trans,
 
274
  x_gather_idx: torch.Tensor,
275
  is_glu_activation: bool,
276
  stream_id: int,
277
+ is_concatenated_gate_up: bool = False,
278
  ) -> None:
279
  I, H, E = dw1.size()
280
  if is_glu_activation:
 
295
  mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id)
296
  current_stream = cuda.CUstream(stream_id)
297
 
298
+ compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype, is_concatenated_gate_up)
299
  if compile_dw1_key not in _up_projection_backward_weight.compile_cache:
300
+ dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation, is_concatenated_gate_up=is_concatenated_gate_up)
301
  tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)]
302
  _up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile(
303
  dw1_module,
 
310
  mE_permute_order,
311
  current_stream,
312
  )
313
+ _up_projection_backward_weight.compile_cache[(TENSORMAP, compile_dw1_key)] = tensormaps
314
 
315
+ dw1_tensormaps = _up_projection_backward_weight.compile_cache[(TENSORMAP, compile_dw1_key)]
316
  _up_projection_backward_weight.compile_cache[compile_dw1_key](
317
  mX_trans,
318
  mDz_trans,
 
408
  mE_permute_order,
409
  current_stream,
410
  )
411
+ _down_projection_backward_act.compile_cache[(TENSORMAP, compile_dz_key)] = tensormaps
412
 
413
  if ds_partial is None:
414
  ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"]
415
  ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device)
416
  mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id)
417
 
418
+ dz_tensormaps = _down_projection_backward_act.compile_cache[(TENSORMAP, compile_dz_key)]
419
  _down_projection_backward_act.compile_cache[compile_dz_key](
420
  mDout,
421
  mW2_trans,
 
522
  mE_permute_order,
523
  current_stream,
524
  )
525
+ _down_projection_backward_weight.compile_cache[(TENSORMAP, compile_dw2_key)] = tensormaps
526
 
527
+ dw2_tensormaps = _down_projection_backward_weight.compile_cache[(TENSORMAP, compile_dw2_key)]
528
  _down_projection_backward_weight.compile_cache[compile_dw2_key](
529
  mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream
530
  )
build/torch-cuda/functional/forward.py CHANGED
@@ -65,6 +65,7 @@ def _up_projection_forward(
65
  activation_type: str,
66
  is_glu_activation: bool,
67
  is_inference_mode_enabled: bool = False,
 
68
  ) -> None:
69
  I, H, E = w1.size()
70
  if is_glu_activation:
@@ -89,10 +90,10 @@ def _up_projection_forward(
89
 
90
  current_stream = cuda.CUstream(stream_id)
91
 
92
- compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled)
93
  if compile_w1_key not in _up_projection_forward.compile_cache:
94
  w1_module = HopperWgmma_MoE_Up_proj_Fwd(
95
- E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled
96
  )
97
  tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
98
  _up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
@@ -109,9 +110,9 @@ def _up_projection_forward(
109
  mE_permute_order,
110
  current_stream,
111
  )
112
- _up_projection_forward.compile_cache[TENSORMAP] = tensormaps
113
 
114
- w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP]
115
  _up_projection_forward.compile_cache[compile_w1_key](
116
  mX,
117
  mW1,
@@ -168,9 +169,9 @@ def _down_projection_forward(
168
  _down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
169
  w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
170
  )
171
- _down_projection_forward.compile_cache[TENSORMAP] = tensormaps
172
 
173
- w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP]
174
  _down_projection_forward.compile_cache[compile_w2_key](
175
  mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream
176
  )
 
65
  activation_type: str,
66
  is_glu_activation: bool,
67
  is_inference_mode_enabled: bool = False,
68
+ is_concatenated_gate_up: bool = False,
69
  ) -> None:
70
  I, H, E = w1.size()
71
  if is_glu_activation:
 
90
 
91
  current_stream = cuda.CUstream(stream_id)
92
 
93
+ compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled, is_concatenated_gate_up)
94
  if compile_w1_key not in _up_projection_forward.compile_cache:
95
  w1_module = HopperWgmma_MoE_Up_proj_Fwd(
96
+ E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled, is_concatenated_gate_up=is_concatenated_gate_up,
97
  )
98
  tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)]
99
  _up_projection_forward.compile_cache[compile_w1_key] = cute.compile(
 
110
  mE_permute_order,
111
  current_stream,
112
  )
113
+ _up_projection_forward.compile_cache[(TENSORMAP, compile_w1_key)] = tensormaps
114
 
115
+ w1_tensormaps = _up_projection_forward.compile_cache[(TENSORMAP, compile_w1_key)]
116
  _up_projection_forward.compile_cache[compile_w1_key](
117
  mX,
118
  mW1,
 
169
  _down_projection_forward.compile_cache[compile_w2_key] = cute.compile(
170
  w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream
171
  )
172
+ _down_projection_forward.compile_cache[(TENSORMAP, compile_w2_key)] = tensormaps
173
 
174
+ w2_tensormaps = _down_projection_forward.compile_cache[(TENSORMAP, compile_w2_key)]
175
  _down_projection_forward.compile_cache[compile_w2_key](
176
  mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream
177
  )
build/torch-cuda/functional/moe_config.py CHANGED
@@ -37,9 +37,10 @@ class HopperGEMMConfig:
37
 
38
 
39
  class HopperWgmma_MoE_Up_proj_Fwd:
40
- def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False):
41
  super().__init__()
42
  is_glu_activation = is_glu(activation_type)
 
43
  if is_glu_activation:
44
  assert (
45
  H % 64 == 0 and H >= 512 and I % 64 == 0
@@ -127,6 +128,18 @@ class HopperWgmma_MoE_Up_proj_Fwd:
127
  def __call__(
128
  self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream
129
  ):
 
 
 
 
 
 
 
 
 
 
 
 
130
  return self.module(
131
  mX,
132
  mW1,
@@ -424,7 +437,8 @@ class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd:
424
 
425
 
426
  class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
427
- def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
 
428
  super().__init__()
429
  if is_glu_activation:
430
  assert (
@@ -478,6 +492,17 @@ class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
478
  def __call__(
479
  self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream
480
  ):
 
 
 
 
 
 
 
 
 
 
 
481
  return self.module(
482
  mDz,
483
  mW1_trans,
@@ -504,7 +529,8 @@ class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
504
 
505
 
506
  class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
507
- def __init__(self, E: int, H: int, I: int, is_glu_activation: bool):
 
508
  super().__init__()
509
  if is_glu_activation:
510
  assert (
@@ -556,6 +582,18 @@ class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
556
 
557
  @cute.jit
558
  def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
 
 
 
 
 
 
 
 
 
 
 
 
559
  return self.module(
560
  mX_trans,
561
  mDz_trans,
 
37
 
38
 
39
  class HopperWgmma_MoE_Up_proj_Fwd:
40
+ def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False, is_concatenated_gate_up: bool = False):
41
  super().__init__()
42
  is_glu_activation = is_glu(activation_type)
43
+ self.is_concatenated_gate_up = is_concatenated_gate_up
44
  if is_glu_activation:
45
  assert (
46
  H % 64 == 0 and H >= 512 and I % 64 == 0
 
128
  def __call__(
129
  self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream
130
  ):
131
+ if const_expr(self.is_concatenated_gate_up):
132
+ # mW1 is (2*I, H, E) concatenated [gate; up]. Reshape N dim to ((2, I))
133
+ # so TMA reads interleaved pairs from the two halves.
134
+ half_N = mW1.shape[0] // 2
135
+ mW1 = cute.make_tensor(
136
+ mW1.iterator,
137
+ cute.make_layout(
138
+ ((2, half_N), mW1.shape[1], mW1.shape[2]),
139
+ stride=((half_N * mW1.stride[0], mW1.stride[0]), mW1.stride[1], mW1.stride[2]),
140
+ ),
141
+ )
142
+
143
  return self.module(
144
  mX,
145
  mW1,
 
437
 
438
 
439
  class HopperWgmma_MoE_Up_proj_ActGrad_Bwd:
440
+ def __init__(self, E: int, H: int, I: int, is_glu_activation: bool, is_concatenated_gate_up: bool = False):
441
+ self.is_concatenated_gate_up = is_concatenated_gate_up
442
  super().__init__()
443
  if is_glu_activation:
444
  assert (
 
492
  def __call__(
493
  self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream
494
  ):
495
+ if const_expr(self.is_concatenated_gate_up):
496
+ # mW1_trans is (H, 2*I, E) with concatenated N dim (dim 1).
497
+ # Reshape dim 1 to ((2, I)) so TMA reads interleaved from concatenated memory.
498
+ half_N = mW1_trans.shape[1] // 2
499
+ mW1_trans = cute.make_tensor(
500
+ mW1_trans.iterator,
501
+ cute.make_layout(
502
+ (mW1_trans.shape[0], (2, half_N), mW1_trans.shape[2]),
503
+ stride=(mW1_trans.stride[0], (half_N * mW1_trans.stride[1], mW1_trans.stride[1]), mW1_trans.stride[2]),
504
+ ),
505
+ )
506
  return self.module(
507
  mDz,
508
  mW1_trans,
 
529
 
530
 
531
  class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd:
532
+ def __init__(self, E: int, H: int, I: int, is_glu_activation: bool, is_concatenated_gate_up: bool = False):
533
+ self.is_concatenated_gate_up = is_concatenated_gate_up
534
  super().__init__()
535
  if is_glu_activation:
536
  assert (
 
582
 
583
  @cute.jit
584
  def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream):
585
+ if const_expr(self.is_concatenated_gate_up):
586
+ # mDw1_trans is (H, 2*I, E) — output in concatenated layout.
587
+ # Reshape dim 1 to ((2, I)) so GEMM writes interleaved results
588
+ # to the correct concatenated memory positions.
589
+ half_N = mDw1_trans.shape[1] // 2
590
+ mDw1_trans = cute.make_tensor(
591
+ mDw1_trans.iterator,
592
+ cute.make_layout(
593
+ (mDw1_trans.shape[0], (2, half_N), mDw1_trans.shape[2]),
594
+ stride=(mDw1_trans.stride[0], (half_N * mDw1_trans.stride[1], mDw1_trans.stride[1]), mDw1_trans.stride[2]),
595
+ ),
596
+ )
597
  return self.module(
598
  mX_trans,
599
  mDz_trans,