File size: 427 Bytes
42f26af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

def nan_hook(self,inp, output):

    nan_mask = torch.isnan(output)

    if nan_mask.any():

        print("In", self.__class__.__name__)

        raise RuntimeError(f"Found NAN in output at indices: ", nan_mask.nonzero())

    inf_mask = torch.isinf(output)

    if inf_mask.any():

        print("In", self.__class__.__name__)

        raise RuntimeError(f"Found INF in output at indices: ", inf_mask.nonzero())