mkrimmel-pplx commited on
Commit
4fdc645
·
1 Parent(s): 581ccac

refactor: new modeling code

Browse files
Files changed (3) hide show
  1. configuration.py +2 -125
  2. modeling.py +52 -774
  3. st_quantize.py +16 -2
configuration.py CHANGED
@@ -1,128 +1,5 @@
1
- # coding=utf-8
2
- # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This file has been modified from the original Qwen3 implementation.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
 
18
- from typing import Optional
19
- from transformers import PretrainedConfig
20
-
21
-
22
- class PPLXQwen3Config(PretrainedConfig):
23
- """
24
- PPLX configuration class for Qwen3Model compatible with transformers < 5.X.
25
- This implementation only supports bidirectional attention (no causal or dropout variants).
26
-
27
- Args:
28
- vocab_size (int, optional, defaults to 151936):
29
- Vocabulary size of the Qwen3 model.
30
- hidden_size (int, optional, defaults to 4096):
31
- Dimension of the hidden representations.
32
- intermediate_size (int, optional, defaults to 22016):
33
- Dimension of the MLP representations.
34
- num_hidden_layers (int, optional, defaults to 32):
35
- Number of hidden layers in the Transformer encoder.
36
- num_attention_heads (int, optional, defaults to 32):
37
- Number of attention heads for each attention layer.
38
- num_key_value_heads (int, optional, defaults to 32):
39
- Number of key_value heads for Grouped Query Attention.
40
- head_dim (int, optional, defaults to 128):
41
- The attention head dimension.
42
- hidden_act (str, optional, defaults to "silu"):
43
- The non-linear activation function.
44
- max_position_embeddings (int, optional, defaults to 32768):
45
- The maximum sequence length.
46
- initializer_range (float, optional, defaults to 0.02):
47
- The standard deviation for weight initialization.
48
- rms_norm_eps (float, optional, defaults to 1e-06):
49
- The epsilon for rms normalization layers.
50
- attention_bias (bool, optional, defaults to False):
51
- Whether to use bias in attention projection layers.
52
- attention_dropout (float, optional, defaults to 0.0):
53
- The dropout ratio for attention probabilities.
54
- rope_theta (float, optional, defaults to 10000.0):
55
- The base period of the RoPE embeddings.
56
- pad_token_id (int, optional):
57
- The id of the padding token.
58
- bos_token_id (int, optional):
59
- The id of the beginning-of-sequence token.
60
- eos_token_id (int, optional):
61
- The id of the end-of-sequence token.
62
- attn_implementation (str, optional):
63
- The attention implementation to use. Options: "eager", "sdpa".
64
- If None, will auto-select based on availability.
65
- """
66
 
 
67
  model_type = "bidirectional_pplx_qwen3"
68
-
69
- def __init__(
70
- self,
71
- vocab_size: Optional[int] = 151936,
72
- hidden_size: Optional[int] = 4096,
73
- intermediate_size: Optional[int] = 22016,
74
- num_hidden_layers: Optional[int] = 32,
75
- num_attention_heads: Optional[int] = 32,
76
- num_key_value_heads: Optional[int] = 32,
77
- head_dim: Optional[int] = 128,
78
- hidden_act: Optional[str] = "silu",
79
- max_position_embeddings: Optional[int] = 32768,
80
- initializer_range: Optional[float] = 0.02,
81
- rms_norm_eps: Optional[float] = 1e-6,
82
- attention_bias: Optional[bool] = False,
83
- attention_dropout: Optional[float] = 0.0,
84
- rope_theta: Optional[float] = 10000.0,
85
- pad_token_id: Optional[int] = None,
86
- bos_token_id: Optional[int] = None,
87
- eos_token_id: Optional[int] = None,
88
- attn_implementation: Optional[str] = None,
89
- **kwargs,
90
- ):
91
- # Extract attn_implementation from kwargs if not explicitly provided
92
- if attn_implementation is None and 'attn_implementation' in kwargs:
93
- attn_implementation = kwargs.pop('attn_implementation')
94
-
95
- self.vocab_size = vocab_size
96
- self.max_position_embeddings = max_position_embeddings
97
- self.hidden_size = hidden_size
98
- self.intermediate_size = intermediate_size
99
- self.num_hidden_layers = num_hidden_layers
100
- self.num_attention_heads = num_attention_heads
101
- self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads
102
- self.head_dim = head_dim
103
- self.hidden_act = hidden_act
104
- self.initializer_range = initializer_range
105
- self.rms_norm_eps = rms_norm_eps
106
- self.attention_bias = attention_bias
107
- self.attention_dropout = attention_dropout
108
- self.rope_theta = rope_theta
109
-
110
- # Legacy: only bidirectional attention supported
111
- self.is_causal = False
112
-
113
- # Initialize parent class with token IDs
114
- super().__init__(
115
- pad_token_id=pad_token_id,
116
- bos_token_id=bos_token_id,
117
- eos_token_id=eos_token_id,
118
- **kwargs,
119
- )
120
-
121
- # Store attn_implementation as a regular attribute AFTER super().__init__() (will be serialized)
122
- self.attn_implementation = attn_implementation
123
-
124
-
125
- # Register for AutoConfig
126
- PPLXQwen3Config.register_for_auto_class()
127
-
128
- __all__ = ["PPLXQwen3Config"]
 
