| |
| |
|
|
|
|
| from typing import Tuple |
|
|
| import torch |
| import triton |
| import triton.language as tl |
|
|
|
|
| @triton.jit |
| def hierarchical_sae_forward_kernel( |
| loss_per_batch_ptr, |
| final_recon_ptr, |
| indices_ptr, |
| weight_ptr, |
| bias_ptr, |
| vals_ptr, |
| target_ptr, |
| B: tl.constexpr, |
| D: tl.constexpr, |
| K: tl.constexpr, |
| BLOCK_D: tl.constexpr, |
| LOOP_NUM_STAGES: tl.constexpr, |
| BLOCK_B: tl.constexpr, |
| ): |
| tl.static_assert((D % BLOCK_D) == 0) |
| tl.static_assert((B % BLOCK_B) == 0) |
| tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2") |
| tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2") |
| tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2") |
|
|
| pid_b = tl.program_id(axis=0).to(tl.int64) |
| pid_d = tl.program_id(axis=1).to(tl.int64) |
|
|
| batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) |
| batch_offsets = batch_offsets.to(tl.int64) |
| tl.multiple_of(batch_offsets, BLOCK_B) |
|
|
| offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
| offset_d = offset_d.to(tl.int64) |
|
|
| tl.multiple_of(offset_d, BLOCK_D) |
| tl.max_contiguous(offset_d, BLOCK_D) |
|
|
| batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :] |
|
|
| bias_tile = tl.load(bias_ptr + offset_d).to(tl.float32) |
|
|
| recon = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
| recon += bias_tile[None, :] |
|
|
| target = tl.load(target_ptr + batch_d_offset).to(tl.float32) |
|
|
| loss_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
|
|
| row_idx_ptr = indices_ptr + batch_offsets * K |
| row_val_ptr = vals_ptr + batch_offsets * K |
|
|
| idx = tl.load(row_idx_ptr).to(tl.int64) |
| val = tl.load(row_val_ptr).to(tl.float32) |
| val = val[:, None] |
| weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
| for t in tl.range(0, K, num_stages=LOOP_NUM_STAGES): |
| recon += weight_tile * val |
| diff = recon - target |
| loss_accum += diff * diff |
|
|
| if t + 1 < K: |
| idx_next = tl.load(row_idx_ptr + (t + 1)).to(tl.int64) |
| val_next = tl.load(row_val_ptr + (t + 1)).to(tl.float32) |
| weight_next = tl.load(weight_ptr + idx_next[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
| idx = idx_next |
| val = val_next[:, None] |
| weight_tile = weight_next |
|
|
| loss_tile = tl.sum(loss_accum, axis=1) |
| tl.atomic_add( |
| loss_per_batch_ptr + batch_offsets, |
| loss_tile, |
| sem="relaxed", |
| ) |
| tl.store( |
| final_recon_ptr + batch_d_offset, |
| recon, |
| ) |
|
|
|
|
| @triton.jit |
| def hierarchical_sae_backward_kernel( |
| weight_grad_ptr, |
| vals_grad_ptr, |
| bias_grad_ptr, |
| final_recon_ptr, |
| indices_ptr, |
| weight_ptr, |
| vals_ptr, |
| target_ptr, |
| B: tl.constexpr, |
| D: tl.constexpr, |
| K: tl.constexpr, |
| BLOCK_D: tl.constexpr, |
| LOOP_NUM_STAGES: tl.constexpr, |
| BLOCK_B: tl.constexpr, |
| ): |
| tl.static_assert((D % BLOCK_D) == 0) |
| tl.static_assert((B % BLOCK_B) == 0) |
| tl.static_assert((K & (K - 1)) == 0, f"{K=} must be a power of 2") |
| tl.static_assert((BLOCK_D & (BLOCK_D - 1)) == 0, f"{BLOCK_D=} must be a power of 2") |
| tl.static_assert((BLOCK_B & (BLOCK_B - 1)) == 0, f"{BLOCK_B=} must be a power of 2") |
|
|
| pid_b = tl.program_id(axis=0).to(tl.int64) |
| pid_d = tl.program_id(axis=1).to(tl.int64) |
|
|
| batch_offsets = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) |
| batch_offsets = batch_offsets.to(tl.int64) |
| tl.multiple_of(batch_offsets, BLOCK_B) |
|
|
| offset_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) |
| offset_d = offset_d.to(tl.int64) |
|
|
| tl.multiple_of(offset_d, BLOCK_D) |
| tl.max_contiguous(offset_d, BLOCK_D) |
|
|
| batch_d_offset = batch_offsets[:, None] * D + offset_d[None, :] |
|
|
| recon = tl.load(final_recon_ptr + batch_d_offset).to(tl.float32) |
| target = tl.load(target_ptr + batch_d_offset).to(tl.float32) |
|
|
| suffix = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
| bias_accum = tl.zeros([BLOCK_B, BLOCK_D], dtype=tl.float32) |
| scale = tl.full((), 2.0 / (B * K * D), dtype=tl.float32) |
|
|
| row_idx_ptr = indices_ptr + batch_offsets * K |
| row_val_ptr = vals_ptr + batch_offsets * K |
| k_offsets = tl.arange(0, K) |
| val_grad_tile = tl.zeros([BLOCK_B, K], dtype=tl.float32) |
|
|
| step = K - 1 |
| idx = tl.load(row_idx_ptr + step).to(tl.int64) |
| val = tl.load(row_val_ptr + step).to(tl.float32) |
| weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
| for _ in tl.range(0, K, num_stages=LOOP_NUM_STAGES): |
| curr_step = step |
|
|
| diff = recon - target |
| grad_curr = diff * scale |
| suffix += grad_curr |
| bias_accum += grad_curr |
|
|
| val_broadcast = val[:, None] |
| contrib = suffix * val_broadcast |
| tl.atomic_add( |
| weight_grad_ptr + idx[:, None] * D + offset_d[None, :], |
| contrib, |
| sem="relaxed", |
| ) |
|
|
| dot_partial = tl.sum(weight_tile * suffix, axis=1) |
| mask_curr = k_offsets[None, :] == curr_step |
| val_grad_tile = tl.where(mask_curr, dot_partial[:, None], val_grad_tile) |
|
|
| recon -= weight_tile * val_broadcast |
|
|
| if curr_step > 0: |
| step = curr_step - 1 |
| idx = tl.load(row_idx_ptr + step).to(tl.int64) |
| val = tl.load(row_val_ptr + step).to(tl.float32) |
| weight_tile = tl.load(weight_ptr + idx[:, None] * D + offset_d[None, :]).to(tl.float32) |
|
|
| bias_grad_tile = tl.sum(bias_accum, axis=0) |
| tl.atomic_add( |
| bias_grad_ptr + offset_d, |
| bias_grad_tile, |
| sem="relaxed", |
| ) |
|
|
| row_val_grad_ptr = vals_grad_ptr + batch_offsets[:, None] * K + k_offsets[None, :] |
| tl.atomic_add( |
| row_val_grad_ptr, |
| val_grad_tile, |
| sem="relaxed", |
| ) |
|
|
|
|
| def _hierarchical_sae_forward( |
| indices: torch.Tensor, |
| weight: torch.Tensor, |
| vals: torch.Tensor, |
| bias: torch.Tensor, |
| target: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| B, K = indices.shape |
| F, D = weight.shape |
|
|
| loss_per_batch = torch.zeros((B,), dtype=torch.float32, device=weight.device) |
| final_recon = torch.empty((B, D), dtype=torch.float32, device=weight.device) |
|
|
| def _forward_grid(meta): |
| return ( |
| B // meta["BLOCK_B"], |
| D // meta["BLOCK_D"], |
| ) |
|
|
| hierarchical_sae_forward_kernel[_forward_grid]( |
| loss_per_batch, |
| final_recon, |
| indices, |
| weight, |
| bias, |
| vals, |
| target, |
| B=B, |
| D=D, |
| K=K, |
| BLOCK_D=64, |
| LOOP_NUM_STAGES=4, |
| BLOCK_B=1, |
| num_warps=2, |
| num_stages=2, |
| ) |
| loss = loss_per_batch.sum() / (B * K * D) |
| return loss, final_recon |
|
|
|
|
| def _hierarchical_sae_backward( |
| indices: torch.Tensor, |
| weight: torch.Tensor, |
| vals: torch.Tensor, |
| target: torch.Tensor, |
| final_recon: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| device = weight.device |
| B, K = indices.shape |
| F, D = weight.shape |
|
|
| dW = torch.zeros((F, D), dtype=torch.float32, device=device) |
| dVals = torch.zeros((B, K), dtype=torch.float32, device=device) |
| db = torch.zeros((D,), dtype=torch.float32, device=device) |
|
|
| def _backward_grid(meta): |
| return ( |
| B // meta["BLOCK_B"], |
| D // meta["BLOCK_D"], |
| ) |
|
|
| hierarchical_sae_backward_kernel[_backward_grid]( |
| dW, |
| dVals, |
| db, |
| final_recon, |
| indices, |
| weight, |
| vals, |
| target, |
| B=B, |
| D=D, |
| K=K, |
| BLOCK_D=32, |
| LOOP_NUM_STAGES=16, |
| BLOCK_B=16, |
| num_warps=8, |
| num_stages=8, |
| ) |
|
|
| return dW, dVals, db |
|
|
|
|
| class HierarchicalSAELossFunction(torch.autograd.Function): |
| @staticmethod |
| @torch.amp.custom_fwd(device_type="cuda") |
| def forward( |
| ctx, |
| indices: torch.Tensor, |
| weight: torch.Tensor, |
| vals: torch.Tensor, |
| bias: torch.Tensor, |
| target: torch.Tensor, |
| ): |
| loss, final_recon = _hierarchical_sae_forward(indices, weight, vals, bias, target) |
| ctx.save_for_backward(indices, weight, vals, target, final_recon) |
| return loss |
|
|
| @staticmethod |
| @torch.amp.custom_bwd(device_type="cuda") |
| def backward(ctx, grad): |
| indices, weight, vals, target, final_recon = ctx.saved_tensors |
| dW, dVals, db = _hierarchical_sae_backward(indices, weight, vals, target, final_recon) |
|
|
| if grad is not None: |
| dW.mul_(grad) |
| dVals.mul_(grad) |
| db.mul_(grad) |
|
|
| return None, dW, dVals, db, None |
|
|
|
|
| def triton_hierarchical_sae_loss( |
| indices: torch.Tensor, |
| weight: torch.Tensor, |
| vals: torch.Tensor, |
| bias: torch.Tensor, |
| target: torch.Tensor, |
| ) -> torch.Tensor: |
| return HierarchicalSAELossFunction.apply(indices, weight, vals, bias, target) |
|
|
|
|
| def hierarchical_sae_loss( |
| indices: torch.Tensor, |
| weight: torch.Tensor, |
| vals: torch.Tensor, |
| bias: torch.Tensor, |
| target: torch.Tensor, |
| ) -> torch.Tensor: |
| emb = weight[indices].to(torch.float32) |
| recon_cum = bias.to(torch.float32) + (emb * vals.unsqueeze(-1)).cumsum(dim=1) |
| diff = recon_cum.to(torch.float32) - target.to(torch.float32).unsqueeze(1) |
| loss = diff.pow(2).mean() |
| return loss |
|
|