| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | from functools import partial |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import triton |
| | import triton.language as tl |
| | from einops import rearrange |
| | from torch.distributed import DeviceMesh |
| | from torch.distributed.tensor import DTensor, Replicate, Shard, distribute_module |
| | from torch.distributed.tensor.parallel import ParallelStyle |
| |
|
| | from fla.utils import get_multiprocessor_count, input_guard |
| |
|
| |
|
| | def layer_norm_ref( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | upcast: bool = False |
| | ): |
| | dtype = x.dtype |
| | if upcast: |
| | weight = weight.float() |
| | bias = bias.float() if bias is not None else None |
| | if upcast: |
| | x = x.float() |
| | residual = residual.float() if residual is not None else residual |
| | if residual is not None: |
| | x = (x + residual).to(x.dtype) |
| | out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to( |
| | dtype |
| | ) |
| | return out if not prenorm else (out, x) |
| |
|
| |
|
| | def rms_norm_ref( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | upcast: bool = False |
| | ): |
| | dtype = x.dtype |
| | if upcast: |
| | weight = weight.float() |
| | bias = bias.float() if bias is not None else None |
| | if upcast: |
| | x = x.float() |
| | residual = residual.float() if residual is not None else residual |
| | if residual is not None: |
| | x = (x + residual).to(x.dtype) |
| | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) |
| | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) |
| | out = out.to(dtype) |
| | return out if not prenorm else (out, x) |
| |
|
| |
|
| | def group_norm_ref( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | num_groups: int, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | is_rms_norm: bool = False, |
| | prenorm: bool = False, |
| | upcast: bool = False |
| | ): |
| | dtype = x.dtype |
| | if upcast: |
| | weight = weight.float() |
| | bias = bias.float() if bias is not None else None |
| | if upcast: |
| | x = x.float() |
| | residual = residual.float() if residual is not None else residual |
| | if residual is not None: |
| | x = (x + residual).to(x.dtype) |
| | residual = x |
| | x, weight = [ |
| | rearrange(data, "... (g d) -> ... g d", g=num_groups) for data in (x, weight) |
| | ] |
| | if bias is not None: |
| | bias = rearrange(bias, '... (g d) -> ... g d', g=num_groups) |
| | if not is_rms_norm: |
| | mean = x.mean(dim=-1, keepdim=True) |
| | x = x - mean |
| | rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) |
| | out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight) |
| | out = rearrange(out, "... g d -> ... (g d)") |
| | out = out.to(dtype) |
| | return out if not prenorm else (out, residual) |
| |
|
| |
|
| | class GroupNormRef(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | num_groups: int, |
| | hidden_size: int, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5, |
| | is_rms_norm: bool = False |
| | ) -> GroupNormRef: |
| | super().__init__() |
| |
|
| | if hidden_size % num_groups != 0: |
| | raise ValueError('num_channels must be divisible by num_groups') |
| |
|
| | self.num_groups = num_groups |
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| | self.is_rms_norm = is_rms_norm |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | if self.is_rms_norm: |
| | s += f", is_rms_norm={self.is_rms_norm}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, residual=None, prenorm=False): |
| | return group_norm_ref( |
| | x, |
| | self.weight, |
| | self.bias, |
| | num_groups=self.num_groups, |
| | residual=residual, |
| | eps=self.eps, |
| | is_rms_norm=self.is_rms_norm, |
| | prenorm=prenorm, |
| | upcast=True |
| | ) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) |
| | for BT in [8, 16, 32, 64, 128] |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | for num_stages in [2, 3, 4] |
| | ], |
| | key=['D', 'NB', 'HAS_RESIDUAL', 'STORE_RESIDUAL_OUT', 'IS_RMS_NORM'], |
| | ) |
| | @triton.jit |
| | def layer_norm_fwd_kernel( |
| | x, |
| | y, |
| | w, |
| | b, |
| | res, |
| | res_out, |
| | mean, |
| | rstd, |
| | eps, |
| | T, |
| | G: tl.constexpr, |
| | D: tl.constexpr, |
| | BT: tl.constexpr, |
| | BD: tl.constexpr, |
| | NB: tl.constexpr, |
| | IS_RMS_NORM: tl.constexpr, |
| | HAS_RESIDUAL: tl.constexpr, |
| | STORE_RESIDUAL_OUT: tl.constexpr, |
| | HAS_WEIGHT: tl.constexpr, |
| | HAS_BIAS: tl.constexpr |
| | ): |
| | i_t = tl.program_id(0) |
| |
|
| | o_t = i_t * BT + tl.arange(0, BT) |
| | o_g = o_t % G |
| | o_d = tl.arange(0, BD) |
| | m_d = o_d < D |
| |
|
| | p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) |
| | if HAS_RESIDUAL: |
| | p_res = tl.make_block_ptr(res, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32) |
| | if STORE_RESIDUAL_OUT: |
| | p_res_out = tl.make_block_ptr(res_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1)) |
| | if not IS_RMS_NORM: |
| | b_mean = tl.sum(b_x, axis=1) / D |
| | p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| | tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,)) |
| | b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0) |
| | b_var = tl.sum(b_xbar * b_xbar, axis=1) / D |
| | else: |
| | b_xbar = tl.where(m_d[None, :], b_x, 0.0) |
| | b_var = tl.sum(b_xbar * b_xbar, axis=1) / D |
| | b_rstd = 1 / tl.sqrt(b_var + eps) |
| |
|
| | p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
| | tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) |
| |
|
| | if HAS_WEIGHT: |
| | b_w = tl.load(w + o_g[:, None] * D + o_d[None, :], mask=m_d[None, :]).to(tl.float32) |
| | if HAS_BIAS: |
| | b_b = tl.load(b + o_g[:, None] * D + o_d[None, :], mask=m_d[None, :]).to(tl.float32) |
| | b_x_hat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] |
| | b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat |
| | if HAS_BIAS: |
| | b_y = b_y + b_b |
| |
|
| | |
| | p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) |
| | tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| |
|
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | for num_stages in [2, 3, 4] |
| | ], |
| | key=['D', 'HAS_RESIDUAL', 'STORE_RESIDUAL_OUT', 'IS_RMS_NORM'], |
| | ) |
| | @triton.jit |
| | def layer_norm_fwd_kernel1( |
| | x, |
| | y, |
| | w, |
| | b, |
| | res, |
| | res_out, |
| | mean, |
| | rstd, |
| | eps, |
| | G: tl.constexpr, |
| | D: tl.constexpr, |
| | BD: tl.constexpr, |
| | IS_RMS_NORM: tl.constexpr, |
| | HAS_RESIDUAL: tl.constexpr, |
| | STORE_RESIDUAL_OUT: tl.constexpr, |
| | HAS_WEIGHT: tl.constexpr, |
| | HAS_BIAS: tl.constexpr |
| | ): |
| | i_t = tl.program_id(0) |
| | i_g = i_t % G |
| |
|
| | x += i_t * D |
| | y += i_t * D |
| | if HAS_RESIDUAL: |
| | res += i_t * D |
| | if STORE_RESIDUAL_OUT: |
| | res_out += i_t * D |
| |
|
| | o_d = tl.arange(0, BD) |
| | m_d = o_d < D |
| | b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32) |
| | if HAS_RESIDUAL: |
| | b_x += tl.load(res + o_d, mask=m_d, other=0.0).to(tl.float32) |
| | if STORE_RESIDUAL_OUT: |
| | tl.store(res_out + o_d, b_x, mask=m_d) |
| | if not IS_RMS_NORM: |
| | b_mean = tl.sum(b_x, axis=0) / D |
| | tl.store(mean + i_t, b_mean) |
| | b_xbar = tl.where(m_d, b_x - b_mean, 0.0) |
| | b_var = tl.sum(b_xbar * b_xbar, axis=0) / D |
| | else: |
| | b_xbar = tl.where(m_d, b_x, 0.0) |
| | b_var = tl.sum(b_xbar * b_xbar, axis=0) / D |
| | b_rstd = 1 / tl.sqrt(b_var + eps) |
| | tl.store(rstd + i_t, b_rstd) |
| |
|
| | if HAS_WEIGHT: |
| | b_w = tl.load(w + i_g * D + o_d, mask=m_d).to(tl.float32) |
| | if HAS_BIAS: |
| | b_b = tl.load(b + i_g * D + o_d, mask=m_d).to(tl.float32) |
| | b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd |
| | b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat |
| | if HAS_BIAS: |
| | b_y = b_y + b_b |
| |
|
| | |
| | tl.store(y + o_d, b_y, mask=m_d) |
| |
|
| |
|
| | @triton.heuristics({ |
| | 'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None |
| | }) |
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({'BT': BT}, num_warps=num_warps, num_stages=num_stages) |
| | for BT in [8, 16, 32, 64] |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | for num_stages in [2, 3, 4] |
| | ], |
| | key=['D', 'NB', 'HAS_DRESIDUAL', 'STORE_DRESIDUAL', 'IS_RMS_NORM'], |
| | ) |
| | @triton.jit |
| | def layer_norm_bwd_kernel( |
| | x, |
| | w, |
| | b, |
| | y, |
| | dy, |
| | dx, |
| | dw, |
| | db, |
| | dres, |
| | dres_in, |
| | mean, |
| | rstd, |
| | T, |
| | G: tl.constexpr, |
| | D: tl.constexpr, |
| | BS: tl.constexpr, |
| | BT: tl.constexpr, |
| | BD: tl.constexpr, |
| | NB: tl.constexpr, |
| | GS: tl.constexpr, |
| | IS_RMS_NORM: tl.constexpr, |
| | HAS_DRESIDUAL: tl.constexpr, |
| | STORE_DRESIDUAL: tl.constexpr, |
| | HAS_WEIGHT: tl.constexpr, |
| | HAS_BIAS: tl.constexpr, |
| | RECOMPUTE_OUTPUT: tl.constexpr, |
| | ): |
| | i_s = tl.program_id(0) |
| | i_g, i_sg = i_s // GS, i_s % GS |
| |
|
| | o_d = tl.arange(0, BD) |
| | m_d = o_d < D |
| | if HAS_WEIGHT: |
| | b_w = tl.load(w + i_g * D + o_d, mask=m_d).to(tl.float32) |
| | b_dw = tl.zeros((BT, BD), dtype=tl.float32) |
| | if HAS_BIAS: |
| | b_b = tl.load(b + i_g * D + o_d, mask=m_d, other=0.0).to(tl.float32) |
| | b_db = tl.zeros((BT, BD), dtype=tl.float32) |
| |
|
| | T = min(i_sg * BS + BS, T // G) |
| | for i_t in range(i_sg * BS, T, BT): |
| | p_x = tl.make_block_ptr(x + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) |
| | p_dy = tl.make_block_ptr(dy + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) |
| | p_dx = tl.make_block_ptr(dx + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) |
| | |
| | b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) |
| | b_dy = tl.load(p_dy, boundary_check=(0, 1)).to(tl.float32) |
| |
|
| | if not IS_RMS_NORM: |
| | p_mean = tl.make_block_ptr(mean + i_g, (T,), (G,), (i_t,), (BT,), (0,)) |
| | b_mean = tl.load(p_mean, boundary_check=(0,)) |
| | p_rstd = tl.make_block_ptr(rstd + i_g, (T,), (G,), (i_t,), (BT,), (0,)) |
| | b_rstd = tl.load(p_rstd, boundary_check=(0,)) |
| | |
| | b_xhat = (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] |
| | b_xhat = tl.where(m_d[None, :], b_xhat, 0.0) |
| |
|
| | b_y = b_xhat * b_w[None, :] if HAS_WEIGHT else b_xhat |
| | if HAS_BIAS: |
| | b_y = b_y + b_b[None, :] |
| | if RECOMPUTE_OUTPUT: |
| | p_y = tl.make_block_ptr(y + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) |
| | tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| | b_wdy = b_dy |
| |
|
| | if HAS_WEIGHT or HAS_BIAS: |
| | m_t = (i_t + tl.arange(0, BT)) < T |
| | if HAS_WEIGHT: |
| | b_wdy = b_dy * b_w |
| | b_dw += tl.where(m_t[:, None], b_dy * b_xhat, 0.0) |
| | if HAS_BIAS: |
| | b_db += tl.where(m_t[:, None], b_dy, 0.0) |
| | if not IS_RMS_NORM: |
| | b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D |
| | b_c2 = tl.sum(b_wdy, axis=1) / D |
| | b_dx = (b_wdy - (b_xhat * b_c1[:, None] + b_c2[:, None])) * b_rstd[:, None] |
| | else: |
| | b_c1 = tl.sum(b_xhat * b_wdy, axis=1) / D |
| | b_dx = (b_wdy - b_xhat * b_c1[:, None]) * b_rstd[:, None] |
| | if HAS_DRESIDUAL: |
| | p_dres = tl.make_block_ptr(dres + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) |
| | b_dres = tl.load(p_dres, boundary_check=(0, 1)).to(tl.float32) |
| | b_dx += b_dres |
| | |
| | if STORE_DRESIDUAL: |
| | p_dres_in = tl.make_block_ptr(dres_in + i_g * D, (T, D), (G*D, 1), (i_t, 0), (BT, BD), (1, 0)) |
| | tl.store(p_dres_in, b_dx.to(p_dres_in.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| | tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) |
| |
|
| | if HAS_WEIGHT: |
| | tl.store(dw + i_s * D + o_d, tl.sum(b_dw, axis=0), mask=m_d) |
| | if HAS_BIAS: |
| | tl.store(db + i_s * D + o_d, tl.sum(b_db, axis=0), mask=m_d) |
| |
|
| |
|
| | @triton.heuristics({ |
| | 'RECOMPUTE_OUTPUT': lambda args: args['y'] is not None |
| | }) |
| | @triton.autotune( |
| | configs=[ |
| | triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
| | for num_warps in [1, 2, 4, 8, 16, 32] |
| | for num_stages in [2, 3, 4] |
| | ], |
| | key=['D', 'HAS_DRESIDUAL', 'STORE_DRESIDUAL', 'IS_RMS_NORM'], |
| | ) |
| | @triton.jit |
| | def layer_norm_bwd_kernel1( |
| | x, |
| | w, |
| | b, |
| | y, |
| | dy, |
| | dx, |
| | dw, |
| | db, |
| | dres, |
| | dres_in, |
| | mean, |
| | rstd, |
| | T, |
| | G: tl.constexpr, |
| | D: tl.constexpr, |
| | BS: tl.constexpr, |
| | BD: tl.constexpr, |
| | GS: tl.constexpr, |
| | IS_RMS_NORM: tl.constexpr, |
| | HAS_DRESIDUAL: tl.constexpr, |
| | STORE_DRESIDUAL: tl.constexpr, |
| | HAS_WEIGHT: tl.constexpr, |
| | HAS_BIAS: tl.constexpr, |
| | RECOMPUTE_OUTPUT: tl.constexpr, |
| | ): |
| | i_s = tl.program_id(0) |
| | i_g, i_sg = i_s // GS, i_s % GS |
| |
|
| | o_d = tl.arange(0, BD) |
| | mask = o_d < D |
| |
|
| | if HAS_WEIGHT: |
| | b_w = tl.load(w + i_g * D + o_d, mask=mask).to(tl.float32) |
| | b_dw = tl.zeros((BD,), dtype=tl.float32) |
| | if RECOMPUTE_OUTPUT and HAS_BIAS: |
| | b_b = tl.load(b + i_g * D + o_d, mask=mask, other=0.0).to(tl.float32) |
| | if HAS_BIAS: |
| | b_db = tl.zeros((BD,), dtype=tl.float32) |
| |
|
| | for i_t in range(i_sg * BS * G + i_g, min((i_sg * BS + BS) * G + i_g, T), G): |
| | b_x = tl.load(x + i_t * D + o_d, mask=mask, other=0).to(tl.float32) |
| | b_dy = tl.load(dy + i_t * D + o_d, mask=mask, other=0).to(tl.float32) |
| |
|
| | if not IS_RMS_NORM: |
| | b_mean = tl.load(mean + i_t) |
| | b_rstd = tl.load(rstd + i_t) |
| | |
| | b_xhat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd |
| | b_xhat = tl.where(mask, b_xhat, 0.0) |
| | if RECOMPUTE_OUTPUT: |
| | b_y = b_xhat * b_w if HAS_WEIGHT else b_xhat |
| | if HAS_BIAS: |
| | b_y = b_y + b_b |
| | tl.store(y + i_t * D + o_d, b_y, mask=mask) |
| | b_wdy = b_dy |
| | if HAS_WEIGHT: |
| | b_wdy = b_dy * b_w |
| | b_dw += b_dy * b_xhat |
| | if HAS_BIAS: |
| | b_db += b_dy |
| | if not IS_RMS_NORM: |
| | b_c1 = tl.sum(b_xhat * b_wdy, axis=0) / D |
| | b_c2 = tl.sum(b_wdy, axis=0) / D |
| | b_dx = (b_wdy - (b_xhat * b_c1 + b_c2)) * b_rstd |
| | else: |
| | b_c1 = tl.sum(b_xhat * b_wdy, axis=0) / D |
| | b_dx = (b_wdy - b_xhat * b_c1) * b_rstd |
| | if HAS_DRESIDUAL: |
| | b_dres = tl.load(dres + i_t * D + o_d, mask=mask, other=0).to(tl.float32) |
| | b_dx += b_dres |
| | |
| | b_dx = tl.cast(b_dx, dtype=dx.dtype.element_ty, fp_downcast_rounding='rtne') |
| | if STORE_DRESIDUAL: |
| | tl.store(dres_in + i_t * D + o_d, b_dx, mask=mask) |
| | tl.store(dx + i_t * D + o_d, b_dx, mask=mask) |
| |
|
| | if HAS_WEIGHT: |
| | tl.store(dw + i_s * D + o_d, b_dw, mask=mask) |
| | if HAS_BIAS: |
| | tl.store(db + i_s * D + o_d, b_db, mask=mask) |
| |
|
| |
|
| | def layer_norm_fwd( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | eps: float = 1e-5, |
| | residual: torch.Tensor = None, |
| | out_dtype: torch.dtype = None, |
| | residual_dtype: torch.dtype = None, |
| | is_rms_norm: bool = False, |
| | num_groups: int = 1, |
| | ): |
| | if residual is not None: |
| | residual_dtype = residual.dtype |
| | T, D, G = *x.shape, num_groups |
| | if residual is not None: |
| | assert residual.shape == (T, D) |
| | if weight is not None: |
| | assert weight.shape == (G * D,) |
| | if bias is not None: |
| | assert bias.shape == (G * D,) |
| | |
| | y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) |
| | if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype): |
| | res_out = torch.empty(T, D, device=x.device, dtype=residual_dtype) |
| | else: |
| | res_out = None |
| | mean = torch.empty((T,), dtype=torch.float, device=x.device) if not is_rms_norm else None |
| | rstd = torch.empty((T,), dtype=torch.float, device=x.device) |
| | |
| | MAX_FUSED_SIZE = 65536 // x.element_size() |
| | BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) |
| | if D > BD: |
| | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
| | |
| |
|
| | if D <= 512: |
| | NB = triton.cdiv(T, 2048) |
| | def grid(meta): return (triton.cdiv(T, meta['BT']), ) |
| | layer_norm_fwd_kernel[grid]( |
| | x, |
| | y, |
| | weight, |
| | bias, |
| | residual, |
| | res_out, |
| | mean, |
| | rstd, |
| | eps, |
| | T=T, |
| | G=G, |
| | D=D, |
| | BD=BD, |
| | NB=NB, |
| | IS_RMS_NORM=is_rms_norm, |
| | HAS_RESIDUAL=residual is not None, |
| | STORE_RESIDUAL_OUT=res_out is not None, |
| | HAS_WEIGHT=weight is not None, |
| | HAS_BIAS=bias is not None, |
| | ) |
| | else: |
| | layer_norm_fwd_kernel1[(T,)]( |
| | x, |
| | y, |
| | weight, |
| | bias, |
| | residual, |
| | res_out, |
| | mean, |
| | rstd, |
| | eps, |
| | G=G, |
| | D=D, |
| | BD=BD, |
| | IS_RMS_NORM=is_rms_norm, |
| | HAS_RESIDUAL=residual is not None, |
| | STORE_RESIDUAL_OUT=res_out is not None, |
| | HAS_WEIGHT=weight is not None, |
| | HAS_BIAS=bias is not None, |
| | ) |
| | |
| | return y, mean, rstd, res_out if res_out is not None else x |
| |
|
| |
|
| | def layer_norm_bwd( |
| | dy: torch.Tensor, |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | mean: torch.Tensor = None, |
| | rstd: torch.Tensor = None, |
| | dres: torch.Tensor = None, |
| | has_residual: bool = False, |
| | is_rms_norm: bool = False, |
| | x_dtype: torch.dtype = None, |
| | recompute_output: bool = False, |
| | num_groups: int = 1, |
| | ): |
| | T, D, G = *x.shape, num_groups |
| | assert dy.shape == (T, D) |
| | if dres is not None: |
| | assert dres.shape == (T, D) |
| | if weight is not None: |
| | assert weight.shape == (G * D,) |
| | if bias is not None: |
| | assert bias.shape == (G * D,) |
| | |
| | dx = torch.empty_like(x) if x_dtype is None else torch.empty(T, D, dtype=x_dtype, device=x.device) |
| | dres_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None |
| | y = torch.empty(T, D, dtype=dy.dtype, device=dy.device) if recompute_output else None |
| |
|
| | |
| | MAX_FUSED_SIZE = 65536 // x.element_size() |
| | BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) |
| | if D > BD: |
| | raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") |
| | |
| | NS = triton.cdiv(get_multiprocessor_count(x.device.index), G) * G |
| | BS = triton.cdiv(T, NS) |
| | GS = NS // G |
| |
|
| | dw = torch.empty((NS, D), dtype=torch.float, device=weight.device) if weight is not None else None |
| | db = torch.empty((NS, D), dtype=torch.float, device=bias.device) if bias is not None else None |
| | grid = (NS,) |
| |
|
| | if D <= 512: |
| | NB = triton.cdiv(T, 2048) |
| | layer_norm_bwd_kernel[grid]( |
| | x, |
| | weight, |
| | bias, |
| | y, |
| | dy, |
| | dx, |
| | dw, |
| | db, |
| | dres, |
| | dres_in, |
| | mean, |
| | rstd, |
| | T=T, |
| | G=G, |
| | D=D, |
| | BS=BS, |
| | BD=BD, |
| | NB=NB, |
| | GS=GS, |
| | IS_RMS_NORM=is_rms_norm, |
| | HAS_DRESIDUAL=dres is not None, |
| | STORE_DRESIDUAL=dres_in is not None, |
| | HAS_WEIGHT=weight is not None, |
| | HAS_BIAS=bias is not None, |
| | ) |
| | else: |
| | layer_norm_bwd_kernel1[grid]( |
| | x, |
| | weight, |
| | bias, |
| | y, |
| | dy, |
| | dx, |
| | dw, |
| | db, |
| | dres, |
| | dres_in, |
| | mean, |
| | rstd, |
| | T=T, |
| | G=G, |
| | D=D, |
| | BS=BS, |
| | BD=BD, |
| | GS=GS, |
| | IS_RMS_NORM=is_rms_norm, |
| | HAS_DRESIDUAL=dres is not None, |
| | STORE_DRESIDUAL=dres_in is not None, |
| | HAS_WEIGHT=weight is not None, |
| | HAS_BIAS=bias is not None, |
| | ) |
| | dw = dw.view(G, -1, D).sum(1).to(weight).view_as(weight) if weight is not None else None |
| | db = db.view(G, -1, D).sum(1).to(bias).view_as(bias) if bias is not None else None |
| | |
| | if has_residual and dx.dtype == x.dtype: |
| | dres_in = dx |
| | return (dx, dw, db, dres_in) if not recompute_output else (dx, dw, db, dres_in, y) |
| |
|
| |
|
| | class LayerNormFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | @input_guard |
| | def forward( |
| | ctx, |
| | x, |
| | weight, |
| | bias, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False, |
| | is_rms_norm: bool = False, |
| | num_groups: int = 1 |
| | ): |
| | x_shape_og = x.shape |
| |
|
| | if x.shape[-1] % num_groups != 0: |
| | raise ValueError('num_channels must be divisible by num_groups') |
| | |
| | x = x.reshape(-1, (x.shape[-1] // num_groups)) |
| | if residual is not None: |
| | assert residual.shape == x_shape_og |
| | residual = residual.reshape_as(x) |
| | residual_dtype = ( |
| | residual.dtype |
| | if residual is not None |
| | else (torch.float32 if residual_in_fp32 else None) |
| | ) |
| | y, mean, rstd, res_out = layer_norm_fwd( |
| | x, |
| | weight, |
| | bias, |
| | eps, |
| | residual, |
| | residual_dtype=residual_dtype, |
| | is_rms_norm=is_rms_norm, |
| | num_groups=num_groups |
| | ) |
| | ctx.save_for_backward(res_out, weight, bias, mean, rstd) |
| | ctx.x_shape_og = x_shape_og |
| | ctx.eps = eps |
| | ctx.is_rms_norm = is_rms_norm |
| | ctx.num_groups = num_groups |
| | ctx.has_residual = residual is not None |
| | ctx.prenorm = prenorm |
| | ctx.x_dtype = x.dtype |
| | y = y.reshape(x_shape_og) |
| | return y if not prenorm else (y, res_out.reshape(x_shape_og)) |
| |
|
| | @staticmethod |
| | @input_guard |
| | def backward(ctx, dy, *args): |
| | x, weight, bias, mean, rstd = ctx.saved_tensors |
| | dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups)) |
| | assert dy.shape == x.shape |
| | if ctx.prenorm: |
| | dresidual = args[0] |
| | dresidual = dresidual.reshape(-1, x.shape[-1]) |
| | assert dresidual.shape == x.shape |
| | else: |
| | dresidual = None |
| | dx, dw, db, dresidual_in = layer_norm_bwd( |
| | dy, |
| | x, |
| | weight, |
| | bias, |
| | mean, |
| | rstd, |
| | dresidual, |
| | ctx.has_residual, |
| | ctx.is_rms_norm, |
| | x_dtype=ctx.x_dtype, |
| | num_groups=ctx.num_groups |
| | ) |
| | return ( |
| | dx.reshape(ctx.x_shape_og), |
| | dw, |
| | db, |
| | dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None |
| | ) |
| |
|
| |
|
| | def layer_norm( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False, |
| | is_rms_norm: bool = False |
| | ): |
| | return LayerNormFunction.apply( |
| | x, |
| | weight, |
| | bias, |
| | residual, |
| | eps, |
| | prenorm, |
| | residual_in_fp32, |
| | is_rms_norm |
| | ) |
| |
|
| |
|
| | def group_norm( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False, |
| | is_rms_norm: bool = False, |
| | num_groups: int = 1 |
| | ): |
| | return LayerNormFunction.apply( |
| | x, |
| | weight, |
| | bias, |
| | residual, |
| | eps, |
| | prenorm, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | num_groups |
| | ) |
| |
|
| |
|
| | def rms_norm( |
| | x: torch.Tensor, |
| | weight: torch.Tensor, |
| | bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False |
| | ): |
| | return LayerNormFunction.apply( |
| | x, |
| | weight, |
| | bias, |
| | residual, |
| | eps, |
| | prenorm, |
| | residual_in_fp32, |
| | True |
| | ) |
| |
|
| |
|
| | def layer_norm_linear( |
| | x: torch.Tensor, |
| | norm_weight: torch.Tensor, |
| | norm_bias: torch.Tensor, |
| | linear_weight: torch.Tensor, |
| | linear_bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False, |
| | is_rms_norm: bool = False, |
| | num_groups: int = 1 |
| | ): |
| | return LayerNormLinearFunction.apply( |
| | x, |
| | norm_weight, |
| | norm_bias, |
| | linear_weight, |
| | linear_bias, |
| | residual, |
| | eps, |
| | prenorm, |
| | residual_in_fp32, |
| | is_rms_norm, |
| | num_groups |
| | ) |
| |
|
| |
|
| | def rms_norm_linear( |
| | x: torch.Tensor, |
| | norm_weight: torch.Tensor, |
| | norm_bias: torch.Tensor, |
| | linear_weight: torch.Tensor, |
| | linear_bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False |
| | ): |
| | return layer_norm_linear( |
| | x=x, |
| | norm_weight=norm_weight, |
| | norm_bias=norm_bias, |
| | linear_weight=linear_weight, |
| | linear_bias=linear_bias, |
| | residual=residual, |
| | eps=eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | is_rms_norm=True |
| | ) |
| |
|
| |
|
| | def group_norm_linear( |
| | x: torch.Tensor, |
| | norm_weight: torch.Tensor, |
| | norm_bias: torch.Tensor, |
| | linear_weight: torch.Tensor, |
| | linear_bias: torch.Tensor, |
| | residual: torch.Tensor = None, |
| | eps: float = 1e-5, |
| | prenorm: bool = False, |
| | residual_in_fp32: bool = False, |
| | is_rms_norm: bool = False, |
| | num_groups: int = 1 |
| | ): |
| | return layer_norm_linear( |
| | x=x, |
| | norm_weight=norm_weight, |
| | norm_bias=norm_bias, |
| | linear_weight=linear_weight, |
| | linear_bias=linear_bias, |
| | residual=residual, |
| | eps=eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | is_rms_norm=is_rms_norm, |
| | num_groups=num_groups |
| | ) |
| |
|
| |
|
| | class LayerNorm(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5 |
| | ) -> LayerNorm: |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): |
| | return layer_norm( |
| | x, |
| | self.weight, |
| | self.bias, |
| | residual=residual, |
| | eps=self.eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32 |
| | ) |
| |
|
| |
|
| | class GroupNorm(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | num_groups: int, |
| | hidden_size: int, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5, |
| | is_rms_norm: bool = False |
| | ) -> GroupNorm: |
| | super().__init__() |
| |
|
| | if hidden_size % num_groups != 0: |
| | raise ValueError('num_channels must be divisible by num_groups') |
| |
|
| | self.num_groups = num_groups |
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| | self.is_rms_norm = is_rms_norm |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | if self.is_rms_norm: |
| | s += f", is_rms_norm={self.is_rms_norm}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): |
| | return group_norm( |
| | x, |
| | self.weight, |
| | self.bias, |
| | residual=residual, |
| | eps=self.eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | is_rms_norm=self.is_rms_norm, |
| | num_groups=self.num_groups |
| | ) |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5 |
| | ) -> RMSNorm: |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): |
| | return rms_norm( |
| | x, |
| | self.weight, |
| | self.bias, |
| | residual=residual, |
| | eps=self.eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | ) |
| |
|
| |
|
| | class LayerNormLinearFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | @input_guard |
| | def forward( |
| | ctx, |
| | x, |
| | norm_weight, |
| | norm_bias, |
| | linear_weight, |
| | linear_bias, |
| | residual=None, |
| | eps=1e-5, |
| | prenorm=False, |
| | residual_in_fp32=False, |
| | is_rms_norm=False, |
| | num_groups=1 |
| | ): |
| | x_shape_og = x.shape |
| |
|
| | if x.shape[-1] % num_groups != 0: |
| | raise ValueError('num_channels must be divisible by num_groups') |
| | |
| | x = x.reshape(-1, (x.shape[-1] // num_groups)) |
| | if residual is not None: |
| | assert residual.shape == x_shape_og |
| | residual = residual.reshape_as(x) |
| | residual_dtype = ( |
| | residual.dtype |
| | if residual is not None |
| | else (torch.float32 if residual_in_fp32 else None) |
| | ) |
| | y, mean, rstd, res_out = layer_norm_fwd( |
| | x, |
| | norm_weight, |
| | norm_bias, |
| | eps, |
| | residual, |
| | out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(), |
| | residual_dtype=residual_dtype, |
| | is_rms_norm=is_rms_norm, |
| | num_groups=num_groups |
| | ) |
| | y = y.reshape(x_shape_og) |
| | dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype |
| | linear_weight = linear_weight.to(dtype) |
| | linear_bias = linear_bias.to(dtype) if linear_bias is not None else None |
| | out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias) |
| | |
| | ctx.save_for_backward(res_out, norm_weight, norm_bias, linear_weight, mean, rstd) |
| | ctx.x_shape_og = x_shape_og |
| | ctx.eps = eps |
| | ctx.is_rms_norm = is_rms_norm |
| | ctx.num_groups = num_groups |
| | ctx.has_residual = residual is not None |
| | ctx.prenorm = prenorm |
| | ctx.x_dtype = x.dtype |
| | ctx.linear_bias_is_none = linear_bias is None |
| | return out if not prenorm else (out, res_out.reshape(x_shape_og)) |
| |
|
| | @staticmethod |
| | @input_guard |
| | def backward(ctx, dout, *args): |
| | x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors |
| | dout = dout.reshape(-1, dout.shape[-1]) |
| | dy = F.linear(dout, linear_weight.t()) |
| | dy = dy.reshape(-1, (dy.shape[-1] // ctx.num_groups)) |
| | dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0) |
| | assert dy.shape == x.shape |
| | if ctx.prenorm: |
| | dresidual = args[0] |
| | dresidual = dresidual.reshape(-1, x.shape[-1]) |
| | assert dresidual.shape == x.shape |
| | else: |
| | dresidual = None |
| | dx, dnorm_weight, dnorm_bias, dresidual_in, y = layer_norm_bwd( |
| | dy, |
| | x, |
| | norm_weight, |
| | norm_bias, |
| | mean, |
| | rstd, |
| | dresidual, |
| | ctx.has_residual, |
| | ctx.is_rms_norm, |
| | x_dtype=ctx.x_dtype, |
| | recompute_output=True, |
| | num_groups=ctx.num_groups |
| | ) |
| | dlinear_weight = torch.einsum("bo,bi->oi", dout, y.view(-1, linear_weight.shape[-1])) |
| | return ( |
| | dx.reshape(ctx.x_shape_og), |
| | dnorm_weight, |
| | dnorm_bias, |
| | dlinear_weight, |
| | dlinear_bias, |
| | dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, |
| | None, |
| | None, |
| | None, |
| | None, |
| | None |
| | ) |
| |
|
| |
|
| | class LayerNormLinear(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | hidden_size, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5 |
| | ) -> LayerNormLinear: |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): |
| | return layer_norm_linear( |
| | x=x, |
| | norm_weight=self.weight, |
| | norm_bias=self.bias, |
| | linear_weight=weight, |
| | linear_bias=bias, |
| | residual=residual, |
| | eps=self.eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | is_rms_norm=False |
| | ) |
| |
|
| |
|
| | class GroupNormLinear(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | num_groups: int, |
| | hidden_size: int, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5, |
| | is_rms_norm: bool = False |
| | ) -> GroupNormLinear: |
| | super().__init__() |
| |
|
| | if hidden_size % num_groups != 0: |
| | raise ValueError('num_channels must be divisible by num_groups') |
| |
|
| | self.num_groups = num_groups |
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| | self.is_rms_norm = is_rms_norm |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.num_groups}, {self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | if self.is_rms_norm: |
| | s += f", is_rms_norm={self.is_rms_norm}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): |
| | return layer_norm_linear( |
| | x=x, |
| | norm_weight=self.weight, |
| | norm_bias=self.bias, |
| | linear_weight=weight, |
| | linear_bias=bias, |
| | residual=residual, |
| | eps=self.eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | is_rms_norm=self.is_rms_norm, |
| | num_groups=self.num_groups |
| | ) |
| |
|
| |
|
| | class RMSNormLinear(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | hidden_size, |
| | elementwise_affine: bool = True, |
| | bias: bool = False, |
| | eps: float = 1e-5 |
| | ) -> RMSNormLinear: |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.elementwise_affine = elementwise_affine |
| | self.eps = eps |
| |
|
| | self.register_parameter("weight", None) |
| | self.register_parameter("bias", None) |
| | if elementwise_affine: |
| | self.weight = nn.Parameter(torch.empty(hidden_size)) |
| | if bias: |
| | self.bias = nn.Parameter(torch.empty(hidden_size)) |
| |
|
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | if self.elementwise_affine: |
| | nn.init.ones_(self.weight) |
| | if self.bias is not None: |
| | nn.init.zeros_(self.bias) |
| |
|
| | def __repr__(self) -> str: |
| | s = f"{self.__class__.__name__}({self.hidden_size}" |
| | if not self.elementwise_affine: |
| | s += f", elementwise_affine={self.elementwise_affine}" |
| | s += f", eps={self.eps}" |
| | s += ")" |
| | return s |
| |
|
| | def forward(self, x, weight, bias, residual=None, prenorm=False, residual_in_fp32=False): |
| | return layer_norm_linear( |
| | x=x, |
| | norm_weight=self.weight, |
| | norm_bias=self.bias, |
| | linear_weight=weight, |
| | linear_bias=bias, |
| | residual=residual, |
| | eps=self.eps, |
| | prenorm=prenorm, |
| | residual_in_fp32=residual_in_fp32, |
| | is_rms_norm=True |
| | ) |
| |
|
| |
|
| | class NormParallel(ParallelStyle): |
| |
|
| | def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): |
| | super().__init__() |
| | self.sequence_sharding = (Shard(sequence_dim),) |
| | self.use_local_output = use_local_output |
| |
|
| | def _replicate_module_fn( |
| | self, name: str, module: nn.Module, device_mesh: DeviceMesh |
| | ): |
| | for p_name, param in module.named_parameters(): |
| | |
| | |
| | replicated_param = torch.nn.Parameter( |
| | DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) |
| | ) |
| | module.register_parameter(p_name, replicated_param) |
| |
|
| | @staticmethod |
| | def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): |
| | input_tensor = inputs[0] |
| | if isinstance(input_tensor, DTensor): |
| | |
| | if input_tensor.placements != sequence_sharding: |
| | input_tensor = input_tensor.redistribute( |
| | placements=sequence_sharding, async_op=True |
| | ) |
| | return input_tensor |
| | elif isinstance(input_tensor, torch.Tensor): |
| | |
| | return DTensor.from_local( |
| | input_tensor, device_mesh, sequence_sharding, run_check=False |
| | ) |
| | else: |
| | raise ValueError( |
| | f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" |
| | ) |
| |
|
| | @staticmethod |
| | def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): |
| | return outputs.to_local() if use_local_output else outputs |
| |
|
| | def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| | return distribute_module( |
| | module, |
| | device_mesh, |
| | self._replicate_module_fn, |
| | partial(self._prepare_input_fn, self.sequence_sharding), |
| | partial(self._prepare_output_fn, self.use_local_output), |
| | ) |
| |
|