1
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ class PPLXQwen3Config(Qwen3Config):
5
  model_type = "bidirectional_pplx_qwen3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling.py CHANGED
@@ -1,805 +1,83 @@
1
- # coding=utf-8
2
- # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
- #
4
- # This file has been modified from the original Qwen3 implementation.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
-
18
- from typing import Optional, Tuple, Literal
19
-
20
- import numpy as np
21
  import torch
22
- from torch import nn
23
- import torch.nn.functional as F
24
- from transformers import AutoTokenizer
25
- from transformers.modeling_utils import PreTrainedModel
26
- from transformers.modeling_outputs import BaseModelOutputWithPast
27
-
28
  from .configuration import PPLXQwen3Config
29
- from .st_quantize import FlexibleQuantizer
30
-
31
-
32
- # Activation functions mapping
33
- ACT2FN = {
34
- "silu": nn.functional.silu,
35
- "gelu": nn.functional.gelu,
36
- "relu": nn.functional.relu,
37
- }
38
-
39
-
40
- class PPLXQwen3RMSNorm(nn.Module):
41
- """RMSNorm implementation compatible with transformers < 5.X"""
42
-
43
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
44
- super().__init__()
45
- self.weight = nn.Parameter(torch.ones(hidden_size))
46
- self.variance_epsilon = eps
47
-
48
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
49
- input_dtype = hidden_states.dtype
50
- hidden_states = hidden_states.to(torch.float32)
51
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
52
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
53
- return self.weight * hidden_states.to(input_dtype)
54
-
55
-
56
- class PPLXQwen3MLP(nn.Module):
57
- """MLP implementation compatible with transformers < 5.X"""
58
-
59
- def __init__(self, config):
60
- super().__init__()
61
- self.config = config
62
- self.hidden_size = config.hidden_size
63
- self.intermediate_size = config.intermediate_size
64
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
65
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
66
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
67
- self.act_fn = ACT2FN[config.hidden_act]
68
-
69
- def forward(self, x):
70
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
71
- return down_proj
72
-
73
-
74
- class PPLXQwen3RotaryEmbedding(nn.Module):
75
- """Rotary Position Embedding implementation compatible with transformers < 5.X"""
76
-
77
- def __init__(self, config, device=None):
78
- super().__init__()
79
- self.max_seq_len_cached = config.max_position_embeddings
80
- self.original_max_seq_len = config.max_position_embeddings
81
- self.config = config
82
-
83
- # Check rope type and raise if not default
84
- self.rope_type = self.config.rope_parameters["rope_type"]
85
- if self.rope_type != "default":
86
- raise NotImplementedError("Only default RoPE implemented")
87
-
88
- # Compute inverse frequencies using the static method
89
- inv_freq, self.attention_scaling = self.compute_default_rope_parameters(
90
- config, device
91
- )
92
- self.register_buffer("inv_freq", inv_freq, persistent=False)
93
- self.original_inv_freq = inv_freq
94
-
95
- @staticmethod
96
- def compute_default_rope_parameters(
97
- config: Optional["PPLXQwen3Config"] = None,
98
- device: Optional[torch.device] = None,
99
- ) -> Tuple[torch.Tensor, float]:
100
- """
101
- Computes the inverse frequencies according to the original RoPE implementation
102
-
103
- Args:
104
- config: The model configuration.
105
- device: The device to use for initialization of the inverse frequencies.
106
-
107
- Returns:
108
- Tuple of (inv_freq, attention_scaling), containing the inverse frequencies
109
- for the RoPE embeddings and the post-processing scaling factor applied to
110
- the computed cos/sin.
111
- """
112
- base = config.rope_parameters["rope_theta"]
113
- dim = config.head_dim
114
 
115
- attention_factor = 1.0 # Unused in default RoPE
116
 
