mkrimmel-pplx commited on
Commit
39de4f2
·
verified ·
1 Parent(s): 727dfda

feat: new model implementation (#1)

Browse files

- fix: added new implementation (ff0893c1cd4f25a76e5392e7995d9219a9aed482)
- feat: updated context model (8d01c688fc64bae72a41575570dd514b7454a033)
- refactor: new modeling code (12fc1ef7c6890644f5fc6a691fc24bd001442d95)

Files changed (6) hide show
  1. config.json +6 -8
  2. configuration.py +5 -0
  3. configuration_qwen3.py +0 -206
  4. modeling.py +83 -0
  5. modules.json +2 -1
  6. st_quantize.py +50 -62
config.json CHANGED
@@ -1,13 +1,12 @@
1
  {
2
  "architectures": [
3
- "Qwen3Model"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
- "AutoConfig": "configuration_qwen3.Qwen3Config",
9
- "AutoModel": "perplexity-ai/bidirectional-qwen3-implementation--modeling_qwen3.Qwen3Model",
10
- "AutoModelForMaskedLM": "modeling_qwen3.Qwen3ForMaskedLM"
11
  },
12
  "bos_token_id": 151643,
13
  "dtype": "float32",
@@ -57,8 +56,7 @@
57
  ],
58
  "max_position_embeddings": 32768,
59
  "max_window_layers": 36,
60
- "mlm_loss_variant": "elbo_normalize",
61
- "model_type": "qwen3",
62
  "num_attention_heads": 32,
63
  "num_hidden_layers": 36,
64
  "num_key_value_heads": 8,
@@ -73,6 +71,6 @@
73
  "transformers_version": "5.0.0.dev0",
74
  "use_cache": false,
75
  "use_sliding_window": false,
76
- "variant": "bidirectional",
77
- "vocab_size": 151936
78
  }
 
1
  {
2
  "architectures": [
3
+ "PPLXQwen3Model"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
+ "AutoConfig": "configuration.PPLXQwen3Config",
9
+ "AutoModel": "modeling.PPLXQwen3Model"
 
10
  },
11
  "bos_token_id": 151643,
12
  "dtype": "float32",
 
56
  ],
57
  "max_position_embeddings": 32768,
58
  "max_window_layers": 36,
59
+ "model_type": "bidirectional_pplx_qwen3",
 
60
  "num_attention_heads": 32,
61
  "num_hidden_layers": 36,
62
  "num_key_value_heads": 8,
 
71
  "transformers_version": "5.0.0.dev0",
72
  "use_cache": false,
73
  "use_sliding_window": false,
74
+ "vocab_size": 151936,
75
+ "attn_implementation": "sdpa"
76
  }
