File size: 1,564 Bytes
c20d7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Utility function for loss implementations.

For licensing see accompanying LICENSE file.
Copyright (C) 2025 Apple Inc. All Rights Reserved.
"""

from __future__ import annotations

from typing import Callable

import torch


def robust_where(
    condition: torch.Tensor,
    input: torch.Tensor,
    branch_true_func: Callable[[torch.Tensor], torch.Tensor],
    branch_false_func: Callable[[torch.Tensor], torch.Tensor],
    branch_true_safe_value: float | None = None,
    branch_false_safe_value: float | None = None,
) -> torch.Tensor:
    """Robust torch.where function to avoid NaN in backward pass.

    See https://github.com/pytorch/pytorch/issues/68425

    Args:
        condition: When True (nonzero), yield branch_true_func(input),
            otherwise yield branch_false_func(input)
        input: The input tensor for torch.where
        branch_true_func: Callable for values at indices where condition is True.
        branch_false_func: Callable for values at indices where condition is False.
        branch_true_safe_value: Safe value to replace the true branch.
        branch_false_safe_value: Safe value to replace the false branch.
    """
    input_1 = input
    input_2 = input
    if branch_true_safe_value is not None:
        input_1 = torch.where(condition, input_1, branch_true_safe_value)
    if branch_false_safe_value is not None:
        input_2 = torch.where(~condition, input_2, branch_false_safe_value)
    return torch.where(
        condition,
        branch_true_func(input_1),
        branch_false_func(input_2),
    )