starse / binary_static_embedding.py
BorisTM's picture
Clean StaRSE-512 repository state
8ae7fd7
Raw
History Blame Contribute Delete
9.37 kB
"""Sign-packed StaticEmbedding for sentence-transformers.
Compact storage: every embedding row is represented as a sign bitmask (one bit
per dimension, packed into uint8 bytes) plus a per-row L2 norm. At load time the
module reconstructs a float ``EmbeddingBag`` lookup table identical to what a
trained ``norm * sign(unit) / sqrt(dim)`` projection would produce, so inference
behaves like a regular :class:`StaticEmbedding`.
On disk the model is ~30x smaller than the fp32 form. To use it via
``SentenceTransformer``, pass ``trust_remote_code=True``::
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("BorisTM/starse-512", trust_remote_code=True)
embeddings = model.encode(["пример"])
"""
from __future__ import annotations
import math
import os
from pathlib import Path
from typing import Any
try:
from typing import Self
except ImportError:
from typing_extensions import Self
import numpy as np
import torch
from safetensors.torch import load_file as load_safetensors_file
from safetensors.torch import save_file as save_safetensors_file
from tokenizers import Tokenizer
from torch import nn
from transformers import PreTrainedTokenizerFast
from sentence_transformers.base.modules.input_module import InputModule
class BinaryStaticEmbedding(InputModule):
"""1-bit sign + per-row L2 norm StaticEmbedding."""
modalities: list[str] = ["text"]
config_keys: list[str] = ["embedding_dim", "vocab_size"]
config_file_name: str = "binary_static_embedding_config.json"
weights_file_name: str = "model.safetensors"
tokenizer_file_name: str = "tokenizer.json"
def __init__(
self,
tokenizer: Tokenizer | PreTrainedTokenizerFast,
embedding_dim: int,
vocab_size: int,
packed_signs: torch.Tensor | np.ndarray | None = None,
norms: torch.Tensor | np.ndarray | None = None,
embedding_weights: torch.Tensor | np.ndarray | None = None,
**kwargs,
) -> None:
super().__init__()
if isinstance(tokenizer, PreTrainedTokenizerFast):
tokenizer = tokenizer._tokenizer
elif not isinstance(tokenizer, Tokenizer):
raise ValueError("tokenizer must be a fast tokenizer (Tokenizer or PreTrainedTokenizerFast)")
self.tokenizer: Tokenizer = tokenizer
self.tokenizer.no_padding()
self.embedding_dim = int(embedding_dim)
self.vocab_size = int(vocab_size)
if embedding_weights is not None:
weight_tensor = _as_float_tensor(embedding_weights)
if weight_tensor.shape != (self.vocab_size, self.embedding_dim):
raise ValueError(
f"embedding_weights shape {tuple(weight_tensor.shape)} does not match "
f"(vocab_size={self.vocab_size}, embedding_dim={self.embedding_dim})"
)
elif packed_signs is not None and norms is not None:
weight_tensor = self._unpack_to_lookup(
packed_signs=_as_uint8_tensor(packed_signs),
norms=_as_float_tensor(norms),
embedding_dim=self.embedding_dim,
)
else:
weight_tensor = torch.zeros((self.vocab_size, self.embedding_dim), dtype=torch.float32)
self.embedding = nn.EmbeddingBag.from_pretrained(weight_tensor, freeze=True)
self.num_embeddings = self.embedding.num_embeddings
# For the model card
self.base_model = kwargs.get("base_model", None)
# ------------------------------------------------------------------ utils
@staticmethod
def _unpack_to_lookup(packed_signs: torch.Tensor, norms: torch.Tensor, embedding_dim: int) -> torch.Tensor:
"""Reconstruct a float ``[vocab, dim]`` lookup from packed sign bits and per-row norms."""
if packed_signs.dtype != torch.uint8:
raise TypeError(f"packed_signs must be uint8, got {packed_signs.dtype}")
expected_packed_dim = (embedding_dim + 7) // 8
if packed_signs.dim() != 2 or packed_signs.shape[1] != expected_packed_dim:
raise ValueError(
f"packed_signs shape {tuple(packed_signs.shape)} does not match (vocab, ceil(dim/8)={expected_packed_dim})"
)
bits = np.unpackbits(packed_signs.cpu().numpy(), axis=1, bitorder="big")[:, :embedding_dim]
signs = bits.astype(np.float32) * 2.0 - 1.0 # 0 -> -1, 1 -> +1
scale = norms.detach().to(torch.float32).cpu().unsqueeze(1) / math.sqrt(embedding_dim)
return (torch.from_numpy(signs) * scale).contiguous()
@staticmethod
def _pack_from_weight(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Inverse of ``_unpack_to_lookup``: extract packed sign bits and per-row norms."""
weight = weight.detach().float().cpu()
norms = torch.linalg.vector_norm(weight, dim=1).clamp_min(1e-12)
signs = (weight >= 0).to(torch.uint8).numpy()
packed = np.packbits(signs, axis=1, bitorder="big")
return torch.from_numpy(packed), norms
# ------------------------------------------------------------------ forward
def preprocess(self, inputs: list[str], prompt: str | None = None, **kwargs) -> dict[str, torch.Tensor]:
if prompt:
inputs = self._prepend_prompt(inputs, prompt)
encodings = self.tokenizer.encode_batch(inputs, add_special_tokens=False)
encodings_ids = [encoding.ids for encoding in encodings]
offsets = torch.from_numpy(
np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]])
)
input_ids = torch.tensor(
[token_id for token_ids in encodings_ids for token_id in token_ids],
dtype=torch.long,
)
return {"input_ids": input_ids, "offsets": offsets}
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"])
return features
@property
def max_seq_length(self) -> int:
return math.inf
def get_embedding_dimension(self) -> int:
return self.embedding_dim
# ------------------------------------------------------------------ persistence
def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None:
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
packed_signs, norms = self._pack_from_weight(self.embedding.weight)
save_safetensors_file(
{"packed_signs": packed_signs, "norms": norms},
str(output_path / self.weights_file_name),
)
self.save_config(str(output_path))
self.tokenizer.save(str(output_path / self.tokenizer_file_name))
def save_config(self, output_path: str) -> None:
import json
payload = {
"embedding_dim": self.embedding_dim,
"vocab_size": self.vocab_size,
"packed_bit_order": "big",
"scale": "norm / sqrt(embedding_dim)",
}
with open(Path(output_path) / self.config_file_name, "w", encoding="utf-8") as handle:
json.dump(payload, handle, ensure_ascii=False, indent=2)
@classmethod
def load(
cls,
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
) -> Self:
hub_kwargs = {
"subfolder": subfolder,
"token": token,
"cache_folder": cache_folder,
"revision": revision,
"local_files_only": local_files_only,
}
config_path = cls.load_file_path(model_name_or_path, filename=cls.config_file_name, **hub_kwargs)
if config_path is None:
raise FileNotFoundError(f"{cls.config_file_name} not found at {model_name_or_path}")
import json
with open(config_path, "r", encoding="utf-8") as handle:
config = json.load(handle)
tokenizer_path = cls.load_file_path(model_name_or_path, filename=cls.tokenizer_file_name, **hub_kwargs)
tokenizer = Tokenizer.from_file(tokenizer_path)
weights_path = cls.load_file_path(model_name_or_path, filename=cls.weights_file_name, **hub_kwargs)
if weights_path is None:
raise FileNotFoundError(f"{cls.weights_file_name} not found at {model_name_or_path}")
state = load_safetensors_file(weights_path)
packed_signs = state["packed_signs"]
norms = state["norms"]
return cls(
tokenizer=tokenizer,
embedding_dim=int(config["embedding_dim"]),
vocab_size=int(config["vocab_size"]),
packed_signs=packed_signs,
norms=norms,
)
def _as_float_tensor(value: torch.Tensor | np.ndarray) -> torch.Tensor:
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
return value.detach().to(torch.float32)
def _as_uint8_tensor(value: torch.Tensor | np.ndarray) -> torch.Tensor:
if isinstance(value, np.ndarray):
value = torch.from_numpy(value)
if value.dtype != torch.uint8:
value = value.to(torch.uint8)
return value