Add new modeling with contextual encoding

#2
Files changed (6) hide show
  1. config.json +5 -7
  2. configuration.py +5 -0
  3. configuration_qwen3.py +0 -206
  4. modeling.py +336 -0
  5. modules.json +3 -2
  6. st_quantize.py +50 -62
config.json CHANGED
@@ -1,13 +1,12 @@
1
  {
2
  "architectures": [
3
- "Qwen3Model"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
- "AutoConfig": "configuration_qwen3.Qwen3Config",
9
- "AutoModel": "perplexity-ai/bidirectional-qwen3-implementation--modeling_qwen3.Qwen3Model",
10
- "AutoModelForMaskedLM": "modeling_qwen3.Qwen3ForMaskedLM"
11
  },
12
  "bos_token_id": 151643,
13
  "dtype": "float32",
@@ -49,8 +48,7 @@
49
  ],
50
  "max_position_embeddings": 32768,
51
  "max_window_layers": 28,
52
- "mlm_loss_variant": "elbo_normalize",
53
- "model_type": "qwen3",
54
  "num_attention_heads": 16,
55
  "num_hidden_layers": 28,
56
  "num_key_value_heads": 8,
@@ -65,6 +63,6 @@
65
  "transformers_version": "5.0.0.dev0",
66
  "use_cache": false,
67
  "use_sliding_window": false,
68
- "variant": "bidirectional",
69
  "vocab_size": 151936
70
  }
 
1
  {
2
  "architectures": [
3
+ "PPLXQwen3Model"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
+ "AutoConfig": "configuration.PPLXQwen3Config",
9
+ "AutoModel": "modeling.PPLXQwen3ContextualModel"
 
10
  },
11
  "bos_token_id": 151643,
12
  "dtype": "float32",
 
48
  ],
49
  "max_position_embeddings": 32768,
50
  "max_window_layers": 28,
51
+ "model_type": "bidirectional_pplx_qwen3",
 
52
  "num_attention_heads": 16,
53
  "num_hidden_layers": 28,
54
  "num_key_value_heads": 8,
 
63
  "transformers_version": "5.0.0.dev0",
64
  "use_cache": false,
65
  "use_sliding_window": false,
66
+ "attn_implementation": "sdpa",
67
  "vocab_size": 151936
68
  }
