File size: 4,382 Bytes
b8b90b5
 
 
 
 
 
 
 
 
 
eeb94eb
b8b90b5
 
eeb94eb
b8b90b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeb94eb
 
 
b8b90b5
eeb94eb
 
 
62b3811
eeb94eb
 
62b3811
eeb94eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1095893
eeb94eb
 
b8b90b5
eeb94eb
b8b90b5
eeb94eb
 
 
b8b90b5
eeb94eb
b8b90b5
eeb94eb
b8b90b5
eeb94eb
 
 
 
a1702e0
b8b90b5
 
 
 
eeb94eb
b8b90b5
eeb94eb
 
 
 
 
 
 
 
3a87604
b8b90b5
3a87604
 
 
 
b8b90b5
1095893
b8b90b5
eeb94eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8b90b5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.

This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-06B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-06B", trust_remote_code=True)` -> Splade
"""

import torch
import os
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode


class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
    def tie_weights(self, *args, **kwargs):
        """Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
        if (
            self.config.tie_word_embeddings
            and hasattr(self, "lm_head")
            and hasattr(self, "model")
        ):
            self.lm_head.weight = self.model.embed_tokens.weight
            missing_keys = kwargs.get("missing_keys")
            if missing_keys is not None:
                missing_keys.discard("lm_head.weight")
        else:
            super().tie_weights(*args, **kwargs)

    def _init_weights(self, module):
        """Skip lm_head init when it will be tied to embed_tokens later."""
        if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
            return
        super()._init_weights(module)


class SpladeConfig(PretrainedConfig):
    model_type = "qwen3"

    def __init__(
        self,
        model_name_or_path: str = "Qwen/Qwen3-0.6B",
        attn_implementation: str = "flash_attention_2",
        bidirectional: bool = True,  # only for decoder models
        padding_side: str = "left",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_name_or_path = model_name_or_path
        self.attn_implementation = attn_implementation
        self.bidirectional = bidirectional
        self.padding_side = padding_side


class Splade(PreTrainedModel):
    config_class = SpladeConfig

    # methods for MTEB's interface
    similarity = similarity
    encode = encode

    def __init__(self, config, weights_path=None, token=None):
        super().__init__(config)
        self.name = "splade"

        base_cfg = AutoConfig.from_pretrained(
            weights_path,
            attn_implementation=config.attn_implementation,
            torch_dtype="auto",
        )

        self.tokenizer = prepare_tokenizer(
            weights_path, padding_side=config.padding_side
        )

        if is_flash_attn_2_available():
            config.attn_implementation = "flash_attention_2"
        else:
            config.attn_implementation = "sdpa"

        self.model = Qwen3ForCausalLM.from_pretrained(
            weights_path,
            config=base_cfg,
            torch_dtype=torch.bfloat16,
            attn_implementation=config.attn_implementation,
            token=token,
        )

    def save_pretrained(self, save_directory, *args, **kwargs):
        self.model.save_pretrained(os.path.join(save_directory, "lora"))
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, model_name_or_path, *args, **kwargs):
        token = kwargs.get("token", None)

        config = SpladeConfig.from_pretrained(
            model_name_or_path,
            token=token,
        )

        model = cls(config, weights_path=model_name_or_path, token=token)

        model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
        return model

    def forward(self, **tokens):
        output = self.model(**tokens)
        splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
        return (splade_reps,)

    def get_width(self):
        return self.model.config.vocab_size

    def create_batch_dict(self, input_texts, max_length):
        return self.tokenizer(
            input_texts,
            add_special_tokens=True,
            padding="longest",
            truncation=True,
            max_length=max_length,
            return_attention_mask=True,
            return_tensors="pt",
        )


__all__ = ["Qwen3ForCausalLM", "Splade"]