from __future__ import annotations import math from dataclasses import dataclass from .corpus import build_cooccurrence_matrix, build_vocabulary, tokenize from .linalg import Matrix, Vector, mean, np, top_k_eigenpairs_symmetric, zeros try: from scipy import sparse as scipy_sparse from scipy.sparse.linalg import svds as scipy_svds except (ImportError, ModuleNotFoundError, OSError): scipy_sparse = None scipy_svds = None SKETCHED_EMBEDDING_VOCAB_THRESHOLD = 2048 def _remove_common_embedding_axis(embeddings: object, row_strength: object | None = None) -> object: if np is None: return embeddings values = np.asarray(embeddings, dtype=np.float64) if values.size == 0 or len(values.shape) != 2: return values norms = np.linalg.norm(values, axis=1) nonzero = norms > 1e-12 values[nonzero] /= norms[nonzero, None] if row_strength is not None: strength = np.asarray(row_strength, dtype=np.float64) if strength.shape[0] == values.shape[0]: values[nonzero] *= np.log1p(strength[nonzero])[:, None] common_axis = values.mean(axis=0, keepdims=True) values = values - common_axis norms = np.linalg.norm(values, axis=1) nonzero = norms > 1e-12 values[nonzero] /= norms[nonzero, None] if row_strength is not None: strength = np.asarray(row_strength, dtype=np.float64) if strength.shape[0] == values.shape[0]: values[nonzero] *= np.log1p(strength[nonzero])[:, None] return values def _sketched_sparse_ppmi_embedding(ppmi: object, embedding_dim: int) -> object: coo = ppmi.tocoo() rows = coo.row.astype(np.int64, copy=False) cols = coo.col.astype(np.int64, copy=False) values = coo.data.astype(np.float64, copy=False) embeddings = np.zeros((ppmi.shape[0], embedding_dim), dtype=np.float64) if embedding_dim <= 0 or values.size == 0: return embeddings buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False) signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0) np.add.at(embeddings, (rows, buckets), values * signs) row_strength = np.sqrt(np.asarray(ppmi.sum(axis=1)).ravel()) return _remove_common_embedding_axis(embeddings, row_strength) def fit_sketched_ppmi_embedding_from_counts( id_to_token: list[str], rows: dict[int, dict[int, float]], *, embedding_dim: int, ) -> EmbeddingModel: if not id_to_token: raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.") if embedding_dim <= 0: raise ValueError("Embedding dimension must be positive.") size = len(id_to_token) token_to_id = {token: index for index, token in enumerate(id_to_token)} if np is None: embeddings = zeros(size, embedding_dim) row_sums = [0.0 for _ in range(size)] for row, columns in rows.items(): row_sums[row] = sum(columns.values()) total = sum(row_sums) if total <= 0.0: return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[]) for row, columns in rows.items(): for col, count in columns.items(): denominator = row_sums[row] * row_sums[col] if count <= 0.0 or denominator <= 0.0: continue value = math.log((count * total) / denominator) if value <= 0.0: continue bucket = (col * 1103515245 + 12345) % embedding_dim sign = 1.0 if ((col * 214013 + 2531011) & 1) == 0 else -1.0 embeddings[row][bucket] += value * sign return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[]) embeddings = np.zeros((size, embedding_dim), dtype=np.float64) row_sums = np.zeros(size, dtype=np.float64) for row, columns in rows.items(): row_sums[row] = sum(columns.values()) total = float(row_sums.sum()) if total <= 0.0: return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[]) for row, columns in rows.items(): if not columns or row_sums[row] <= 0.0: continue cols = np.fromiter(columns.keys(), dtype=np.int64) counts = np.fromiter(columns.values(), dtype=np.float64) denominators = row_sums[row] * row_sums[cols] valid = (counts > 0.0) & (denominators > 0.0) if not np.any(valid): continue cols = cols[valid] values = np.log((counts[valid] * total) / denominators[valid]) positive = values > 0.0 if not np.any(positive): continue cols = cols[positive] values = values[positive] buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False) signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0) np.add.at(embeddings[row], buckets, values * signs) embeddings = _remove_common_embedding_axis(embeddings, row_sums) return EmbeddingModel( token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) def _positive_ppmi_values( *, row: int, columns: dict[int, float], row_sums: object, total: float, ) -> tuple[object, object]: cols = np.fromiter(columns.keys(), dtype=np.int64) counts = np.fromiter(columns.values(), dtype=np.float64) if cols.size == 0: return cols, counts denominators = float(row_sums[row]) * row_sums[cols] valid = (counts > 0.0) & (denominators > 0.0) if not np.any(valid): return cols[:0], counts[:0] cols = cols[valid] values = np.log((counts[valid] * total) / denominators[valid]) positive = values > 0.0 return cols[positive], values[positive] def fit_randomized_ppmi_embedding_from_counts( id_to_token: list[str], rows: dict[int, dict[int, float]], *, embedding_dim: int, oversampling: int = 32, ) -> EmbeddingModel: if np is None: return fit_sketched_ppmi_embedding_from_counts( id_to_token, rows, embedding_dim=embedding_dim, ) if not id_to_token: raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.") if embedding_dim <= 0: raise ValueError("Embedding dimension must be positive.") size = len(id_to_token) token_to_id = {token: index for index, token in enumerate(id_to_token)} row_sums = np.zeros(size, dtype=np.float64) for row, columns in rows.items(): row_sums[row] = sum(columns.values()) total = float(row_sums.sum()) if total <= 0.0: return EmbeddingModel( token_to_id=token_to_id, id_to_token=id_to_token, embeddings=np.zeros((size, embedding_dim), dtype=np.float64), ppmi_matrix=[], ) width = min(size, max(embedding_dim, embedding_dim + oversampling)) rng = np.random.default_rng(1729 + size * 31 + embedding_dim) omega = rng.standard_normal((size, width)).astype(np.float64, copy=False) sketch = np.zeros((size, width), dtype=np.float64) ppmi_cache: dict[int, tuple[object, object]] = {} for row, columns in rows.items(): if not columns or row_sums[row] <= 0.0: continue cols, values = _positive_ppmi_values( row=row, columns=columns, row_sums=row_sums, total=total, ) if values.size == 0: continue ppmi_cache[row] = (cols, values) sketch[row] = values @ omega[cols] if not ppmi_cache: return EmbeddingModel( token_to_id=token_to_id, id_to_token=id_to_token, embeddings=np.zeros((size, embedding_dim), dtype=np.float64), ppmi_matrix=[], ) basis, _ = np.linalg.qr(sketch, mode="reduced") compressed = np.zeros((basis.shape[1], size), dtype=np.float64) for row, (cols, values) in ppmi_cache.items(): compressed[:, cols] += basis[row, :, None] * values[None, :] left_small, singular_values, _ = np.linalg.svd(compressed, full_matrices=False) left = basis @ left_small width = min(embedding_dim, left.shape[1], singular_values.shape[0]) embeddings = np.zeros((size, embedding_dim), dtype=np.float64) if width > 0: embeddings[:, :width] = left[:, :width] * np.sqrt(np.maximum(singular_values[:width], 0.0))[None, :] embeddings = _remove_common_embedding_axis(embeddings, np.sqrt(row_sums)) return EmbeddingModel( token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) def positive_pointwise_mutual_information(matrix: Matrix) -> Matrix: if scipy_sparse is not None and scipy_sparse.issparse(matrix): counts = matrix.tocoo() if counts.nnz == 0: return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64) row_sums = np.asarray(matrix.sum(axis=1)).ravel() total = float(row_sums.sum()) if total == 0.0: return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64) denominators = row_sums[counts.row] * row_sums[counts.col] valid = (counts.data > 0.0) & (denominators > 0.0) if not np.any(valid): return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64) ratios = (counts.data[valid] * total) / denominators[valid] data = np.maximum(np.log(ratios), 0.0) keep = data > 0.0 if not np.any(keep): return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64) return scipy_sparse.coo_matrix( ( data[keep], (counts.row[valid][keep], counts.col[valid][keep]), ), shape=counts.shape, dtype=np.float64, ).tocsr() if not matrix: return [] if np is not None: counts = np.asarray(matrix, dtype=np.float64) row_sums = counts.sum(axis=1) total = float(row_sums.sum()) if total == 0.0: return np.zeros_like(counts).tolist() denominator = np.outer(row_sums, row_sums) valid = (counts > 0.0) & (denominator > 0.0) ppmi = np.zeros_like(counts) with np.errstate(divide="ignore", invalid="ignore"): ratios = np.divide( counts * total, denominator, out=np.ones_like(counts), where=valid, ) ppmi[valid] = np.maximum(np.log(ratios[valid]), 0.0) return ppmi.tolist() row_sums = [sum(row) for row in matrix] total = sum(row_sums) if total == 0.0: return zeros(len(matrix), len(matrix)) ppmi = zeros(len(matrix), len(matrix)) for row in range(len(matrix)): for col in range(len(matrix[row])): count = matrix[row][col] if count <= 0.0 or row_sums[row] == 0.0 or row_sums[col] == 0.0: continue p_ij = count / total p_i = row_sums[row] / total p_j = row_sums[col] / total value = math.log(p_ij / (p_i * p_j)) ppmi[row][col] = max(0.0, value) return ppmi @dataclass(slots=True) class EmbeddingModel: token_to_id: dict[str, int] id_to_token: list[str] embeddings: Matrix ppmi_matrix: Matrix def vector(self, token: str) -> Vector: index = self.token_to_id.get(token) if index is None and token.lower() != token: index = self.token_to_id.get(token.lower()) if index is None: return [0.0 for _ in range(self.dimension)] row = self.embeddings[index] return row.astype(float).tolist() if hasattr(row, "tolist") else row[:] @property def dimension(self) -> int: if hasattr(self.embeddings, "shape"): return int(self.embeddings.shape[1]) if len(self.embeddings.shape) > 1 else 0 return len(self.embeddings[0]) if self.embeddings else 0 @property def projection_axis(self) -> Vector: if hasattr(self.embeddings, "shape"): if int(self.embeddings.shape[0]) == 0: return [] return self.embeddings.mean(axis=0).astype(float).tolist() if not self.embeddings: return [] return [ mean([row[column] for row in self.embeddings]) for column in range(self.dimension) ] def complete_id_to_token( id_to_token: list[str], required_tokens: list[str] | tuple[str, ...] | set[str] | None, ) -> list[str]: if not required_tokens: return id_to_token completed = list(id_to_token) seen = set(completed) for token in required_tokens: if token not in seen: completed.append(token) seen.add(token) return completed def extend_embedding_model_vocabulary( model: EmbeddingModel, required_tokens: list[str] | tuple[str, ...] | set[str] | None, ) -> EmbeddingModel: id_to_token = complete_id_to_token(model.id_to_token, required_tokens) missing_count = len(id_to_token) - len(model.id_to_token) if missing_count <= 0: return model dimension = model.dimension if np is not None and hasattr(model.embeddings, "shape"): existing = np.asarray(model.embeddings, dtype=np.float64) missing = np.zeros((missing_count, dimension), dtype=existing.dtype) embeddings = np.vstack([existing, missing]) else: embeddings = [ row.astype(float).tolist() if hasattr(row, "tolist") else list(row) for row in model.embeddings ] embeddings.extend([[0.0 for _ in range(dimension)] for _ in range(missing_count)]) return EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) def fit_ppmi_embedding( text: str, *, embedding_dim: int, window_size: int, min_frequency: int = 1, max_vocab: int | None = None, ) -> EmbeddingModel: tokens = tokenize(text) if not tokens: raise ValueError("Cannot fit REFRAMR embeddings on empty text.") return fit_ppmi_embedding_from_tokens( tokens, embedding_dim=embedding_dim, window_size=window_size, min_frequency=min_frequency, max_vocab=max_vocab, ) def fit_ppmi_embedding_from_tokens( tokens: list[str], *, embedding_dim: int, window_size: int, min_frequency: int = 1, max_vocab: int | None = None, required_tokens: list[str] | tuple[str, ...] | set[str] | None = None, ) -> EmbeddingModel: if not tokens: raise ValueError("Cannot fit REFRAMR embeddings on an empty token stream.") token_to_id, id_to_token = build_vocabulary(tokens, min_frequency, max_vocab) cooccurrence = build_cooccurrence_matrix(tokens, token_to_id, window_size) ppmi = positive_pointwise_mutual_information(cooccurrence) eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim) embeddings = zeros(len(id_to_token), embedding_dim) for component, (eigenvalue, eigenvector) in enumerate(eigenpairs): scale = math.sqrt(max(eigenvalue, 0.0)) for row in range(len(id_to_token)): embeddings[row][component] = eigenvector[row] * scale if np is not None: embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64)) model = EmbeddingModel( token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=ppmi, ) return extend_embedding_model_vocabulary(model, required_tokens) def fit_ppmi_embedding_from_cooccurrence( id_to_token: list[str], cooccurrence: Matrix, *, embedding_dim: int, ) -> EmbeddingModel: if not id_to_token: raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.") ppmi = positive_pointwise_mutual_information(cooccurrence) if scipy_sparse is not None and scipy_sparse.issparse(ppmi): embedding_width = min(embedding_dim, len(id_to_token)) if len(id_to_token) >= SKETCHED_EMBEDDING_VOCAB_THRESHOLD or embedding_width >= 128: embeddings = _sketched_sparse_ppmi_embedding(ppmi, embedding_dim) return EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) embeddings = zeros(len(id_to_token), embedding_dim) if embedding_width <= 0 or ppmi.nnz == 0: return EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) if embedding_width < min(ppmi.shape) and scipy_svds is not None: left, values, _ = scipy_svds(ppmi.asfptype(), k=embedding_width, which="LM") order = np.argsort(values)[::-1] for component, source_index in enumerate(order): scale = math.sqrt(max(float(values[source_index]), 0.0)) column = left[:, source_index] for row, value in enumerate(column): embeddings[row][component] = float(value) * scale else: dense = ppmi.toarray().tolist() eigenpairs = top_k_eigenpairs_symmetric(dense, embedding_width) for component, (eigenvalue, eigenvector) in enumerate(eigenpairs): scale = math.sqrt(max(eigenvalue, 0.0)) for row in range(len(id_to_token)): embeddings[row][component] = eigenvector[row] * scale if np is not None: embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64)) return EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[], ) eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim) embeddings = zeros(len(id_to_token), embedding_dim) for component, (eigenvalue, eigenvector) in enumerate(eigenpairs): scale = math.sqrt(max(eigenvalue, 0.0)) for row in range(len(id_to_token)): embeddings[row][component] = eigenvector[row] * scale if np is not None: embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64)) return EmbeddingModel( token_to_id={token: index for index, token in enumerate(id_to_token)}, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=ppmi, )