configuration.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
2
+
3
+
4
+ class PPLXQwen3Config(Qwen3Config):
5
+ model_type = "bidirectional_pplx_qwen3"
configuration_qwen3.py DELETED
@@ -1,206 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Qwen3 model configuration"""
16
-
17
- from typing import Optional, Literal
18
-
19
- import warnings
20
-
21
- from transformers.configuration_utils import PreTrainedConfig, layer_type_validation
22
- from transformers.modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
23
- from transformers.utils import logging
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- class Qwen3Config(PreTrainedConfig):
30
- r"""
31
- This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
32
- Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
- with the defaults will yield a similar configuration to that of
34
- Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
35
-
36
- Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
37
- documentation from [`PreTrainedConfig`] for more information.
38
-
39
-
40
- Args:
41
- vocab_size (`int`, *optional*, defaults to 151936):
42
- Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
43
- `inputs_ids` passed when calling [`Qwen3Model`]
44
- hidden_size (`int`, *optional*, defaults to 4096):
45
- Dimension of the hidden representations.
46
- intermediate_size (`int`, *optional*, defaults to 22016):
47
- Dimension of the MLP representations.
48
- num_hidden_layers (`int`, *optional*, defaults to 32):
49
- Number of hidden layers in the Transformer encoder.
50
- num_attention_heads (`int`, *optional*, defaults to 32):
51
- Number of attention heads for each attention layer in the Transformer encoder.
52
- num_key_value_heads (`int`, *optional*, defaults to 32):
53
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
- by meanpooling all the original heads within that group. For more details, check out [this
58
- paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
59
- head_dim (`int`, *optional*, defaults to 128):
60
- The attention head dimension.
61
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
- The non-linear activation function (function or string) in the decoder.
63
- max_position_embeddings (`int`, *optional*, defaults to 32768):
64
- The maximum sequence length that this model might ever be used with.
65
- initializer_range (`float`, *optional*, defaults to 0.02):
66
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
68
- The epsilon used by the rms normalization layers.
69
- use_cache (`bool`, *optional*, defaults to `True`):
70
- Whether or not the model should return the last key/values attentions (not used by all models). Only
71
- relevant if `config.is_decoder=True`.
72
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
73
- Whether the model's input and output word embeddings should be tied.
74
- rope_parameters (`RopeParameters`, *optional*):
75
- Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain
76
- a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
77
- with longer `max_position_embeddings`.
78
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
79
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
80
- use_sliding_window (`bool`, *optional*, defaults to `False`):
81
- Whether to use sliding window attention.
82
- sliding_window (`int`, *optional*, defaults to 4096):
83
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
84
- max_window_layers (`int`, *optional*, defaults to 28):
85
- The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
86
- additional layer afterwards will use SWA (Sliding Window Attention).
87
- layer_types (`list`, *optional*):
88
- Attention pattern for each layer.
89
- attention_dropout (`float`, *optional*, defaults to 0.0):
90
- The dropout ratio for the attention probabilities.
91
-
92
- ```python
93
- >>> from transformers import Qwen3Model, Qwen3Config
94
-
95
- >>> # Initializing a Qwen3 style configuration
96
- >>> configuration = Qwen3Config()
97
-
98
- >>> # Initializing a model from the Qwen3-8B style configuration
99
- >>> model = Qwen3Model(configuration)
100
-
101
- >>> # Accessing the model configuration
102
- >>> configuration = model.config
103
- ```"""
104
-
105
- model_type = "qwen3"
106
- keys_to_ignore_at_inference = ["past_key_values"]
107
-
108
- # Default tensor parallel plan for base model `Qwen3`
109
- base_model_tp_plan = {
110
- "layers.*.self_attn.q_proj": "colwise",
111
- "layers.*.self_attn.k_proj": "colwise",
112
- "layers.*.self_attn.v_proj": "colwise",
113
- "layers.*.self_attn.o_proj": "rowwise",
114
- "layers.*.mlp.gate_proj": "colwise",
115
- "layers.*.mlp.up_proj": "colwise",
116
- "layers.*.mlp.down_proj": "rowwise",
117
- }
118
- base_model_pp_plan = {
119
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
120
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
121
- "norm": (["hidden_states"], ["hidden_states"]),
122
- }
123
-
124
- def __init__(
125
- self,
126
- vocab_size: Optional[int] = 151936,
127
- hidden_size: Optional[int] = 4096,
128
- intermediate_size: Optional[int] = 22016,
129
- num_hidden_layers: Optional[int] = 32,
130
- num_attention_heads: Optional[int] = 32,
131
- num_key_value_heads: Optional[int] = 32,
132
- head_dim: Optional[int] = 128,
133
- hidden_act: Optional[str] = "silu",
134
- max_position_embeddings: Optional[int] = 32768,
135
- initializer_range: Optional[float] = 0.02,
136
- rms_norm_eps: Optional[int] = 1e-6,
137
- use_cache: Optional[bool] = True,
138
- tie_word_embeddings: Optional[bool] = False,
139
- rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
140
- attention_bias: Optional[bool] = False,
141
- use_sliding_window: Optional[bool] = False,
142
- sliding_window: Optional[int] = 4096,
143
- max_window_layers: Optional[int] = 28,
144
- layer_types: Optional[list[str]] = None,
145
- attention_dropout: Optional[float] = 0.0,
146
- variant: Literal["causal", "bidirectional", "causal_dropout"] = "causal",
147
- mlm_loss_variant: Literal["simple", "masked_normalize", "elbo_normalize", "flat_cart"] = "simple",
148
- **kwargs,
149
- ):
150
- self.vocab_size = vocab_size
151
- self.max_position_embeddings = max_position_embeddings
152
- self.hidden_size = hidden_size
153
- self.intermediate_size = intermediate_size
154
- self.num_hidden_layers = num_hidden_layers
155
- self.num_attention_heads = num_attention_heads
156
- self.use_sliding_window = use_sliding_window
157
- self.sliding_window = sliding_window if self.use_sliding_window else None
158
- self.max_window_layers = max_window_layers
159
-
160
- # for backward compatibility
161
- if num_key_value_heads is None:
162
- num_key_value_heads = num_attention_heads
163
-
164
- self.num_key_value_heads = num_key_value_heads
165
- self.head_dim = head_dim
166
- self.hidden_act = hidden_act
167
- self.initializer_range = initializer_range
168
- self.rms_norm_eps = rms_norm_eps
169
- self.use_cache = use_cache
170
- self.attention_bias = attention_bias
171
- self.attention_dropout = attention_dropout
172
- # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
173
- rope_scaling = kwargs.pop("rope_scaling", None)
174
- self.rope_parameters = rope_scaling or rope_parameters
175
-
176
- self.layer_types = layer_types
177
- if self.layer_types is None:
178
- self.layer_types = [
179
- "sliding_attention"
180
- if self.sliding_window is not None and i >= self.max_window_layers
181
- else "full_attention"
182
- for i in range(self.num_hidden_layers)
183
- ]
184
- layer_type_validation(self.layer_types, self.num_hidden_layers)
185
-
186
- # Validate the correctness of rotary position embeddings parameters
187
- rope_theta = kwargs.get("rope_theta", 10000.0)
188
- standardize_rope_params(self, rope_theta=rope_theta)
189
- rope_config_validation(self)
190
-
191
- self.variant = variant
192
- self.mlm_loss_variant = mlm_loss_variant
193
-
194
- if mlm_loss_variant not in ["simple", "masked_normalize", "elbo_normalize", "flat_cart"]:
195
- raise NotImplementedError(f"Loss variant {mlm_loss_variant} unknown")
196
-
197
- if variant != "causal" and use_cache:
198
- warnings.warn("Cannot use cache (use_cache) and bidirectional attention (is_causal=False)")
199
-
200
- super().__init__(
201
- tie_word_embeddings=tie_word_embeddings,
202
- **kwargs,
203
- )
204
-
205
-
206
- __all__ = ["Qwen3Config"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal
2
+ import numpy as np
3
+ import torch
4
+ from transformers import Qwen3Model
5
+ from transformers.cache_utils import Cache
6
+ from transformers.masking_utils import create_causal_mask
7
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
8
+ from transformers.processing_utils import Unpack
9
+ from transformers.utils import TransformersKwargs
10
+ from .configuration import PPLXQwen3Config
11
+ from transformers import AutoTokenizer
12
+ from .st_quantize import FlexibleQuantizer
13
+
14
+
15
+ # From modeling_t5gemma.py
16
+ def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
17
+ """
18
+ This creates bidirectional attention mask.
19
+ """
20
+
21
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
22
+ if attention_mask is None:
23
+ return torch.ones((), dtype=torch.bool)
24
+ return attention_mask[batch_idx, kv_idx].to(torch.bool)
25
+
26
+ return inner_mask
27
+
28
+
29
+ class PPLXQwen3Model(Qwen3Model):
30
+ _supports_flash_attn = True
31
+ _supports_sdpa = True
32
+
33
+ config_class = PPLXQwen3Config
34
+
35
+ def __init__(self, config):
36
+ super().__init__(config)
37
+ self.post_init()
38
+
39
+ def post_init(self):
40
+ super().post_init()
41
+ # Override to set all layers to non-causal attention. This'll work with attn_implementation="flash_attention_2" or "sdpa"
42
+ for layer in self.layers:
43
+ layer.self_attn.is_causal = False
44
+
45
+ def forward(
46
+ self,
47
+ input_ids: torch.LongTensor | None = None,
48
+ attention_mask: torch.Tensor | None = None,
49
+ position_ids: torch.LongTensor | None = None,
50
+ past_key_values: Cache | None = None,
51
+ inputs_embeds: torch.FloatTensor | None = None,
52
+ use_cache: bool | None = None,
53
+ cache_position: torch.LongTensor | None = None,
54
+ **kwargs: Unpack[TransformersKwargs],
55
+ ) -> BaseModelOutputWithPooling:
56
+ if inputs_embeds is None:
57
+ inputs_embeds = self.embed_tokens(input_ids)
58
+ input_ids = None
59
+
60
+ # We construct a dummy tensor imitating initial positions
61
+ dummy_cache_position = torch.arange(
62
+ inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
63
+ )
64
+ attention_mask = {
65
+ "full_attention": create_causal_mask(
66
+ config=self.config,
67
+ input_embeds=inputs_embeds,
68
+ attention_mask=attention_mask,
69
+ cache_position=dummy_cache_position,
70
+ past_key_values=None,
71
+ position_ids=position_ids,
72
+ or_mask_function=bidirectional_mask_function(attention_mask),
73
+ )
74
+ }
75
+
76
+ outputs = super().forward(
77
+ input_ids=input_ids,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_values=past_key_values,
81
+ inputs_embeds=inputs_embeds,
82
+ use_cache=use_cache,
83
+ cache_position=cache_position,
84
+ **kwargs,
85
+ )
86
+ return outputs
87
+
88
+
89
+ class PPLXQwen3ContextualModel(PPLXQwen3Model):
90
+ """
91
+ Qwen3 model with contextual encoding support for late chunking.
92
+
93
+ This model extends PPLXQwen3Model with an encode() method that supports both
94
+ standard encoding (list[str]) and contextual encoding (list[list[str]]) with late chunking.
95
+
96
+ IMPORTANT: This model MUST be loaded with trust_remote_code=True:
97
+
98
+ from transformers import AutoModel
99
+
100
+ model = AutoModel.from_pretrained(
101
+ "path/to/model",
102
+ trust_remote_code=True # REQUIRED!
103
+ )
104
+
105
+ embeddings = model.encode([["chunk1", "chunk2"]])
106
+
107
+ Loading without trust_remote_code=True will fail to load this custom model class.
108
+ """
109
+
110
+ config_class = PPLXQwen3Config
111
+
112
+ def __init__(self, config):
113
+ super().__init__(config)
114
+
115
+ if not isinstance(config, PPLXQwen3Config):
116
+ raise TypeError(
117
+ f"PPLXQwen3ContextualModel requires PPLXQwen3Config, got {type(config).__name__}. "
118
+ f"Did you forget to load with trust_remote_code=True?"
119
+ )
120
+
121
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
122
+ self._flexible_quantizer = FlexibleQuantizer()
123
+
124
+ @staticmethod
125
+ def mean_pooling(
126
+ token_embeddings: torch.Tensor, attention_mask: torch.Tensor
127
+ ) -> torch.Tensor:
128
+ """Apply mean pooling to token embeddings."""
129
+ input_mask_expanded = (
130
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
131
+ )
132
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
133
+ input_mask_expanded.sum(1), min=1e-9
134
+ )
135
+
136
+ @torch.inference_mode()
137
+ def encode(
138
+ self,
139
+ documents: list[list[str]],
140
+ batch_size: int = 32,
141
+ show_progress_bar: bool = False,
142
+ device: str | torch.device | None = None,
143
+ normalize_embeddings: bool = False,
144
+ convert_to_numpy: bool = True,
145
+ quantization: Literal["int8", "binary"] = "int8",
146
+ ) -> list[np.ndarray] | list[torch.Tensor]:
147
+ """
148
+ Encode documents with late chunking (contextual embeddings).
149
+
150
+ This model is designed specifically for contextual encoding and always expects
151
+ documents as nested lists where each document is a list of text chunks.
152
+
153
+ The encoding process:
154
+ 1. Concatenate chunks with separator tokens
155
+ 2. Run forward pass to get token embeddings
156
+ 3. Extract and pool individual chunk embeddings (late chunking)
157
+ 4. Apply quantization (Int8 or binary, always enabled)
158
+ 5. Normalize embeddings if requested (applied after quantization)
159
+ 6. Convert to numpy or return as tensors
160
+
161
+ Args:
162
+ documents: List of documents, where each document is a list of text chunks.
163
+ Example: [["chunk1", "chunk2"], ["chunk1", "chunk2", "chunk3"]]
164
+ batch_size: Batch size for encoding
165
+ show_progress_bar: Show progress bar during encoding
166
+ device: Device to use for computation (defaults to model's device)
167
+ normalize_embeddings: Normalize embeddings to unit length (applied after quantization)
168
+ convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
169
+ quantization: Quantization type to apply. Options:
170
+ - "int8": Int8 tanh quantization (default)
171
+ - "binary": Binary tanh quantization
172
+
173
+ Returns:
174
+ List of numpy arrays or tensors (preserves document structure).
175
+ Each element has shape (n_chunks, hidden_dim).
176
+ embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
177
+ Output type depends on quantization method:
178
+ - Int8: int8 values in range [-128, 127]
179
+ - Binary: float values -1.0 or 1.0
180
+ """
181
+
182
+ if not isinstance(documents, list) or not all(
183
+ isinstance(doc, list) for doc in documents
184
+ ):
185
+ raise TypeError(
186
+ "Input 'documents' must be a list of lists of strings for contextual encoding."
187
+ )
188
+
189
+ if quantization not in ["int8", "binary"]:
190
+ raise ValueError(
191
+ f"Unsupported quantization type: '{quantization}'. "
192
+ f"Supported types are: 'int8', 'binary'. "
193
+ f"Got: {type(quantization).__name__} = '{quantization}'"
194
+ )
195
+
196
+ self.eval()
197
+
198
+ if device is None:
199
+ device = next(self.parameters()).device
200
+
201
+ all_embeddings = []
202
+
203
+ range_iter = range(0, len(documents), batch_size)
204
+ if show_progress_bar:
205
+ try:
206
+ from tqdm import tqdm
207
+
208
+ range_iter = tqdm(range_iter, desc="Encoding documents")
209
+ except ImportError:
210
+ pass
211
+
212
+ for i in range_iter:
213
+ batch_docs = documents[i : i + batch_size]
214
+
215
+ doc_strings = [
216
+ self.tokenizer.sep_token.join(chunks) for chunks in batch_docs
217
+ ]
218
+
219
+ inputs = self.tokenizer(
220
+ doc_strings,
221
+ padding=True,
222
+ truncation=True,
223
+ return_tensors="pt",
224
+ )
225
+ inputs = {k: v.to(device) for k, v in inputs.items()}
226
+
227
+ outputs = self.forward(**inputs)
228
+ token_embeddings = outputs.last_hidden_state
229
+
230
+ batch_chunk_embeddings = self._extract_chunks_from_concatenated(
231
+ input_ids=inputs["input_ids"],
232
+ token_embeddings=token_embeddings,
233
+ attention_mask=inputs["attention_mask"],
234
+ )
235
+
236
+ batch_chunk_embeddings = [
237
+ torch.stack([chunk for chunk in doc_chunks], dim=0)
238
+ for doc_chunks in batch_chunk_embeddings
239
+ ]
240
+
241
+ batch_chunk_embeddings = [
242
+ self._flexible_quantizer(
243
+ {"sentence_embedding": emb}, quantization=quantization
244
+ )["sentence_embedding"]
245
+ for emb in batch_chunk_embeddings
246
+ ]
247
+
248
+ if normalize_embeddings:
249
+ batch_chunk_embeddings = [
250
+ torch.nn.functional.normalize(emb, p=2, dim=-1)
251
+ for emb in batch_chunk_embeddings
252
+ ]
253
+
254
+ batch_chunk_embeddings = [emb.cpu() for emb in batch_chunk_embeddings]
255
+
256
+ all_embeddings.extend(batch_chunk_embeddings)
257
+
258
+ if convert_to_numpy:
259
+ all_embeddings = [emb.numpy() for emb in all_embeddings]
260
+
261
+ return all_embeddings
262
+
263
+ def _extract_chunks_from_concatenated(
264
+ self,
265
+ input_ids: torch.Tensor,
266
+ token_embeddings: torch.Tensor,
267
+ attention_mask: torch.Tensor,
268
+ ) -> list[list[torch.Tensor]]:
269
+ """
270
+ Extract individual chunk embeddings from concatenated sequence using late chunking.
271
+
272
+ This method splits concatenated sequences like "[chunk1][SEP][chunk2][SEP]..."
273
+ back into individual chunk embeddings by finding SEP token positions.
274
+
275
+ Args:
276
+ input_ids: Token IDs (batch_size, seq_len)
277
+ token_embeddings: Token embeddings (batch_size, seq_len, hidden_dim)
278
+ attention_mask: Attention mask (batch_size, seq_len)
279
+
280
+ Returns:
281
+ list[list[torch.Tensor]]: List of documents, each containing list of chunk embeddings
282
+
283
+ Note:
284
+ The sep_token_id is retrieved from self.tokenizer.sep_token_id.
285
+ Common values: Qwen2=151643, BERT=102, varies by tokenizer.
286
+ """
287
+ sep_token_id = self.tokenizer.sep_token_id
288
+ batch_size = input_ids.shape[0]
289
+
290
+ all_doc_chunks = []
291
+
292
+ for batch_idx in range(batch_size):
293
+ # non-pad sep tokens
294
+ valid_positions = attention_mask[batch_idx].bool()
295
+ sep_positions = (
296
+ (input_ids[batch_idx] == sep_token_id) & valid_positions
297
+ ).nonzero(as_tuple=True)[0]
298
+
299
+ chunk_embeddings = []
300
+ start_pos = 0
301
+
302
+ for sep_pos in sep_positions:
303
+ chunk_tokens = token_embeddings[batch_idx, start_pos:sep_pos]
304
+ chunk_mask = attention_mask[batch_idx, start_pos:sep_pos]
305
+
306
+ chunk_emb = self.mean_pooling(
307
+ chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
308
+ ).squeeze(0)
309
+
310
+ chunk_embeddings.append(chunk_emb)
311
+
312
+ start_pos = sep_pos + 1
313
+
314
+ # Handle the last chunk (after the last SEP token)
315
+ last_valid_pos = attention_mask[batch_idx].sum().item()
316
+
317
+ chunk_tokens = token_embeddings[batch_idx, start_pos:last_valid_pos]
318
+ chunk_mask = attention_mask[batch_idx, start_pos:last_valid_pos]
319
+
320
+ if chunk_mask.sum() > 0:
321
+ chunk_emb = self.mean_pooling(
322
+ chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
323
+ ).squeeze(0)
324
+ else:
325
+ # Empty chunk - create zero embedding
326
+ chunk_emb = torch.zeros(
327
+ token_embeddings.shape[-1],
328
+ device=token_embeddings.device,
329
+ dtype=token_embeddings.dtype,
330
+ )
331
+
332
+ chunk_embeddings.append(chunk_emb)
333
+
334
+ all_doc_chunks.append(chunk_embeddings)
335
+
336
+ return all_doc_chunks
modules.json CHANGED
@@ -15,6 +15,7 @@
15
  "idx": 2,
16
  "name": "2",
17
  "path": "",
18
- "type": "st_quantize.UnnormalizedInt8TanhQuantizer"
 
19
  }
20
- ]
 
15
  "idx": 2,
16
  "name": "2",
17
  "path": "",
18
+ "type": "st_quantize.FlexibleQuantizer",
19
+ "kwargs": ["quantization"]
20
  }
21
+ ]
st_quantize.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
- from torch import nn
3
- from typing import Optional
4
  from typing import Literal
 
5
 
6
 
7
  class Quantizer(torch.nn.Module):
@@ -26,9 +25,7 @@ class Quantizer(torch.nn.Module):
26
  result = soft
27
  else:
28
  result = (
29
- self._hard_quantize(x, *args, **kwargs).detach()
30
- + soft
31
- - soft.detach()
32
  )
33
 
34
  return result
@@ -37,85 +34,76 @@ class Quantizer(torch.nn.Module):
37
  class Int8TanhQuantizer(Quantizer):
38
  def __init__(
39
  self,
40
- normalize: bool = False,
41
  hard: bool = True,
42
  ):
43
  super().__init__(hard=hard)
44
  self.qmin = -128
45
  self.qmax = 127
46
- self._normalize = normalize
47
 
48
  def _soft_quantize(self, x, *args, **kwargs):
49
- if self._normalize:
50
- x = (x - x.mean(dim=-1, keepdim=True)) / (
51
- x.std(dim=-1, keepdim=True) + 1e-8
52
- )
53
-
54
  return torch.tanh(x)
55
 
56
  def _hard_quantize(self, x, *args, **kwargs):
57
  soft = self._soft_quantize(x)
58
  int_x = torch.round(soft * self.qmax)
59
  int_x = torch.clamp(int_x, self.qmin, self.qmax)
60
- return int_x / self.qmax
61
-
62
-
63
- class UnnormalizedInt8TanhQuantizer(Int8TanhQuantizer):
64
- def __init__(self):
65
- super().__init__()
66
- self.quantizer = Int8TanhQuantizer(normalize=False)
67
-
68
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
69
- features["sentence_embedding"] = self.quantizer(
70
- features["sentence_embedding"]
71
- )
72
- return features
73
-
74
- @classmethod
75
- def load(cls, input_path: str) -> "PoolAndQuantize":
76
- return cls()
77
-
78
-
79
- class NormalizedInt8TanhQuantizer(Int8TanhQuantizer):
80
- def __init__(self):
81
- super().__init__()
82
- self.quantizer = Int8TanhQuantizer(normalize=True)
83
-
84
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
85
- features["sentence_embedding"] = self.quantizer(
86
- features["sentence_embedding"]
87
- )
88
- return features
89
-
90
- @classmethod
91
- def load(cls, input_path: str) -> "PoolAndQuantize":
92
- return cls()
93
 
94
 
95
- class Binarizer(Quantizer):
96
- def __init__(self, tanh_scale: float = 1.0, **kwargs):
97
- super().__init__(**kwargs)
98
- self._tanh_scale = tanh_scale
 
 
 
 
99
 
100
- def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
101
- return torch.where(x > 0, 1.0, -1.0)
102
 
103
- def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
104
- return torch.tanh(x * self._tanh_scale)
105
 
106
 
107
- class UnnormalizedBinarizer(nn.Module):
108
- def __init__(self, tanh_scale: float = 1.0, hard: bool = True):
109
  super().__init__()
110
- self.quantizer = Binarizer(tanh_scale=tanh_scale, hard=hard)
 
111
 
112
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
113
- features["sentence_embedding"] = self.quantizer(
114
- features["sentence_embedding"]
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return features
117
 
118
  @classmethod
119
- def load(cls, input_path: str) -> "UnnormalizedBinarizer":
 
 
 
 
 
 
 
 
 
120
  return cls()
121
-
 
 
 
1
  import torch
 
 
2
  from typing import Literal
3
+ from sentence_transformers.models import Module
4
 
5
 
6
  class Quantizer(torch.nn.Module):
 
25
  result = soft
26
  else:
27
  result = (
28
+ self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
 
 
29
  )
30
 
31
  return result
 
34
  class Int8TanhQuantizer(Quantizer):
35
  def __init__(
36
  self,
 
37
  hard: bool = True,
38
  ):
39
  super().__init__(hard=hard)
40
  self.qmin = -128
41
  self.qmax = 127
 
42
 
43
  def _soft_quantize(self, x, *args, **kwargs):
 
 
 
 
 
44
  return torch.tanh(x)
45
 
46
  def _hard_quantize(self, x, *args, **kwargs):
47
  soft = self._soft_quantize(x)
48
  int_x = torch.round(soft * self.qmax)
49
  int_x = torch.clamp(int_x, self.qmin, self.qmax)
50
+ return int_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
+ class BinaryTanhQuantizer(Quantizer):
54
+ def __init__(
55
+ self,
56
+ hard: bool = True,
57
+ scale: float = 1.0,
58
+ ):
59
+ super().__init__(hard)
60
+ self._scale = scale
61
 
62
+ def _soft_quantize(self, x, *args, **kwargs):
63
+ return torch.tanh(self._scale * x)
64
 
65
+ def _hard_quantize(self, x, *args, **kwargs):
66
+ return torch.where(x >= 0, 1.0, -1.0)
67
 
68
 
69
+ class FlexibleQuantizer(Module):
70
+ def __init__(self):
71
  super().__init__()
72
+ self._int8_quantizer = Int8TanhQuantizer()
73
+ self._binary_quantizer = BinaryTanhQuantizer()
74
 
75
+ def forward(
76
+ self,
77
+ features: dict[str, torch.Tensor],
78
+ quantization: Literal["binary", "int8"] = "int8",
79
+ **kwargs
80
+ ) -> dict[str, torch.Tensor]:
81
+ if quantization == "int8":
82
+ features["sentence_embedding"] = self._int8_quantizer(
83
+ features["sentence_embedding"]
84
+ )
85
+ elif quantization == "binary":
86
+ features["sentence_embedding"] = self._binary_quantizer(
87
+ features["sentence_embedding"]
88
+ )
89
+ else:
90
+ raise ValueError(
91
+ f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
92
+ )
93
  return features
94
 
95
  @classmethod
96
+ def load(
97
+ cls,
98
+ model_name_or_path: str,
99
+ subfolder: str = "",
100
+ token: bool | str | None = None,
101
+ cache_folder: str | None = None,
102
+ revision: str | None = None,
103
+ local_files_only: bool = False,
104
+ **kwargs,
105
+ ):
106
  return cls()
107
+
108
+ def save(self, output_path: str, *args, **kwargs) -> None:
109
+ return