medmekk's picture
Upload folder using huggingface_hub
c67ae40 verified
raw
history blame contribute delete
561 Bytes
import torch
from typing import Iterable
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
if denominator == 0: # Which means that all elements in x and y are 0
return 0.0
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def count_bytes(*tensors):
total = 0
for t in tensors:
if isinstance(t, (tuple, list)):
total += count_bytes(*t)
elif t is not None:
total += t.numel() * t.element_size()
return total