Reframr-RFM-v1-Base / reframr /embeddings.py
OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
from __future__ import annotations
import math
from dataclasses import dataclass
from .corpus import build_cooccurrence_matrix, build_vocabulary, tokenize
from .linalg import Matrix, Vector, mean, np, top_k_eigenpairs_symmetric, zeros
try:
from scipy import sparse as scipy_sparse
from scipy.sparse.linalg import svds as scipy_svds
except (ImportError, ModuleNotFoundError, OSError):
scipy_sparse = None
scipy_svds = None
SKETCHED_EMBEDDING_VOCAB_THRESHOLD = 2048
def _remove_common_embedding_axis(embeddings: object, row_strength: object | None = None) -> object:
if np is None:
return embeddings
values = np.asarray(embeddings, dtype=np.float64)
if values.size == 0 or len(values.shape) != 2:
return values
norms = np.linalg.norm(values, axis=1)
nonzero = norms > 1e-12
values[nonzero] /= norms[nonzero, None]
if row_strength is not None:
strength = np.asarray(row_strength, dtype=np.float64)
if strength.shape[0] == values.shape[0]:
values[nonzero] *= np.log1p(strength[nonzero])[:, None]
common_axis = values.mean(axis=0, keepdims=True)
values = values - common_axis
norms = np.linalg.norm(values, axis=1)
nonzero = norms > 1e-12
values[nonzero] /= norms[nonzero, None]
if row_strength is not None:
strength = np.asarray(row_strength, dtype=np.float64)
if strength.shape[0] == values.shape[0]:
values[nonzero] *= np.log1p(strength[nonzero])[:, None]
return values
def _sketched_sparse_ppmi_embedding(ppmi: object, embedding_dim: int) -> object:
coo = ppmi.tocoo()
rows = coo.row.astype(np.int64, copy=False)
cols = coo.col.astype(np.int64, copy=False)
values = coo.data.astype(np.float64, copy=False)
embeddings = np.zeros((ppmi.shape[0], embedding_dim), dtype=np.float64)
if embedding_dim <= 0 or values.size == 0:
return embeddings
buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False)
signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0)
np.add.at(embeddings, (rows, buckets), values * signs)
row_strength = np.sqrt(np.asarray(ppmi.sum(axis=1)).ravel())
return _remove_common_embedding_axis(embeddings, row_strength)
def fit_sketched_ppmi_embedding_from_counts(
id_to_token: list[str],
rows: dict[int, dict[int, float]],
*,
embedding_dim: int,
) -> EmbeddingModel:
if not id_to_token:
raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
if embedding_dim <= 0:
raise ValueError("Embedding dimension must be positive.")
size = len(id_to_token)
token_to_id = {token: index for index, token in enumerate(id_to_token)}
if np is None:
embeddings = zeros(size, embedding_dim)
row_sums = [0.0 for _ in range(size)]
for row, columns in rows.items():
row_sums[row] = sum(columns.values())
total = sum(row_sums)
if total <= 0.0:
return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
for row, columns in rows.items():
for col, count in columns.items():
denominator = row_sums[row] * row_sums[col]
if count <= 0.0 or denominator <= 0.0:
continue
value = math.log((count * total) / denominator)
if value <= 0.0:
continue
bucket = (col * 1103515245 + 12345) % embedding_dim
sign = 1.0 if ((col * 214013 + 2531011) & 1) == 0 else -1.0
embeddings[row][bucket] += value * sign
return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
embeddings = np.zeros((size, embedding_dim), dtype=np.float64)
row_sums = np.zeros(size, dtype=np.float64)
for row, columns in rows.items():
row_sums[row] = sum(columns.values())
total = float(row_sums.sum())
if total <= 0.0:
return EmbeddingModel(token_to_id=token_to_id, id_to_token=id_to_token, embeddings=embeddings, ppmi_matrix=[])
for row, columns in rows.items():
if not columns or row_sums[row] <= 0.0:
continue
cols = np.fromiter(columns.keys(), dtype=np.int64)
counts = np.fromiter(columns.values(), dtype=np.float64)
denominators = row_sums[row] * row_sums[cols]
valid = (counts > 0.0) & (denominators > 0.0)
if not np.any(valid):
continue
cols = cols[valid]
values = np.log((counts[valid] * total) / denominators[valid])
positive = values > 0.0
if not np.any(positive):
continue
cols = cols[positive]
values = values[positive]
buckets = ((cols * 1103515245 + 12345) % embedding_dim).astype(np.int64, copy=False)
signs = np.where(((cols * 214013 + 2531011) & 1) == 0, 1.0, -1.0)
np.add.at(embeddings[row], buckets, values * signs)
embeddings = _remove_common_embedding_axis(embeddings, row_sums)
return EmbeddingModel(
token_to_id=token_to_id,
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
def _positive_ppmi_values(
*,
row: int,
columns: dict[int, float],
row_sums: object,
total: float,
) -> tuple[object, object]:
cols = np.fromiter(columns.keys(), dtype=np.int64)
counts = np.fromiter(columns.values(), dtype=np.float64)
if cols.size == 0:
return cols, counts
denominators = float(row_sums[row]) * row_sums[cols]
valid = (counts > 0.0) & (denominators > 0.0)
if not np.any(valid):
return cols[:0], counts[:0]
cols = cols[valid]
values = np.log((counts[valid] * total) / denominators[valid])
positive = values > 0.0
return cols[positive], values[positive]
def fit_randomized_ppmi_embedding_from_counts(
id_to_token: list[str],
rows: dict[int, dict[int, float]],
*,
embedding_dim: int,
oversampling: int = 32,
) -> EmbeddingModel:
if np is None:
return fit_sketched_ppmi_embedding_from_counts(
id_to_token,
rows,
embedding_dim=embedding_dim,
)
if not id_to_token:
raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
if embedding_dim <= 0:
raise ValueError("Embedding dimension must be positive.")
size = len(id_to_token)
token_to_id = {token: index for index, token in enumerate(id_to_token)}
row_sums = np.zeros(size, dtype=np.float64)
for row, columns in rows.items():
row_sums[row] = sum(columns.values())
total = float(row_sums.sum())
if total <= 0.0:
return EmbeddingModel(
token_to_id=token_to_id,
id_to_token=id_to_token,
embeddings=np.zeros((size, embedding_dim), dtype=np.float64),
ppmi_matrix=[],
)
width = min(size, max(embedding_dim, embedding_dim + oversampling))
rng = np.random.default_rng(1729 + size * 31 + embedding_dim)
omega = rng.standard_normal((size, width)).astype(np.float64, copy=False)
sketch = np.zeros((size, width), dtype=np.float64)
ppmi_cache: dict[int, tuple[object, object]] = {}
for row, columns in rows.items():
if not columns or row_sums[row] <= 0.0:
continue
cols, values = _positive_ppmi_values(
row=row,
columns=columns,
row_sums=row_sums,
total=total,
)
if values.size == 0:
continue
ppmi_cache[row] = (cols, values)
sketch[row] = values @ omega[cols]
if not ppmi_cache:
return EmbeddingModel(
token_to_id=token_to_id,
id_to_token=id_to_token,
embeddings=np.zeros((size, embedding_dim), dtype=np.float64),
ppmi_matrix=[],
)
basis, _ = np.linalg.qr(sketch, mode="reduced")
compressed = np.zeros((basis.shape[1], size), dtype=np.float64)
for row, (cols, values) in ppmi_cache.items():
compressed[:, cols] += basis[row, :, None] * values[None, :]
left_small, singular_values, _ = np.linalg.svd(compressed, full_matrices=False)
left = basis @ left_small
width = min(embedding_dim, left.shape[1], singular_values.shape[0])
embeddings = np.zeros((size, embedding_dim), dtype=np.float64)
if width > 0:
embeddings[:, :width] = left[:, :width] * np.sqrt(np.maximum(singular_values[:width], 0.0))[None, :]
embeddings = _remove_common_embedding_axis(embeddings, np.sqrt(row_sums))
return EmbeddingModel(
token_to_id=token_to_id,
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
def positive_pointwise_mutual_information(matrix: Matrix) -> Matrix:
if scipy_sparse is not None and scipy_sparse.issparse(matrix):
counts = matrix.tocoo()
if counts.nnz == 0:
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
row_sums = np.asarray(matrix.sum(axis=1)).ravel()
total = float(row_sums.sum())
if total == 0.0:
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
denominators = row_sums[counts.row] * row_sums[counts.col]
valid = (counts.data > 0.0) & (denominators > 0.0)
if not np.any(valid):
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
ratios = (counts.data[valid] * total) / denominators[valid]
data = np.maximum(np.log(ratios), 0.0)
keep = data > 0.0
if not np.any(keep):
return scipy_sparse.csr_matrix(counts.shape, dtype=np.float64)
return scipy_sparse.coo_matrix(
(
data[keep],
(counts.row[valid][keep], counts.col[valid][keep]),
),
shape=counts.shape,
dtype=np.float64,
).tocsr()
if not matrix:
return []
if np is not None:
counts = np.asarray(matrix, dtype=np.float64)
row_sums = counts.sum(axis=1)
total = float(row_sums.sum())
if total == 0.0:
return np.zeros_like(counts).tolist()
denominator = np.outer(row_sums, row_sums)
valid = (counts > 0.0) & (denominator > 0.0)
ppmi = np.zeros_like(counts)
with np.errstate(divide="ignore", invalid="ignore"):
ratios = np.divide(
counts * total,
denominator,
out=np.ones_like(counts),
where=valid,
)
ppmi[valid] = np.maximum(np.log(ratios[valid]), 0.0)
return ppmi.tolist()
row_sums = [sum(row) for row in matrix]
total = sum(row_sums)
if total == 0.0:
return zeros(len(matrix), len(matrix))
ppmi = zeros(len(matrix), len(matrix))
for row in range(len(matrix)):
for col in range(len(matrix[row])):
count = matrix[row][col]
if count <= 0.0 or row_sums[row] == 0.0 or row_sums[col] == 0.0:
continue
p_ij = count / total
p_i = row_sums[row] / total
p_j = row_sums[col] / total
value = math.log(p_ij / (p_i * p_j))
ppmi[row][col] = max(0.0, value)
return ppmi
@dataclass(slots=True)
class EmbeddingModel:
token_to_id: dict[str, int]
id_to_token: list[str]
embeddings: Matrix
ppmi_matrix: Matrix
def vector(self, token: str) -> Vector:
index = self.token_to_id.get(token)
if index is None and token.lower() != token:
index = self.token_to_id.get(token.lower())
if index is None:
return [0.0 for _ in range(self.dimension)]
row = self.embeddings[index]
return row.astype(float).tolist() if hasattr(row, "tolist") else row[:]
@property
def dimension(self) -> int:
if hasattr(self.embeddings, "shape"):
return int(self.embeddings.shape[1]) if len(self.embeddings.shape) > 1 else 0
return len(self.embeddings[0]) if self.embeddings else 0
@property
def projection_axis(self) -> Vector:
if hasattr(self.embeddings, "shape"):
if int(self.embeddings.shape[0]) == 0:
return []
return self.embeddings.mean(axis=0).astype(float).tolist()
if not self.embeddings:
return []
return [
mean([row[column] for row in self.embeddings])
for column in range(self.dimension)
]
def fit_ppmi_embedding(
text: str,
*,
embedding_dim: int,
window_size: int,
min_frequency: int = 1,
max_vocab: int | None = None,
) -> EmbeddingModel:
tokens = tokenize(text)
if not tokens:
raise ValueError("Cannot fit REFRAMR embeddings on empty text.")
return fit_ppmi_embedding_from_tokens(
tokens,
embedding_dim=embedding_dim,
window_size=window_size,
min_frequency=min_frequency,
max_vocab=max_vocab,
)
def fit_ppmi_embedding_from_tokens(
tokens: list[str],
*,
embedding_dim: int,
window_size: int,
min_frequency: int = 1,
max_vocab: int | None = None,
) -> EmbeddingModel:
if not tokens:
raise ValueError("Cannot fit REFRAMR embeddings on an empty token stream.")
token_to_id, id_to_token = build_vocabulary(tokens, min_frequency, max_vocab)
cooccurrence = build_cooccurrence_matrix(tokens, token_to_id, window_size)
ppmi = positive_pointwise_mutual_information(cooccurrence)
eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim)
embeddings = zeros(len(id_to_token), embedding_dim)
for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
scale = math.sqrt(max(eigenvalue, 0.0))
for row in range(len(id_to_token)):
embeddings[row][component] = eigenvector[row] * scale
if np is not None:
embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
return EmbeddingModel(
token_to_id=token_to_id,
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=ppmi,
)
def fit_ppmi_embedding_from_cooccurrence(
id_to_token: list[str],
cooccurrence: Matrix,
*,
embedding_dim: int,
) -> EmbeddingModel:
if not id_to_token:
raise ValueError("Cannot fit REFRAMR embeddings without a vocabulary.")
ppmi = positive_pointwise_mutual_information(cooccurrence)
if scipy_sparse is not None and scipy_sparse.issparse(ppmi):
embedding_width = min(embedding_dim, len(id_to_token))
if len(id_to_token) >= SKETCHED_EMBEDDING_VOCAB_THRESHOLD or embedding_width >= 128:
embeddings = _sketched_sparse_ppmi_embedding(ppmi, embedding_dim)
return EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
embeddings = zeros(len(id_to_token), embedding_dim)
if embedding_width <= 0 or ppmi.nnz == 0:
return EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
if embedding_width < min(ppmi.shape) and scipy_svds is not None:
left, values, _ = scipy_svds(ppmi.asfptype(), k=embedding_width, which="LM")
order = np.argsort(values)[::-1]
for component, source_index in enumerate(order):
scale = math.sqrt(max(float(values[source_index]), 0.0))
column = left[:, source_index]
for row, value in enumerate(column):
embeddings[row][component] = float(value) * scale
else:
dense = ppmi.toarray().tolist()
eigenpairs = top_k_eigenpairs_symmetric(dense, embedding_width)
for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
scale = math.sqrt(max(eigenvalue, 0.0))
for row in range(len(id_to_token)):
embeddings[row][component] = eigenvector[row] * scale
if np is not None:
embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
return EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=[],
)
eigenpairs = top_k_eigenpairs_symmetric(ppmi, embedding_dim)
embeddings = zeros(len(id_to_token), embedding_dim)
for component, (eigenvalue, eigenvector) in enumerate(eigenpairs):
scale = math.sqrt(max(eigenvalue, 0.0))
for row in range(len(id_to_token)):
embeddings[row][component] = eigenvector[row] * scale
if np is not None:
embeddings = _remove_common_embedding_axis(np.asarray(embeddings, dtype=np.float64))
return EmbeddingModel(
token_to_id={token: index for index, token in enumerate(id_to_token)},
id_to_token=id_to_token,
embeddings=embeddings,
ppmi_matrix=ppmi,
)