File size: 4,354 Bytes
1820cb3
169a600
1820cb3
e12840f
1820cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e12840f
1820cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e12840f
1820cb3
 
e12840f
 
 
 
 
 
 
 
1820cb3
e12840f
 
1820cb3
e12840f
 
1820cb3
 
169a600
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e12840f
 
1820cb3
e12840f
 
169a600
1820cb3
e12840f
 
 
169a600
 
e12840f
 
 
 
 
 
 
 
 
169a600
 
 
 
e12840f
 
169a600
e12840f
1820cb3
 
 
e12840f
 
 
 
 
 
 
 
 
 
1820cb3
169a600
 
e12840f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import numpy as np
from typing import Literal
from sentence_transformers.models import Module


class Quantizer(torch.nn.Module):
    def __init__(self, hard: bool = True):
        """
        Args:
            hard: Whether to use hard or soft quantization. Defaults to True.
        """
        super().__init__()
        self._hard = hard

    def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, x, *args, **kwargs) -> torch.Tensor:
        soft = self._soft_quantize(x, *args, **kwargs)

        if not self._hard:
            result = soft
        else:
            result = (
                self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
            )

        return result


class Int8TanhQuantizer(Quantizer):
    def __init__(
        self,
        hard: bool = True,
    ):
        super().__init__(hard=hard)
        self.qmin = -128
        self.qmax = 127

    def _soft_quantize(self, x, *args, **kwargs):
        return torch.tanh(x)

    def _hard_quantize(self, x, *args, **kwargs):
        soft = self._soft_quantize(x)
        int_x = torch.round(soft * self.qmax)
        int_x = torch.clamp(int_x, self.qmin, self.qmax)
        return int_x


class BinaryTanhQuantizer(Quantizer):
    def __init__(
        self,
        hard: bool = True,
        scale: float = 1.0,
    ):
        super().__init__(hard)
        self._scale = scale

    def _soft_quantize(self, x, *args, **kwargs):
        return torch.tanh(self._scale * x)

    def _hard_quantize(self, x, *args, **kwargs):
        return torch.where(x >= 0, 1.0, -1.0)


class PackedBinaryQuantizer:
    """
    Packs binary embeddings into uint8 format for efficient storage.

    This quantizer applies a binary threshold (x >= 0) and packs 8 consecutive
    bits into a single uint8 byte using numpy.packbits. This reduces memory
    usage by 8x compared to float32 and by 4x compared to int8.

    IMPORTANT: This is an inference-only quantizer - it is not differentiable
    and should only be used for encoding/inference, not during training.

    Args:
        x: Input tensor of any float dtype, shape (..., embedding_dim)

    Returns:
        Packed binary tensor of dtype uint8, shape (..., embedding_dim // 8)

    Example:
        >>> quantizer = PackedBinaryQuantizer()
        >>> embeddings = torch.randn(2, 1024)  # float32
        >>> packed = quantizer(embeddings)     # uint8, shape (2, 128)
    """
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        bits = np.where(x.cpu().numpy() >= 0, True, False)
        packed = np.packbits(bits, axis=-1)
        return torch.from_numpy(packed).to(x.device)


class FlexibleQuantizer(Module):
    def __init__(self):
        super().__init__()
        self._int8_quantizer = Int8TanhQuantizer()
        self._binary_quantizer = BinaryTanhQuantizer()
        self._packed_binary_quantizer = PackedBinaryQuantizer()

    def forward(
        self,
        features: dict[str, torch.Tensor],
        quantization: Literal["int8", "binary", "ubinary"] = "int8",
        **kwargs,
    ) -> dict[str, torch.Tensor]:
        if quantization == "int8":
            features["sentence_embedding"] = self._int8_quantizer(
                features["sentence_embedding"]
            )
        elif quantization == "binary":
            features["sentence_embedding"] = self._binary_quantizer(
                features["sentence_embedding"]
            )
        elif quantization == "ubinary":
            features["sentence_embedding"] = self._packed_binary_quantizer(
                features["sentence_embedding"]
            )
        else:
            raise ValueError(
                f"Invalid quantization type: {quantization}. Must be 'binary', 'ubinary', or 'int8'."
            )
        return features

    @classmethod
    def load(
        cls,
        model_name_or_path: str,
        subfolder: str = "",
        token: bool | str | None = None,
        cache_folder: str | None = None,
        revision: str | None = None,
        local_files_only: bool = False,
        **kwargs,
    ):
        return cls()

    def save(self, output_path: str, *args, **kwargs) -> None:
        return