ml-sharp / src /sharp /utils /robust.py
amael-apple's picture
Initial commit
c20d7cc
"""Utility function for loss implementations.
For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""
from __future__ import annotations
from typing import Callable
import torch
def robust_where(
condition: torch.Tensor,
input: torch.Tensor,
branch_true_func: Callable[[torch.Tensor], torch.Tensor],
branch_false_func: Callable[[torch.Tensor], torch.Tensor],
branch_true_safe_value: float | None = None,
branch_false_safe_value: float | None = None,
) -> torch.Tensor:
"""Robust torch.where function to avoid NaN in backward pass.
See https://github.com/pytorch/pytorch/issues/68425
Args:
condition: When True (nonzero), yield branch_true_func(input),
otherwise yield branch_false_func(input)
input: The input tensor for torch.where
branch_true_func: Callable for values at indices where condition is True.
branch_false_func: Callable for values at indices where condition is False.
branch_true_safe_value: Safe value to replace the true branch.
branch_false_safe_value: Safe value to replace the false branch.
"""
input_1 = input
input_2 = input
if branch_true_safe_value is not None:
input_1 = torch.where(condition, input_1, branch_true_safe_value)
if branch_false_safe_value is not None:
input_2 = torch.where(~condition, input_2, branch_false_safe_value)
return torch.where(
condition,
branch_true_func(input_1),
branch_false_func(input_2),
)