File size: 6,609 Bytes
222c67d 0cf3c13 222c67d 0cf3c13 222c67d 0cf3c13 222c67d 5afcf9e 222c67d 0cf3c13 222c67d c7dfa27 222c67d 0cf3c13 c7dfa27 222c67d c7dfa27 222c67d c7dfa27 222c67d c7dfa27 222c67d 0cf3c13 222c67d c7dfa27 5afcf9e 222c67d 5afcf9e ad75b95 5afcf9e ad75b95 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 222c67d 5afcf9e 0cf3c13 5afcf9e 222c67d 789894d 222c67d 789894d 222c67d 5afcf9e 222c67d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | """
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B in the `lora/` subfolder;
`Qwen3ForCausalLM.from_pretrained` loads the base model and applies the adapter.
"""
import os
import torch
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode
# The adapter lives in this subfolder rather than at the repo root so that
# `find_adapter_config_file` doesn't trigger transformers' auto-PEFT path,
# which would otherwise redirect hub loads to `Qwen/Qwen3-8B` and lose the
# `auto_map` routing to the classes in this file.
ADAPTER_SUBFOLDER = "lora"
class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
def tie_weights(self, *args, **kwargs):
"""Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
if (
self.config.tie_word_embeddings
and hasattr(self, "lm_head")
and hasattr(self, "model")
):
self.lm_head.weight = self.model.embed_tokens.weight
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard("lm_head.weight")
else:
super().tie_weights(*args, **kwargs)
def _init_weights(self, module):
"""Skip lm_head init when it will be tied to embed_tokens later."""
if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
return
super()._init_weights(module)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from huggingface_hub import snapshot_download
from peft import PeftConfig, PeftModel
token = kwargs.get("token")
# Resolve the adapter to a local path before handing it to PEFT.
# PEFT's `subfolder=` kwarg uses `os.path.join` on Windows, producing
# backslashed hub paths that break the safetensors-vs-bin fallback.
if os.path.isdir(pretrained_model_name_or_path):
adapter_path = os.path.join(pretrained_model_name_or_path, ADAPTER_SUBFOLDER)
else:
local_repo = snapshot_download(
pretrained_model_name_or_path,
allow_patterns=[f"{ADAPTER_SUBFOLDER}/*"],
token=token,
)
adapter_path = os.path.join(local_repo, ADAPTER_SUBFOLDER)
if not os.path.isfile(os.path.join(adapter_path, "adapter_config.json")):
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
peft_config = PeftConfig.from_pretrained(adapter_path, token=token)
# Use provided splade config (has is_causal=False) or load it from the adapter repo
config = kwargs.pop("config", None)
if config is None or not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, token=token)
base_model = super().from_pretrained(
peft_config.base_model_name_or_path,
*model_args,
config=config,
**kwargs,
)
return PeftModel.from_pretrained(base_model, adapter_path, token=token)
class SpladeConfig(PretrainedConfig):
model_type = "qwen3"
def __init__(
self,
model_name_or_path: str = "Qwen/Qwen3-8B",
attn_implementation: str = "flash_attention_2",
bidirectional: bool = True, # only for decoder models
padding_side: str = "left",
**kwargs,
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.attn_implementation = attn_implementation
self.bidirectional = bidirectional
self.padding_side = padding_side
class Splade(PreTrainedModel):
config_class = SpladeConfig
# methods for MTEB's interface
similarity = similarity
encode = encode
def __init__(self, config, weights_path=None, token=None):
super().__init__(config)
self.name = "splade"
base_cfg = AutoConfig.from_pretrained(
weights_path,
attn_implementation=config.attn_implementation,
torch_dtype="auto",
token=token,
)
self.tokenizer = prepare_tokenizer(
weights_path, padding_side=config.padding_side
)
if is_flash_attn_2_available():
config.attn_implementation = "flash_attention_2"
else:
config.attn_implementation = "sdpa"
self.model = Qwen3ForCausalLM.from_pretrained(
weights_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=config.attn_implementation,
token=token,
)
def save_pretrained(self, save_directory, *args, **kwargs):
self.model.save_pretrained(os.path.join(save_directory, ADAPTER_SUBFOLDER))
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
token = kwargs.get("token", None)
config = SpladeConfig.from_pretrained(
model_name_or_path,
token=token,
)
model = cls(config, weights_path=model_name_or_path, token=token)
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
return model
def forward(self, **tokens):
output = self.model(**tokens)
splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
return (splade_reps,)
def get_width(self):
return self.model.config.vocab_size
def create_batch_dict(self, input_texts, max_length):
return self.tokenizer(
input_texts,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
)
__all__ = ["Qwen3ForCausalLM", "Splade"]
|