| | import numpy as np |
| | from cloudpathlib import AnyPath |
| |
|
| | from src.data.esm.utils.types import PathLike |
| |
|
| |
|
| | class LSHTable: |
| | def __init__(self, n_bits: int, dim: int, hyperplanes: np.ndarray | None = None): |
| | if hyperplanes is None: |
| | hyperplanes = np.random.randn(n_bits, dim) |
| | hyperplanes = hyperplanes / np.linalg.norm( |
| | hyperplanes, axis=-1, keepdims=True |
| | ) |
| | else: |
| | assert hyperplanes.shape == (n_bits, dim), ( |
| | hyperplanes.shape, |
| | (n_bits, dim), |
| | ) |
| | assert hyperplanes is not None |
| | self.hyperplanes: np.ndarray = hyperplanes |
| | self.values = 1 << np.arange(n_bits) |
| |
|
| | def __call__(self, array, tokenize: bool = True): |
| | similarity = self.hyperplanes @ array.T |
| | bits = np.where(similarity >= 0, 1, 0) |
| | if tokenize: |
| | tokens = bits.T @ self.values |
| | return tokens |
| | else: |
| | return bits.T |
| |
|
| |
|
| | class LSHTokenized: |
| | def __init__( |
| | self, |
| | n_bits: int, |
| | dim: int, |
| | num_tables: int = 1, |
| | filepath: PathLike | None = None, |
| | allow_create_hyperplanes: bool = False, |
| | ): |
| | table_hyperplanes = None |
| | if filepath is not None: |
| | filepath = AnyPath(filepath) |
| | if not filepath.exists(): |
| | raise FileNotFoundError(filepath) |
| | table_hyperplanes = np.load(filepath) |
| | for i in range(num_tables): |
| | assert str(i) in table_hyperplanes, f"Missing hyperplane for table {i}" |
| | elif not allow_create_hyperplanes: |
| | raise RuntimeError( |
| | "Not allowed to create hyperplanes but no filepath provided" |
| | ) |
| |
|
| | self.tables = [ |
| | LSHTable( |
| | n_bits, |
| | dim, |
| | table_hyperplanes[str(i)] if table_hyperplanes is not None else None, |
| | ) |
| | for i in range(num_tables) |
| | ] |
| |
|
| | def write_hyperplanes(self, filepath: PathLike): |
| | hyperplanes: dict[str, np.ndarray] = { |
| | str(i): table.hyperplanes for i, table in enumerate(self.tables) |
| | } |
| | np.savez(filepath, **hyperplanes) |
| |
|
| | def __call__(self, array): |
| | tokens = np.stack([table(array) for table in self.tables], 1) |
| | return tokens |
| |
|
| |
|
| | class LSHBitstream: |
| | def __init__( |
| | self, |
| | n_bits: int, |
| | dim: int, |
| | filepath: PathLike | None = None, |
| | allow_create_hyperplanes: bool = False, |
| | ): |
| | table_hyperplanes = None |
| | if filepath is not None: |
| | filepath = AnyPath(filepath) |
| | if not filepath.exists(): |
| | raise FileNotFoundError(filepath) |
| | table_hyperplanes = np.load(filepath) |
| | elif not allow_create_hyperplanes: |
| | raise RuntimeError( |
| | "Not allowed to create hyperplanes but no filepath provided" |
| | ) |
| |
|
| | self.table = LSHTable( |
| | n_bits, dim, table_hyperplanes if table_hyperplanes is not None else None |
| | ) |
| |
|
| | def write_hyperplanes(self, filepath: PathLike): |
| | np.save(filepath, self.table.hyperplanes) |
| |
|
| | def __call__(self, array): |
| | return self.table(array, tokenize=False) |
| |
|