Michael Benayoun commited on
Commit
b444fe2
·
1 Parent(s): ab3f905
Files changed (1) hide show
  1. build/torch-neuron/__init__.py +0 -67
build/torch-neuron/__init__.py CHANGED
@@ -21,9 +21,6 @@ def rmsnorm(hidden_states, weight, eps: float = 1e-6):
21
  Returns:
22
  Normalized tensor of shape (B, S, H)
23
  """
24
- # Get input shape
25
- original_shape = hidden_states.shape
26
-
27
  num_rows = 1
28
  for r in hidden_states.shape[:-1]:
29
  num_rows *= r
@@ -79,10 +76,7 @@ def rmsnorm(hidden_states, weight, eps: float = 1e-6):
79
  # Step 6: Normalize: row * rsqrt(variance + eps)
80
  # Broadcast rms_reciprocal across hidden_dim using tensor_scalar
81
  normalized = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
82
- # rms_reciprocal_fp32 = sbuf.view(dtype=nl.float32, shape=(rows, 1))
83
- # nisa.tensor_copy(dst=rms_reciprocal_fp32, src=rms_reciprocal) # Convert to fp32 for better precision in multiplication
84
  nisa.tensor_scalar(normalized, row_tile, nl.multiply, rms_reciprocal)
85
- # nisa.tensor_tensor(normalized, row_tile, rms_reciprocal, op=nl.multiply)
86
 
87
  # Step 7: Apply weight element-wise
88
  weight_tile_rows = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
@@ -102,67 +96,6 @@ def rmsnorm(hidden_states, weight, eps: float = 1e-6):
102
 
103
  return output_flat
104
 
105
- @nki.jit(platform_target="trn2")
106
- def rmsnorm_(hidden_states, weight, eps: float = 1e-6):
107
- """
108
- Optimized NKI kernel for RMSNorm.
109
- """
110
- # 1. Calculate shapes
111
- B, S, H = hidden_states.shape
112
- num_rows = B * S
113
- hidden_dim = H
114
- max_rows = nl.tile_size.pmax # Maximum hardware partition size (usually 128)
115
-
116
- # 2. Allocate Output in HBM
117
- output_flat = nl.ndarray(shape=(num_rows, hidden_dim), dtype=hidden_states.dtype, buffer=nl.hbm)
118
-
119
- # 3. FAST WEIGHT LOADING: Load the 1D weight into SBUF exactly ONCE before the loop.
120
- weight_sbuf = nl.ndarray(shape=(1, hidden_dim), dtype=weight.dtype, buffer=nl.sbuf)
121
- nisa.dma_copy(dst=weight_sbuf, src=weight.reshape((1, hidden_dim)))
122
-
123
- # 4. Process in chunks using NKI's hardware-optimized affine_range
124
- # (Assuming num_rows is perfectly divisible by max_rows for standard tiling)
125
- print("Num rows:", num_rows, "Max rows per tile:", max_rows)
126
- for i in nl.affine_range(num_rows // max_rows):
127
-
128
- # Calculate the exact memory offset for this specific chunk
129
- offset = i * max_rows
130
-
131
- # Allocate fast on-chip memory (SBUF) for our tiles
132
- in_tile = nl.ndarray(shape=(max_rows, hidden_dim), dtype=hidden_states.dtype, buffer=nl.sbuf)
133
- out_tile = nl.ndarray(shape=(max_rows, hidden_dim), dtype=hidden_states.dtype, buffer=nl.sbuf)
134
-
135
- # DMA Load: Pull just this chunk from HBM to SBUF
136
- nisa.dma_copy(dst=in_tile, src=hidden_states.reshape((num_rows, hidden_dim))[offset : offset + max_rows, :])
137
-
138
- # Step 1: Compute x^2
139
- squared = nisa.tensor_tensor(in_tile, in_tile, op=nl.multiply)
140
-
141
- # Step 2: Sum across hidden_dim (axis 1). Results in shape (max_rows, 1)
142
- square_sum = nisa.tensor_reduce(data=squared, op=nl.add, axis=1)
143
-
144
- # Step 3 & 4: Mean and Add epsilon
145
- mean = nisa.tensor_scalar(square_sum, nl.multiply, 1.0 / hidden_dim)
146
- mean_eps = nisa.tensor_scalar(mean, nl.add, eps)
147
-
148
- # Step 5: rsqrt(mean + eps)
149
- sqrt_mean = nisa.activation(data=mean_eps, op=nl.sqrt)
150
- rms_reciprocal = nisa.reciprocal(data=sqrt_mean)
151
-
152
- # Step 6: Normalize.
153
- # The hardware automatically broadcasts the (max_rows, 1) reciprocal across the (max_rows, hidden_dim) input tile.
154
- normalized = nisa.tensor_tensor(in_tile, rms_reciprocal, op=nl.multiply)
155
-
156
- # Step 7: Apply weight.
157
- # The hardware automatically broadcasts the (1, hidden_dim) weight across the (max_rows, hidden_dim) normalized tile.
158
- nisa.tensor_tensor(dst=out_tile, data0=normalized, data1=weight_sbuf, op=nl.multiply)
159
-
160
- # DMA Store: Push the result back to HBM.
161
- # BUG FIXED: Using `offset` ensures we write to the correct block in the output tensor!
162
- nisa.dma_copy(dst=output_flat[offset : offset + max_rows, :], src=out_tile)
163
-
164
- return output_flat
165
-
166
  from . import layers
167
 
168
  __all__ = [
 
21
  Returns:
22
  Normalized tensor of shape (B, S, H)
23
  """
 
 
 
24
  num_rows = 1
25
  for r in hidden_states.shape[:-1]:
26
  num_rows *= r
 
76
  # Step 6: Normalize: row * rsqrt(variance + eps)
77
  # Broadcast rms_reciprocal across hidden_dim using tensor_scalar
78
  normalized = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
 
 
79
  nisa.tensor_scalar(normalized, row_tile, nl.multiply, rms_reciprocal)
 
80
 
81
  # Step 7: Apply weight element-wise
82
  weight_tile_rows = sbuf.view(dtype=dtype, shape=(rows, hidden_dim))
 
96
 
97
  return output_flat
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  from . import layers
100
 
101
  __all__ = [