// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 #pragma once #include #include #include typedef float float_t; extern const float_t NEG_INF; template inline T max_val(T v) { return v; } template inline T max_val(T v, Args... rest) { auto restMax = max_val(rest...); return std::max(v, restMax); } template inline T sum_exp(T maxVal, T v) { return std::exp(v - maxVal); } template inline T sum_exp(T maxVal, T v, Args... rest) { auto restSum = sum_exp(maxVal, rest...); return sum_exp(maxVal, v) + restSum; } template inline T log_sum_exp(T v, Args ...args) { auto maxVal = max_val(v, args...); if (maxVal == -std::numeric_limits::infinity()) { return -std::numeric_limits::infinity(); } auto sumExp = sum_exp(maxVal, v, args...); return maxVal + std::log(sumExp); }