finetune-demo-lora / tests /e2e /kernels /test_geglu.py
rayraycano's picture
Training in progress, step 20
fcca8c8 verified
"""Tests for GEGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
from axolotl.kernels.geglu import geglu_backward, geglu_forward
def test_geglu_forward_shape():
"""Test that GEGLU forward pass preserves expected shapes."""
batch, seq_len, hidden_dim = 2, 3, 64
gate = torch.randn(batch, seq_len, hidden_dim, device="cuda")
up = torch.randn(batch, seq_len, hidden_dim, device="cuda")
out = geglu_forward(gate, up)
assert out.shape == (batch, seq_len, hidden_dim)
assert out.dtype == gate.dtype
assert out.device == gate.device
def test_geglu_forward_values():
"""Test GEGLU forward pass matches PyTorch reference implementation."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
# Custom implementation
triton_out = geglu_forward(gate.clone(), up.clone())
# PyTorch reference
torch_out = F.gelu(gate) * up
assert torch.allclose(triton_out, torch_out, rtol=1e-3)
def test_geglu_backward():
"""Test GEGLU backward pass matches PyTorch autograd."""
gate = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
up = torch.randn(2, 3, 64, device="cuda", requires_grad=True)
grad_output = torch.randn(2, 3, 64, device="cuda")
# PyTorch reference - compute intermediates
gelu_gate = F.gelu(gate)
torch_out = gelu_gate * up
torch_out.backward(grad_output)
# Custom backward pass
gate_clone = gate.clone().detach()
up_clone = up.clone().detach()
grad_output_clone = grad_output.clone()
h, grad_gate, grad_up = geglu_backward(grad_output_clone, gate_clone, up_clone)
# Compare outputs and gradients
assert torch.allclose(h, torch_out, rtol=1e-3)
assert torch.allclose(grad_gate, gate.grad, rtol=1e-3)
assert torch.allclose(grad_up, up.grad, rtol=1e-3)
def test_geglu_inplace_preservation():
"""Test that GEGLU backward doesn't modify original tensors unexpectedly."""
gate = torch.randn(2, 3, 64, device="cuda")
up = torch.randn(2, 3, 64, device="cuda")
grad_output = torch.randn(2, 3, 64, device="cuda")
gate_copy = gate.clone()
up_copy = up.clone()
grad_copy = grad_output.clone()
geglu_backward(grad_output, gate, up)
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"