mkrimmel-pplx commited on
Commit
2958ec7
·
1 Parent(s): ab9dcdc

feat: add quantization

Browse files
Files changed (1) hide show
  1. st_quantize.py +14 -1
st_quantize.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from typing import Literal
3
  from sentence_transformers.models import Module
4
 
@@ -64,6 +65,13 @@ class BinaryTanhQuantizer(Quantizer):
64
 
65
  def _hard_quantize(self, x, *args, **kwargs):
66
  return torch.where(x >= 0, 1.0, -1.0)
 
 
 
 
 
 
 
67
 
68
 
69
  class FlexibleQuantizer(Module):
@@ -71,11 +79,12 @@ class FlexibleQuantizer(Module):
71
  super().__init__()
72
  self._int8_quantizer = Int8TanhQuantizer()
73
  self._binary_quantizer = BinaryTanhQuantizer()
 
74
 
75
  def forward(
76
  self,
77
  features: dict[str, torch.Tensor],
78
- quantization: Literal["binary", "int8"] = "int8",
79
  **kwargs
80
  ) -> dict[str, torch.Tensor]:
81
  if quantization == "int8":
@@ -86,6 +95,10 @@ class FlexibleQuantizer(Module):
86
  features["sentence_embedding"] = self._binary_quantizer(
87
  features["sentence_embedding"]
88
  )
 
 
 
 
89
  else:
90
  raise ValueError(
91
  f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
 
1
  import torch
2
+ import numpy as np
3
  from typing import Literal
4
  from sentence_transformers.models import Module
5
 
 
65
 
66
  def _hard_quantize(self, x, *args, **kwargs):
67
  return torch.where(x >= 0, 1.0, -1.0)
68
+
69
+
70
+ class PackedBinaryQuantizer:
71
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
72
+ bits = np.where(x.cpu().numpy() >= 0, True, False)
73
+ packed = np.packbits(bits, axis=-1)
74
+ return torch.from_numpy(packed).to(x.device)
75
 
76
 
77
  class FlexibleQuantizer(Module):
 
79
  super().__init__()
80
  self._int8_quantizer = Int8TanhQuantizer()
81
  self._binary_quantizer = BinaryTanhQuantizer()
82
+ self._packed_binary_quantizer = PackedBinaryQuantizer()
83
 
84
  def forward(
85
  self,
86
  features: dict[str, torch.Tensor],
87
+ quantization: Literal["int8", "binary", "ubinary"] = "int8",
88
  **kwargs
89
  ) -> dict[str, torch.Tensor]:
90
  if quantization == "int8":
 
95
  features["sentence_embedding"] = self._binary_quantizer(
96
  features["sentence_embedding"]
97
  )
98
+ elif quantization == "ubinary":
99
+ features["sentence_embedding"] = self._packed_binary_quantizer(
100
+ features["sentence_embedding"]
101
+ )
102
  else:
103
  raise ValueError(
104
  f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."