configuration.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
2
+
3
+
4
+ class PPLXQwen3Config(Qwen3Config):
5
+ model_type = "bidirectional_pplx_qwen3"
configuration_qwen3.py DELETED
@@ -1,206 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Qwen3 model configuration"""
16
-
17
- from typing import Optional, Literal
18
-
19
- import warnings
20
-
21
- from transformers.configuration_utils import PreTrainedConfig, layer_type_validation
22
- from transformers.modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
23
- from transformers.utils import logging
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- class Qwen3Config(PreTrainedConfig):
30
- r"""
31
- This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
32
- Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
- with the defaults will yield a similar configuration to that of
34
- Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
35
-
36
- Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
37
- documentation from [`PreTrainedConfig`] for more information.
38
-
39
-
40
- Args:
41
- vocab_size (`int`, *optional*, defaults to 151936):
42
- Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
43
- `inputs_ids` passed when calling [`Qwen3Model`]
44
- hidden_size (`int`, *optional*, defaults to 4096):
45
- Dimension of the hidden representations.
46
- intermediate_size (`int`, *optional*, defaults to 22016):
47
- Dimension of the MLP representations.
48
- num_hidden_layers (`int`, *optional*, defaults to 32):
49
- Number of hidden layers in the Transformer encoder.
50
- num_attention_heads (`int`, *optional*, defaults to 32):
51
- Number of attention heads for each attention layer in the Transformer encoder.
52
- num_key_value_heads (`int`, *optional*, defaults to 32):
53
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
- by meanpooling all the original heads within that group. For more details, check out [this
58
- paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
59
- head_dim (`int`, *optional*, defaults to 128):
60
- The attention head dimension.
61
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
- The non-linear activation function (function or string) in the decoder.
63
- max_position_embeddings (`int`, *optional*, defaults to 32768):
64
- The maximum sequence length that this model might ever be used with.
65
- initializer_range (`float`, *optional*, defaults to 0.02):
66
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
68
- The epsilon used by the rms normalization layers.
69
- use_cache (`bool`, *optional*, defaults to `True`):
70
- Whether or not the model should return the last key/values attentions (not used by all models). Only
71
- relevant if `config.is_decoder=True`.
72
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
73
- Whether the model's input and output word embeddings should be tied.
74
- rope_parameters (`RopeParameters`, *optional*):
75
- Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain
76
- a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
77
- with longer `max_position_embeddings`.
78
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
79
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
80
- use_sliding_window (`bool`, *optional*, defaults to `False`):
81
- Whether to use sliding window attention.
82
- sliding_window (`int`, *optional*, defaults to 4096):
83
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
84
- max_window_layers (`int`, *optional*, defaults to 28):
85
- The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
86
- additional layer afterwards will use SWA (Sliding Window Attention).
87
- layer_types (`list`, *optional*):
88
- Attention pattern for each layer.
89
- attention_dropout (`float`, *optional*, defaults to 0.0):
90
- The dropout ratio for the attention probabilities.
91
-
92
- ```python
93
- >>> from transformers import Qwen3Model, Qwen3Config
94
-
95
- >>> # Initializing a Qwen3 style configuration
96
- >>> configuration = Qwen3Config()
97
-
98
- >>> # Initializing a model from the Qwen3-8B style configuration
99
- >>> model = Qwen3Model(configuration)
100
-
101
- >>> # Accessing the model configuration
102
- >>> configuration = model.config
103
- ```"""
104
-
105
- model_type = "qwen3"
106
- keys_to_ignore_at_inference = ["past_key_values"]
107
-
108
- # Default tensor parallel plan for base model `Qwen3`
109
- base_model_tp_plan = {
110
- "layers.*.self_attn.q_proj": "colwise",
111
- "layers.*.self_attn.k_proj": "colwise",
112
- "layers.*.self_attn.v_proj": "colwise",
113
- "layers.*.self_attn.o_proj": "rowwise",
114
- "layers.*.mlp.gate_proj": "colwise",
115
- "layers.*.mlp.up_proj": "colwise",
116
- "layers.*.mlp.down_proj": "rowwise",
117
- }
118
- base_model_pp_plan = {
119
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
120
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
121
- "norm": (["hidden_states"], ["hidden_states"]),
122
- }
123
-
124
- def __init__(
125
- self,
126
- vocab_size: Optional[int] = 151936,
127
- hidden_size: Optional[int] = 4096,
128
- intermediate_size: Optional[int] = 22016,
129
- num_hidden_layers: Optional[int] = 32,
130
- num_attention_heads: Optional[int] = 32,
131
- num_key_value_heads: Optional[int] = 32,
132
- head_dim: Optional[int] = 128,
133
- hidden_act: Optional[str] = "silu",
134
- max_position_embeddings: Optional[int] = 32768,
135
- initializer_range: Optional[float] = 0.02,
136
- rms_norm_eps: Optional[int] = 1e-6,
137
- use_cache: Optional[bool] = True,
138
- tie_word_embeddings: Optional[bool] = False,
139
- rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
140
- attention_bias: Optional[bool] = False,
141
- use_sliding_window: Optional[bool] = False,
142
- sliding_window: Optional[int] = 4096,
143
- max_window_layers: Optional[int] = 28,
144
- layer_types: Optional[list[str]] = None,
145
- attention_dropout: Optional[float] = 0.0,
146
- variant: Literal["causal", "bidirectional", "causal_dropout"] = "causal",
147
- mlm_loss_variant: Literal["simple", "masked_normalize", "elbo_normalize", "flat_cart"] = "simple",
148
- **kwargs,
149
- ):
150
- self.vocab_size = vocab_size
151
- self.max_position_embeddings = max_position_embeddings
152
- self.hidden_size = hidden_size
153
- self.intermediate_size = intermediate_size
154
- self.num_hidden_layers = num_hidden_layers
155
- self.num_attention_heads = num_attention_heads
156
- self.use_sliding_window = use_sliding_window
157
- self.sliding_window = sliding_window if self.use_sliding_window else None
158
- self.max_window_layers = max_window_layers
159
-
160
- # for backward compatibility
161
- if num_key_value_heads is None:
162
- num_key_value_heads = num_attention_heads
163
-
164
- self.num_key_value_heads = num_key_value_heads
165
- self.head_dim = head_dim
166
- self.hidden_act = hidden_act
167
- self.initializer_range = initializer_range
168
- self.rms_norm_eps = rms_norm_eps
169
- self.use_cache = use_cache
170
- self.attention_bias = attention_bias
171
- self.attention_dropout = attention_dropout
172
- # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
173
- rope_scaling = kwargs.pop("rope_scaling", None)
174
- self.rope_parameters = rope_scaling or rope_parameters
175
-
176
- self.layer_types = layer_types
177
- if self.layer_types is None:
178
- self.layer_types = [
179
- "sliding_attention"
180
- if self.sliding_window is not None and i >= self.max_window_layers
181
- else "full_attention"
182
- for i in range(self.num_hidden_layers)
183
- ]
184
- layer_type_validation(self.layer_types, self.num_hidden_layers)
185
-
186
- # Validate the correctness of rotary position embeddings parameters
187
- rope_theta = kwargs.get("rope_theta", 10000.0)
188
- standardize_rope_params(self, rope_theta=rope_theta)
189
- rope_config_validation(self)
190
-
191
- self.variant = variant
192
- self.mlm_loss_variant = mlm_loss_variant
193
-
194
- if mlm_loss_variant not in ["simple", "masked_normalize", "elbo_normalize", "flat_cart"]:
195
- raise NotImplementedError(f"Loss variant {mlm_loss_variant} unknown")
196
-
197
- if variant != "causal" and use_cache:
198
- warnings.warn("Cannot use cache (use_cache) and bidirectional attention (is_causal=False)")
199
-
200
- super().__init__(
201
- tie_word_embeddings=tie_word_embeddings,
202
- **kwargs,
203
- )
204
-
205
-
206
- __all__ = ["Qwen3Config"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+ import torch
3
+ from transformers import Qwen3Model
4
+ from transformers.cache_utils import Cache
5
+ from transformers.masking_utils import create_causal_mask
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.processing_utils import Unpack
8
+ from transformers.utils import TransformersKwargs
9
+ from .configuration import PPLXQwen3Config
10
+
11
+
12
+ # From modeling_t5gemma.py
13
+ def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
14
+ """
15
+ This creates bidirectional attention mask.
16
+ """
17
+
18
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
19
+ if attention_mask is None:
20
+ return torch.ones((), dtype=torch.bool)
21
+ return attention_mask[batch_idx, kv_idx].to(torch.bool)
22
+
23
+ return inner_mask
24
+
25
+
26
+ class PPLXQwen3Model(Qwen3Model):
27
+ _supports_flash_attn = True
28
+ _supports_sdpa = True
29
+
30
+ config_class = PPLXQwen3Config
31
+
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+ self.post_init()
35
+
36
+ def post_init(self):
37
+ super().post_init()
38
+ # Override to set all layers to non-causal attention. This'll work with attn_implementation="flash_attention_2" or "sdpa"
39
+ for layer in self.layers:
40
+ layer.self_attn.is_causal = False
41
+
42
+ def forward(
43
+ self,
44
+ input_ids: torch.LongTensor | None = None,
45
+ attention_mask: torch.Tensor | None = None,
46
+ position_ids: torch.LongTensor | None = None,
47
+ past_key_values: Cache | None = None,
48
+ inputs_embeds: torch.FloatTensor | None = None,
49
+ use_cache: bool | None = None,
50
+ cache_position: torch.LongTensor | None = None,
51
+ **kwargs: Unpack[TransformersKwargs],
52
+ ) -> BaseModelOutputWithPooling:
53
+ if inputs_embeds is None:
54
+ inputs_embeds = self.embed_tokens(input_ids)
55
+ input_ids = None
56
+
57
+ # We construct a dummy tensor imitating initial positions
58
+ dummy_cache_position = torch.arange(
59
+ inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
60
+ )
61
+ attention_mask = {
62
+ "full_attention": create_causal_mask(
63
+ config=self.config,
64
+ input_embeds=inputs_embeds,
65
+ attention_mask=attention_mask,
66
+ cache_position=dummy_cache_position,
67
+ past_key_values=None,
68
+ position_ids=position_ids,
69
+ or_mask_function=bidirectional_mask_function(attention_mask),
70
+ )
71
+ }
72
+
73
+ outputs = super().forward(
74
+ input_ids=input_ids,
75
+ attention_mask=attention_mask,
76
+ position_ids=position_ids,
77
+ past_key_values=past_key_values,
78
+ inputs_embeds=inputs_embeds,
79
+ use_cache=use_cache,
80
+ cache_position=cache_position,
81
+ **kwargs,
82
+ )
83
+ return outputs
modules.json CHANGED
@@ -15,6 +15,7 @@
15
  "idx": 2,
16
  "name": "2",
17
  "path": "",
18
- "type": "st_quantize.UnnormalizedInt8TanhQuantizer"
 
19
  }
20
  ]
 
15
  "idx": 2,
16
  "name": "2",
17
  "path": "",
18
+ "type": "st_quantize.FlexibleQuantizer",
19
+ "kwargs": ["quantization"]
20
  }
21
  ]
st_quantize.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
- from torch import nn
3
- from typing import Optional
4
  from typing import Literal
 
5
 
6
 
7
  class Quantizer(torch.nn.Module):
@@ -26,9 +25,7 @@ class Quantizer(torch.nn.Module):
26
  result = soft
27
  else:
28
  result = (
29
- self._hard_quantize(x, *args, **kwargs).detach()
30
- + soft
31
- - soft.detach()
32
  )
33
 
34
  return result
@@ -37,85 +34,76 @@ class Quantizer(torch.nn.Module):
37
  class Int8TanhQuantizer(Quantizer):
38
  def __init__(
39
  self,
40
- normalize: bool = False,
41
  hard: bool = True,
42
  ):
43
  super().__init__(hard=hard)
44
  self.qmin = -128
45
  self.qmax = 127
46
- self._normalize = normalize
47
 
48
  def _soft_quantize(self, x, *args, **kwargs):
49
- if self._normalize:
50
- x = (x - x.mean(dim=-1, keepdim=True)) / (
51
- x.std(dim=-1, keepdim=True) + 1e-8
52
- )
53
-
54
  return torch.tanh(x)
55
 
56
  def _hard_quantize(self, x, *args, **kwargs):
57
  soft = self._soft_quantize(x)
58
  int_x = torch.round(soft * self.qmax)
59
  int_x = torch.clamp(int_x, self.qmin, self.qmax)
60
- return int_x / self.qmax
61
-
62
-
63
- class UnnormalizedInt8TanhQuantizer(Int8TanhQuantizer):
64
- def __init__(self):
65
- super().__init__()
66
- self.quantizer = Int8TanhQuantizer(normalize=False)
67
-
68
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
69
- features["sentence_embedding"] = self.quantizer(
70
- features["sentence_embedding"]
71
- )
72
- return features
73
-
74
- @classmethod
75
- def load(cls, input_path: str) -> "PoolAndQuantize":
76
- return cls()
77
-
78
-
79
- class NormalizedInt8TanhQuantizer(Int8TanhQuantizer):
80
- def __init__(self):
81
- super().__init__()
82
- self.quantizer = Int8TanhQuantizer(normalize=True)
83
-
84
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
85
- features["sentence_embedding"] = self.quantizer(
86
- features["sentence_embedding"]
87
- )
88
- return features
89
-
90
- @classmethod
91
- def load(cls, input_path: str) -> "PoolAndQuantize":
92
- return cls()
93
 
94
 
95
- class Binarizer(Quantizer):
96
- def __init__(self, tanh_scale: float = 1.0, **kwargs):
97
- super().__init__(**kwargs)
98
- self._tanh_scale = tanh_scale
 
 
 
 
99
 
100
- def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
101
- return torch.where(x > 0, 1.0, -1.0)
102
 
103
- def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
104
- return torch.tanh(x * self._tanh_scale)
105
 
106
 
107
- class UnnormalizedBinarizer(nn.Module):
108
- def __init__(self, tanh_scale: float = 1.0, hard: bool = True):
109
  super().__init__()
110
- self.quantizer = Binarizer(tanh_scale=tanh_scale, hard=hard)
 
111
 
112
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
113
- features["sentence_embedding"] = self.quantizer(
114
- features["sentence_embedding"]
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return features
117
 
118
  @classmethod
119
- def load(cls, input_path: str) -> "UnnormalizedBinarizer":
 
 
 
 
 
 
 
 
 
120
  return cls()
121
-
 
 
 
1
  import torch
 
 
2
  from typing import Literal
3
+ from sentence_transformers.models import Module
4
 
5
 
6
  class Quantizer(torch.nn.Module):
 
25
  result = soft
26
  else:
27
  result = (
28
+ self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
 
 
29
  )
30
 
31
  return result
 
34
  class Int8TanhQuantizer(Quantizer):
35
  def __init__(
36
  self,
 
37
  hard: bool = True,
38
  ):
39
  super().__init__(hard=hard)
40
  self.qmin = -128
41
  self.qmax = 127
 
42
 
43
  def _soft_quantize(self, x, *args, **kwargs):
 
 
 
 
 
44
  return torch.tanh(x)
45
 
46
  def _hard_quantize(self, x, *args, **kwargs):
47
  soft = self._soft_quantize(x)
48
  int_x = torch.round(soft * self.qmax)
49
  int_x = torch.clamp(int_x, self.qmin, self.qmax)
50
+ return int_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
+ class BinaryTanhQuantizer(Quantizer):
54
+ def __init__(
55
+ self,
56
+ hard: bool = True,
57
+ scale: float = 1.0,
58
+ ):
59
+ super().__init__(hard)
60
+ self._scale = scale
61
 
62
+ def _soft_quantize(self, x, *args, **kwargs):
63
+ return torch.tanh(self._scale * x)
64
 
65
+ def _hard_quantize(self, x, *args, **kwargs):
66
+ return torch.where(x >= 0, 1.0, -1.0)
67
 
68
 
69
+ class FlexibleQuantizer(Module):
70
+ def __init__(self):
71
  super().__init__()
72
+ self._int8_quantizer = Int8TanhQuantizer()
73
+ self._binary_quantizer = BinaryTanhQuantizer()
74
 
75
+ def forward(
76
+ self,
77
+ features: dict[str, torch.Tensor],
78
+ quantization: Literal["binary", "int8"] = "int8",
79
+ **kwargs
80
+ ) -> dict[str, torch.Tensor]:
81
+ if quantization == "int8":
82
+ features["sentence_embedding"] = self._int8_quantizer(
83
+ features["sentence_embedding"]
84
+ )
85
+ elif quantization == "binary":
86
+ features["sentence_embedding"] = self._binary_quantizer(
87
+ features["sentence_embedding"]
88
+ )
89
+ else:
90
+ raise ValueError(
91
+ f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
92
+ )
93
  return features
94
 
95
  @classmethod
96
+ def load(
97
+ cls,
98
+ model_name_or_path: str,
99
+ subfolder: str = "",
100
+ token: bool | str | None = None,
101
+ cache_folder: str | None = None,
102
+ revision: str | None = None,
103
+ local_files_only: bool = False,
104
+ **kwargs,
105
+ ):
106
  return cls()
107
+
108
+ def save(self, output_path: str, *args, **kwargs) -> None:
109
+ return