File size: 6,609 Bytes
222c67d
 
 
 
 
0cf3c13
222c67d
 
0cf3c13
 
222c67d
 
0cf3c13
 
222c67d
 
 
5afcf9e
222c67d
 
0cf3c13
 
 
 
 
 
222c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7dfa27
222c67d
 
0cf3c13
c7dfa27
 
 
 
 
 
 
 
 
 
 
222c67d
c7dfa27
222c67d
c7dfa27
222c67d
 
c7dfa27
 
222c67d
 
 
0cf3c13
222c67d
 
 
 
 
 
 
 
c7dfa27
5afcf9e
 
 
222c67d
5afcf9e
 
 
ad75b95
5afcf9e
 
ad75b95
5afcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222c67d
5afcf9e
 
222c67d
5afcf9e
222c67d
5afcf9e
 
222c67d
5afcf9e
222c67d
5afcf9e
222c67d
5afcf9e
222c67d
5afcf9e
 
 
 
222c67d
 
 
 
 
5afcf9e
222c67d
5afcf9e
 
 
0cf3c13
5afcf9e
 
 
 
222c67d
 
 
789894d
222c67d
789894d
222c67d
 
 
5afcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222c67d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.

This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade

The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B in the `lora/` subfolder;
`Qwen3ForCausalLM.from_pretrained` loads the base model and applies the adapter.
"""

import os

import torch
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode

# The adapter lives in this subfolder rather than at the repo root so that
# `find_adapter_config_file` doesn't trigger transformers' auto-PEFT path,
# which would otherwise redirect hub loads to `Qwen/Qwen3-8B` and lose the
# `auto_map` routing to the classes in this file.
ADAPTER_SUBFOLDER = "lora"


class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
    def tie_weights(self, *args, **kwargs):
        """Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
        if (
            self.config.tie_word_embeddings
            and hasattr(self, "lm_head")
            and hasattr(self, "model")
        ):
            self.lm_head.weight = self.model.embed_tokens.weight
            missing_keys = kwargs.get("missing_keys")
            if missing_keys is not None:
                missing_keys.discard("lm_head.weight")
        else:
            super().tie_weights(*args, **kwargs)

    def _init_weights(self, module):
        """Skip lm_head init when it will be tied to embed_tokens later."""
        if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
            return
        super()._init_weights(module)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        from huggingface_hub import snapshot_download
        from peft import PeftConfig, PeftModel

        token = kwargs.get("token")

        # Resolve the adapter to a local path before handing it to PEFT.
        # PEFT's `subfolder=` kwarg uses `os.path.join` on Windows, producing
        # backslashed hub paths that break the safetensors-vs-bin fallback.
        if os.path.isdir(pretrained_model_name_or_path):
            adapter_path = os.path.join(pretrained_model_name_or_path, ADAPTER_SUBFOLDER)
        else:
            local_repo = snapshot_download(
                pretrained_model_name_or_path,
                allow_patterns=[f"{ADAPTER_SUBFOLDER}/*"],
                token=token,
            )
            adapter_path = os.path.join(local_repo, ADAPTER_SUBFOLDER)

        if not os.path.isfile(os.path.join(adapter_path, "adapter_config.json")):
            return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

        peft_config = PeftConfig.from_pretrained(adapter_path, token=token)

        # Use provided splade config (has is_causal=False) or load it from the adapter repo
        config = kwargs.pop("config", None)
        if config is None or not isinstance(config, PretrainedConfig):
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, token=token)

        base_model = super().from_pretrained(
            peft_config.base_model_name_or_path,
            *model_args,
            config=config,
            **kwargs,
        )

        return PeftModel.from_pretrained(base_model, adapter_path, token=token)


class SpladeConfig(PretrainedConfig):
    model_type = "qwen3"

    def __init__(
        self,
        model_name_or_path: str = "Qwen/Qwen3-8B",
        attn_implementation: str = "flash_attention_2",
        bidirectional: bool = True,  # only for decoder models
        padding_side: str = "left",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_name_or_path = model_name_or_path
        self.attn_implementation = attn_implementation
        self.bidirectional = bidirectional
        self.padding_side = padding_side


class Splade(PreTrainedModel):
    config_class = SpladeConfig

    # methods for MTEB's interface
    similarity = similarity
    encode = encode

    def __init__(self, config, weights_path=None, token=None):
        super().__init__(config)
        self.name = "splade"

        base_cfg = AutoConfig.from_pretrained(
            weights_path,
            attn_implementation=config.attn_implementation,
            torch_dtype="auto",
            token=token,
        )

        self.tokenizer = prepare_tokenizer(
            weights_path, padding_side=config.padding_side
        )

        if is_flash_attn_2_available():
            config.attn_implementation = "flash_attention_2"
        else:
            config.attn_implementation = "sdpa"

        self.model = Qwen3ForCausalLM.from_pretrained(
            weights_path,
            config=base_cfg,
            torch_dtype=torch.bfloat16,
            attn_implementation=config.attn_implementation,
            token=token,
        )

    def save_pretrained(self, save_directory, *args, **kwargs):
        self.model.save_pretrained(os.path.join(save_directory, ADAPTER_SUBFOLDER))
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, model_name_or_path, *args, **kwargs):
        token = kwargs.get("token", None)

        config = SpladeConfig.from_pretrained(
            model_name_or_path,
            token=token,
        )

        model = cls(config, weights_path=model_name_or_path, token=token)

        model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
        return model

    def forward(self, **tokens):
        output = self.model(**tokens)
        splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
        return (splade_reps,)

    def get_width(self):
        return self.model.config.vocab_size

    def create_batch_dict(self, input_texts, max_length):
        return self.tokenizer(
            input_texts,
            add_special_tokens=True,
            padding="longest",
            truncation=True,
            max_length=max_length,
            return_attention_mask=True,
            return_tensors="pt",
        )


__all__ = ["Qwen3ForCausalLM", "Splade"]