Michael Benayoun commited on
Commit ·
b444fe2
1
Parent(s): ab3f905
debug
Browse files
build/torch-neuron/__init__.py
CHANGED
|
@@ -21,9 +21,6 @@ def rmsnorm(hidden_states, weight, eps: float = 1e-6):
|
|
| 21 |
Returns:
|
| 22 |
Normalized tensor of shape (B, S, H)
|
| 23 |
"""
|
| 24 |
-
# Get input shape
|
| 25 |
-
original_shape = hidden_states.shape
|
| 26 |
-
|
| 27 |
num_rows = 1
|
| 28 |
for r in hidden_states.shape[:-1]:
|
| 29 |
num_rows *= r
|
|
@@ -79,10 +76,7 @@ def rmsnorm(hidden_states, weight, eps: float = 1e-6):
|
|
| 79 |
# Step 6: Normalize: row * rsqrt(variance + eps)
|
| 80 |
# Broadcast rms_reciprocal across hidden_dim using tensor_scalar
|
| 81 |
normalized = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
|
| 82 |
-
# rms_reciprocal_fp32 = sbuf.view(dtype=nl.float32, shape=(rows, 1))
|
| 83 |
-
# nisa.tensor_copy(dst=rms_reciprocal_fp32, src=rms_reciprocal) # Convert to fp32 for better precision in multiplication
|
| 84 |
nisa.tensor_scalar(normalized, row_tile, nl.multiply, rms_reciprocal)
|
| 85 |
-
# nisa.tensor_tensor(normalized, row_tile, rms_reciprocal, op=nl.multiply)
|
| 86 |
|
| 87 |
# Step 7: Apply weight element-wise
|
| 88 |
weight_tile_rows = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
|
|
@@ -102,67 +96,6 @@ def rmsnorm(hidden_states, weight, eps: float = 1e-6):
|
|
| 102 |
|
| 103 |
return output_flat
|
| 104 |
|
| 105 |
-
@nki.jit(platform_target="trn2")
|
| 106 |
-
def rmsnorm_(hidden_states, weight, eps: float = 1e-6):
|
| 107 |
-
"""
|
| 108 |
-
Optimized NKI kernel for RMSNorm.
|
| 109 |
-
"""
|
| 110 |
-
# 1. Calculate shapes
|
| 111 |
-
B, S, H = hidden_states.shape
|
| 112 |
-
num_rows = B * S
|
| 113 |
-
hidden_dim = H
|
| 114 |
-
max_rows = nl.tile_size.pmax # Maximum hardware partition size (usually 128)
|
| 115 |
-
|
| 116 |
-
# 2. Allocate Output in HBM
|
| 117 |
-
output_flat = nl.ndarray(shape=(num_rows, hidden_dim), dtype=hidden_states.dtype, buffer=nl.hbm)
|
| 118 |
-
|
| 119 |
-
# 3. FAST WEIGHT LOADING: Load the 1D weight into SBUF exactly ONCE before the loop.
|
| 120 |
-
weight_sbuf = nl.ndarray(shape=(1, hidden_dim), dtype=weight.dtype, buffer=nl.sbuf)
|
| 121 |
-
nisa.dma_copy(dst=weight_sbuf, src=weight.reshape((1, hidden_dim)))
|
| 122 |
-
|
| 123 |
-
# 4. Process in chunks using NKI's hardware-optimized affine_range
|
| 124 |
-
# (Assuming num_rows is perfectly divisible by max_rows for standard tiling)
|
| 125 |
-
print("Num rows:", num_rows, "Max rows per tile:", max_rows)
|
| 126 |
-
for i in nl.affine_range(num_rows // max_rows):
|
| 127 |
-
|
| 128 |
-
# Calculate the exact memory offset for this specific chunk
|
| 129 |
-
offset = i * max_rows
|
| 130 |
-
|
| 131 |
-
# Allocate fast on-chip memory (SBUF) for our tiles
|
| 132 |
-
in_tile = nl.ndarray(shape=(max_rows, hidden_dim), dtype=hidden_states.dtype, buffer=nl.sbuf)
|
| 133 |
-
out_tile = nl.ndarray(shape=(max_rows, hidden_dim), dtype=hidden_states.dtype, buffer=nl.sbuf)
|
| 134 |
-
|
| 135 |
-
# DMA Load: Pull just this chunk from HBM to SBUF
|
| 136 |
-
nisa.dma_copy(dst=in_tile, src=hidden_states.reshape((num_rows, hidden_dim))[offset : offset + max_rows, :])
|
| 137 |
-
|
| 138 |
-
# Step 1: Compute x^2
|
| 139 |
-
squared = nisa.tensor_tensor(in_tile, in_tile, op=nl.multiply)
|
| 140 |
-
|
| 141 |
-
# Step 2: Sum across hidden_dim (axis 1). Results in shape (max_rows, 1)
|
| 142 |
-
square_sum = nisa.tensor_reduce(data=squared, op=nl.add, axis=1)
|
| 143 |
-
|
| 144 |
-
# Step 3 & 4: Mean and Add epsilon
|
| 145 |
-
mean = nisa.tensor_scalar(square_sum, nl.multiply, 1.0 / hidden_dim)
|
| 146 |
-
mean_eps = nisa.tensor_scalar(mean, nl.add, eps)
|
| 147 |
-
|
| 148 |
-
# Step 5: rsqrt(mean + eps)
|
| 149 |
-
sqrt_mean = nisa.activation(data=mean_eps, op=nl.sqrt)
|
| 150 |
-
rms_reciprocal = nisa.reciprocal(data=sqrt_mean)
|
| 151 |
-
|
| 152 |
-
# Step 6: Normalize.
|
| 153 |
-
# The hardware automatically broadcasts the (max_rows, 1) reciprocal across the (max_rows, hidden_dim) input tile.
|
| 154 |
-
normalized = nisa.tensor_tensor(in_tile, rms_reciprocal, op=nl.multiply)
|
| 155 |
-
|
| 156 |
-
# Step 7: Apply weight.
|
| 157 |
-
# The hardware automatically broadcasts the (1, hidden_dim) weight across the (max_rows, hidden_dim) normalized tile.
|
| 158 |
-
nisa.tensor_tensor(dst=out_tile, data0=normalized, data1=weight_sbuf, op=nl.multiply)
|
| 159 |
-
|
| 160 |
-
# DMA Store: Push the result back to HBM.
|
| 161 |
-
# BUG FIXED: Using `offset` ensures we write to the correct block in the output tensor!
|
| 162 |
-
nisa.dma_copy(dst=output_flat[offset : offset + max_rows, :], src=out_tile)
|
| 163 |
-
|
| 164 |
-
return output_flat
|
| 165 |
-
|
| 166 |
from . import layers
|
| 167 |
|
| 168 |
__all__ = [
|
|
|
|
| 21 |
Returns:
|
| 22 |
Normalized tensor of shape (B, S, H)
|
| 23 |
"""
|
|
|
|
|
|
|
|
|
|
| 24 |
num_rows = 1
|
| 25 |
for r in hidden_states.shape[:-1]:
|
| 26 |
num_rows *= r
|
|
|
|
| 76 |
# Step 6: Normalize: row * rsqrt(variance + eps)
|
| 77 |
# Broadcast rms_reciprocal across hidden_dim using tensor_scalar
|
| 78 |
normalized = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
|
|
|
|
|
|
|
| 79 |
nisa.tensor_scalar(normalized, row_tile, nl.multiply, rms_reciprocal)
|
|
|
|
| 80 |
|
| 81 |
# Step 7: Apply weight element-wise
|
| 82 |
weight_tile_rows = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
|
|
|
|
| 96 |
|
| 97 |
return output_flat
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
from . import layers
|
| 100 |
|
| 101 |
__all__ = [
|