splade-code-06B / modeling_splade.py
Tom Aarsen
Attempt to fix meta tensor loading error
c4c1b0e
raw
history blame
1.05 kB
"""
This file exists solely to allow loading the Qwen3ForCausalLM via the AutoModelForMaskedLM class.
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
"""
from transformers import Qwen3ForCausalLM as _Qwen3ForCausalLM
class Qwen3ForCausalLM(_Qwen3ForCausalLM):
def tie_weights(self, *args, **kwargs):
"""Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
super().tie_weights(*args, **kwargs)
if (
self.config.tie_word_embeddings
and hasattr(self, "lm_head")
and hasattr(self, "model")
):
self.lm_head.weight = self.model.embed_tokens.weight
def _init_weights(self, module):
"""Skip lm_head init when it will be tied to embed_tokens later."""
if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
return
super()._init_weights(module)
__all__ = ["Qwen3ForCausalLM"]