gdn: add layers.py + expose layers ns (Qwen3.6 Qwen3_5{,Moe}GatedDeltaNet + Qwen3-Next GatedDeltaNet) for kernelize()
Browse files
build/torch210-cxx11-cu130-aarch64-linux/layers.py
CHANGED
|
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
-
methods.
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 30 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
|
|
| 174 |
has_backward: bool = False
|
| 175 |
can_torch_compile: bool = False
|
| 176 |
|
| 177 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 178 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 179 |
|
| 180 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
|
|
| 200 |
has_backward: bool = False
|
| 201 |
can_torch_compile: bool = False
|
| 202 |
|
| 203 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 204 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 205 |
batch_size, seq_len, _ = hidden_states.shape
|
| 206 |
|
|
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
+
methods. ``_validate_layer`` also requires the layer ``forward`` signature to
|
| 28 |
+
match the host's argument count exactly, so ``forward`` takes the same
|
| 29 |
+
``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
|
| 30 |
+
transformers >= 5.10; the kernel path ignores those kwargs.
|
| 31 |
|
| 32 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 33 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
|
|
| 177 |
has_backward: bool = False
|
| 178 |
can_torch_compile: bool = False
|
| 179 |
|
| 180 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 181 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 182 |
|
| 183 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
|
|
| 203 |
has_backward: bool = False
|
| 204 |
can_torch_compile: bool = False
|
| 205 |
|
| 206 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 207 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 208 |
batch_size, seq_len, _ = hidden_states.shape
|
| 209 |
|
build/torch211-cxx11-cu130-aarch64-linux/layers.py
CHANGED
|
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
-
methods.
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 30 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
|
|
| 174 |
has_backward: bool = False
|
| 175 |
can_torch_compile: bool = False
|
| 176 |
|
| 177 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 178 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 179 |
|
| 180 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
|
|
| 200 |
has_backward: bool = False
|
| 201 |
can_torch_compile: bool = False
|
| 202 |
|
| 203 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 204 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 205 |
batch_size, seq_len, _ = hidden_states.shape
|
| 206 |
|
|
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
+
methods. ``_validate_layer`` also requires the layer ``forward`` signature to
|
| 28 |
+
match the host's argument count exactly, so ``forward`` takes the same
|
| 29 |
+
``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
|
| 30 |
+
transformers >= 5.10; the kernel path ignores those kwargs.
|
| 31 |
|
| 32 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 33 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
|
|
| 177 |
has_backward: bool = False
|
| 178 |
can_torch_compile: bool = False
|
| 179 |
|
| 180 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 181 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 182 |
|
| 183 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
|
|
| 203 |
has_backward: bool = False
|
| 204 |
can_torch_compile: bool = False
|
| 205 |
|
| 206 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 207 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 208 |
batch_size, seq_len, _ = hidden_states.shape
|
| 209 |
|
build/torch212-cxx11-cu130-aarch64-linux/layers.py
CHANGED
|
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
-
methods.
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 30 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
|
|
| 174 |
has_backward: bool = False
|
| 175 |
can_torch_compile: bool = False
|
| 176 |
|
| 177 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 178 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 179 |
|
| 180 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
|
|
| 200 |
has_backward: bool = False
|
| 201 |
can_torch_compile: bool = False
|
| 202 |
|
| 203 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 204 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 205 |
batch_size, seq_len, _ = hidden_states.shape
|
| 206 |
|
|
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
+
methods. ``_validate_layer`` also requires the layer ``forward`` signature to
|
| 28 |
+
match the host's argument count exactly, so ``forward`` takes the same
|
| 29 |
+
``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
|
| 30 |
+
transformers >= 5.10; the kernel path ignores those kwargs.
|
| 31 |
|
| 32 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 33 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
|
|
| 177 |
has_backward: bool = False
|
| 178 |
can_torch_compile: bool = False
|
| 179 |
|
| 180 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 181 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 182 |
|
| 183 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
|
|
| 203 |
has_backward: bool = False
|
| 204 |
can_torch_compile: bool = False
|
| 205 |
|
| 206 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 207 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 208 |
batch_size, seq_len, _ = hidden_states.shape
|
| 209 |
|
build/torch212-cxx11-cu132-aarch64-linux/layers.py
CHANGED
|
@@ -24,7 +24,10 @@ input-projection layout, so the shared core lives in the module-level
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
-
methods.
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 30 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
@@ -174,7 +177,7 @@ class GatedDeltaNet(nn.Module):
|
|
| 174 |
has_backward: bool = False
|
| 175 |
can_torch_compile: bool = False
|
| 176 |
|
| 177 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 178 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 179 |
|
| 180 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
@@ -200,7 +203,7 @@ class Qwen3_5GatedDeltaNet(nn.Module):
|
|
| 200 |
has_backward: bool = False
|
| 201 |
can_torch_compile: bool = False
|
| 202 |
|
| 203 |
-
def forward(self, hidden_states, cache_params=None, attention_mask=None):
|
| 204 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 205 |
batch_size, seq_len, _ = hidden_states.shape
|
| 206 |
|
|
|
|
| 24 |
|
| 25 |
``kernels`` forbids extra class members and a custom ``__init__`` on a layer
|
| 26 |
(``_validate_layer``), which is why all helpers are module-level functions, not
|
| 27 |
+
methods. ``_validate_layer`` also requires the layer ``forward`` signature to
|
| 28 |
+
match the host's argument count exactly, so ``forward`` takes the same
|
| 29 |
+
``**kwargs`` (``Unpack[TransformersKwargs]``) the host GDN layers carry in
|
| 30 |
+
transformers >= 5.10; the kernel path ignores those kwargs.
|
| 31 |
|
| 32 |
On a DGX Spark the upstream ``fla`` / ``causal_conv1d`` fast paths have no SM121
|
| 33 |
build, so ``transformers`` silently falls back to a slow pure-torch
|
|
|
|
| 177 |
has_backward: bool = False
|
| 178 |
can_torch_compile: bool = False
|
| 179 |
|
| 180 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 181 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 182 |
|
| 183 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states)
|
|
|
|
| 203 |
has_backward: bool = False
|
| 204 |
can_torch_compile: bool = False
|
| 205 |
|
| 206 |
+
def forward(self, hidden_states, cache_params=None, attention_mask=None, **kwargs):
|
| 207 |
hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 208 |
batch_size, seq_len, _ = hidden_states.shape
|
| 209 |
|