File size: 3,547 Bytes
4251f22 ffd62d4 4251f22 39de4f2 4251f22 39de4f2 4251f22 39de4f2 4251f22 39de4f2 4251f22 39de4f2 4251f22 39de4f2 ffd62d4 4251f22 39de4f2 4251f22 39de4f2 ffd62d4 4251f22 39de4f2 ffd62d4 39de4f2 ffd62d4 39de4f2 ffd62d4 39de4f2 4251f22 39de4f2 4251f22 39de4f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import numpy as np
from typing import Literal
from sentence_transformers.models import Module
class Quantizer(torch.nn.Module):
def __init__(self, hard: bool = True):
"""
Args:
hard: Whether to use hard or soft quantization. Defaults to True.
"""
super().__init__()
self._hard = hard
def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError
def forward(self, x, *args, **kwargs) -> torch.Tensor:
soft = self._soft_quantize(x, *args, **kwargs)
if not self._hard:
result = soft
else:
result = (
self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
)
return result
class Int8TanhQuantizer(Quantizer):
def __init__(
self,
hard: bool = True,
):
super().__init__(hard=hard)
self.qmin = -128
self.qmax = 127
def _soft_quantize(self, x, *args, **kwargs):
return torch.tanh(x)
def _hard_quantize(self, x, *args, **kwargs):
soft = self._soft_quantize(x)
int_x = torch.round(soft * self.qmax)
int_x = torch.clamp(int_x, self.qmin, self.qmax)
return int_x
class BinaryTanhQuantizer(Quantizer):
def __init__(
self,
hard: bool = True,
scale: float = 1.0,
):
super().__init__(hard)
self._scale = scale
def _soft_quantize(self, x, *args, **kwargs):
return torch.tanh(self._scale * x)
def _hard_quantize(self, x, *args, **kwargs):
return torch.where(x >= 0, 1.0, -1.0)
class PackedBinaryQuantizer:
def __call__(self, x: torch.Tensor) -> torch.Tensor:
bits = np.where(x.cpu().numpy() >= 0, True, False)
packed = np.packbits(bits, axis=-1)
return torch.from_numpy(packed).to(x.device)
class FlexibleQuantizer(Module):
def __init__(self):
super().__init__()
self._int8_quantizer = Int8TanhQuantizer()
self._binary_quantizer = BinaryTanhQuantizer()
self._packed_binary_quantizer = PackedBinaryQuantizer()
def forward(
self,
features: dict[str, torch.Tensor],
quantization: Literal["int8", "binary", "ubinary"] = "int8",
**kwargs
) -> dict[str, torch.Tensor]:
if quantization == "int8":
features["sentence_embedding"] = self._int8_quantizer(
features["sentence_embedding"]
)
elif quantization == "binary":
features["sentence_embedding"] = self._binary_quantizer(
features["sentence_embedding"]
)
elif quantization == "ubinary":
features["sentence_embedding"] = self._packed_binary_quantizer(
features["sentence_embedding"]
)
else:
raise ValueError(
f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
)
return features
@classmethod
def load(
cls,
model_name_or_path: str,
subfolder: str = "",
token: bool | str | None = None,
cache_folder: str | None = None,
revision: str | None = None,
local_files_only: bool = False,
**kwargs,
):
return cls()
def save(self, output_path: str, *args, **kwargs) -> None:
return
|