"""Sign-packed StaticEmbedding for sentence-transformers. Compact storage: every embedding row is represented as a sign bitmask (one bit per dimension, packed into uint8 bytes) plus a per-row L2 norm. At load time the module reconstructs a float ``EmbeddingBag`` lookup table identical to what a trained ``norm * sign(unit) / sqrt(dim)`` projection would produce, so inference behaves like a regular :class:`StaticEmbedding`. On disk the model is ~30x smaller than the fp32 form. To use it via ``SentenceTransformer``, pass ``trust_remote_code=True``:: from sentence_transformers import SentenceTransformer model = SentenceTransformer("BorisTM/starse-512", trust_remote_code=True) embeddings = model.encode(["пример"]) """ from __future__ import annotations import math import os from pathlib import Path from typing import Any try: from typing import Self except ImportError: from typing_extensions import Self import numpy as np import torch from safetensors.torch import load_file as load_safetensors_file from safetensors.torch import save_file as save_safetensors_file from tokenizers import Tokenizer from torch import nn from transformers import PreTrainedTokenizerFast from sentence_transformers.base.modules.input_module import InputModule class BinaryStaticEmbedding(InputModule): """1-bit sign + per-row L2 norm StaticEmbedding.""" modalities: list[str] = ["text"] config_keys: list[str] = ["embedding_dim", "vocab_size"] config_file_name: str = "binary_static_embedding_config.json" weights_file_name: str = "model.safetensors" tokenizer_file_name: str = "tokenizer.json" def __init__( self, tokenizer: Tokenizer | PreTrainedTokenizerFast, embedding_dim: int, vocab_size: int, packed_signs: torch.Tensor | np.ndarray | None = None, norms: torch.Tensor | np.ndarray | None = None, embedding_weights: torch.Tensor | np.ndarray | None = None, **kwargs, ) -> None: super().__init__() if isinstance(tokenizer, PreTrainedTokenizerFast): tokenizer = tokenizer._tokenizer elif not isinstance(tokenizer, Tokenizer): raise ValueError("tokenizer must be a fast tokenizer (Tokenizer or PreTrainedTokenizerFast)") self.tokenizer: Tokenizer = tokenizer self.tokenizer.no_padding() self.embedding_dim = int(embedding_dim) self.vocab_size = int(vocab_size) if embedding_weights is not None: weight_tensor = _as_float_tensor(embedding_weights) if weight_tensor.shape != (self.vocab_size, self.embedding_dim): raise ValueError( f"embedding_weights shape {tuple(weight_tensor.shape)} does not match " f"(vocab_size={self.vocab_size}, embedding_dim={self.embedding_dim})" ) elif packed_signs is not None and norms is not None: weight_tensor = self._unpack_to_lookup( packed_signs=_as_uint8_tensor(packed_signs), norms=_as_float_tensor(norms), embedding_dim=self.embedding_dim, ) else: weight_tensor = torch.zeros((self.vocab_size, self.embedding_dim), dtype=torch.float32) self.embedding = nn.EmbeddingBag.from_pretrained(weight_tensor, freeze=True) self.num_embeddings = self.embedding.num_embeddings # For the model card self.base_model = kwargs.get("base_model", None) # ------------------------------------------------------------------ utils @staticmethod def _unpack_to_lookup(packed_signs: torch.Tensor, norms: torch.Tensor, embedding_dim: int) -> torch.Tensor: """Reconstruct a float ``[vocab, dim]`` lookup from packed sign bits and per-row norms.""" if packed_signs.dtype != torch.uint8: raise TypeError(f"packed_signs must be uint8, got {packed_signs.dtype}") expected_packed_dim = (embedding_dim + 7) // 8 if packed_signs.dim() != 2 or packed_signs.shape[1] != expected_packed_dim: raise ValueError( f"packed_signs shape {tuple(packed_signs.shape)} does not match (vocab, ceil(dim/8)={expected_packed_dim})" ) bits = np.unpackbits(packed_signs.cpu().numpy(), axis=1, bitorder="big")[:, :embedding_dim] signs = bits.astype(np.float32) * 2.0 - 1.0 # 0 -> -1, 1 -> +1 scale = norms.detach().to(torch.float32).cpu().unsqueeze(1) / math.sqrt(embedding_dim) return (torch.from_numpy(signs) * scale).contiguous() @staticmethod def _pack_from_weight(weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Inverse of ``_unpack_to_lookup``: extract packed sign bits and per-row norms.""" weight = weight.detach().float().cpu() norms = torch.linalg.vector_norm(weight, dim=1).clamp_min(1e-12) signs = (weight >= 0).to(torch.uint8).numpy() packed = np.packbits(signs, axis=1, bitorder="big") return torch.from_numpy(packed), norms # ------------------------------------------------------------------ forward def preprocess(self, inputs: list[str], prompt: str | None = None, **kwargs) -> dict[str, torch.Tensor]: if prompt: inputs = self._prepend_prompt(inputs, prompt) encodings = self.tokenizer.encode_batch(inputs, add_special_tokens=False) encodings_ids = [encoding.ids for encoding in encodings] offsets = torch.from_numpy( np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]) ) input_ids = torch.tensor( [token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long, ) return {"input_ids": input_ids, "offsets": offsets} def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]: features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"]) return features @property def max_seq_length(self) -> int: return math.inf def get_embedding_dimension(self) -> int: return self.embedding_dim # ------------------------------------------------------------------ persistence def save(self, output_path: str, *args, safe_serialization: bool = True, **kwargs) -> None: output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) packed_signs, norms = self._pack_from_weight(self.embedding.weight) save_safetensors_file( {"packed_signs": packed_signs, "norms": norms}, str(output_path / self.weights_file_name), ) self.save_config(str(output_path)) self.tokenizer.save(str(output_path / self.tokenizer_file_name)) def save_config(self, output_path: str) -> None: import json payload = { "embedding_dim": self.embedding_dim, "vocab_size": self.vocab_size, "packed_bit_order": "big", "scale": "norm / sqrt(embedding_dim)", } with open(Path(output_path) / self.config_file_name, "w", encoding="utf-8") as handle: json.dump(payload, handle, ensure_ascii=False, indent=2) @classmethod def load( cls, model_name_or_path: str, subfolder: str = "", token: bool | str | None = None, cache_folder: str | None = None, revision: str | None = None, local_files_only: bool = False, **kwargs, ) -> Self: hub_kwargs = { "subfolder": subfolder, "token": token, "cache_folder": cache_folder, "revision": revision, "local_files_only": local_files_only, } config_path = cls.load_file_path(model_name_or_path, filename=cls.config_file_name, **hub_kwargs) if config_path is None: raise FileNotFoundError(f"{cls.config_file_name} not found at {model_name_or_path}") import json with open(config_path, "r", encoding="utf-8") as handle: config = json.load(handle) tokenizer_path = cls.load_file_path(model_name_or_path, filename=cls.tokenizer_file_name, **hub_kwargs) tokenizer = Tokenizer.from_file(tokenizer_path) weights_path = cls.load_file_path(model_name_or_path, filename=cls.weights_file_name, **hub_kwargs) if weights_path is None: raise FileNotFoundError(f"{cls.weights_file_name} not found at {model_name_or_path}") state = load_safetensors_file(weights_path) packed_signs = state["packed_signs"] norms = state["norms"] return cls( tokenizer=tokenizer, embedding_dim=int(config["embedding_dim"]), vocab_size=int(config["vocab_size"]), packed_signs=packed_signs, norms=norms, ) def _as_float_tensor(value: torch.Tensor | np.ndarray) -> torch.Tensor: if isinstance(value, np.ndarray): value = torch.from_numpy(value) return value.detach().to(torch.float32) def _as_uint8_tensor(value: torch.Tensor | np.ndarray) -> torch.Tensor: if isinstance(value, np.ndarray): value = torch.from_numpy(value) if value.dtype != torch.uint8: value = value.to(torch.uint8) return value