| """Utility function for loss implementations. | |
| For licensing see accompanying LICENSE file. | |
| Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| """ | |
| from __future__ import annotations | |
| from typing import Callable | |
| import torch | |
| def robust_where( | |
| condition: torch.Tensor, | |
| input: torch.Tensor, | |
| branch_true_func: Callable[[torch.Tensor], torch.Tensor], | |
| branch_false_func: Callable[[torch.Tensor], torch.Tensor], | |
| branch_true_safe_value: float | None = None, | |
| branch_false_safe_value: float | None = None, | |
| ) -> torch.Tensor: | |
| """Robust torch.where function to avoid NaN in backward pass. | |
| See https://github.com/pytorch/pytorch/issues/68425 | |
| Args: | |
| condition: When True (nonzero), yield branch_true_func(input), | |
| otherwise yield branch_false_func(input) | |
| input: The input tensor for torch.where | |
| branch_true_func: Callable for values at indices where condition is True. | |
| branch_false_func: Callable for values at indices where condition is False. | |
| branch_true_safe_value: Safe value to replace the true branch. | |
| branch_false_safe_value: Safe value to replace the false branch. | |
| """ | |
| input_1 = input | |
| input_2 = input | |
| if branch_true_safe_value is not None: | |
| input_1 = torch.where(condition, input_1, branch_true_safe_value) | |
| if branch_false_safe_value is not None: | |
| input_2 = torch.where(~condition, input_2, branch_false_safe_value) | |
| return torch.where( | |
| condition, | |
| branch_true_func(input_1), | |
| branch_false_func(input_2), | |
| ) | |