File size: 3,015 Bytes
4251f22
 
12fc1ef
4251f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d01c68
4251f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff0893c
4251f22
 
ff0893c
 
8d01c68
ff0893c
 
 
 
 
8d01c68
ff0893c
 
4251f22
ff0893c
 
4251f22
 
12fc1ef
ff0893c
4251f22
ff0893c
 
4251f22
8d01c68
 
 
 
12fc1ef
8d01c68
ff0893c
 
 
 
 
 
 
 
 
8d01c68
 
 
4251f22
8d01c68
4251f22
12fc1ef
 
 
 
 
 
 
 
 
 
4251f22
12fc1ef
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from typing import Literal
from sentence_transformers.models import Module


class Quantizer(torch.nn.Module):
    def __init__(self, hard: bool = True):
        """
        Args:
            hard: Whether to use hard or soft quantization. Defaults to True.
        """
        super().__init__()
        self._hard = hard

    def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, x, *args, **kwargs) -> torch.Tensor:
        soft = self._soft_quantize(x, *args, **kwargs)

        if not self._hard:
            result = soft
        else:
            result = (
                self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
            )

        return result


class Int8TanhQuantizer(Quantizer):
    def __init__(
        self,
        hard: bool = True,
    ):
        super().__init__(hard=hard)
        self.qmin = -128
        self.qmax = 127

    def _soft_quantize(self, x, *args, **kwargs):
        return torch.tanh(x)

    def _hard_quantize(self, x, *args, **kwargs):
        soft = self._soft_quantize(x)
        int_x = torch.round(soft * self.qmax)
        int_x = torch.clamp(int_x, self.qmin, self.qmax)
        return int_x


class BinaryTanhQuantizer(Quantizer):
    def __init__(
        self,
        hard: bool = True,
        scale: float = 1.0,
    ):
        super().__init__(hard)
        self._scale = scale

    def _soft_quantize(self, x, *args, **kwargs):
        return torch.tanh(self._scale * x)

    def _hard_quantize(self, x, *args, **kwargs):
        return torch.where(x >= 0, 1.0, -1.0)


class FlexibleQuantizer(Module):
    def __init__(self):
        super().__init__()
        self._int8_quantizer = Int8TanhQuantizer()
        self._binary_quantizer = BinaryTanhQuantizer()

    def forward(
        self,
        features: dict[str, torch.Tensor],
        quantization: Literal["binary", "int8"] = "int8",
        **kwargs
    ) -> dict[str, torch.Tensor]:
        if quantization == "int8":
            features["sentence_embedding"] = self._int8_quantizer(
                features["sentence_embedding"]
            )
        elif quantization == "binary":
            features["sentence_embedding"] = self._binary_quantizer(
                features["sentence_embedding"]
            )
        else:
            raise ValueError(
                f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
            )
        return features

    @classmethod
    def load(
        cls,
        model_name_or_path: str,
        subfolder: str = "",
        token: bool | str | None = None,
        cache_folder: str | None = None,
        revision: str | None = None,
        local_files_only: bool = False,
        **kwargs,
    ):
        return cls()
        
    def save(self, output_path: str, *args, **kwargs) -> None: 
        return