[Bug] RoPE initialization for SWA layers modifies shared config object, causing incorrect rope_theta for non-SWA layers

#32
by sirusGray - opened

In the MiMoV2Model's init method, two separate MiMoV2FlashRotaryEmbedding instances are created: one for standard attention (rotary_emb) and one for Sliding Window Attention
(swa_rotary_emb). Both instances are passed the same config object by reference.

The MiMoV2FlashRotaryEmbedding's init method contains the following code:

class MiMoV2FlashRotaryEmbedding(nn.Module):
def init(self, config: MiMoV2FlashConfig, is_swa, device=None):
# ...
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
if is_swa:
self.config.rope_theta = config.swa_rope_theta
self.config.head_dim = config.swa_head_dim
# ...
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)

When swa_rotary_emb is initialized (is_swa=True), it directly modifies the passed-in config object (self.config.rope_theta = config.swa_rope_theta). Because this config object is
shared, this modification overwrites the original rope_theta value.

Sign up or log in to comment