117
- # Compute the inverse frequencies
118
- inv_freq = 1.0 / (
119
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
120
- )
121
- return inv_freq, attention_factor
122
-
123
- def forward(self, x, position_ids):
124
- # Expand inv_freq to match batch size
125
- inv_freq_expanded = (
126
- self.inv_freq[None, :, None]
127
- .float()
128
- .expand(position_ids.shape[0], -1, 1)
129
- .to(x.device)
130
- )
131
- position_ids_expanded = position_ids[:, None, :].float()
132
-
133
- # Compute frequencies
134
- device_type = (
135
- x.device.type
136
- if isinstance(x.device.type, str) and x.device.type != "mps"
137
- else "cpu"
138
- )
139
- with torch.autocast(device_type=device_type, enabled=False):
140
- freqs = (
141
- inv_freq_expanded.float() @ position_ids_expanded.float()
142
- ).transpose(1, 2)
143
- emb = torch.cat((freqs, freqs), dim=-1)
144
- cos = emb.cos() * self.attention_scaling
145
- sin = emb.sin() * self.attention_scaling
146
-
147
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
148
-
149
-
150
- def rotate_half(x):
151
- """Rotates half the hidden dims of the input."""
152
- x1 = x[..., : x.shape[-1] // 2]
153
- x2 = x[..., x.shape[-1] // 2 :]
154
- return torch.cat((-x2, x1), dim=-1)
155
-
156
-
157
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
158
- """Applies Rotary Position Embedding to the query and key tensors."""
159
- cos = cos.unsqueeze(unsqueeze_dim)
160
- sin = sin.unsqueeze(unsqueeze_dim)
161
- q_embed = (q * cos) + (rotate_half(q) * sin)
162
- k_embed = (k * cos) + (rotate_half(k) * sin)
163
- return q_embed, k_embed
164
-
165
-
166
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
167
  """
168
- Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
169
- Hidden states go from (batch, num_key_value_heads, seqlen, head_dim)
170
- to (batch, num_attention_heads, seqlen, head_dim)
171
  """
172
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
173
- if n_rep == 1:
174
- return hidden_states
175
- hidden_states = hidden_states[:, :, None, :, :].expand(
176
- batch, num_key_value_heads, n_rep, slen, head_dim
177
- )
178
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
179
-
180
-
181
- def eager_attention_forward(
182
- query: torch.Tensor,
183
- key: torch.Tensor,
184
- value: torch.Tensor,
185
- attention_mask: Optional[torch.Tensor],
186
- scaling: float,
187
- dropout: float = 0.0,
188
- training: bool = False,
189
- num_key_value_groups: int = 1,
190
- **kwargs,
191
- ) -> Tuple[torch.Tensor, torch.Tensor]:
192
- """
193
- Eager (vanilla) attention implementation.
194
-
195
- Args:
196
- query: (batch, num_heads, seq_len, head_dim)
197
- key: (batch, num_kv_heads, seq_len, head_dim)
198
- value: (batch, num_kv_heads, seq_len, head_dim)
199
- attention_mask: (batch, 1, seq_len, seq_len)
200
- scaling: attention scaling factor
201
- dropout: dropout probability
202
- training: whether in training mode
203
- num_key_value_groups: number of query heads per key/value head (for GQA)
204
- """
205
- # Repeat k/v heads if using GQA
206
- key_states = repeat_kv(key, num_key_value_groups)
207
- value_states = repeat_kv(value, num_key_value_groups)
208
-
209
- # Compute attention scores
210
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
211
-
212
- # Apply attention mask
213
- if attention_mask is not None:
214
- attn_weights = attn_weights + attention_mask
215
-
216
- # Softmax and dropout
217
- attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
218
- attn_weights = F.dropout(attn_weights, p=dropout, training=training)
219
-
220
- # Compute output
221
- attn_output = torch.matmul(attn_weights, value_states)
222
-
223
- return attn_output, attn_weights
224
-
225
-
226
- def sdpa_attention_forward(
227
- query: torch.Tensor,
228
- key: torch.Tensor,
229
- value: torch.Tensor,
230
- attention_mask: Optional[torch.Tensor],
231
- scaling: float,
232
- dropout: float = 0.0,
233
- training: bool = False,
234
- num_key_value_groups: int = 1,
235
- **kwargs,
236
- ) -> Tuple[torch.Tensor, None]:
237
- """
238
- Scaled Dot Product Attention using PyTorch's native implementation.
239
-
240
- Args:
241
- query: (batch, num_heads, seq_len, head_dim)
242
- key: (batch, num_kv_heads, seq_len, head_dim)
243
- value: (batch, num_kv_heads, seq_len, head_dim)
244
- attention_mask: (batch, 1, seq_len, seq_len) or None
245
- scaling: attention scaling factor (handled internally by SDPA)
246
- dropout: dropout probability
247
- training: whether in training mode
248
- num_key_value_groups: number of query heads per key/value head (for GQA)
249
- """
250
- # Repeat k/v heads if using GQA
251
- key = repeat_kv(key, num_key_value_groups)
252
- value = repeat_kv(value, num_key_value_groups)
253
-
254
- # Convert attention mask for SDPA
255
- # SDPA expects additive mask in shape (batch, num_heads, seq_len, seq_len) or broadcastable
256
- attn_mask = None
257
- if attention_mask is not None:
258
- # attention_mask is (batch, 1, seq_len, seq_len)
259
- # Broadcast to (batch, num_heads, seq_len, seq_len) by repeating
260
- batch_size, _, seq_len, _ = attention_mask.shape
261
- num_heads = query.shape[1]
262
- # Expand to match num_heads
263
- attn_mask = attention_mask.expand(batch_size, num_heads, seq_len, seq_len)
264
-
265
- # PyTorch SDPA
266
- attn_output = F.scaled_dot_product_attention(
267
- query,
268
- key,
269
- value,
270
- attn_mask=attn_mask,
271
- dropout_p=dropout if training else 0.0,
272
- is_causal=False, # We handle masking explicitly for bidirectional
273
- scale=scaling,
274
- )
275
-
276
- return attn_output, None
277
-
278
-
279
- # Mapping of attention implementation names to functions
280
- ATTENTION_IMPLEMENTATIONS = {
281
- "eager": eager_attention_forward,
282
- "sdpa": sdpa_attention_forward,
283
- }
284
-
285
-
286
- class PPLXQwen3Attention(nn.Module):
287
- """
288
- Multi-headed attention implementation compatible with transformers < 5.X
289
- Supports multiple attention backends: eager, sdpa
290
- """
291
-
292
- def __init__(self, config, layer_idx: int):
293
- super().__init__()
294
- self.config = config
295
- self.layer_idx = layer_idx
296
- self.head_dim = config.head_dim
297
- self.num_attention_heads = config.num_attention_heads
298
- self.num_key_value_heads = config.num_key_value_heads
299
- self.num_key_value_groups = (
300
- config.num_attention_heads // config.num_key_value_heads
301
- )
302
- self.scaling = self.head_dim**-0.5
303
- self.attention_dropout = config.attention_dropout
304
-
305
- self.q_proj = nn.Linear(
306
- config.hidden_size,
307
- config.num_attention_heads * self.head_dim,
308
- bias=config.attention_bias,
309
- )
310
- self.k_proj = nn.Linear(
311
- config.hidden_size,
312
- config.num_key_value_heads * self.head_dim,
313
- bias=config.attention_bias,
314
- )
315
- self.v_proj = nn.Linear(
316
- config.hidden_size,
317
- config.num_key_value_heads * self.head_dim,
318
- bias=config.attention_bias,
319
- )
320
- self.o_proj = nn.Linear(
321
- config.num_attention_heads * self.head_dim,
322
- config.hidden_size,
323
- bias=config.attention_bias,
324
- )
325
- self.q_norm = PPLXQwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
326
- self.k_norm = PPLXQwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
327
-
328
- # Select attention implementation
329
- self._select_attention_implementation(config)
330
-
331
- def _select_attention_implementation(self, config):
332
- """Select the attention implementation based on config or availability."""
333
- attn_impl = getattr(config, "attn_implementation", None)
334
-
335
- if attn_impl is None:
336
- # Auto-select: prefer faster implementations
337
- if hasattr(F, "scaled_dot_product_attention"):
338
- attn_impl = "sdpa"
339
- else:
340
- attn_impl = "eager"
341
-
342
- if attn_impl not in ATTENTION_IMPLEMENTATIONS:
343
- raise ValueError(
344
- f"Unknown attention implementation: {attn_impl}. "
345
- f"Available: {list(ATTENTION_IMPLEMENTATIONS.keys())}"
346
- )
347
-
348
- # Check availability
349
- if attn_impl == "sdpa" and not hasattr(F, "scaled_dot_product_attention"):
350
- raise ImportError(
351
- "sdpa requested but not available. Please use PyTorch >= 2.0"
352
- )
353
-
354
- self.attn_implementation = attn_impl
355
- self.attn_function = ATTENTION_IMPLEMENTATIONS[attn_impl]
356
-
357
- def forward(
358
- self,
359
- hidden_states: torch.Tensor,
360
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
361
- attention_mask: Optional[torch.Tensor] = None,
362
- **kwargs,
363
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
364
- input_shape = hidden_states.shape[:-1]
365
- hidden_shape = (*input_shape, -1, self.head_dim)
366
-
367
- # Project and reshape
368
- query_states = self.q_norm(
369
- self.q_proj(hidden_states).view(hidden_shape)
370
- ).transpose(1, 2)
371
- key_states = self.k_norm(
372
- self.k_proj(hidden_states).view(hidden_shape)
373
- ).transpose(1, 2)
374
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
375
-
376
- # Apply rotary embeddings
377
- cos, sin = position_embeddings
378
- query_states, key_states = apply_rotary_pos_emb(
379
- query_states, key_states, cos, sin
380
- )
381
-
382
- # Call the selected attention implementation
383
- attn_output, attn_weights = self.attn_function(
384
- query=query_states,
385
- key=key_states,
386
- value=value_states,
387
- attention_mask=attention_mask,
388
- scaling=self.scaling,
389
- dropout=self.attention_dropout,
390
- training=self.training,
391
- num_key_value_groups=self.num_key_value_groups,
392
- )
393
-
394
- # Reshape and project output
395
- attn_output = attn_output.transpose(1, 2).contiguous()
396
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
397
- attn_output = self.o_proj(attn_output)
398
-
399
- return attn_output, attn_weights
400
-
401
-
402
- class PPLXQwen3DecoderLayer(nn.Module):
403
- """Decoder layer implementation compatible with transformers < 5.X"""
404
-
405
- def __init__(self, config, layer_idx: int):
406
- super().__init__()
407
- self.hidden_size = config.hidden_size
408
- self.self_attn = PPLXQwen3Attention(config=config, layer_idx=layer_idx)
409
- self.mlp = PPLXQwen3MLP(config)
410
- self.input_layernorm = PPLXQwen3RMSNorm(
411
- config.hidden_size, eps=config.rms_norm_eps
412
- )
413
- self.post_attention_layernorm = PPLXQwen3RMSNorm(
414
- config.hidden_size, eps=config.rms_norm_eps
415
- )
416
-
417
- def forward(
418
- self,
419
- hidden_states: torch.Tensor,
420
- attention_mask: Optional[torch.Tensor] = None,
421
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
422
- **kwargs,
423
- ) -> torch.Tensor:
424
- # Self Attention
425
- residual = hidden_states
426
- hidden_states = self.input_layernorm(hidden_states)
427
- hidden_states, _ = self.self_attn(
428
- hidden_states=hidden_states,
429
- attention_mask=attention_mask,
430
- position_embeddings=position_embeddings,
431
- )
432
- hidden_states = residual + hidden_states
433
 
434
- # MLP
435
- residual = hidden_states
436
- hidden_states = self.post_attention_layernorm(hidden_states)
437
- hidden_states = self.mlp(hidden_states)
438
- hidden_states = residual + hidden_states
439
 
440
- return hidden_states
441
 
442
 
443
- class PPLXQwen3PreTrainedModel(PreTrainedModel):
444
- """
445
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
446
- models.
447
- """
448
 
449
  config_class = PPLXQwen3Config
450
- base_model_prefix = "model"
451
- supports_gradient_checkpointing = False
452
- _no_split_modules = ["PPLXQwen3DecoderLayer"]
453
- _skip_keys_device_placement = ["past_key_values"]
454
-
455
-
456
- class PPLXQwen3Model(PPLXQwen3PreTrainedModel):
457
- """
458
- Qwen3 Model implementation compatible with transformers < 5.X.
459
- Only supports bidirectional attention (no causal masking or caching).
460
- """
461
 
462
  def __init__(self, config):
463
  super().__init__(config)
464
- self.padding_idx = config.pad_token_id
465
- self.vocab_size = config.vocab_size
466
-
467
- self.embed_tokens = nn.Embedding(
468
- config.vocab_size, config.hidden_size, self.padding_idx
469
- )
470
- self.layers = nn.ModuleList(
471
- [
472
- PPLXQwen3DecoderLayer(config, layer_idx)
473
- for layer_idx in range(config.num_hidden_layers)
474
- ]
475
- )
476
- self.norm = PPLXQwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
477
- self.rotary_emb = PPLXQwen3RotaryEmbedding(config=config)
478
-
479
- # Initialize weights and apply final processing
480
  self.post_init()
481
 
 
 
 
 
 
 
482
  def forward(
483
  self,
484
- input_ids: Optional[torch.LongTensor] = None,
485
- attention_mask: Optional[torch.Tensor] = None,
486
- position_ids: Optional[torch.LongTensor] = None,
487
- inputs_embeds: Optional[torch.FloatTensor] = None,
488
- **kwargs,
489
- ) -> BaseModelOutputWithPast:
490
- # Get embeddings
 
 
491
  if inputs_embeds is None:
492
  inputs_embeds = self.embed_tokens(input_ids)
 
493
 
494
- batch_size, seq_length = inputs_embeds.shape[:2]
495
-
496
- # Create position IDs if not provided
497
- if position_ids is None:
498
- position_ids = (
499
- torch.arange(seq_length, device=inputs_embeds.device)
500
- .unsqueeze(0)
501
- .expand(batch_size, -1)
502
- )
503
-
504
- # Create bidirectional attention mask
505
- # Transform from (batch_size, seq_length) to (batch_size, 1, seq_length, seq_length)
506
- if attention_mask is not None:
507
- # Expand attention mask to 4D
508
- attention_mask = attention_mask[:, None, None, :].to(
509
- dtype=inputs_embeds.dtype
510
- )
511
- attention_mask = (1.0 - attention_mask) * torch.finfo(
512
- inputs_embeds.dtype
513
- ).min
514
- # Broadcast to full attention shape
515
- attention_mask = attention_mask.expand(
516
- batch_size, 1, seq_length, seq_length
517
- )
518
- else:
519
- # No masking needed for bidirectional attention with no padding
520
- attention_mask = torch.zeros(
521
- (batch_size, 1, seq_length, seq_length),
522
- dtype=inputs_embeds.dtype,
523
- device=inputs_embeds.device,
524
- )
525
-
526
- # Get rotary embeddings
527
- position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
528
-
529
- # Pass through decoder layers
530
- hidden_states = inputs_embeds
531
- for decoder_layer in self.layers:
532
- hidden_states = decoder_layer(
533
- hidden_states,
534
  attention_mask=attention_mask,
535
- position_embeddings=position_embeddings,
 
 
 
536
  )
 
537
 
538
- # Final norm
539
- hidden_states = self.norm(hidden_states)
540
-
541
- return BaseModelOutputWithPast(last_hidden_state=hidden_states)
542
-
543
-
544
- class PPLXQwen3ContextualModel(PPLXQwen3PreTrainedModel):
545
- """
546
- Qwen3 model with contextual encoding support for late chunking.
547
-
548
- This model extends PPLXQwen3Model with an encode() method that supports both
549
- standard encoding (list[str]) and contextual encoding (list[list[str]]) with late chunking.
550
- """
551
-
552
- def __init__(self, config):
553
- super().__init__(config)
554
- self.model = PPLXQwen3Model(config)
555
- self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
556
- self._flexible_quantizer = FlexibleQuantizer()
557
- self.post_init()
558
-
559
- def forward(
560
- self,
561
- input_ids: Optional[torch.LongTensor] = None,
562
- attention_mask: Optional[torch.Tensor] = None,
563
- position_ids: Optional[torch.LongTensor] = None,
564
- inputs_embeds: Optional[torch.FloatTensor] = None,
565
- **kwargs,
566
- ) -> BaseModelOutputWithPast:
567
- """Forward pass through the model."""
568
- return self.model(
569
  input_ids=input_ids,
570
  attention_mask=attention_mask,
571
  position_ids=position_ids,
 
572
  inputs_embeds=inputs_embeds,
 
 
573
  **kwargs,
574
  )
575
-
576
- @staticmethod
577
- def mean_pooling(
578
- token_embeddings: torch.Tensor, attention_mask: torch.Tensor
579
- ) -> torch.Tensor:
580
- """Apply mean pooling to token embeddings."""
581
- input_mask_expanded = (
582
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
583
- )
584
- return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
585
- input_mask_expanded.sum(1), min=1e-9
586
- )
587
-
588
- @torch.inference_mode()
589
- def encode(
590
- self,
591
- documents: list[list[str]],
592
- batch_size: int = 32,
593
- show_progress_bar: bool = False,
594
- device: str | torch.device | None = None,
595
- normalize_embeddings: bool = False,
596
- convert_to_numpy: bool = True,
597
- quantization: Literal["int8", "binary"] = "int8",
598
- ) -> list[np.ndarray] | list[torch.Tensor]:
599
- """
600
- Encode documents with late chunking (contextual embeddings).
601
-
602
- This model is designed specifically for contextual encoding and always expects
603
- documents as nested lists where each document is a list of text chunks.
604
-
605
- The encoding process:
606
- 1. Concatenate chunks with separator tokens
607
- 2. Run forward pass to get token embeddings
608
- 3. Extract and pool individual chunk embeddings (late chunking)
609
- 4. Apply quantization (Int8 or binary, always enabled)
610
- 5. Normalize embeddings if requested (applied after quantization)
611
- 6. Convert to numpy or return as tensors
612
-
613
- Args:
614
- documents: List of documents, where each document is a list of text chunks.
615
- Example: [["chunk1", "chunk2"], ["chunk1", "chunk2", "chunk3"]]
616
- batch_size: Batch size for encoding
617
- show_progress_bar: Show progress bar during encoding
618
- device: Device to use for computation (defaults to model's device)
619
- normalize_embeddings: Normalize embeddings to unit length (applied after quantization)
620
- convert_to_numpy: If True, returns list[np.ndarray], otherwise list[torch.Tensor]
621
- quantization: Quantization type to apply. Options:
622
- - "int8": Int8 tanh quantization (default)
623
- - "binary": Binary tanh quantization
624
-
625
- Returns:
626
- List of numpy arrays or tensors (preserves document structure).
627
- Each element has shape (n_chunks, hidden_dim).
628
- embeddings[0].shape = (2, 1024), embeddings[1].shape = (3, 1024)
629
- Output type depends on quantization method:
630
- - Int8: int8 values in range [-128, 127]
631
- - Binary: float values -1.0 or 1.0
632
- """
633
-
634
- if not isinstance(documents, list) or not all(
635
- isinstance(doc, list) for doc in documents
636
- ):
637
- raise TypeError(
638
- "Input 'documents' must be a list of lists of strings for contextual encoding."
639
- )
640
-
641
- if quantization not in ["int8", "binary"]:
642
- raise ValueError(
643
- f"Unsupported quantization type: '{quantization}'. "
644
- f"Supported types are: 'int8', 'binary'. "
645
- f"Got: {type(quantization).__name__} = '{quantization}'"
646
- )
647
-
648
- self.eval()
649
-
650
- if device is None:
651
- device = next(self.parameters()).device
652
-
653
- all_embeddings = []
654
-
655
- range_iter = range(0, len(documents), batch_size)
656
- if show_progress_bar:
657
- try:
658
- from tqdm import tqdm
659
-
660
- range_iter = tqdm(range_iter, desc="Encoding documents")
661
- except ImportError:
662
- pass
663
-
664
- for i in range_iter:
665
- batch_docs = documents[i : i + batch_size]
666
-
667
- doc_strings = [
668
- self.tokenizer.sep_token.join(chunks) for chunks in batch_docs
669
- ]
670
-
671
- inputs = self.tokenizer(
672
- doc_strings,
673
- padding=True,
674
- truncation=True,
675
- return_tensors="pt",
676
- )
677
- inputs = {k: v.to(device) for k, v in inputs.items()}
678
-
679
- outputs = self.forward(**inputs)
680
- token_embeddings = outputs.last_hidden_state
681
-
682
- batch_chunk_embeddings = self._extract_chunks_from_concatenated(
683
- input_ids=inputs["input_ids"],
684
- token_embeddings=token_embeddings,
685
- attention_mask=inputs["attention_mask"],
686
- )
687
-
688
- batch_chunk_embeddings = [
689
- torch.stack([chunk for chunk in doc_chunks], dim=0)
690
- for doc_chunks in batch_chunk_embeddings
691
- ]
692
-
693
- batch_chunk_embeddings = [
694
- self._flexible_quantizer(
695
- {"sentence_embedding": emb}, quantization=quantization
696
- )["sentence_embedding"]
697
- for emb in batch_chunk_embeddings
698
- ]
699
-
700
- if normalize_embeddings:
701
- batch_chunk_embeddings = [
702
- torch.nn.functional.normalize(emb, p=2, dim=-1)
703
- for emb in batch_chunk_embeddings
704
- ]
705
-
706
- batch_chunk_embeddings = [emb.cpu() for emb in batch_chunk_embeddings]
707
-
708
- all_embeddings.extend(batch_chunk_embeddings)
709
-
710
- if convert_to_numpy:
711
- all_embeddings = [emb.numpy() for emb in all_embeddings]
712
-
713
- return all_embeddings
714
-
715
- def _extract_chunks_from_concatenated(
716
- self,
717
- input_ids: torch.Tensor,
718
- token_embeddings: torch.Tensor,
719
- attention_mask: torch.Tensor,
720
- ) -> list[list[torch.Tensor]]:
721
- """
722
- Extract individual chunk embeddings from concatenated sequence using late chunking.
723
-
724
- This method splits concatenated sequences like "[chunk1][SEP][chunk2][SEP]..."
725
- back into individual chunk embeddings by finding SEP token positions.
726
-
727
- Args:
728
- input_ids: Token IDs (batch_size, seq_len)
729
- token_embeddings: Token embeddings (batch_size, seq_len, hidden_dim)
730
- attention_mask: Attention mask (batch_size, seq_len)
731
-
732
- Returns:
733
- list[list[torch.Tensor]]: List of documents, each containing list of chunk embeddings
734
-
735
- Note:
736
- The sep_token_id is retrieved from self.tokenizer.sep_token_id.
737
- Common values: Qwen2=151643, BERT=102, varies by tokenizer.
738
- """
739
- sep_token_id = self.tokenizer.sep_token_id
740
- batch_size = input_ids.shape[0]
741
-
742
- all_doc_chunks = []
743
-
744
- for batch_idx in range(batch_size):
745
- # non-pad sep tokens
746
- valid_positions = attention_mask[batch_idx].bool()
747
- sep_positions = (
748
- (input_ids[batch_idx] == sep_token_id) & valid_positions
749
- ).nonzero(as_tuple=True)[0]
750
-
751
- chunk_embeddings = []
752
- start_pos = 0
753
-
754
- for sep_pos in sep_positions:
755
- chunk_tokens = token_embeddings[batch_idx, start_pos:sep_pos]
756
- chunk_mask = attention_mask[batch_idx, start_pos:sep_pos]
757
-
758
- chunk_emb = self.mean_pooling(
759
- chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
760
- ).squeeze(0)
761
-
762
- chunk_embeddings.append(chunk_emb)
763
-
764
- start_pos = sep_pos + 1
765
-
766
- # Handle the last chunk (after the last SEP token)
767
- last_valid_pos = attention_mask[batch_idx].sum().item()
768
-
769
- chunk_tokens = token_embeddings[batch_idx, start_pos:last_valid_pos]
770
- chunk_mask = attention_mask[batch_idx, start_pos:last_valid_pos]
771
-
772
- if chunk_mask.sum() > 0:
773
- chunk_emb = self.mean_pooling(
774
- chunk_tokens.unsqueeze(0), chunk_mask.unsqueeze(0)
775
- ).squeeze(0)
776
- else:
777
- # Empty chunk - create zero embedding
778
- chunk_emb = torch.zeros(
779
- token_embeddings.shape[-1],
780
- device=token_embeddings.device,
781
- dtype=token_embeddings.dtype,
782
- )
783
-
784
- chunk_embeddings.append(chunk_emb)
785
-
786
- all_doc_chunks.append(chunk_embeddings)
787
-
788
- return all_doc_chunks
789
-
790
-
791
- # Register for AutoModel
792
- PPLXQwen3Model.register_for_auto_class("AutoModel")
793
- PPLXQwen3ContextualModel.register_for_auto_class("AutoModel")
794
-
795
- __all__ = [
796
- "PPLXQwen3Config",
797
- "PPLXQwen3Model",
798
- "PPLXQwen3PreTrainedModel",
799
- "PPLXQwen3ContextualModel",
800
- "PPLXQwen3RMSNorm",
801
- "PPLXQwen3MLP",
802
- "PPLXQwen3RotaryEmbedding",
803
- "PPLXQwen3Attention",
804
- "PPLXQwen3DecoderLayer",
805
- ]
 
1
+ from typing import Callable
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
+ from transformers import Qwen3Model
4
+ from transformers.cache_utils import Cache
5
+ from transformers.masking_utils import create_causal_mask
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.processing_utils import Unpack
8
+ from transformers.utils import TransformersKwargs
9
  from .configuration import PPLXQwen3Config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
 
11
 
12
+ # From modeling_t5gemma.py
13
+ def bidirectional_mask_function(attention_mask: torch.Tensor | None) -> Callable:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
+ This creates bidirectional attention mask.
 
 
16
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
19
+ if attention_mask is None:
20
+ return torch.ones((), dtype=torch.bool)
21
+ return attention_mask[batch_idx, kv_idx].to(torch.bool)
 
22
 
23
+ return inner_mask
24
 
25
 
26
+ class PPLXQwen3Model(Qwen3Model):
27
+ _supports_flash_attn = True
28
+ _supports_sdpa = True
 
 
29
 
30
  config_class = PPLXQwen3Config
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def __init__(self, config):
33
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  self.post_init()
35
 
36
+ def post_init(self):
37
+ super().post_init()
38
+ # Override to set all layers to non-causal attention. This'll work with attn_implementation="flash_attention_2" or "sdpa"
39
+ for layer in self.layers:
40
+ layer.self_attn.is_causal = False
41
+
42
  def forward(
43
  self,
44
+ input_ids: torch.LongTensor | None = None,
45
+ attention_mask: torch.Tensor | None = None,
46
+ position_ids: torch.LongTensor | None = None,
47
+ past_key_values: Cache | None = None,
48
+ inputs_embeds: torch.FloatTensor | None = None,
49
+ use_cache: bool | None = None,
50
+ cache_position: torch.LongTensor | None = None,
51
+ **kwargs: Unpack[TransformersKwargs],
52
+ ) -> BaseModelOutputWithPooling:
53
  if inputs_embeds is None:
54
  inputs_embeds = self.embed_tokens(input_ids)
55
+ input_ids = None
56
 
57
+ # We construct a dummy tensor imitating initial positions
58
+ dummy_cache_position = torch.arange(
59
+ inputs_embeds.shape[1], device=inputs_embeds.device, dtype=torch.long
60
+ )
61
+ attention_mask = {
62
+ "full_attention": create_causal_mask(
63
+ config=self.config,
64
+ input_embeds=inputs_embeds,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  attention_mask=attention_mask,
66
+ cache_position=dummy_cache_position,
67
+ past_key_values=None,
68
+ position_ids=position_ids,
69
+ or_mask_function=bidirectional_mask_function(attention_mask),
70
  )
71
+ }
72
 
73
+ outputs = super().forward(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  input_ids=input_ids,
75
  attention_mask=attention_mask,
76
  position_ids=position_ids,
77
+ past_key_values=past_key_values,
78
  inputs_embeds=inputs_embeds,
79
+ use_cache=use_cache,
80
+ cache_position=cache_position,
81
  **kwargs,
82
  )
83
+ return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
st_quantize.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from typing import Literal
 
3
 
4
 
5
  class Quantizer(torch.nn.Module):
@@ -65,7 +66,7 @@ class BinaryTanhQuantizer(Quantizer):
65
  return torch.where(x >= 0, 1.0, -1.0)
66
 
67
 
68
- class FlexibleQuantizer(torch.nn.Module):
69
  def __init__(self):
70
  super().__init__()
71
  self._int8_quantizer = Int8TanhQuantizer()
@@ -75,6 +76,7 @@ class FlexibleQuantizer(torch.nn.Module):
75
  self,
76
  features: dict[str, torch.Tensor],
77
  quantization: Literal["binary", "int8"] = "int8",
 
78
  ) -> dict[str, torch.Tensor]:
79
  if quantization == "int8":
80
  features["sentence_embedding"] = self._int8_quantizer(
@@ -91,5 +93,17 @@ class FlexibleQuantizer(torch.nn.Module):
91
  return features
92
 
93
  @classmethod
94
- def load(cls, input_path: str):
 
 
 
 
 
 
 
 
 
95
  return cls()
 
 
 
 
1
  import torch
2
  from typing import Literal
3
+ from sentence_transformers.models import Module
4
 
5
 
6
  class Quantizer(torch.nn.Module):
 
66
  return torch.where(x >= 0, 1.0, -1.0)
67
 
68
 
69
+ class FlexibleQuantizer(Module):
70
  def __init__(self):
71
  super().__init__()
72
  self._int8_quantizer = Int8TanhQuantizer()
 
76
  self,
77
  features: dict[str, torch.Tensor],
78
  quantization: Literal["binary", "int8"] = "int8",
79
+ **kwargs
80
  ) -> dict[str, torch.Tensor]:
81
  if quantization == "int8":
82
  features["sentence_embedding"] = self._int8_quantizer(
 
93
  return features
94
 
95
  @classmethod
96
+ def load(
97
+ cls,
98
+ model_name_or_path: str,
99
+ subfolder: str = "",
100
+ token: bool | str | None = None,
101
+ cache_folder: str | None = None,
102
+ revision: str | None = None,
103
+ local_files_only: bool = False,
104
+ **kwargs,
105
+ ):
106
  return cls()
107
+
108
+ def save(self, output_path: str, *args, **kwargs) -> None:
109
+ return