bowang0911 commited on
Commit
30aec64
·
verified ·
1 Parent(s): c8bdc84

Update st_quantize.py

Browse files
Files changed (1) hide show
  1. st_quantize.py +28 -0
st_quantize.py CHANGED
@@ -90,4 +90,32 @@ class NormalizedInt8TanhQuantizer(Int8TanhQuantizer):
90
  @classmethod
91
  def load(cls, input_path: str) -> "PoolAndQuantize":
92
  return cls()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
 
90
  @classmethod
91
  def load(cls, input_path: str) -> "PoolAndQuantize":
92
  return cls()
93
+
94
+
95
+ class Binarizer(Quantizer):
96
+ def __init__(self, tanh_scale: float = 1.0, **kwargs):
97
+ super().__init__(**kwargs)
98
+ self._tanh_scale = tanh_scale
99
+
100
+ def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
101
+ return torch.where(x > 0, 1.0, -1.0)
102
+
103
+ def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
104
+ return torch.tanh(x * self._tanh_scale)
105
+
106
+
107
+ class UnnormalizedBinarizer(nn.Module):
108
+ def __init__(self, tanh_scale: float = 1.0, hard: bool = True):
109
+ super().__init__()
110
+ self.quantizer = Binarizer(tanh_scale=tanh_scale, hard=hard)
111
+
112
+ def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
113
+ features["sentence_embedding"] = self.quantizer(
114
+ features["sentence_embedding"]
115
+ )
116
+ return features
117
+
118
+ @classmethod
119
+ def load(cls, input_path: str) -> "UnnormalizedBinarizer":
120
+ return cls()
121