nologik commited on
Commit
ef12347
·
verified ·
1 Parent(s): 4770883

gdn: add layers.py + expose layers ns (Qwen3.6 Qwen3_5{,Moe}GatedDeltaNet + Qwen3-Next GatedDeltaNet) for kernelize()

Browse files
build/torch210-cxx11-cu130-aarch64-linux/layers.py CHANGED
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
- methods.
 
 
 
28
 
29
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
30
  build, so ``transformers`` silently falls back to a slow pure-torch
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
174
  has_backward: bool = False
175
  can_torch_compile: bool = False
176
 
177
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
178
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
179
 
180
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
200
  has_backward: bool = False
201
  can_torch_compile: bool = False
202
 
203
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
204
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
205
  batch_size, seq_len, _ = hidden_states.shape
206
 
 
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
+ methods. ``_validate_layer`` also requires the layer ``forward`` signature to
28
+ match the host's argument count exactly, so ``forward`` takes the same
29
+ ``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
30
+ transformers >= 5.10; the kernel path ignores those kwargs.
31
 
32
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
33
  build, so ``transformers`` silently falls back to a slow pure-torch
 
177
  has_backward: bool = False
178
  can_torch_compile: bool = False
179
 
180
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
181
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
182
 
183
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
 
203
  has_backward: bool = False
204
  can_torch_compile: bool = False
205
 
206
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
207
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
208
  batch_size, seq_len, _ = hidden_states.shape
209
 
build/torch211-cxx11-cu130-aarch64-linux/layers.py CHANGED
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
- methods.
 
 
 
28
 
29
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
30
  build, so ``transformers`` silently falls back to a slow pure-torch
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
174
  has_backward: bool = False
175
  can_torch_compile: bool = False
176
 
177
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
178
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
179
 
180
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
200
  has_backward: bool = False
201
  can_torch_compile: bool = False
202
 
203
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
204
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
205
  batch_size, seq_len, _ = hidden_states.shape
206
 
 
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
+ methods. ``_validate_layer`` also requires the layer ``forward`` signature to
28
+ match the host's argument count exactly, so ``forward`` takes the same
29
+ ``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
30
+ transformers >= 5.10; the kernel path ignores those kwargs.
31
 
32
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
33
  build, so ``transformers`` silently falls back to a slow pure-torch
 
177
  has_backward: bool = False
178
  can_torch_compile: bool = False
179
 
180
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
181
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
182
 
183
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
 
203
  has_backward: bool = False
204
  can_torch_compile: bool = False
205
 
206
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
207
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
208
  batch_size, seq_len, _ = hidden_states.shape
209
 
build/torch212-cxx11-cu130-aarch64-linux/layers.py CHANGED
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
- methods.
 
 
 
28
 
29
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
30
  build, so ``transformers`` silently falls back to a slow pure-torch
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
174
  has_backward: bool = False
175
  can_torch_compile: bool = False
176
 
177
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
178
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
179
 
180
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
200
  has_backward: bool = False
201
  can_torch_compile: bool = False
202
 
203
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
204
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
205
  batch_size, seq_len, _ = hidden_states.shape
206
 
 
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
+ methods. ``_validate_layer`` also requires the layer ``forward`` signature to
28
+ match the host's argument count exactly, so ``forward`` takes the same
29
+ ``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
30
+ transformers >= 5.10; the kernel path ignores those kwargs.
31
 
32
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
33
  build, so ``transformers`` silently falls back to a slow pure-torch
 
177
  has_backward: bool = False
178
  can_torch_compile: bool = False
179
 
180
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
181
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
182
 
183
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
 
203
  has_backward: bool = False
204
  can_torch_compile: bool = False
205
 
206
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
207
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
208
  batch_size, seq_len, _ = hidden_states.shape
209
 
build/torch212-cxx11-cu132-aarch64-linux/layers.py CHANGED
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
- methods.
 
 
 
28
 
29
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
30
  build, so ``transformers`` silently falls back to a slow pure-torch
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
174
  has_backward: bool = False
175
  can_torch_compile: bool = False
176
 
177
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
178
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
179
 
180
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
200
  has_backward: bool = False
201
  can_torch_compile: bool = False
202
 
203
- def forward(self, hidden_states, cache_params=None, attention_mask=None):
204
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
205
  batch_size, seq_len, _ = hidden_states.shape
206
 
 
24
 
25
  ``kernels`` forbids extra class members and a custom ``__init__`` on a layer
26
  (``_validate_layer``), which is why all helpers are module-level functions, not
27
+ methods. ``_validate_layer`` also requires the layer ``forward`` signature to
28
+ match the host's argument count exactly, so ``forward`` takes the same
29
+ ``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
30
+ transformers >= 5.10; the kernel path ignores those kwargs.
31
 
32
  On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
33
  build, so ``transformers`` silently falls back to a slow pure-torch
 
177
  has_backward: bool = False
178
  can_torch_compile: bool = False
179
 
180
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
181
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
182
 
183
  projected_states_qkvz = self.in_proj_qkvz(hidden_states)
 
203
  has_backward: bool = False
204
  can_torch_compile: bool = False
205
 
206
+ def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
207
  hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
208
  batch_size, seq_len, _ = hidden_states.shape
209