Kernels:
Trusted publisher
| # SPDX-License-Identifier: MIT | |
| # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | |
| import triton | |
| import triton.language as tl | |
| def _compute_fp8_scaling_factors(x, fp8_max: tl.constexpr): | |
| # compute fp8 scaling and descaling factor for a block | |
| x_amax = tl.max(tl.abs(x)) # NOTE: abs deals with negative values | |
| x_amax = tl.where(x_amax <= 1e-9, 1e-9, x_amax) | |
| scale_x = fp8_max / x_amax | |
| descale_x = x_amax / fp8_max | |
| return scale_x, descale_x | |