| from __future__ import annotations | |
| import torch | |
| def sparse_allclose( | |
| input: torch.Tensor, other: torch.Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False | |
| ) -> bool: | |
| """ | |
| Check if two sparse embeddings are close to each other. | |
| This function works with sparse embeddings in either: | |
| 1. Tensor format (assuming sparse tensors) | |
| 2. Dictionary format with 'indices' and 'values' keys | |
| Args: | |
| input: First sparse embedding (tensor) | |
| other: Second sparse embedding (tensor) | |
| rtol: Relative tolerance | |
| atol: Absolute tolerance | |
| equal_nan: If True, NaN values in the same locations are considered equal | |
| Returns: | |
| bool: True if embeddings are close according to tolerances | |
| """ | |
| # Check if shape matches | |
| if input.shape != other.shape: | |
| return False | |
| input = input.coalesce() | |
| other = other.coalesce() | |
| # Convert dict format to appropriate tensors if needed | |
| input_indices = input.indices() | |
| input_values = input.values() | |
| other_indices = other.indices() | |
| other_values = other.values() | |
| # Check if indices are the same | |
| if not torch.equal(input_indices, other_indices): | |
| return False | |
| # Check if values are close | |
| return torch.allclose(input_values, other_values, rtol=rtol, atol=atol, equal_nan=equal_nan) | |