Tom Aarsen commited on
Commit
c4c1b0e
·
1 Parent(s): d71d855

Attempt to fix meta tensor loading error

Browse files
Files changed (1) hide show
  1. modeling_splade.py +22 -2
modeling_splade.py CHANGED
@@ -3,6 +3,26 @@ This file exists solely to allow loading the Qwen3ForCausalLM via the AutoModelF
3
  Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
4
  with `is_causal=False` in the config.
5
  """
6
- from transformers import Qwen3ForCausalLM
7
 
8
- __all__ = ["Qwen3ForCausalLM"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
4
  with `is_causal=False` in the config.
5
  """
 
6
 
7
+ from transformers import Qwen3ForCausalLM as _Qwen3ForCausalLM
8
+
9
+
10
+ class Qwen3ForCausalLM(_Qwen3ForCausalLM):
11
+ def tie_weights(self, *args, **kwargs):
12
+ """Explicitly re-tie lm_head to embed_tokens to hopefully avoid meta tensor errors."""
13
+ super().tie_weights(*args, **kwargs)
14
+ if (
15
+ self.config.tie_word_embeddings
16
+ and hasattr(self, "lm_head")
17
+ and hasattr(self, "model")
18
+ ):
19
+ self.lm_head.weight = self.model.embed_tokens.weight
20
+
21
+ def _init_weights(self, module):
22
+ """Skip lm_head init when it will be tied to embed_tokens later."""
23
+ if module is getattr(self, "lm_head", None) and self.config.tie_word_embeddings:
24
+ return
25
+ super()._init_weights(module)
26
+
27
+
28
+ __all__ = ["Qwen3ForCausalLM"]