Support Sentence Transformers via SparseEncoder

#1
by tomaarsen HF Staff - opened
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
1_SpladePooling/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "pooling_strategy": "max",
3
+ "activation_function": "relu"
4
+ }
README.md CHANGED
@@ -1,46 +1,53 @@
1
  ---
2
  license: cc-by-nc-sa-4.0
 
 
 
 
 
 
3
  ---
4
 
5
  SPLADE-Code-06B is a sparse retrieval model designed for code retrieval tasks. It is the top-performing models on MTEB for models below 1B (at time of writing, Feb 2026).
6
 
 
7
 
8
- ```python
9
- from transformers import AutoModelForCausalLM, AutoModel
10
- import os
11
- import torch
12
-
13
- splade = AutoModelForCausalLM.from_pretrained("naver/splade-code-06B", trust_remote_code=True)
14
- device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
15
- splade.to(device)
16
- splade.eval()
17
- queries = ["SELECT *\nFROM Student\nWHERE Age = (\nSELECT MAX(Age)\nFROM Student\nWHERE Group = 'specific_group'\n)\nAND Group = 'specific_group';"]
18
- bow_dict = splade.encode(queries, prompt_type="query", top_k_q=10, return_dict=True, print_dict=True)
19
- ```
20
 
 
 
 
21
  ```
22
- +--------------------------------------------------------------------+
23
- | TOP ACTIVATED WORDS |
24
- +--------------------------------------------------------------------+
25
-
26
-
27
- * INPUT: SELECT *
28
- FROM Student
29
- WHERE Age = (
30
- SELECT MAX(Age)
31
- FROM Student
32
- WHERE Group = 'specific_group'
33
- )
34
- AND Group = 'specific_group';
35
-
36
- Ġgroup | ████████████████████ 2.34
37
- ĠAge | ████████████████████ 2.34
38
- Ġage | ███████████████████ 2.33
39
- Ġspecific | ███████████████████ 2.30
40
- _group | ███████████████████ 2.30
41
- ĠStudent | ███████████████████ 2.30
42
- Ġmax | ██████████████████ 2.22
43
- ĠMax | ██████████████████ 2.22
44
- Ġstudent | ██████████████████ 2.20
45
- ĠGroup | ██████████████████ 2.20
 
 
 
 
 
 
 
 
46
  ```
 
1
  ---
2
  license: cc-by-nc-sa-4.0
3
+ tags:
4
+ - sentence-transformers
5
+ - splade
6
+ - sparse-encoder
7
+ - code
8
+ pipeline_tag: feature-extraction
9
  ---
10
 
11
  SPLADE-Code-06B is a sparse retrieval model designed for code retrieval tasks. It is the top-performing models on MTEB for models below 1B (at time of writing, Feb 2026).
12
 
13
+ ## Usage
14
 
15
+ ### Using Sentence Transformers
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ Install Sentence Transformers:
18
+ ```bash
19
+ pip install sentence_transformers
20
  ```
21
+
22
+ ```python
23
+ from sentence_transformers import SparseEncoder
24
+
25
+ model = SparseEncoder("naver/splade-code-06B", trust_remote_code=True)
26
+
27
+ queries = [
28
+ "SELECT *\nFROM Student\nWHERE Age = (\nSELECT MAX(Age)\nFROM Student\nWHERE Group = 'specific_group'\n)\nAND Group = 'specific_group';"
29
+ ]
30
+
31
+ query_embeddings = model.encode(queries)
32
+ print(query_embeddings.shape)
33
+ # torch.Size([1, 151936])
34
+
35
+ sparsity = model.sparsity(query_embeddings)
36
+ print(sparsity)
37
+ # {'active_dims': 1231.0, 'sparsity_ratio': 0.991897904380792}
38
+
39
+ decoded = model.decode(query_embeddings, top_k=10)
40
+ print(decoded)
41
+ # [[
42
+ # ("Ġgroup", 2.34375),
43
+ # ("Ġage", 2.34375),
44
+ # ("ĠAge", 2.34375),
45
+ # ("ĠStudent", 2.296875),
46
+ # ("Ġspecific", 2.296875),
47
+ # ("_group", 2.296875),
48
+ # ("ĠMax", 2.21875),
49
+ # ("Ġmax", 2.21875),
50
+ # ("Ġstudent", 2.203125),
51
+ # ("ĠGroup", 2.1875),
52
+ # ]]
53
  ```
config.json CHANGED
@@ -1,22 +1,34 @@
1
  {
2
- "archi_type": "decoder",
3
  "architectures": [
4
- "Splade"
5
  ],
6
- "attn_implementation": "flash_attention_2",
7
- "attn_type": "causal",
8
- "bidirectional": true,
9
- "lora": false,
10
- "lora_r": 0,
11
- "model_name_or_path": "Qwen/Qwen3-0.6B",
12
- "model_type": "splade",
13
- "n_layers": null,
14
- "padding_side": "left",
 
 
 
 
 
 
 
 
 
 
 
 
15
  "torch_dtype": "bfloat16",
16
- "train_head": false,
17
- "transformers_version": "4.53.3",
 
 
18
  "auto_map": {
19
- "AutoConfig": "splade.SpladeConfig",
20
- "AutoModelForCausalLM": "splade.Splade"
21
  }
22
- }
 
