kashif HF Staff commited on
Commit
dfa9ac6
·
verified ·
1 Parent(s): 408f916

fix: align RotaryEmbedding with Qwen2Moe pattern for transformers compat

Browse files

Replace custom rope_type handling with the standard Qwen2Moe pattern: use compute_default_rope_parameters for "default" type, only look up ROPE_INIT_FUNCTIONS for non-default types. Also adds partial_rotary_factor support.

Files changed (1) hide show
  1. modeling_llada2_moe.py +27 -10
modeling_llada2_moe.py CHANGED
@@ -92,24 +92,41 @@ ALL_LAYERNORM_LAYERS.append(LLaDA2MoeRMSNorm)
92
 
93
 
94
  class LLaDA2MoeRotaryEmbedding(nn.Module):
 
 
95
  def __init__(self, config: LLaDA2MoeConfig, device=None):
96
  super().__init__()
97
- # BC: "rope_type" was originally "type"
98
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
99
- self.rope_type = config.rope_scaling.get(
100
- "rope_type", config.rope_scaling.get("type")
101
- )
102
- else:
103
- self.rope_type = "default"
104
  self.max_seq_len_cached = config.max_position_embeddings
105
  self.original_max_seq_len = config.max_position_embeddings
106
 
107
  self.config = config
108
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
109
 
110
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
111
  self.register_buffer("inv_freq", inv_freq, persistent=False)
112
- self.original_inv_freq = self.inv_freq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  @torch.no_grad()
115
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
 
92
 
93
 
94
  class LLaDA2MoeRotaryEmbedding(nn.Module):
95
+ inv_freq: torch.Tensor # fix linting for register_buffer
96
+
97
  def __init__(self, config: LLaDA2MoeConfig, device=None):
98
  super().__init__()
 
 
 
 
 
 
 
99
  self.max_seq_len_cached = config.max_position_embeddings
100
  self.original_max_seq_len = config.max_position_embeddings
101
 
102
  self.config = config
 
103
 
104
+ self.rope_type = self.config.rope_parameters["rope_type"]
105
+ rope_init_fn: Callable = self.compute_default_rope_parameters
106
+ if self.rope_type != "default":
107
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
108
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
109
+
110
  self.register_buffer("inv_freq", inv_freq, persistent=False)
111
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
112
+
113
+ @staticmethod
114
+ def compute_default_rope_parameters(
115
+ config: LLaDA2MoeConfig = None,
116
+ device=None,
117
+ seq_len: int = None,
118
+ ):
119
+ base = config.rope_parameters["rope_theta"]
120
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
121
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
122
+ dim = int(head_dim * partial_rotary_factor)
123
+
124
+ attention_factor = 1.0 # Unused in this type of RoPE
125
+
126
+ inv_freq = 1.0 / (
127
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
128
+ )
129
+ return inv_freq, attention_factor
130
 
131
  @torch.no_grad()
132
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)