File size: 4,354 Bytes
1820cb3 169a600 1820cb3 e12840f 1820cb3 e12840f 1820cb3 e12840f 1820cb3 e12840f 1820cb3 e12840f 1820cb3 e12840f 1820cb3 169a600 e12840f 1820cb3 e12840f 169a600 1820cb3 e12840f 169a600 e12840f 169a600 e12840f 169a600 e12840f 1820cb3 e12840f 1820cb3 169a600 e12840f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import numpy as np
from typing import Literal
from sentence_transformers.models import Module
class Quantizer(torch.nn.Module):
def __init__(self, hard: bool = True):
"""
Args:
hard: Whether to use hard or soft quantization. Defaults to True.
"""
super().__init__()
self._hard = hard
def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
def forward(self, x, *args, **kwargs) -> torch.Tensor:
soft = self._soft_quantize(x, *args, **kwargs)
if not self._hard:
result = soft
else:
result = (
self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
)
return result
class Int8TanhQuantizer(Quantizer):
def __init__(
self,
hard: bool = True,
):
super().__init__(hard=hard)
self.qmin = -128
self.qmax = 127
def _soft_quantize(self, x, *args, **kwargs):
return torch.tanh(x)
def _hard_quantize(self, x, *args, **kwargs):
soft = self._soft_quantize(x)
int_x = torch.round(soft * self.qmax)
int_x = torch.clamp(int_x, self.qmin, self.qmax)
return int_x
class BinaryTanhQuantizer(Quantizer):
def __init__(
self,
hard: bool = True,
scale: float = 1.0,
):
super().__init__(hard)
self._scale = scale
def _soft_quantize(self, x, *args, **kwargs):
return torch.tanh(self._scale * x)
def _hard_quantize(self, x, *args, **kwargs):
return torch.where(x >= 0, 1.0, -1.0)
class PackedBinaryQuantizer:
"""
Packs binary embeddings into uint8 format for efficient storage.
This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive
bits into a single uint8 byte using numpy.packbits. This reduces memory
usage by 8x compared to float32 and by 4x compared to int8.
IMPORTANT: This is an inference-only quantizer - it is not differentiable
and should only be used for encoding/inference, not during training.
Args:
x: Input tensor of any float dtype, shape (..., embedding_dim)
Returns:
Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8)
Example:
>>> quantizer = PackedBinaryQuantizer()
>>> embeddings = torch.randn(2, 1024) # float32
>>> packed = quantizer(embeddings) # uint8, shape (2, 128)
"""
def __call__(self, x: torch.Tensor) -> torch.Tensor:
bits = np.where(x.cpu().numpy() >= 0, True, False)
packed = np.packbits(bits, axis=-1)
return torch.from_numpy(packed).to(x.device)
class FlexibleQuantizer(Module):
def __init__(self):
super().__init__()
self._int8_quantizer = Int8TanhQuantizer()
self._binary_quantizer = BinaryTanhQuantizer()
self._packed_binary_quantizer = PackedBinaryQuantizer()
def forward(
self,
features: dict[str, torch.Tensor],
quantization: Literal["int8", "binary", "ubinary"] = "int8",
**kwargs,
) -> dict[str, torch.Tensor]:
if quantization == "int8":
features["sentence_embedding"] = self._int8_quantizer(
features["sentence_embedding"]
)
elif quantization == "binary":
features["sentence_embedding"] = self._binary_quantizer(
features["sentence_embedding"]
)
elif quantization == "ubinary":
features["sentence_embedding"] = self._packed_binary_quantizer(
features["sentence_embedding"]
)
else:
raise ValueError(
f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
)
return features
@classmethod
def load(
cls,
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
):
return cls()
def save(self, output_path: str, *args, **kwargs) -> None:
return
|