import torch from typing import Iterable def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() if denominator == 0: # Which means that all elements in x and y are 0 return 0.0 sim = 2 * (x * y).sum() / denominator return 1 - sim def count_bytes(*tensors): total = 0 for t in tensors: if isinstance(t, (tuple, list)): total += count_bytes(*t) elif t is not None: total += t.numel() * t.element_size() return total