Add new model implementation and rest

#1
Files changed (7) hide show
  1. chat_template.jinja +0 -85
  2. config.json +6 -8
  3. configuration.py +5 -0
  4. configuration_qwen3.py +0 -206
  5. modeling.py +335 -0
  6. modules.json +3 -2
  7. st_quantize.py +50 -62
chat_template.jinja DELETED
@@ -1,85 +0,0 @@
1
- {%- if tools %}
2
- {{- '<|im_start|>system\n' }}
3
- {%- if messages[0].role == 'system' %}
4
- {{- messages[0].content + '\n\n' }}
5
- {%- endif %}
6
- {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
- {%- for tool in tools %}
8
- {{- "\n" }}
9
- {{- tool | tojson }}
10
- {%- endfor %}
11
- {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
- {%- else %}
13
- {%- if messages[0].role == 'system' %}
14
- {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
- {%- endif %}
16
- {%- endif %}
17
- {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
- {%- for message in messages[::-1] %}
19
- {%- set index = (messages|length - 1) - loop.index0 %}
20
- {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
- {%- set ns.multi_step_tool = false %}
22
- {%- set ns.last_query_index = index %}
23
- {%- endif %}
24
- {%- endfor %}
25
- {%- for message in messages %}
26
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
27
- {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
28
- {%- elif message.role == "assistant" %}
29
- {%- set content = message.content %}
30
- {%- set reasoning_content = '' %}
31
- {%- if message.reasoning_content is defined and message.reasoning_content is not none %}
32
- {%- set reasoning_content = message.reasoning_content %}
33
- {%- else %}
34
- {%- if '</think>' in message.content %}
35
- {%- set content = message.content.split('</think>')[-1].lstrip('\n') %}
36
- {%- set reasoning_content = message.content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
37
- {%- endif %}
38
- {%- endif %}
39
- {%- if loop.index0 > ns.last_query_index %}
40
- {%- if loop.last or (not loop.last and reasoning_content) %}
41
- {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
42
- {%- else %}
43
- {{- '<|im_start|>' + message.role + '\n' + content }}
44
- {%- endif %}
45
- {%- else %}
46
- {{- '<|im_start|>' + message.role + '\n' + content }}
47
- {%- endif %}
48
- {%- if message.tool_calls %}
49
- {%- for tool_call in message.tool_calls %}
50
- {%- if (loop.first and content) or (not loop.first) %}
51
- {{- '\n' }}
52
- {%- endif %}
53
- {%- if tool_call.function %}
54
- {%- set tool_call = tool_call.function %}
55
- {%- endif %}
56
- {{- '<tool_call>\n{"name": "' }}
57
- {{- tool_call.name }}
58
- {{- '", "arguments": ' }}
59
- {%- if tool_call.arguments is string %}
60
- {{- tool_call.arguments }}
61
- {%- else %}
62
- {{- tool_call.arguments | tojson }}
63
- {%- endif %}
64
- {{- '}\n</tool_call>' }}
65
- {%- endfor %}
66
- {%- endif %}
67
- {{- '<|im_end|>\n' }}
68
- {%- elif message.role == "tool" %}
69
- {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
70
- {{- '<|im_start|>user' }}
71
- {%- endif %}
72
- {{- '\n<tool_response>\n' }}
73
- {{- message.content }}
74
- {{- '\n</tool_response>' }}
75
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
76
- {{- '<|im_end|>\n' }}
77
- {%- endif %}
78
- {%- endif %}
79
- {%- endfor %}
80
- {%- if add_generation_prompt %}
81
- {{- '<|im_start|>assistant\n' }}
82
- {%- if enable_thinking is defined and enable_thinking is false %}
83
- {{- '<think>\n\n</think>\n\n' }}
84
- {%- endif %}
85
- {%- endif %}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -1,13 +1,12 @@
1
  {
2
- "architectures": [
3
- "Qwen3Model"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
- "AutoConfig": "configuration_qwen3.Qwen3Config",
9
- "AutoModel": "perplexity-ai/bidirectional-qwen3-implementation--modeling_qwen3.Qwen3Model",
10
- "AutoModelForMaskedLM": "modeling_qwen3.Qwen3ForMaskedLM"
11
  },
12
  "bos_token_id": 151643,
13
  "dtype": "float32",
@@ -57,8 +56,7 @@
57
  ],
58
  "max_position_embeddings": 32768,
59
  "max_window_layers": 36,
60
- "mlm_loss_variant": "elbo_normalize",
61
- "model_type": "qwen3",
62
  "num_attention_heads": 32,
63
  "num_hidden_layers": 36,
64
  "num_key_value_heads": 8,
@@ -73,6 +71,6 @@
73
  "transformers_version": "5.0.0.dev0",
74
  "use_cache": false,
75
  "use_sliding_window": false,
76
- "variant": "bidirectional",
77
  "vocab_size": 151936
78
  }
 
1
  {
2
+ "architectures": [
3
+ "PPLXQwen3Model"
4
  ],
5
  "attention_bias": false,
6
  "attention_dropout": 0.0,
7
  "auto_map": {
8
+ "AutoConfig": "configuration.PPLXQwen3Config",
9
+ "AutoModel": "modeling.PPLXQwen3ContextualModel"
 
10
  },
11
  "bos_token_id": 151643,
12
  "dtype": "float32",
 
56
  ],
57
  "max_position_embeddings": 32768,
58
  "max_window_layers": 36,
59
+ "model_type": "bidirectional_pplx_qwen3",
 
60
  "num_attention_heads": 32,
61
  "num_hidden_layers": 36,
62
  "num_key_value_heads": 8,
 
71
  "transformers_version": "5.0.0.dev0",
72
  "use_cache": false,
73
  "use_sliding_window": false,
74
+ "attn_implementation": "sdpa",
75
  "vocab_size": 151936
76
  }
configuration.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
2
+
3
+
4
+ class PPLXQwen3Config(Qwen3Config):
5
+ model_type = "bidirectional_pplx_qwen3"
configuration_qwen3.py DELETED
@@ -1,206 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Qwen3 model configuration"""
16
-
17
- from typing import Optional, Literal
18
-
19
- import warnings
20
-
21
- from transformers.configuration_utils import PreTrainedConfig, layer_type_validation
22
- from transformers.modeling_rope_utils import RopeParameters, rope_config_validation, standardize_rope_params
23
- from transformers.utils import logging
24
-
25
-
26
- logger = logging.get_logger(__name__)
27
-
28
-
29
- class Qwen3Config(PreTrainedConfig):
30
- r"""
31
- This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
32
- Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
33
- with the defaults will yield a similar configuration to that of
34
- Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
35
-
36
- Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
37
- documentation from [`PreTrainedConfig`] for more information.
38
-
39
-
40
- Args:
41
- vocab_size (`int`, *optional*, defaults to 151936):
42
- Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
43
- `inputs_ids` passed when calling [`Qwen3Model`]
44
- hidden_size (`int`, *optional*, defaults to 4096):
45
- Dimension of the hidden representations.
46
- intermediate_size (`int`, *optional*, defaults to 22016):
47
- Dimension of the MLP representations.
48
- num_hidden_layers (`int`, *optional*, defaults to 32):
49
- Number of hidden layers in the Transformer encoder.
50
- num_attention_heads (`int`, *optional*, defaults to 32):
51
- Number of attention heads for each attention layer in the Transformer encoder.
52
- num_key_value_heads (`int`, *optional*, defaults to 32):
53
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
54
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
55
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
56
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
57
- by meanpooling all the original heads within that group. For more details, check out [this
58
- paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
59
- head_dim (`int`, *optional*, defaults to 128):
60
- The attention head dimension.
61
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
62
- The non-linear activation function (function or string) in the decoder.
63
- max_position_embeddings (`int`, *optional*, defaults to 32768):
64
- The maximum sequence length that this model might ever be used with.
65
- initializer_range (`float`, *optional*, defaults to 0.02):
66
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
67
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
68
- The epsilon used by the rms normalization layers.
69
- use_cache (`bool`, *optional*, defaults to `True`):
70
- Whether or not the model should return the last key/values attentions (not used by all models). Only
71
- relevant if `config.is_decoder=True`.
72
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
73
- Whether the model's input and output word embeddings should be tied.
74
- rope_parameters (`RopeParameters`, *optional*):
75
- Dictionary containing the configuration parameters for the RoPE embeddings. The dictionaty should contain
76
- a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
77
- with longer `max_position_embeddings`.
78
- attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
79
- Whether to use a bias in the query, key, value and output projection layers during self-attention.
80
- use_sliding_window (`bool`, *optional*, defaults to `False`):
81
- Whether to use sliding window attention.
82
- sliding_window (`int`, *optional*, defaults to 4096):
83
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
84
- max_window_layers (`int`, *optional*, defaults to 28):
85
- The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
86
- additional layer afterwards will use SWA (Sliding Window Attention).
87
- layer_types (`list`, *optional*):
88
- Attention pattern for each layer.
89
- attention_dropout (`float`, *optional*, defaults to 0.0):
90
- The dropout ratio for the attention probabilities.
91
-
92
- ```python
93
- >>> from transformers import Qwen3Model, Qwen3Config
94
-
95
- >>> # Initializing a Qwen3 style configuration
96
- >>> configuration = Qwen3Config()
97
-
98
- >>> # Initializing a model from the Qwen3-8B style configuration
99
- >>> model = Qwen3Model(configuration)
100
-
101
- >>> # Accessing the model configuration
102
- >>> configuration = model.config
103
- ```"""
104
-
105
- model_type = "qwen3"
106
- keys_to_ignore_at_inference = ["past_key_values"]
107
-
108
- # Default tensor parallel plan for base model `Qwen3`
109
- base_model_tp_plan = {
110
- "layers.*.self_attn.q_proj": "colwise",
111
- "layers.*.self_attn.k_proj": "colwise",
112
- "layers.*.self_attn.v_proj": "colwise",
113
- "layers.*.self_attn.o_proj": "rowwise",
114
- "layers.*.mlp.gate_proj": "colwise",
115
- "layers.*.mlp.up_proj": "colwise",
116
- "layers.*.mlp.down_proj": "rowwise",
117
- }
118
- base_model_pp_plan = {
119
- "embed_tokens": (["input_ids"], ["inputs_embeds"]),
120
- "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
121
- "norm": (["hidden_states"], ["hidden_states"]),
122
- }
123
-
124
- def __init__(
125
- self,
126
- vocab_size: Optional[int] = 151936,
127
- hidden_size: Optional[int] = 4096,
128
- intermediate_size: Optional[int] = 22016,
129
- num_hidden_layers: Optional[int] = 32,
130
- num_attention_heads: Optional[int] = 32,
131
- num_key_value_heads: Optional[int] = 32,
132
- head_dim: Optional[int] = 128,
133
- hidden_act: Optional[str] = "silu",
134
- max_position_embeddings: Optional[int] = 32768,
135
- initializer_range: Optional[float] = 0.02,
136
- rms_norm_eps: Optional[int] = 1e-6,
137
- use_cache: Optional[bool] = True,
138
- tie_word_embeddings: Optional[bool] = False,
139
- rope_parameters: Optional[RopeParameters | dict[RopeParameters]] = None,
140
- attention_bias: Optional[bool] = False,
141
- use_sliding_window: Optional[bool] = False,
142
- sliding_window: Optional[int] = 4096,
143
- max_window_layers: Optional[int] = 28,
144
- layer_types: Optional[list[str]] = None,
145
- attention_dropout: Optional[float] = 0.0,
146
- variant: Literal["causal", "bidirectional", "causal_dropout"] = "causal",
147
- mlm_loss_variant: Literal["simple", "masked_normalize", "elbo_normalize", "flat_cart"] = "simple",
148
- **kwargs,
149
- ):
150
- self.vocab_size = vocab_size
151
- self.max_position_embeddings = max_position_embeddings
152
- self.hidden_size = hidden_size
153
- self.intermediate_size = intermediate_size
154
- self.num_hidden_layers = num_hidden_layers
155
- self.num_attention_heads = num_attention_heads
156
- self.use_sliding_window = use_sliding_window
157
- self.sliding_window = sliding_window if self.use_sliding_window else None
158
- self.max_window_layers = max_window_layers
159
-
160
- # for backward compatibility
161
- if num_key_value_heads is None:
162
- num_key_value_heads = num_attention_heads
163
-
164
- self.num_key_value_heads = num_key_value_heads
165
- self.head_dim = head_dim
166
- self.hidden_act = hidden_act
167
- self.initializer_range = initializer_range
168
- self.rms_norm_eps = rms_norm_eps
169
- self.use_cache = use_cache
170
- self.attention_bias = attention_bias
171
- self.attention_dropout = attention_dropout
172
- # Try to set `rope_scaling` if available, otherwise use `rope_parameters`
173
- rope_scaling = kwargs.pop("rope_scaling", None)
174
- self.rope_parameters = rope_scaling or rope_parameters
175
-
176
- self.layer_types = layer_types
177
- if self.layer_types is None:
178
- self.layer_types = [
179
- "sliding_attention"
180
- if self.sliding_window is not None and i >= self.max_window_layers
181
- else "full_attention"
182
- for i in range(self.num_hidden_layers)
183
- ]
184
- layer_type_validation(self.layer_types, self.num_hidden_layers)
185
-
186
- # Validate the correctness of rotary position embeddings parameters
187
- rope_theta = kwargs.get("rope_theta", 10000.0)
188
- standardize_rope_params(self, rope_theta=rope_theta)
189
- rope_config_validation(self)
190
-
191
- self.variant = variant
192
- self.mlm_loss_variant = mlm_loss_variant
193
-
194
- if mlm_loss_variant not in ["simple", "masked_normalize", "elbo_normalize", "flat_cart"]:
195
- raise NotImplementedError(f"Loss variant {mlm_loss_variant} unknown")
196
-
197
- if variant != "causal" and use_cache:
198
- warnings.warn("Cannot use cache (use_cache) and bidirectional attention (is_causal=False)")
199
-
200
- super().__init__(
201
- tie_word_embeddings=tie_word_embeddings,
202
- **kwargs,
203
- )
204
-
205
-
206
- __all__ = ["Qwen3Config"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal
2
+ import numpy as np
3
+ import torch
4
+ from transformers import Qwen3Model
5
+ from transformers.cache_utils import Cache
6
+ from transformers.masking_utils import create_causal_mask
7
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
8
+ from transformers.processing_utils import Unpack
9
+ from transformers.utils import TransformersKwargs
10
+ from .configuration import PPLXQwen3Config
11
+ from transformers import AutoTokenizer
12
+ from .st_quantize import FlexibleQuantizer
13
+
14
+
15
+ def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
16
+ """
17
+ This creates bidirectional attention mask.
18
+ """
19
+
20
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
21
+ if attention_mask is None:
22
+ return torch.ones((), dtype=torch.bool)
23
+ return attention_mask[batch_idx, kv_idx].to(torch.bool)
24
+
25
+ return inner_mask
26
+
27
+
28
+ class PPLXQwen3Model(Qwen3Model):
29
+ _supports_flash_attn = True
30
+ _supports_sdpa = True
31
+
32
+ config_class = PPLXQwen3Config
33
+
34
+ def __init__(self, config):
35
+ super().__init__(config)
36
+ self.post_init()
37
+
38
+ def post_init(self):
39
+ super().post_init()
40
+ # Override to set all layers to non-causal attention. This'll work with attn_implementation="flash_attention_2" or "sdpa"
41
+ for layer in self.layers:
42
+ layer.self_attn.is_causal = False
43
+
44
+ def forward(
45
+ self,
46
+ input_ids: torch.LongTensor | None = None,
47
+ attention_mask: torch.Tensor | None = None,
48
+ position_ids: torch.LongTensor | None = None,
49
+ past_key_values: Cache | None = None,
50
+ inputs_embeds: torch.FloatTensor | None = None,
51
+ use_cache: bool | None = None,
52
+ cache_position: torch.LongTensor | None = None,
53
+ **kwargs: Unpack[TransformersKwargs],
54
+ ) -> BaseModelOutputWithPooling:
55
+ if inputs_embeds is None:
56
+ inputs_embeds = self.embed_tokens(input_ids)
57
+ input_ids = None
58
+
59
+ # We construct a dummy tensor imitating initial positions
60
+ dummy_cache_position = torch.arange(
61
+ inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
62
+ )
63
+ attention_mask = {
64
+ "full_attention": create_causal_mask(
65
+ config=self.config,
66
+ input_embeds=inputs_embeds,
67
+ attention_mask=attention_mask,
68
+ cache_position=dummy_cache_position,
69
+ past_key_values=None,
70
+ position_ids=position_ids,
71
+ or_mask_function=bidirectional_mask_function(attention_mask),
72
+ )
73
+ }
74
+
75
+ outputs = super().forward(
76
+ input_ids=input_ids,
77
+ attention_mask=attention_mask,
78
+ position_ids=position_ids,
79
+ past_key_values=past_key_values,
80
+ inputs_embeds=inputs_embeds,
81
+ use_cache=use_cache,
82
+ cache_position=cache_position,
83
+ **kwargs,
84
+ )
85
+ return outputs
86
+
87
+
88
+ class PPLXQwen3ContextualModel(PPLXQwen3Model):
89
+ """
90
+ Qwen3 model with contextual encoding support for late chunking.
91
+
92
+ This model extends PPLXQwen3Model with an encode() method that supports both
93
+ standard encoding (list[str]) and contextual encoding (list[list[str]]) with late chunking.
94
+
95
+ IMPORTANT: This model MUST be loaded with trust_remote_code=True:
96
+
97
+ from transformers import AutoModel
98
+
99
+ model = AutoModel.from_pretrained(
100
+ "path/to/model",
101
+ trust_remote_code=True # REQUIRED!
102
+ )
103
+
104
+ embeddings = model.encode([["chunk1", "chunk2"]])
105
+
106
+ Loading without trust_remote_code=True will fail to load this custom model class.
107
+ """
108
+
109
+ config_class = PPLXQwen3Config
110
+
111
+ def __init__(self, config):
112
+ super().__init__(config)
113
+
114
+ if not isinstance(config, PPLXQwen3Config):
115
+ raise TypeError(
116
+ f"PPLXQwen3ContextualModel requires PPLXQwen3Config, got {type(config).__name__}. "
117
+ f"Did you forget to load with trust_remote_code=True?"
118
+ )
119
+
120
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
121
+ self._flexible_quantizer = FlexibleQuantizer()
122
+
123
+ @staticmethod
124
+ def mean_pooling(
125
+ token_embeddings: torch.Tensor, attention_mask: torch.Tensor
126
+ ) -> torch.Tensor:
127
+ """Apply mean pooling to token embeddings."""
128
+ input_mask_expanded = (
129
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
130
+ )
131
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
132
+ input_mask_expanded.sum(1), min=1e-9
133
+ )
134
+
135
+ @torch.inference_mode()
136
+ def encode(
137
+ self,
138
+ documents: list[list[str]],
139
+ batch_size: int = 32,
140
+ show_progress_bar: bool = False,
141
+ device: str | torch.device | None = None,
142
+ normalize_embeddings: bool = False,
143
+ convert_to_numpy: bool = True,
144
+ quantization: Literal["int8", "binary"] = "int8",
145
+ ) -> list[np.ndarray] | list[torch.Tensor]:
146
+ """
147
+ Encode documents with late chunking (contextual embeddings).
148
+
149
+ This model is designed specifically for contextual encoding and always expects
150
+ documents as nested lists where each document is a list of text chunks.
151
+
152
+ The encoding process:
153
+ 1. Concatenate chunks with separator tokens
154
+ 2. Run forward pass to get token embeddings
155
+ 3. Extract and pool individual chunk embeddings (late chunking)
156
+ 4. Apply quantization (Int8 or binary, always enabled)
157
+ 5. Normalize embeddings if requested (applied after quantization)
158
+ 6. Convert to numpy or return as tensors
159
+
160
+ Args:
161
+ documents: List of documents, where each document is a list of text chunks.
162
+ Example: [["chunk1", "chunk2"], ["chunk1", "chunk2", "chunk3"]]
163
+ batch_size: Batch size for encoding
164
+ show_progress_bar: Show progress bar during encoding
165
+ device: Device to use for computation (defaults to model's device)
166
+ normalize_embeddings: Normalize embeddings to unit length (applied after quantization)
167
+ convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
168
+ quantization: Quantization type to apply. Options:
169
+ - "int8": Int8 tanh quantization (default)
170
+ - "binary": Binary tanh quantization
171
+
172
+ Returns:
173
+ List of numpy arrays or tensors (preserves document structure).
174
+ Each element has shape (n_chunks, hidden_dim).
175
+ embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
176
+ Output type depends on quantization method:
177
+ - Int8: int8 values in range [-128, 127]
178
+ - Binary: float values -1.0 or 1.0
179
+ """
180
+
181
+ if not isinstance(documents, list) or not all(
182
+ isinstance(doc, list) for doc in documents
183
+ ):
184
+ raise TypeError(
185
+ "Input 'documents' must be a list of lists of strings for contextual encoding."
186
+ )
187
+
188
+ if quantization not in ["int8", "binary"]:
189
+ raise ValueError(
190
+ f"Unsupported quantization type: '{quantization}'. "
191
+ f"Supported types are: 'int8', 'binary'. "
192
+ f"Got: {type(quantization).__name__} = '{quantization}'"
193
+ )
194
+
195
+ self.eval()
196
+
197
+ if device is None:
198
+ device = next(self.parameters()).device
199
+
200
+ all_embeddings = []
201
+
202
+ range_iter = range(0, len(documents), batch_size)
203
+ if show_progress_bar:
204
+ try:
205
+ from tqdm import tqdm
206
+
207
+ range_iter = tqdm(range_iter, desc="Encoding documents")
208
+ except ImportError:
209
+ pass
210
+
211
+ for i in range_iter:
212
+ batch_docs = documents[i : i + batch_size]
213
+
214
+ doc_strings = [
215
+ self.tokenizer.sep_token.join(chunks) for chunks in batch_docs
216
+ ]
217
+
218
+ inputs = self.tokenizer(
219
+ doc_strings,
220
+ padding=True,
221
+ truncation=True,
222
+ return_tensors="pt",
223
+ )
224
+ inputs = {k: v.to(device) for k, v in inputs.items()}
225
+
226
+ outputs = self.forward(**inputs)
227
+ token_embeddings = outputs.last_hidden_state
228
+
229
+ batch_chunk_embeddings = self._extract_chunks_from_concatenated(
230
+ input_ids=inputs["input_ids"],
231
+ token_embeddings=token_embeddings,
232
+ attention_mask=inputs["attention_mask"],
233
+ )
234
+
235
+ batch_chunk_embeddings = [
236
+ torch.stack([chunk for chunk in doc_chunks], dim=0)
237
+ for doc_chunks in batch_chunk_embeddings
238
+ ]
239
+
240
+ batch_chunk_embeddings = [
241
+ self._flexible_quantizer(
242
+ {"sentence_embedding": emb}, quantization=quantization
243
+ )["sentence_embedding"]
244
+ for emb in batch_chunk_embeddings
245
+ ]
246
+
247
+ if normalize_embeddings:
248
+ batch_chunk_embeddings = [
249
+ torch.nn.functional.normalize(emb, p=2, dim=-1)
250
+ for emb in batch_chunk_embeddings
251
+ ]
252
+
253
+ batch_chunk_embeddings = [emb.cpu() for emb in batch_chunk_embeddings]
254
+
255
+ all_embeddings.extend(batch_chunk_embeddings)
256
+
257
+ if convert_to_numpy:
258
+ all_embeddings = [emb.numpy() for emb in all_embeddings]
259
+
260
+ return all_embeddings
261
+
262
+ def _extract_chunks_from_concatenated(
263
+ self,
264
+ input_ids: torch.Tensor,
265
+ token_embeddings: torch.Tensor,
266
+ attention_mask: torch.Tensor,
267
+ ) -> list[list[torch.Tensor]]:
268
+ """
269
+ Extract individual chunk embeddings from concatenated sequence using late chunking.
270
+
271
+ This method splits concatenated sequences like "[chunk1][SEP][chunk2][SEP]..."
272
+ back into individual chunk embeddings by finding SEP token positions.
273
+
274
+ Args:
275
+ input_ids: Token IDs (batch_size, seq_len)
276
+ token_embeddings: Token embeddings (batch_size, seq_len, hidden_dim)
277
+ attention_mask: Attention mask (batch_size, seq_len)
278
+
279
+ Returns:
280
+ list[list[torch.Tensor]]: List of documents, each containing list of chunk embeddings
281
+
282
+ Note:
283
+ The sep_token_id is retrieved from self.tokenizer.sep_token_id.
284
+ Common values: Qwen2=151643, BERT=102, varies by tokenizer.
285
+ """
286
+ sep_token_id = self.tokenizer.sep_token_id
287
+ batch_size = input_ids.shape[0]
288
+
289
+ all_doc_chunks = []
290
+
291
+ for batch_idx in range(batch_size):
292
+ # non-pad sep tokens
293
+ valid_positions = attention_mask[batch_idx].bool()
294
+ sep_positions = (
295
+ (input_ids[batch_idx] == sep_token_id) & valid_positions
296
+ ).nonzero(as_tuple=True)[0]
297
+
298
+ chunk_embeddings = []
299
+ start_pos = 0
300
+
301
+ for sep_pos in sep_positions:
302
+ chunk_tokens = token_embeddings[batch_idx, start_pos:sep_pos]
303
+ chunk_mask = attention_mask[batch_idx, start_pos:sep_pos]
304
+
305
+ chunk_emb = self.mean_pooling(
306
+ chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
307
+ ).squeeze(0)
308
+
309
+ chunk_embeddings.append(chunk_emb)
310
+
311
+ start_pos = sep_pos + 1
312
+
313
+ # Handle the last chunk (after the last SEP token)
314
+ last_valid_pos = attention_mask[batch_idx].sum().item()
315
+
316
+ chunk_tokens = token_embeddings[batch_idx, start_pos:last_valid_pos]
317
+ chunk_mask = attention_mask[batch_idx, start_pos:last_valid_pos]
318
+
319
+ if chunk_mask.sum() > 0:
320
+ chunk_emb = self.mean_pooling(
321
+ chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
322
+ ).squeeze(0)
323
+ else:
324
+ # Empty chunk - create zero embedding
325
+ chunk_emb = torch.zeros(
326
+ token_embeddings.shape[-1],
327
+ device=token_embeddings.device,
328
+ dtype=token_embeddings.dtype,
329
+ )
330
+
331
+ chunk_embeddings.append(chunk_emb)
332
+
333
+ all_doc_chunks.append(chunk_embeddings)
334
+
335
+ return all_doc_chunks
modules.json CHANGED
@@ -15,6 +15,7 @@
15
  "idx": 2,
16
  "name": "2",
17
  "path": "",
18
- "type": "st_quantize.UnnormalizedInt8TanhQuantizer"
 
19
  }
20
- ]
 
15
  "idx": 2,
16
  "name": "2",
17
  "path": "",
18
+ "type": "st_quantize.FlexibleQuantizer",
19
+ "kwargs": ["quantization"]
20
  }
21
+ ]
st_quantize.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
- from torch import nn
3
- from typing import Optional
4
  from typing import Literal
 
5
 
6
 
7
  class Quantizer(torch.nn.Module):
@@ -26,9 +25,7 @@ class Quantizer(torch.nn.Module):
26
  result = soft
27
  else:
28
  result = (
29
- self._hard_quantize(x, *args, **kwargs).detach()
30
- + soft
31
- - soft.detach()
32
  )
33
 
34
  return result
@@ -37,85 +34,76 @@ class Quantizer(torch.nn.Module):
37
  class Int8TanhQuantizer(Quantizer):
38
  def __init__(
39
  self,
40
- normalize: bool = False,
41
  hard: bool = True,
42
  ):
43
  super().__init__(hard=hard)
44
  self.qmin = -128
45
  self.qmax = 127
46
- self._normalize = normalize
47
 
48
  def _soft_quantize(self, x, *args, **kwargs):
49
- if self._normalize:
50
- x = (x - x.mean(dim=-1, keepdim=True)) / (
51
- x.std(dim=-1, keepdim=True) + 1e-8
52
- )
53
-
54
  return torch.tanh(x)
55
 
56
  def _hard_quantize(self, x, *args, **kwargs):
57
  soft = self._soft_quantize(x)
58
  int_x = torch.round(soft * self.qmax)
59
  int_x = torch.clamp(int_x, self.qmin, self.qmax)
60
- return int_x / self.qmax
61
-
62
-
63
- class UnnormalizedInt8TanhQuantizer(Int8TanhQuantizer):
64
- def __init__(self):
65
- super().__init__()
66
- self.quantizer = Int8TanhQuantizer(normalize=False)
67
-
68
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
69
- features["sentence_embedding"] = self.quantizer(
70
- features["sentence_embedding"]
71
- )
72
- return features
73
-
74
- @classmethod
75
- def load(cls, input_path: str) -> "PoolAndQuantize":
76
- return cls()
77
-
78
-
79
- class NormalizedInt8TanhQuantizer(Int8TanhQuantizer):
80
- def __init__(self):
81
- super().__init__()
82
- self.quantizer = Int8TanhQuantizer(normalize=True)
83
-
84
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
85
- features["sentence_embedding"] = self.quantizer(
86
- features["sentence_embedding"]
87
- )
88
- return features
89
-
90
- @classmethod
91
- def load(cls, input_path: str) -> "PoolAndQuantize":
92
- return cls()
93
 
94
 
95
- class Binarizer(Quantizer):
96
- def __init__(self, tanh_scale: float = 1.0, **kwargs):
97
- super().__init__(**kwargs)
98
- self._tanh_scale = tanh_scale
 
 
 
 
99
 
100
- def _hard_quantize(self, x, *args, **kwargs) -> torch.Tensor:
101
- return torch.where(x > 0, 1.0, -1.0)
102
 
103
- def _soft_quantize(self, x, *args, **kwargs) -> torch.Tensor:
104
- return torch.tanh(x * self._tanh_scale)
105
 
106
 
107
- class UnnormalizedBinarizer(nn.Module):
108
- def __init__(self, tanh_scale: float = 1.0, hard: bool = True):
109
  super().__init__()
110
- self.quantizer = Binarizer(tanh_scale=tanh_scale, hard=hard)
 
111
 
112
- def forward(self, features: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
113
- features["sentence_embedding"] = self.quantizer(
114
- features["sentence_embedding"]
115
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  return features
117
 
118
  @classmethod
119
- def load(cls, input_path: str) -> "UnnormalizedBinarizer":
 
 
 
 
 
 
 
 
 
120
  return cls()
121
-
 
 
 
1
  import torch
 
 
2
  from typing import Literal
3
+ from sentence_transformers.models import Module
4
 
5
 
6
  class Quantizer(torch.nn.Module):
 
25
  result = soft
26
  else:
27
  result = (
28
+ self._hard_quantize(x, *args, **kwargs).detach() + soft - soft.detach()
 
 
29
  )
30
 
31
  return result
 
34
  class Int8TanhQuantizer(Quantizer):
35
  def __init__(
36
  self,
 
37
  hard: bool = True,
38
  ):
39
  super().__init__(hard=hard)
40
  self.qmin = -128
41
  self.qmax = 127
 
42
 
43
  def _soft_quantize(self, x, *args, **kwargs):
 
 
 
 
 
44
  return torch.tanh(x)
45
 
46
  def _hard_quantize(self, x, *args, **kwargs):
47
  soft = self._soft_quantize(x)
48
  int_x = torch.round(soft * self.qmax)
49
  int_x = torch.clamp(int_x, self.qmin, self.qmax)
50
+ return int_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
 
53
+ class BinaryTanhQuantizer(Quantizer):
54
+ def __init__(
55
+ self,
56
+ hard: bool = True,
57
+ scale: float = 1.0,
58
+ ):
59
+ super().__init__(hard)
60
+ self._scale = scale
61
 
62
+ def _soft_quantize(self, x, *args, **kwargs):
63
+ return torch.tanh(self._scale * x)
64
 
65
+ def _hard_quantize(self, x, *args, **kwargs):
66
+ return torch.where(x >= 0, 1.0, -1.0)
67
 
68
 
69
+ class FlexibleQuantizer(Module):
70
+ def __init__(self):
71
  super().__init__()
72
+ self._int8_quantizer = Int8TanhQuantizer()
73
+ self._binary_quantizer = BinaryTanhQuantizer()
74
 
75
+ def forward(
76
+ self,
77
+ features: dict[str, torch.Tensor],
78
+ quantization: Literal["binary", "int8"] = "int8",
79
+ **kwargs
80
+ ) -> dict[str, torch.Tensor]:
81
+ if quantization == "int8":
82
+ features["sentence_embedding"] = self._int8_quantizer(
83
+ features["sentence_embedding"]
84
+ )
85
+ elif quantization == "binary":
86
+ features["sentence_embedding"] = self._binary_quantizer(
87
+ features["sentence_embedding"]
88
+ )
89
+ else:
90
+ raise ValueError(
91
+ f"Invalid quantization type: {quantization}. Must be 'binary' or 'int8'."
92
+ )
93
  return features
94
 
95
  @classmethod
96
+ def load(
97
+ cls,
98
+ model_name_or_path: str,
99
+ subfolder: str = "",
100
+ token: bool | str | None = None,
101
+ cache_folder: str | None = None,
102
+ revision: str | None = None,
103
+ local_files_only: bool = False,
104
+ **kwargs,
105
+ ):
106
  return cls()
107
+
108
+ def save(self, output_path: str, *args, **kwargs) -> None:
109
+ return