File size: 1,357 Bytes
bd33eac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from __future__ import annotations
import torch
def sparse_allclose(
input: torch.Tensor, other: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False
) -> bool:
"""
Check if two sparse embeddings are close to each other.
This function works with sparse embeddings in either:
1. Tensor format (assuming sparse tensors)
2. Dictionary format with 'indices' and 'values' keys
Args:
input: First sparse embedding (tensor)
other: Second sparse embedding (tensor)
rtol: Relative tolerance
atol: Absolute tolerance
equal_nan: If True, NaN values in the same locations are considered equal
Returns:
bool: True if embeddings are close according to tolerances
"""
# Check if shape matches
if input.shape != other.shape:
return False
input = input.coalesce()
other = other.coalesce()
# Convert dict format to appropriate tensors if needed
input_indices = input.indices()
input_values = input.values()
other_indices = other.indices()
other_values = other.values()
# Check if indices are the same
if not torch.equal(input_indices, other_indices):
return False
# Check if values are close
return torch.allclose(input_values, other_values, rtol=rtol, atol=atol, equal_nan=equal_nan)
|