| import torch | |
| from typing import Iterable | |
| def calc_diff(x: torch.Tensor, y: torch.Tensor): | |
| x, y = x.double(), y.double() | |
| denominator = (x * x + y * y).sum() | |
| if denominator == 0: # Which means that all elements in x and y are 0 | |
| return 0.0 | |
| sim = 2 * (x * y).sum() / denominator | |
| return 1 - sim | |
| def count_bytes(*tensors): | |
| total = 0 | |
| for t in tensors: | |
| if isinstance(t, (tuple, list)): | |
| total += count_bytes(*t) | |
| elif t is not None: | |
| total += t.numel() * t.element_size() | |
| return total | |