import torch import numpy as np from typing import Literal from sentence_transformers.models import Module class Quantizer(torch.nn.Module): def __init__(self, hard: bool = True): """ Args: hard: Whether to use hard or soft quantization. Defaults to True. """ super().__init__() self._hard = hard def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor: raise NotImplementedError def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor: raise NotImplementedError def forward(self, x, *args, **kwargs) -> torch.Tensor: soft = self._soft_quantize(x, *args, **kwargs) if not self._hard: result = soft else: result = ( self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach() ) return result class Int8TanhQuantizer(Quantizer): def __init__( self, hard: bool = True, ): super().__init__(hard=hard) self.qmin = -128 self.qmax = 127 def _soft_quantize(self, x, *args, **kwargs): return torch.tanh(x) def _hard_quantize(self, x, *args, **kwargs): soft = self._soft_quantize(x) int_x = torch.round(soft * self.qmax) int_x = torch.clamp(int_x, self.qmin, self.qmax) return int_x class BinaryTanhQuantizer(Quantizer): def __init__( self, hard: bool = True, scale: float = 1.0, ): super().__init__(hard) self._scale = scale def _soft_quantize(self, x, *args, **kwargs): return torch.tanh(self._scale * x) def _hard_quantize(self, x, *args, **kwargs): return torch.where(x >= 0, 1.0, -1.0) class PackedBinaryQuantizer: """ Packs binary embeddings into uint8 format for efficient storage. This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive bits into a single uint8 byte using numpy.packbits. This reduces memory usage by 8x compared to float32 and by 4x compared to int8. IMPORTANT: This is an inference-only quantizer - it is not differentiable and should only be used for encoding/inference, not during training. Args: x: Input tensor of any float dtype, shape (..., embedding_dim) Returns: Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8) Example: >>> quantizer = PackedBinaryQuantizer() >>> embeddings = torch.randn(2, 1024) # float32 >>> packed = quantizer(embeddings) # uint8, shape (2, 128) """ def __call__(self, x: torch.Tensor) -> torch.Tensor: bits = np.where(x.cpu().numpy() >= 0, True, False) packed = np.packbits(bits, axis=-1) return torch.from_numpy(packed).to(x.device) class FlexibleQuantizer(Module): def __init__(self): super().__init__() self._int8_quantizer = Int8TanhQuantizer() self._binary_quantizer = BinaryTanhQuantizer() self._packed_binary_quantizer = PackedBinaryQuantizer() def forward( self, features: dict[str, torch.Tensor], quantization: Literal["int8", "binary", "ubinary"] = "int8", **kwargs, ) -> dict[str, torch.Tensor]: if quantization == "int8": features["sentence_embedding"] = self._int8_quantizer( features["sentence_embedding"] ) elif quantization == "binary": features["sentence_embedding"] = self._binary_quantizer( features["sentence_embedding"] ) elif quantization == "ubinary": features["sentence_embedding"] = self._packed_binary_quantizer( features["sentence_embedding"] ) else: raise ValueError( f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'." ) return features @classmethod def load( cls, model_name_or_path: str, subfolder: str = "", token: bool | str | None = None, cache_folder: str | None = None, revision: str | None = None, local_files_only: bool = False, **kwargs, ): return cls() def save(self, output_path: str, *args, **kwargs) -> None: return