Update st_quantize.py
Browse files- st_quantize.py +28 -0
st_quantize.py
CHANGED
|
@@ -90,4 +90,32 @@ class NormalizedInt8TanhQuantizer(Int8TanhQuantizer):
|
|
| 90 |
@classmethod
|
| 91 |
def load(cls, input_path: str) -> "PoolAndQuantize":
|
| 92 |
return cls()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
|
|
|
| 90 |
@classmethod
|
| 91 |
def load(cls, input_path: str) -> "PoolAndQuantize":
|
| 92 |
return cls()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Binarizer(Quantizer):
|
| 96 |
+
def __init__(self, tanh_scale: float = 1.0, **kwargs):
|
| 97 |
+
super().__init__(**kwargs)
|
| 98 |
+
self._tanh_scale = tanh_scale
|
| 99 |
+
|
| 100 |
+
def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
|
| 101 |
+
return torch.where(x > 0, 1.0, -1.0)
|
| 102 |
+
|
| 103 |
+
def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
|
| 104 |
+
return torch.tanh(x * self._tanh_scale)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class UnnormalizedBinarizer(nn.Module):
|
| 108 |
+
def __init__(self, tanh_scale: float = 1.0, hard: bool = True):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.quantizer = Binarizer(tanh_scale=tanh_scale, hard=hard)
|
| 111 |
+
|
| 112 |
+
def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 113 |
+
features["sentence_embedding"] = self.quantizer(
|
| 114 |
+
features["sentence_embedding"]
|
| 115 |
+
)
|
| 116 |
+
return features
|
| 117 |
+
|
| 118 |
+
@classmethod
|
| 119 |
+
def load(cls, input_path: str) -> "UnnormalizedBinarizer":
|
| 120 |
+
return cls()
|
| 121 |
|