Commit
·
2958ec7
1
Parent(s):
ab9dcdc
feat: add quantization
Browse files- st_quantize.py +14 -1
st_quantize.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import torch
|
|
|
|
| 2 |
from typing import Literal
|
| 3 |
from sentence_transformers.models import Module
|
| 4 |
|
|
@@ -64,6 +65,13 @@ class BinaryTanhQuantizer(Quantizer):
|
|
| 64 |
|
| 65 |
def _hard_quantize(self, x, *args, **kwargs):
|
| 66 |
return torch.where(x >= 0, 1.0, -1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
class FlexibleQuantizer(Module):
|
|
@@ -71,11 +79,12 @@ class FlexibleQuantizer(Module):
|
|
| 71 |
super().__init__()
|
| 72 |
self._int8_quantizer = Int8TanhQuantizer()
|
| 73 |
self._binary_quantizer = BinaryTanhQuantizer()
|
|
|
|
| 74 |
|
| 75 |
def forward(
|
| 76 |
self,
|
| 77 |
features: dict[str, torch.Tensor],
|
| 78 |
-
quantization: Literal["binary", "
|
| 79 |
**kwargs
|
| 80 |
) -> dict[str, torch.Tensor]:
|
| 81 |
if quantization == "int8":
|
|
@@ -86,6 +95,10 @@ class FlexibleQuantizer(Module):
|
|
| 86 |
features["sentence_embedding"] = self._binary_quantizer(
|
| 87 |
features["sentence_embedding"]
|
| 88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
else:
|
| 90 |
raise ValueError(
|
| 91 |
f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
from typing import Literal
|
| 4 |
from sentence_transformers.models import Module
|
| 5 |
|
|
|
|
| 65 |
|
| 66 |
def _hard_quantize(self, x, *args, **kwargs):
|
| 67 |
return torch.where(x >= 0, 1.0, -1.0)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class PackedBinaryQuantizer:
|
| 71 |
+
def __call__(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
bits = np.where(x.cpu().numpy() >= 0, True, False)
|
| 73 |
+
packed = np.packbits(bits, axis=-1)
|
| 74 |
+
return torch.from_numpy(packed).to(x.device)
|
| 75 |
|
| 76 |
|
| 77 |
class FlexibleQuantizer(Module):
|
|
|
|
| 79 |
super().__init__()
|
| 80 |
self._int8_quantizer = Int8TanhQuantizer()
|
| 81 |
self._binary_quantizer = BinaryTanhQuantizer()
|
| 82 |
+
self._packed_binary_quantizer = PackedBinaryQuantizer()
|
| 83 |
|
| 84 |
def forward(
|
| 85 |
self,
|
| 86 |
features: dict[str, torch.Tensor],
|
| 87 |
+
quantization: Literal["int8", "binary", "ubinary"] = "int8",
|
| 88 |
**kwargs
|
| 89 |
) -> dict[str, torch.Tensor]:
|
| 90 |
if quantization == "int8":
|
|
|
|
| 95 |
features["sentence_embedding"] = self._binary_quantizer(
|
| 96 |
features["sentence_embedding"]
|
| 97 |
)
|
| 98 |
+
elif quantization == "ubinary":
|
| 99 |
+
features["sentence_embedding"] = self._packed_binary_quantizer(
|
| 100 |
+
features["sentence_embedding"]
|
| 101 |
+
)
|
| 102 |
else:
|
| 103 |
raise ValueError(
|
| 104 |
f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
|