pplx-embed-context-v1-0.6b / st_quantize.py
seslami-pplx's picture
support ubinary (#5)
d1bcaee
import torch
import numpy as np
from typing import Literal
from sentence_transformers.models import Module
class Quantizer(torch.nn.Module):
def __init__(self, hard: bool = True):
"""
Args:
hard: Whether to use hard or soft quantization. Defaults to True.
"""
super().__init__()
self._hard = hard
def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
def forward(self, x, *args, **kwargs) -> torch.Tensor:
soft = self._soft_quantize(x, *args, **kwargs)
if not self._hard:
result = soft
else:
result = (
self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
)
return result
class Int8TanhQuantizer(Quantizer):
def __init__(
self,
hard: bool = True,
):
super().__init__(hard=hard)
self.qmin = -128
self.qmax = 127
def _soft_quantize(self, x, *args, **kwargs):
return torch.tanh(x)
def _hard_quantize(self, x, *args, **kwargs):
soft = self._soft_quantize(x)
int_x = torch.round(soft * self.qmax)
int_x = torch.clamp(int_x, self.qmin, self.qmax)
return int_x
class BinaryTanhQuantizer(Quantizer):
def __init__(
self,
hard: bool = True,
scale: float = 1.0,
):
super().__init__(hard)
self._scale = scale
def _soft_quantize(self, x, *args, **kwargs):
return torch.tanh(self._scale * x)
def _hard_quantize(self, x, *args, **kwargs):
return torch.where(x >= 0, 1.0, -1.0)
class PackedBinaryQuantizer:
"""
Packs binary embeddings into uint8 format for efficient storage.
This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive
bits into a single uint8 byte using numpy.packbits. This reduces memory
usage by 8x compared to float32 and by 4x compared to int8.
IMPORTANT: This is an inference-only quantizer - it is not differentiable
and should only be used for encoding/inference, not during training.
Args:
x: Input tensor of any float dtype, shape (..., embedding_dim)
Returns:
Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8)
Example:
>>> quantizer = PackedBinaryQuantizer()
>>> embeddings = torch.randn(2, 1024) # float32
>>> packed = quantizer(embeddings) # uint8, shape (2, 128)
"""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
bits = np.where(x.cpu().numpy() >= 0, True, False)
packed = np.packbits(bits, axis=-1)
return torch.from_numpy(packed).to(x.device)
class FlexibleQuantizer(Module):
def __init__(self):
super().__init__()
self._int8_quantizer = Int8TanhQuantizer()
self._binary_quantizer = BinaryTanhQuantizer()
self._packed_binary_quantizer = PackedBinaryQuantizer()
def forward(
self,
features: dict[str, torch.Tensor],
quantization: Literal["int8", "binary", "ubinary"] = "int8",
**kwargs,
) -> dict[str, torch.Tensor]:
if quantization == "int8":
features["sentence_embedding"] = self._int8_quantizer(
features["sentence_embedding"]
)
elif quantization == "binary":
features["sentence_embedding"] = self._binary_quantizer(
features["sentence_embedding"]
)
elif quantization == "ubinary":
features["sentence_embedding"] = self._packed_binary_quantizer(
features["sentence_embedding"]
)
else:
raise ValueError(
f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
)
return features
@classmethod
def load(
cls,
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
):
return cls()
def save(self, output_path: str, *args, **kwargs) -> None:
return