"""Utility function for loss implementations. For licensing see accompanying LICENSE file. Copyright (C) 2025 Apple Inc. All Rights Reserved. """ from __future__ import annotations from typing import Callable import torch def robust_where( condition: torch.Tensor, input: torch.Tensor, branch_true_func: Callable[[torch.Tensor], torch.Tensor], branch_false_func: Callable[[torch.Tensor], torch.Tensor], branch_true_safe_value: float | None = None, branch_false_safe_value: float | None = None, ) -> torch.Tensor: """Robust torch.where function to avoid NaN in backward pass. See https://github.com/pytorch/pytorch/issues/68425 Args: condition: When True (nonzero), yield branch_true_func(input), otherwise yield branch_false_func(input) input: The input tensor for torch.where branch_true_func: Callable for values at indices where condition is True. branch_false_func: Callable for values at indices where condition is False. branch_true_safe_value: Safe value to replace the true branch. branch_false_safe_value: Safe value to replace the false branch. """ input_1 = input input_2 = input if branch_true_safe_value is not None: input_1 = torch.where(condition, input_1, branch_true_safe_value) if branch_false_safe_value is not None: input_2 = torch.where(~condition, input_2, branch_false_safe_value) return torch.where( condition, branch_true_func(input_1), branch_false_func(input_2), )