| """ 'Fast' Normalization Functions |
| |
| For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. |
| |
| Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) |
| |
| Hacked together by / Copyright 2022 Ross Wightman |
| """ |
| from typing import List, Optional |
|
|
| import torch |
| from torch.nn import functional as F |
|
|
| try: |
| from apex.normalization.fused_layer_norm import fused_layer_norm_affine |
| has_apex = True |
| except ImportError: |
| has_apex = False |
|
|
| try: |
| from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm |
| has_apex_rmsnorm = True |
| except ImportError: |
| has_apex_rmsnorm = False |
|
|
|
|
| has_torch_rms_norm = hasattr(F, 'rms_norm') |
|
|
| |
| _USE_FAST_NORM = False |
|
|
|
|
| def get_autocast_dtype(device: str = 'cuda'): |
| try: |
| return torch.get_autocast_dtype(device) |
| except (AttributeError, TypeError): |
| |
| if device == 'cpu': |
| return torch.get_autocast_cpu_dtype() |
| else: |
| assert device == 'cuda' |
| return torch.get_autocast_gpu_dtype() |
|
|
|
|
| def is_autocast_enabled(device: str = 'cuda'): |
| try: |
| return torch.is_autocast_enabled(device) |
| except TypeError: |
| |
| if device == 'cpu': |
| return torch.is_autocast_cpu_enabled() |
| else: |
| assert device == 'cuda' |
| return torch.is_autocast_enabled() |
|
|
|
|
| def is_fast_norm(): |
| return _USE_FAST_NORM |
|
|
|
|
| def set_fast_norm(enable=True): |
| global _USE_FAST_NORM |
| _USE_FAST_NORM = enable |
|
|
|
|
| def fast_group_norm( |
| x: torch.Tensor, |
| num_groups: int, |
| weight: Optional[torch.Tensor] = None, |
| bias: Optional[torch.Tensor] = None, |
| eps: float = 1e-5 |
| ) -> torch.Tensor: |
| if torch.jit.is_scripting(): |
| |
| return F.group_norm(x, num_groups, weight, bias, eps) |
|
|
| if is_autocast_enabled(x.device.type): |
| |
| |
| dt = get_autocast_dtype(x.device.type) |
| x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None |
|
|
| with torch.amp.autocast(device_type=x.device.type, enabled=False): |
| return F.group_norm(x, num_groups, weight, bias, eps) |
|
|
|
|
| def fast_layer_norm( |
| x: torch.Tensor, |
| normalized_shape: List[int], |
| weight: Optional[torch.Tensor] = None, |
| bias: Optional[torch.Tensor] = None, |
| eps: float = 1e-5 |
| ) -> torch.Tensor: |
| if torch.jit.is_scripting(): |
| |
| return F.layer_norm(x, normalized_shape, weight, bias, eps) |
|
|
| if has_apex: |
| return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) |
|
|
| if is_autocast_enabled(x.device.type): |
| |
| |
| dt = get_autocast_dtype(x.device.type) |
| x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None |
|
|
| with torch.amp.autocast(device_type=x.device.type, enabled=False): |
| return F.layer_norm(x, normalized_shape, weight, bias, eps) |
|
|
|
|
| def rms_norm( |
| x: torch.Tensor, |
| normalized_shape: List[int], |
| weight: Optional[torch.Tensor] = None, |
| eps: float = 1e-5, |
| ): |
| norm_ndim = len(normalized_shape) |
| v = x.pow(2) |
| if torch.jit.is_scripting(): |
| |
| |
| |
| assert norm_ndim == 1 |
| v = torch.mean(v, dim=-1).unsqueeze(-1) |
| else: |
| dims = tuple(range(-1, -norm_ndim - 1, -1)) |
| v = torch.mean(v, dim=dims, keepdim=True) |
| x = x * torch.rsqrt(v + eps) |
| if weight is not None: |
| x = x * weight |
| return x |
|
|
|
|
| def fast_rms_norm( |
| x: torch.Tensor, |
| normalized_shape: List[int], |
| weight: Optional[torch.Tensor] = None, |
| eps: float = 1e-5, |
| ) -> torch.Tensor: |
| if torch.jit.is_scripting(): |
| |
| return rms_norm(x, normalized_shape, weight, eps) |
|
|
| if has_apex_rmsnorm: |
| if weight is None: |
| return fused_rms_norm(x, normalized_shape, eps) |
| else: |
| return fused_rms_norm_affine(x, weight, normalized_shape, eps) |
|
|
| if is_autocast_enabled(x.device.type): |
| |
| |
| dt = get_autocast_dtype(x.device.type) |
| x, weight = x.to(dt), weight.to(dt) |
|
|
| with torch.amp.autocast(device_type=x.device.type, enabled=False): |
| if has_torch_rms_norm: |
| x = F.rms_norm(x, normalized_shape, weight, eps) |
| else: |
| x = rms_norm(x, normalized_shape, weight, eps) |
|
|
| return x |
|
|
|
|
| def simple_norm( |
| x: torch.Tensor, |
| normalized_shape: List[int], |
| weight: Optional[torch.Tensor] = None, |
| eps: float = 1e-5, |
| ): |
| norm_ndim = len(normalized_shape) |
| if torch.jit.is_scripting(): |
| |
| |
| |
| assert norm_ndim == 1 |
| v = torch.var(x, dim=-1).unsqueeze(-1) |
| else: |
| dims = tuple(range(-1, -norm_ndim - 1, -1)) |
| v = torch.var(x, dim=dims, keepdim=True) |
| x = x * torch.rsqrt(v + eps) |
| if weight is not None: |
| x = x * weight |
| return x |
|
|
|
|
| def fast_simple_norm( |
| x: torch.Tensor, |
| normalized_shape: List[int], |
| weight: Optional[torch.Tensor] = None, |
| eps: float = 1e-5, |
| ) -> torch.Tensor: |
| if torch.jit.is_scripting(): |
| |
| return simple_norm(x, normalized_shape, weight, eps) |
|
|
| if is_autocast_enabled(x.device.type): |
| |
| |
| dt = get_autocast_dtype(x.device.type) |
| x, weight = x.to(dt), weight.to(dt) |
|
|
| with torch.amp.autocast(device_type=x.device.type, enabled=False): |
| x = simple_norm(x, normalized_shape, weight, eps) |
| return x |
|
|
|
|