1
  {
 
2
  "architectures": [
3
+ "Qwen3ForCausalLM"
4
  ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "is_causal": false,
15
+ "max_position_embeddings": 40960,
16
+ "max_window_layers": 28,
17
+ "model_type": "qwen3",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 28,
20
+ "num_key_value_heads": 8,
21
+ "rms_norm_eps": 1e-06,
22
+ "rope_scaling": null,
23
+ "rope_theta": 1000000,
24
+ "sliding_window": null,
25
+ "tie_word_embeddings": true,
26
  "torch_dtype": "bfloat16",
27
+ "transformers_version": "4.51.0",
28
+ "use_cache": true,
29
+ "use_sliding_window": false,
30
+ "vocab_size": 151936,
31
  "auto_map": {
32
+ "AutoModelForMaskedLM": "modeling_splade.Qwen3ForCausalLM"
 
33
  }
34
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "SparseEncoder",
3
+ "prompts": {},
4
+ "default_prompt_name": null,
5
+ "similarity_fn_name": "dot"
6
+ }
modeling_qwen3_bidir.py DELETED
@@ -1,960 +0,0 @@
1
- ###
2
- # Adapted from https://github.com/huggingface/transformers/blob/v4.51.2/src/transformers/models/qwen3/modeling_qwen3.py
3
- ###
4
-
5
- from functools import partial
6
- from typing import Callable, Optional, Tuple, Union
7
-
8
- import torch
9
- from torch import nn
10
-
11
- from transformers.activations import ACT2FN
12
- from transformers.cache_utils import (
13
- Cache,
14
- DynamicCache,
15
- SlidingWindowCache,
16
- StaticCache,
17
- )
18
- from transformers.generation import GenerationMixin
19
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
20
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
21
- from transformers.modeling_outputs import (
22
- BaseModelOutputWithPast,
23
- CausalLMOutputWithPast,
24
- )
25
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
26
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
- from transformers.processing_utils import Unpack
28
- from transformers.utils import (
29
- LossKwargs,
30
- add_start_docstrings,
31
- add_start_docstrings_to_model_forward,
32
- can_return_tuple,
33
- logging,
34
- replace_return_docstrings,
35
- )
36
- from transformers.utils.deprecation import deprecate_kwarg
37
- from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
38
-
39
-
40
- logger = logging.get_logger(__name__)
41
-
42
- _CHECKPOINT_FOR_DOC = "Qwen/Qwen3-8B"
43
- _CONFIG_FOR_DOC = "Qwen3Config"
44
-
45
-
46
- class Qwen3RMSNorm(nn.Module):
47
- def __init__(self, hidden_size, eps=1e-6):
48
- """
49
- Qwen3RMSNorm is equivalent to T5LayerNorm
50
- """
51
- super().__init__()
52
- self.weight = nn.Parameter(torch.ones(hidden_size))
53
- self.variance_epsilon = eps
54
-
55
- def forward(self, hidden_states):
56
- input_dtype = hidden_states.dtype
57
- hidden_states = hidden_states.to(torch.float32)
58
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
59
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
60
- return self.weight * hidden_states.to(input_dtype)
61
-
62
- def extra_repr(self):
63
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
64
-
65
-
66
- class Qwen3MLP(nn.Module):
67
- def __init__(self, config):
68
- super().__init__()
69
- self.config = config
70
- self.hidden_size = config.hidden_size
71
- self.intermediate_size = config.intermediate_size
72
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
73
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
74
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
75
- self.act_fn = ACT2FN[config.hidden_act]
76
-
77
- def forward(self, x):
78
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
79
- return down_proj
80
-
81
-
82
- def rotate_half(x):
83
- """Rotates half the hidden dims of the input."""
84
- x1 = x[..., : x.shape[-1] // 2]
85
- x2 = x[..., x.shape[-1] // 2 :]
86
- return torch.cat((-x2, x1), dim=-1)
87
-
88
-
89
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
90
- """Applies Rotary Position Embedding to the query and key tensors.
91
- Args:
92
- q (`torch.Tensor`): The query tensor.
93
- k (`torch.Tensor`): The key tensor.
94
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
95
- sin (`torch.Tensor`): The sine part of the rotary embedding.
96
- position_ids (`torch.Tensor`, *optional*):
97
- Deprecated and unused.
98
- unsqueeze_dim (`int`, *optional*, defaults to 1):
99
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
100
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
101
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
102
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
103
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
104
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
105
- Returns:
106
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
107
- """
108
- cos = cos.unsqueeze(unsqueeze_dim)
109
- sin = sin.unsqueeze(unsqueeze_dim)
110
- q_embed = (q * cos) + (rotate_half(q) * sin)
111
- k_embed = (k * cos) + (rotate_half(k) * sin)
112
- return q_embed, k_embed
113
-
114
-
115
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
116
- """
117
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
118
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
119
- """
120
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
121
- if n_rep == 1:
122
- return hidden_states
123
- hidden_states = hidden_states[:, :, None, :, :].expand(
124
- batch, num_key_value_heads, n_rep, slen, head_dim
125
- )
126
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
127
-
128
-
129
- def eager_attention_forward(
130
- module: nn.Module,
131
- query: torch.Tensor,
132
- key: torch.Tensor,
133
- value: torch.Tensor,
134
- attention_mask: Optional[torch.Tensor],
135
- scaling: float,
136
- dropout: float = 0.0,
137
- **kwargs,
138
- ):
139
- key_states = repeat_kv(key, module.num_key_value_groups)
140
- value_states = repeat_kv(value, module.num_key_value_groups)
141
-
142
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
143
- if attention_mask is not None:
144
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
145
- attn_weights = attn_weights + causal_mask
146
-
147
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
148
- query.dtype
149
- )
150
- attn_weights = nn.functional.dropout(
151
- attn_weights, p=dropout, training=module.training
152
- )
153
- attn_output = torch.matmul(attn_weights, value_states)
154
- attn_output = attn_output.transpose(1, 2).contiguous()
155
-
156
- return attn_output, attn_weights
157
-
158
-
159
- class Qwen3BidirAttention(nn.Module):
160
- """Multi-headed attention from 'Attention Is All You Need' paper"""
161
-
162
- def __init__(self, config: Qwen3Config, layer_idx: int):
163
- super().__init__()
164
- self.config = config
165
- self.layer_idx = layer_idx
166
- self.head_dim = getattr(
167
- config, "head_dim", config.hidden_size // config.num_attention_heads
168
- )
169
- self.num_key_value_groups = (
170
- config.num_attention_heads // config.num_key_value_heads
171
- )
172
- self.scaling = self.head_dim**-0.5
173
- self.attention_dropout = config.attention_dropout
174
- self.is_causal = False
175
-
176
- self.q_proj = nn.Linear(
177
- config.hidden_size,
178
- config.num_attention_heads * self.head_dim,
179
- bias=config.attention_bias,
180
- )
181
- self.k_proj = nn.Linear(
182
- config.hidden_size,
183
- config.num_key_value_heads * self.head_dim,
184
- bias=config.attention_bias,
185
- )
186
- self.v_proj = nn.Linear(
187
- config.hidden_size,
188
- config.num_key_value_heads * self.head_dim,
189
- bias=config.attention_bias,
190
- )
191
- self.o_proj = nn.Linear(
192
- config.num_attention_heads * self.head_dim,
193
- config.hidden_size,
194
- bias=config.attention_bias,
195
- )
196
- self.q_norm = Qwen3RMSNorm(
197
- self.head_dim, eps=config.rms_norm_eps
198
- ) # unlike olmo, only on the head dim!
199
- self.k_norm = Qwen3RMSNorm(
200
- self.head_dim, eps=config.rms_norm_eps
201
- ) # thus post q_norm does not need reshape
202
- self.sliding_window = config.sliding_window
203
- if not (
204
- self.config.use_sliding_window
205
- and getattr(self.config, "sliding_window", None) is not None
206
- and self.layer_idx >= self.config.max_window_layers
207
- ):
208
- self.sliding_window = None
209
-
210
- def forward(
211
- self,
212
- hidden_states: torch.Tensor,
213
- position_embeddings: Tuple[torch.Tensor, torch.Tensor],
214
- attention_mask: Optional[torch.Tensor],
215
- past_key_value: Optional[Cache] = None,
216
- cache_position: Optional[torch.LongTensor] = None,
217
- **kwargs: Unpack[FlashAttentionKwargs],
218
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
219
- input_shape = hidden_states.shape[:-1]
220
- hidden_shape = (*input_shape, -1, self.head_dim)
221
-
222
- query_states = self.q_norm(
223
- self.q_proj(hidden_states).view(hidden_shape)
224
- ).transpose(1, 2)
225
- key_states = self.k_norm(
226
- self.k_proj(hidden_states).view(hidden_shape)
227
- ).transpose(1, 2)
228
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
229
-
230
- cos, sin = position_embeddings
231
- query_states, key_states = apply_rotary_pos_emb(
232
- query_states, key_states, cos, sin
233
- )
234
-
235
- if past_key_value is not None:
236
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
237
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
238
- key_states, value_states = past_key_value.update(
239
- key_states, value_states, self.layer_idx, cache_kwargs
240
- )
241
-
242
- attention_interface: Callable = eager_attention_forward
243
- if self.config._attn_implementation != "eager":
244
- if self.config._attn_implementation == "sdpa" and kwargs.get(
245
- "output_attentions", False
246
- ):
247
- logger.warning_once(
248
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
249
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
250
- )
251
- else:
252
- attention_interface = ALL_ATTENTION_FUNCTIONS[
253
- self.config._attn_implementation
254
- ]
255
-
256
- attn_output, attn_weights = attention_interface(
257
- self,
258
- query_states,
259
- key_states,
260
- value_states,
261
- attention_mask,
262
- dropout=0.0 if not self.training else self.attention_dropout,
263
- scaling=self.scaling,
264
- sliding_window=self.sliding_window, # diff with Llama
265
- **kwargs,
266
- )
267
-
268
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
269
- attn_output = self.o_proj(attn_output)
270
- return attn_output, attn_weights
271
-
272
-
273
- class Qwen3BidirDecoderLayer(nn.Module):
274
- def __init__(self, config: Qwen3Config, layer_idx: int):
275
- super().__init__()
276
- self.hidden_size = config.hidden_size
277
- self.self_attn = Qwen3BidirAttention(config=config, layer_idx=layer_idx)
278
- self.mlp = Qwen3MLP(config)
279
- self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280
- self.post_attention_layernorm = Qwen3RMSNorm(
281
- config.hidden_size, eps=config.rms_norm_eps
282
- )
283
- if (
284
- config.sliding_window and config._attn_implementation != "flash_attention_2"
285
- ): # diff with Llama is this warning
286
- logger.warning_once(
287
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
288
- "unexpected results may be encountered."
289
- )
290
-
291
- def forward(
292
- self,
293
- hidden_states: torch.Tensor,
294
- attention_mask: Optional[torch.Tensor] = None,
295
- position_ids: Optional[torch.LongTensor] = None,
296
- past_key_value: Optional[Cache] = None,
297
- output_attentions: Optional[bool] = False,
298
- use_cache: Optional[bool] = False,
299
- cache_position: Optional[torch.LongTensor] = None,
300
- position_embeddings: Optional[
301
- Tuple[torch.Tensor, torch.Tensor]
302
- ] = None, # necessary, but kept here for BC
303
- **kwargs: Unpack[FlashAttentionKwargs],
304
- ) -> Tuple[
305
- torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
306
- ]:
307
- residual = hidden_states
308
-
309
- hidden_states = self.input_layernorm(hidden_states)
310
-
311
- # Self Attention
312
- hidden_states, self_attn_weights = self.self_attn(
313
- hidden_states=hidden_states,
314
- attention_mask=attention_mask,
315
- position_ids=position_ids,
316
- past_key_value=past_key_value,
317
- output_attentions=output_attentions,
318
- use_cache=use_cache,
319
- cache_position=cache_position,
320
- position_embeddings=position_embeddings,
321
- **kwargs,
322
- )
323
- hidden_states = residual + hidden_states
324
-
325
- # Fully Connected
326
- residual = hidden_states
327
- hidden_states = self.post_attention_layernorm(hidden_states)
328
- hidden_states = self.mlp(hidden_states)
329
- hidden_states = residual + hidden_states
330
-
331
- outputs = (hidden_states,)
332
- if output_attentions:
333
- outputs += (self_attn_weights,)
334
-
335
- return outputs
336
-
337
-
338
- class Qwen3RotaryEmbedding(nn.Module):
339
- def __init__(self, config: Qwen3Config, device=None):
340
- super().__init__()
341
- # BC: "rope_type" was originally "type"
342
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
343
- self.rope_type = config.rope_scaling.get(
344
- "rope_type", config.rope_scaling.get("type")
345
- )
346
- else:
347
- self.rope_type = "default"
348
- self.max_seq_len_cached = config.max_position_embeddings
349
- self.original_max_seq_len = config.max_position_embeddings
350
-
351
- self.config = config
352
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
353
-
354
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
355
- self.register_buffer("inv_freq", inv_freq, persistent=False)
356
- self.original_inv_freq = self.inv_freq
357
-
358
- @torch.no_grad()
359
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
360
- def forward(self, x, position_ids):
361
- inv_freq_expanded = (
362
- self.inv_freq[None, :, None]
363
- .float()
364
- .expand(position_ids.shape[0], -1, 1)
365
- .to(x.device)
366
- )
367
- position_ids_expanded = position_ids[:, None, :].float()
368
-
369
- device_type = (
370
- x.device.type
371
- if isinstance(x.device.type, str) and x.device.type != "mps"
372
- else "cpu"
373
- )
374
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
375
- freqs = (
376
- inv_freq_expanded.float() @ position_ids_expanded.float()
377
- ).transpose(1, 2)
378
- emb = torch.cat((freqs, freqs), dim=-1)
379
- cos = emb.cos() * self.attention_scaling
380
- sin = emb.sin() * self.attention_scaling
381
-
382
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
383
-
384
-
385
- QWEN3_START_DOCSTRING = r"""
386
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
387
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
388
- etc.)
389
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
390
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
391
- and behavior.
392
- Parameters:
393
- config ([`Qwen3Config`]):
394
- Model configuration class with all the parameters of the model. Initializing with a config file does not
395
- load the weights associated with the model, only the configuration. Check out the
396
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
397
- """
398
-
399
-
400
- @add_start_docstrings(
401
- "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.",
402
- QWEN3_START_DOCSTRING,
403
- )
404
- class Qwen3PreTrainedModel(PreTrainedModel):
405
- config_class = Qwen3Config
406
- base_model_prefix = "model"
407
- supports_gradient_checkpointing = True
408
- _no_split_modules = ["Qwen3DecoderLayer"]
409
- _skip_keys_device_placement = ["past_key_values"]
410
- _supports_flash_attn_2 = True
411
- _supports_sdpa = True
412
- _supports_flex_attn = True
413
- _supports_cache_class = True
414
- _supports_quantized_cache = True
415
- _supports_static_cache = True
416
- _supports_attention_backend = True
417
-
418
- def _init_weights(self, module):
419
- std = self.config.initializer_range
420
- if isinstance(module, nn.Linear):
421
- module.weight.data.normal_(mean=0.0, std=std)
422
- if module.bias is not None:
423
- module.bias.data.zero_()
424
- elif isinstance(module, nn.Embedding):
425
- module.weight.data.normal_(mean=0.0, std=std)
426
- if module.padding_idx is not None:
427
- module.weight.data[module.padding_idx].zero_()
428
-
429
-
430
- QWEN3_INPUTS_DOCSTRING = r"""
431
- Args:
432
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
433
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
434
- it.
435
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
436
- [`PreTrainedTokenizer.__call__`] for details.
437
- [What are input IDs?](../glossary#input-ids)
438
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
439
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
440
- - 1 for tokens that are **not masked**,
441
- - 0 for tokens that are **masked**.
442
- [What are attention masks?](../glossary#attention-mask)
443
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
444
- [`PreTrainedTokenizer.__call__`] for details.
445
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
446
- `past_key_values`).
447
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
448
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
449
- information on the default strategy.
450
- - 1 indicates the head is **not masked**,
451
- - 0 indicates the head is **masked**.
452
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
453
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
454
- config.n_positions - 1]`.
455
- [What are position IDs?](../glossary#position-ids)
456
- past_key_values (`Cache`, *optional*):
457
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
458
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
459
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
460
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
461
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
462
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
463
- of shape `(batch_size, sequence_length)`.
464
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
465
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
466
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
467
- model's internal embedding lookup matrix.
468
- use_cache (`bool`, *optional*):
469
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
470
- `past_key_values`).
471
- output_attentions (`bool`, *optional*):
472
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
473
- tensors for more detail.
474
- output_hidden_states (`bool`, *optional*):
475
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
476
- more detail.
477
- return_dict (`bool`, *optional*):
478
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
479
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
480
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
481
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
482
- the complete sequence length.
483
- """
484
-
485
-
486
- @add_start_docstrings(
487
- "The bare Qwen3 Model outputting raw hidden-states without any specific head on top.",
488
- QWEN3_START_DOCSTRING,
489
- )
490
- class Qwen3BidirModel(Qwen3PreTrainedModel):
491
- """
492
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen3DecoderLayer`]
493
- Args:
494
- config: Qwen3Config
495
- """
496
-
497
- def __init__(self, config: Qwen3Config):
498
- super().__init__(config)
499
- self.padding_idx = config.pad_token_id
500
- self.vocab_size = config.vocab_size
501
-
502
- self.embed_tokens = nn.Embedding(
503
- config.vocab_size, config.hidden_size, self.padding_idx
504
- )
505
- self.layers = nn.ModuleList(
506
- [
507
- Qwen3BidirDecoderLayer(config, layer_idx)
508
- for layer_idx in range(config.num_hidden_layers)
509
- ]
510
- )
511
- self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
512
- self.rotary_emb = Qwen3RotaryEmbedding(config=config)
513
- self.gradient_checkpointing = False
514
-
515
- # Initialize weights and apply final processing
516
- self.post_init()
517
-
518
- def get_input_embeddings(self):
519
- return self.embed_tokens
520
-
521
- def set_input_embeddings(self, value):
522
- self.embed_tokens = value
523
-
524
- @can_return_tuple
525
- @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
526
- def forward(
527
- self,
528
- input_ids: Optional[torch.LongTensor] = None,
529
- attention_mask: Optional[torch.Tensor] = None,
530
- position_ids: Optional[torch.LongTensor] = None,
531
- past_key_values: Optional[Cache] = None,
532
- inputs_embeds: Optional[torch.FloatTensor] = None,
533
- use_cache: Optional[bool] = None,
534
- output_attentions: Optional[bool] = None,
535
- output_hidden_states: Optional[bool] = None,
536
- cache_position: Optional[torch.LongTensor] = None,
537
- **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
538
- ) -> BaseModelOutputWithPast:
539
- output_attentions = (
540
- output_attentions
541
- if output_attentions is not None
542
- else self.config.output_attentions
543
- )
544
- output_hidden_states = (
545
- output_hidden_states
546
- if output_hidden_states is not None
547
- else self.config.output_hidden_states
548
- )
549
- use_cache = use_cache if use_cache is not None else self.config.use_cache
550
-
551
- if (input_ids is None) ^ (inputs_embeds is not None):
552
- raise ValueError(
553
- "You must specify exactly one of input_ids or inputs_embeds"
554
- )
555
-
556
- if self.gradient_checkpointing and self.training and use_cache:
557
- logger.warning_once(
558
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
559
- )
560
- use_cache = False
561
-
562
- # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
563
- if not isinstance(past_key_values, (type(None), Cache)):
564
- raise ValueError(
565
- "The `past_key_values` should be either a `Cache` object or `None`."
566
- )
567
-
568
- if inputs_embeds is None:
569
- inputs_embeds = self.embed_tokens(input_ids)
570
-
571
- if use_cache and past_key_values is None:
572
- past_key_values = DynamicCache()
573
-
574
- if cache_position is None:
575
- past_seen_tokens = (
576
- past_key_values.get_seq_length() if past_key_values is not None else 0
577
- )
578
- cache_position = torch.arange(
579
- past_seen_tokens,
580
- past_seen_tokens + inputs_embeds.shape[1],
581
- device=inputs_embeds.device,
582
- )
583
-
584
- if position_ids is None:
585
- position_ids = cache_position.unsqueeze(0)
586
-
587
- causal_mask = self._update_causal_mask(
588
- attention_mask,
589
- inputs_embeds,
590
- cache_position,
591
- past_key_values,
592
- output_attentions,
593
- )
594
-
595
- hidden_states = inputs_embeds
596
-
597
- # create position embeddings to be shared across the decoder layers
598
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
599
-
600
- # decoder layers
601
- all_hidden_states = () if output_hidden_states else None
602
- all_self_attns = () if output_attentions else None
603
-
604
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
605
- if output_hidden_states:
606
- all_hidden_states += (hidden_states,)
607
-
608
- if self.gradient_checkpointing and self.training:
609
- layer_outputs = self._gradient_checkpointing_func(
610
- partial(decoder_layer.__call__, **flash_attn_kwargs),
611
- hidden_states,
612
- causal_mask,
613
- position_ids,
614
- past_key_values,
615
- output_attentions,
616
- use_cache,
617
- cache_position,
618
- position_embeddings,
619
- )
620
- else:
621
- layer_outputs = decoder_layer(
622
- hidden_states,
623
- attention_mask=causal_mask,
624
- position_ids=position_ids,
625
- past_key_value=past_key_values,
626
- output_attentions=output_attentions,
627
- use_cache=use_cache,
628
- cache_position=cache_position,
629
- position_embeddings=position_embeddings,
630
- **flash_attn_kwargs,
631
- )
632
-
633
- hidden_states = layer_outputs[0]
634
-
635
- if output_attentions:
636
- all_self_attns += (layer_outputs[1],)
637
-
638
- hidden_states = self.norm(hidden_states)
639
-
640
- # add hidden states from the last decoder layer
641
- if output_hidden_states:
642
- all_hidden_states += (hidden_states,)
643
-
644
- return BaseModelOutputWithPast(
645
- last_hidden_state=hidden_states,
646
- past_key_values=past_key_values if use_cache else None,
647
- hidden_states=all_hidden_states,
648
- attentions=all_self_attns,
649
- )
650
-
651
- def _update_causal_mask(
652
- self,
653
- attention_mask: torch.Tensor,
654
- input_tensor: torch.Tensor,
655
- cache_position: torch.Tensor,
656
- past_key_values: Cache,
657
- output_attentions: bool = False,
658
- ):
659
- if self.config._attn_implementation == "flash_attention_2":
660
- if attention_mask is not None and past_key_values is not None:
661
- valid_rows = attention_mask.sum(dim=1) > 0
662
-
663
- if valid_rows.any():
664
- # Only check right-padding on non-empty rows
665
- right_padded_rows = attention_mask[valid_rows, -1] == 0
666
- is_padding_right = right_padded_rows.any().item()
667
- if is_padding_right:
668
- raise ValueError(
669
- "You are attempting to perform batched generation with padding_side='right'. "
670
- "This may lead to unexpected behaviour for Flash Attention version of Qwen3. "
671
- "Make sure to call `tokenizer.padding_side = 'left'` before tokenizing the input."
672
- )
673
- # is_padding_right = (
674
- # attention_mask[:, -1].sum().item() != input_tensor.size()[0]
675
- # )
676
- # if is_padding_right:
677
- # raise ValueError(
678
- # "You are attempting to perform batched generation with padding_side='right'"
679
- # " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
680
- # " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
681
- # )
682
- if attention_mask is not None and 0.0 in attention_mask:
683
- return attention_mask
684
- return None
685
-
686
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
687
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
688
- # to infer the attention mask.
689
- past_seen_tokens = (
690
- past_key_values.get_seq_length() if past_key_values is not None else 0
691
- )
692
- using_static_cache = isinstance(past_key_values, StaticCache)
693
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
694
-
695
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
696
- if (
697
- self.config._attn_implementation == "sdpa"
698
- and not (using_static_cache or using_sliding_window_cache)
699
- and not output_attentions
700
- ):
701
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
702
- attention_mask,
703
- inputs_embeds=input_tensor,
704
- past_key_values_length=past_seen_tokens,
705
- sliding_window=self.config.sliding_window,
706
- is_training=self.training,
707
- ):
708
- return None
709
-
710
- dtype, device = input_tensor.dtype, input_tensor.device
711
- min_dtype = torch.finfo(dtype).min
712
- sequence_length = input_tensor.shape[1]
713
- # SlidingWindowCache or StaticCache
714
- if using_sliding_window_cache or using_static_cache:
715
- target_length = past_key_values.get_max_cache_shape()
716
- # DynamicCache or no cache
717
- else:
718
- target_length = (
719
- attention_mask.shape[-1]
720
- if isinstance(attention_mask, torch.Tensor)
721
- else past_seen_tokens + sequence_length + 1
722
- )
723
-
724
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
725
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
726
- attention_mask,
727
- sequence_length=sequence_length,
728
- target_length=target_length,
729
- dtype=dtype,
730
- device=device,
731
- cache_position=cache_position,
732
- batch_size=input_tensor.shape[0],
733
- config=self.config,
734
- past_key_values=past_key_values,
735
- )
736
-
737
- if (
738
- self.config._attn_implementation == "sdpa"
739
- and attention_mask is not None
740
- and attention_mask.device.type in ["cuda", "xpu"]
741
- and not output_attentions
742
- ):
743
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
744
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
745
- # Details: https://github.com/pytorch/pytorch/issues/110213
746
- causal_mask = AttentionMaskConverter._unmask_unattended(
747
- causal_mask, min_dtype
748
- )
749
-
750
- return causal_mask
751
-
752
- @staticmethod
753
- def _prepare_4d_causal_attention_mask_with_cache_position(
754
- attention_mask: torch.Tensor,
755
- sequence_length: int,
756
- target_length: int,
757
- dtype: torch.dtype,
758
- device: torch.device,
759
- cache_position: torch.Tensor,
760
- batch_size: int,
761
- config: Qwen3Config,
762
- past_key_values: Cache,
763
- ):
764
- """
765
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
766
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
767
- Args:
768
- attention_mask (`torch.Tensor`):
769
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
770
- sequence_length (`int`):
771
- The sequence length being processed.
772
- target_length (`int`):
773
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
774
- dtype (`torch.dtype`):
775
- The dtype to use for the 4D attention mask.
776
- device (`torch.device`):
777
- The device to place the 4D attention mask on.
778
- cache_position (`torch.Tensor`):
779
- Indices depicting the position of the input sequence tokens in the sequence.
780
- batch_size (`torch.Tensor`):
781
- Batch size.
782
- config (`Qwen3Config`):
783
- The model's configuration class
784
- past_key_values (`Cache`):
785
- The cache class that is being used currently to generate
786
- """
787
- if attention_mask is not None and attention_mask.dim() == 4:
788
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
789
- causal_mask = attention_mask
790
- else:
791
- min_dtype = torch.finfo(dtype).min
792
- causal_mask = torch.full(
793
- (sequence_length, target_length),
794
- fill_value=min_dtype,
795
- dtype=dtype,
796
- device=device,
797
- )
798
- diagonal_attend_mask = torch.arange(
799
- target_length, device=device
800
- ) > cache_position.reshape(-1, 1)
801
- if config.sliding_window is not None:
802
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
803
- # the check is needed to verify is current checkpoint was trained with sliding window or not
804
- if (
805
- not isinstance(past_key_values, SlidingWindowCache)
806
- or sequence_length > target_length
807
- ):
808
- sliding_attend_mask = torch.arange(
809
- target_length, device=device
810
- ) <= (cache_position.reshape(-1, 1) - config.sliding_window)
811
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
812
- causal_mask *= diagonal_attend_mask
813
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
814
- if attention_mask is not None:
815
- causal_mask = (
816
- causal_mask.clone()
817
- ) # copy to contiguous memory for in-place edit
818
- if attention_mask.shape[-1] > target_length:
819
- attention_mask = attention_mask[:, :target_length]
820
- mask_length = attention_mask.shape[-1]
821
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[
822
- :, None, None, :
823
- ].to(causal_mask.device)
824
- padding_mask = padding_mask == 0
825
- causal_mask[:, :, :, :mask_length] = causal_mask[
826
- :, :, :, :mask_length
827
- ].masked_fill(padding_mask, min_dtype)
828
- return causal_mask
829
-
830
-
831
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
832
-
833
-
834
- class Qwen3BidirForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
835
- _tied_weights_keys = ["lm_head.weight"]
836
- _tp_plan = {"lm_head": "colwise_rep"}
837
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
838
-
839
- def __init__(self, config):
840
- super().__init__(config)
841
- self.model = Qwen3BidirModel(config)
842
- self.vocab_size = config.vocab_size
843
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
844
-
845
- # Initialize weights and apply final processing
846
- self.post_init()
847
-
848
- def get_input_embeddings(self):
849
- return self.model.embed_tokens
850
-
851
- def set_input_embeddings(self, value):
852
- self.model.embed_tokens = value
853
-
854
- def get_output_embeddings(self):
855
- return self.lm_head
856
-
857
- def set_output_embeddings(self, new_embeddings):
858
- self.lm_head = new_embeddings
859
-
860
- def set_decoder(self, decoder):
861
- self.model = decoder
862
-
863
- def get_decoder(self):
864
- return self.model
865
-
866
- @can_return_tuple
867
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
868
- @add_start_docstrings_to_model_forward(QWEN3_INPUTS_DOCSTRING)
869
- @replace_return_docstrings(
870
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
871
- )
872
- def forward(
873
- self,
874
- input_ids: Optional[torch.LongTensor] = None,
875
- attention_mask: Optional[torch.Tensor] = None,
876
- position_ids: Optional[torch.LongTensor] = None,
877
- past_key_values: Optional[Cache] = None,
878
- inputs_embeds: Optional[torch.FloatTensor] = None,
879
- labels: Optional[torch.LongTensor] = None,
880
- use_cache: Optional[bool] = None,
881
- output_attentions: Optional[bool] = None,
882
- output_hidden_states: Optional[bool] = None,
883
- cache_position: Optional[torch.LongTensor] = None,
884
- logits_to_keep: Union[int, torch.Tensor] = 0,
885
- **kwargs: Unpack[KwargsForCausalLM],
886
- ) -> CausalLMOutputWithPast:
887
- r"""
888
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
889
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
890
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
891
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
892
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
893
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
894
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
895
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
896
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
897
- This is useful when using packed tensor format (single dimension for batch and sequence length).
898
- Returns:
899
- Example:
900
- ```python
901
- >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
902
- >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
903
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
904
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
905
- >>> inputs = tokenizer(prompt, return_tensors="pt")
906
- >>> # Generate
907
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
908
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
909
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
910
- ```"""
911
- output_attentions = (
912
- output_attentions
913
- if output_attentions is not None
914
- else self.config.output_attentions
915
- )
916
- output_hidden_states = (
917
- output_hidden_states
918
- if output_hidden_states is not None
919
- else self.config.output_hidden_states
920
- )
921
-
922
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
923
- outputs: BaseModelOutputWithPast = self.model(
924
- input_ids=input_ids,
925
- attention_mask=attention_mask,
926
- position_ids=position_ids,
927
- past_key_values=past_key_values,
928
- inputs_embeds=inputs_embeds,
929
- use_cache=use_cache,
930
- output_attentions=output_attentions,
931
- output_hidden_states=output_hidden_states,
932
- cache_position=cache_position,
933
- **kwargs,
934
- )
935
-
936
- hidden_states = outputs.last_hidden_state
937
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
938
- slice_indices = (
939
- slice(-logits_to_keep, None)
940
- if isinstance(logits_to_keep, int)
941
- else logits_to_keep
942
- )
943
- logits = self.lm_head(hidden_states[:, slice_indices, :])
944
-
945
- loss = None
946
- if labels is not None:
947
- loss = self.loss_function(
948
- logits=logits,
949
- labels=labels,
950
- vocab_size=self.config.vocab_size,
951
- **kwargs,
952
- )
953
-
954
- return CausalLMOutputWithPast(
955
- loss=loss,
956
- logits=logits,
957
- past_key_values=outputs.past_key_values,
958
- hidden_states=outputs.hidden_states,
959
- attentions=outputs.attentions,
960
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_splade.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file exists solely to allow loading the Qwen3ForCausalLM via the AutoModelForMaskedLM class.
3
+ Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
4
+ with `is_causal=False` in the config.
5
+ """
6
+
7
+ from transformers import Qwen3ForCausalLM as _Qwen3ForCausalLM
8
+
9
+
10
+ class Qwen3ForCausalLM(_Qwen3ForCausalLM):
11
+ def tie_weights(self, *args, **kwargs):
12
+ """Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
13
+ super().tie_weights(*args, **kwargs)
14
+ if (
15
+ self.config.tie_word_embeddings
16
+ and hasattr(self, "lm_head")
17
+ and hasattr(self, "model")
18
+ ):
19
+ self.lm_head.weight = self.model.embed_tokens.weight
20
+
21
+ def _init_weights(self, module):
22
+ """Skip lm_head init when it will be tied to embed_tokens later."""
23
+ if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
24
+ return
25
+ super()._init_weights(module)
26
+
27
+
28
+ __all__ = ["Qwen3ForCausalLM"]
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.sparse_encoder.models.MLMTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_SpladePooling",
12
+ "type": "sentence_transformers.sparse_encoder.models.SpladePooling"
13
+ }
14
+ ]
splade.py DELETED
@@ -1,109 +0,0 @@
1
- import os
2
- from transformers import (
3
- PretrainedConfig,
4
- PreTrainedModel,
5
- AutoConfig,
6
- )
7
- from huggingface_hub import snapshot_download
8
- from typing import Optional
9
- from transformers.utils import is_flash_attn_2_available
10
- from .utils import (
11
- get_decoder_model,
12
- prepare_tokenizer,
13
- splade_max,
14
- similarity,
15
- encode,
16
- )
17
- from peft import PeftModel
18
-
19
-
20
- class SpladeConfig(PretrainedConfig):
21
- model_type = "splade"
22
-
23
- def __init__(
24
- self,
25
- model_name_or_path: str = "Qwen/Qwen3-0.6B",
26
- attn_implementation: str = "flash_attention_2",
27
- bidirectional: bool = True, # only for decoder models
28
- padding_side: str = "left",
29
- **kwargs,
30
- ):
31
- super().__init__(**kwargs)
32
- self.model_name_or_path = model_name_or_path
33
- self.attn_implementation = attn_implementation
34
- self.bidirectional = bidirectional
35
- self.padding_side = padding_side
36
-
37
-
38
- class Splade(PreTrainedModel):
39
- config_class = SpladeConfig
40
-
41
- # methods for MTEB's interface
42
- similarity = similarity
43
- encode = encode
44
-
45
- def __init__(self, config, weights_path=None, token=None):
46
- super().__init__(config)
47
- self.name = "splade"
48
-
49
- base_cfg = AutoConfig.from_pretrained(
50
- config.model_name_or_path,
51
- attn_implementation=config.attn_implementation,
52
- torch_dtype="auto",
53
- )
54
-
55
- self.tokenizer = prepare_tokenizer(
56
- config.model_name_or_path, padding_side=config.padding_side
57
- )
58
-
59
- if is_flash_attn_2_available():
60
- config.attn_implementation = "flash_attention_2"
61
- else:
62
- config.attn_implementation = "sdpa"
63
-
64
- source = weights_path or config.model_name_or_path
65
-
66
- self.model = get_decoder_model(
67
- model_name_or_path=source,
68
- attn_implementation=config.attn_implementation,
69
- bidirectional=getattr(config, "bidirectional", False),
70
- base_cfg=base_cfg,
71
- token=token
72
- )
73
-
74
- def save_pretrained(self, save_directory, *args, **kwargs):
75
- self.model.save_pretrained(os.path.join(save_directory, "lora"))
76
- self.config.save_pretrained(save_directory)
77
-
78
- @classmethod
79
- def from_pretrained(cls, model_name_or_path, *args, **kwargs):
80
- token = kwargs.get("token", None)
81
-
82
- config = SpladeConfig.from_pretrained(
83
- model_name_or_path,
84
- token=token,
85
- )
86
-
87
- model = cls(config, weights_path=model_name_or_path, token=token)
88
-
89
- model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
90
- return model
91
-
92
- def forward(self, **tokens):
93
- output = self.model(**tokens)
94
- splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
95
- return (splade_reps,)
96
-
97
- def get_width(self):
98
- return self.model.config.vocab_size
99
-
100
- def create_batch_dict(self, input_texts, max_length):
101
- return self.tokenizer(
102
- input_texts,
103
- add_special_tokens=True,
104
- padding="longest",
105
- truncation=True,
106
- max_length=max_length,
107
- return_attention_mask=True,
108
- return_tensors="pt",
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
tokenizer_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "is_local": false,
24
+ "model_max_length": 131072,
25
+ "padding_side": "left",
26
+ "pad_token": "<|endoftext|>",
27
+ "chat_template": null,
28
+ "split_special_tokens": false,
29
+ "tokenizer_class": "Qwen2Tokenizer",
30
+ "unk_token": null
31
+ }
utils.py DELETED
@@ -1,152 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- from typing import Any
5
- from transformers import AutoTokenizer
6
-
7
-
8
- def splade_max(features, attention_mask):
9
- """
10
- SPLADE pooling operation
11
- """
12
- relu = torch.nn.ReLU(inplace=False)
13
- values, ids_ = torch.max(
14
- torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1
15
- )
16
- return values, ids_
17
-
18
-
19
- def encode(
20
- self,
21
- sentences: list[str],
22
- max_length: int = 1024,
23
- prompt_type: str = "document",
24
- return_dict: bool = False,
25
- print_dict: bool = False,
26
- batch_size: int = 8,
27
- top_k_q: int = -1,
28
- top_k_d: int = -1,
29
- **kwargs: Any,
30
- ) -> np.ndarray:
31
- all_embeddings = []
32
- for i in range(0, len(sentences), batch_size):
33
- batch_texts = sentences[i : i + batch_size]
34
- batch_dict = self.create_batch_dict(batch_texts, max_length)
35
- batch_dict = {
36
- key: value.to(self.model.device) for key, value in batch_dict.items()
37
- }
38
- with torch.no_grad():
39
- splare_reps = self(**batch_dict)[0]
40
- if prompt_type == "query" and top_k_q > 0:
41
- splare_reps = top_k(splare_reps, top_k_q)
42
- if prompt_type == "document" and top_k_d > 0:
43
- splare_reps = top_k(splare_reps, top_k_d)
44
- all_embeddings.append(splare_reps.cpu().float().numpy())
45
- if return_dict:
46
- d = bow_dict(self, np.concatenate(all_embeddings, axis=0))
47
- if print_dict:
48
- print_bow_bars(sentences, d)
49
- return d
50
- else:
51
- return np.concatenate(all_embeddings, axis=0)
52
-
53
-
54
- def bow_dict(self, embeddings):
55
- out = []
56
- for vector in embeddings:
57
- idx = np.nonzero(vector)[0]
58
- weights = vector[idx]
59
- d = {k: v for k, v in zip(idx.tolist(), weights.tolist())}
60
- sorted_d = {
61
- self.reverse_voc[k]: float(v)
62
- for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)
63
- }
64
- out.append(sorted_d)
65
- return out
66
-
67
-
68
- def print_bow_bars(sentences, bow_list, width=20):
69
- ascii_header("TOP ACTIVATED WORDS")
70
- for sent, bow in zip(sentences, bow_list):
71
- print(f"* INPUT: {sent}\n")
72
- max_w = max(bow.values())
73
- for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True):
74
- bar = "█" * int(v / max_w * width)
75
- print(f"{k[:25]:25} | {bar} {v:.2f}")
76
- print("\n")
77
-
78
-
79
- def ascii_header(title, width=70):
80
- title = f" {title} "
81
- print("+" + "-" * (width - 2) + "+")
82
- print("|" + title.center(width - 2) + "|")
83
- print("+" + "-" * (width - 2) + "+")
84
- print("\n")
85
-
86
-
87
- def similarity(self, a, b) -> torch.Tensor:
88
- """
89
- MTEB eval requires this
90
- """
91
- if not isinstance(a, torch.Tensor):
92
- a = torch.tensor(a)
93
- if not isinstance(b, torch.Tensor):
94
- b = torch.tensor(b)
95
-
96
- def _dot_score_core(a_tensor, b_tensor):
97
- if len(a_tensor.shape) == 1:
98
- a_tensor = a_tensor.unsqueeze(0)
99
- if len(b_tensor.shape) == 1:
100
- b_tensor = b_tensor.unsqueeze(0)
101
- return a_tensor @ b_tensor.transpose(0, 1)
102
-
103
- return _dot_score_core(a, b)
104
-
105
-
106
- def prepare_tokenizer(tokenizer_name: str, padding_side="right"):
107
- """
108
- loads and prepares tokenizer
109
- """
110
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
111
- tokenizer.pad_token = (
112
- tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
113
- )
114
- tokenizer.padding_side = padding_side
115
- return tokenizer
116
-
117
-
118
- def get_decoder_model(
119
- model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg, token=None
120
- ):
121
- """
122
- base_cfg is the pretrained config of the underlying model
123
- """
124
- print("WARNING: bidirectional only tested for transformer 4.51.2")
125
- assert (
126
- bidirectional is True
127
- ), "the model has been trained with bi-directional attention!"
128
- assert (
129
- attn_implementation == "flash_attention_2"
130
- ), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!"
131
- from .modeling_qwen3_bidir import Qwen3BidirForCausalLM
132
-
133
- return Qwen3BidirForCausalLM.from_pretrained(
134
- model_name_or_path,
135
- config=base_cfg,
136
- torch_dtype=torch.bfloat16,
137
- attn_implementation=attn_implementation,
138
- token=token,
139
- )
140
-
141
-
142
- def top_k(x: torch.Tensor, k: int) -> torch.Tensor:
143
- """
144
- zeroes out all but the top-k values in the last dimension of x
145
- """
146
- _, topk_indices = x.topk(k, dim=-1)
147
- # create a zero tensor of the same shape as x
148
- mask = torch.zeros_like(x, dtype=torch.bool)
149
- # use scatter along the last dimension
150
- mask.scatter_(-1, topk_indices, True)
151
- # zero out all but the top-k
152
- return x * mask