| from __future__ import annotations |
|
|
| import torch |
|
|
|
|
| def sparse_allclose( |
| input: torch.Tensor, other: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False |
| ) -> bool: |
| """ |
| Check if two sparse embeddings are close to each other. |
| |
| This function works with sparse embeddings in either: |
| 1. Tensor format (assuming sparse tensors) |
| 2. Dictionary format with 'indices' and 'values' keys |
| |
| Args: |
| input: First sparse embedding (tensor) |
| other: Second sparse embedding (tensor) |
| rtol: Relative tolerance |
| atol: Absolute tolerance |
| equal_nan: If True, NaN values in the same locations are considered equal |
| |
| Returns: |
| bool: True if embeddings are close according to tolerances |
| """ |
| |
| if input.shape != other.shape: |
| return False |
|
|
| input = input.coalesce() |
| other = other.coalesce() |
|
|
| |
| input_indices = input.indices() |
| input_values = input.values() |
|
|
| other_indices = other.indices() |
| other_values = other.values() |
|
|
| |
| if not torch.equal(input_indices, other_indices): |
| return False |
|
|
| |
| return torch.allclose(input_values, other_values, rtol=rtol, atol=atol, equal_nan=equal_nan) |
|
|