splade-code-8B / splade.py
Tom Aarsen
Simplify and integrate with Sentence Transformers
6b5509e
raw
history blame
5.82 kB
"""
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
This file supports two loading paths:
1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B; `Qwen3ForCausalLM.from_pretrained`
loads the base model and applies the adapter.
"""
import torch
from transformers import Qwen3ForCausalLM as TransformersQwen3ForCausalLM
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig
from transformers.utils import is_flash_attn_2_available
from .utils import prepare_tokenizer, splade_max, similarity, encode
class Qwen3ForCausalLM(TransformersQwen3ForCausalLM):
def tie_weights(self, *args, **kwargs):
"""Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
if (
self.config.tie_word_embeddings
and hasattr(self, "lm_head")
and hasattr(self, "model")
):
self.lm_head.weight = self.model.embed_tokens.weight
missing_keys = kwargs.get("missing_keys")
if missing_keys is not None:
missing_keys.discard("lm_head.weight")
else:
super().tie_weights(*args, **kwargs)
def _init_weights(self, module):
"""Skip lm_head init when it will be tied to embed_tokens later."""
if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
return
super()._init_weights(module)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
from peft import PeftConfig, PeftModel
try:
peft_config = PeftConfig.from_pretrained(
pretrained_model_name_or_path, token=kwargs.get("token")
)
except Exception:
peft_config = None
if peft_config is None:
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
# Use provided splade config (has is_causal=False) or load it from the adapter repo
config = kwargs.pop("config", None)
if config is None or not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, token=kwargs.get("token")
)
# We apply the adapter manually below, so drop any auto-PEFT hints to avoid double loading
kwargs.pop("adapter_kwargs", None)
base_model = super().from_pretrained(
peft_config.base_model_name_or_path,
*model_args,
config=config,
**kwargs,
)
return PeftModel.from_pretrained(
base_model, pretrained_model_name_or_path, token=kwargs.get("token")
)
class SpladeConfig(PretrainedConfig):
model_type = "qwen3"
def __init__(
self,
model_name_or_path: str = "Qwen/Qwen3-8B",
attn_implementation: str = "flash_attention_2",
bidirectional: bool = True, # only for decoder models
padding_side: str = "left",
**kwargs,
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.attn_implementation = attn_implementation
self.bidirectional = bidirectional
self.padding_side = padding_side
class Splade(PreTrainedModel):
config_class = SpladeConfig
# methods for MTEB's interface
similarity = similarity
encode = encode
def __init__(self, config, weights_path=None, token=None):
super().__init__(config)
self.name = "splade"
base_cfg = AutoConfig.from_pretrained(
weights_path,
attn_implementation=config.attn_implementation,
torch_dtype="auto",
token=token,
)
self.tokenizer = prepare_tokenizer(
weights_path, padding_side=config.padding_side
)
if is_flash_attn_2_available():
config.attn_implementation = "flash_attention_2"
else:
config.attn_implementation = "sdpa"
self.model = Qwen3ForCausalLM.from_pretrained(
weights_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=config.attn_implementation,
token=token,
)
def save_pretrained(self, save_directory, *args, **kwargs):
self.model.save_pretrained(save_directory)
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
token = kwargs.get("token", None)
config = SpladeConfig.from_pretrained(
model_name_or_path,
token=token,
)
model = cls(config, weights_path=model_name_or_path, token=token)
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
return model
def forward(self, **tokens):
output = self.model(**tokens)
splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
return (splade_reps,)
def get_width(self):
return self.model.config.vocab_size
def create_batch_dict(self, input_texts, max_length):
return self.tokenizer(
input_texts,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
)
__all__ = ["Qwen3ForCausalLM", "Splade"]