import torch def nan_hook(self,inp, output): nan_mask = torch.isnan(output) if nan_mask.any(): print("In", self.__class__.__name__) raise RuntimeError(f"Found NAN in output at indices: ", nan_mask.nonzero()) inf_mask = torch.isinf(output) if inf_mask.any(): print("In", self.__class__.__name__) raise RuntimeError(f"Found INF in output at indices: ", inf_mask.nonzero())