jungsin3 commited on
Commit
5b4219a
·
verified ·
1 Parent(s): a8e6fcb

I propose modifying the KORMo modelling to ensure compatibility with both Transformers 4.57.1 and 5.2.

Browse files

In the case of RotaryEmbedding, the inv_freq value is calculated in the init and reused.
In Transformers 5.2, the model is loaded using the meta device, so this calculation does not take place. Consequently, in 5.2, logic was added to the _init_weights function to restore inv_freq via an else statement. In the case of KORMo, as it uses a custom _init_weights function, this logic was not applied, resulting in the issue where the RoPE value was not used during inference.
The following changes have been made to the code:

Added logic to restore inv_freq in _init_weights to KORMoPreTrainedModel.
Added the copy_ function used in _init_weights to the top of the file.
We resolved an issue where the original_inv_freq key value was not registered in _buffer by cloning the self.inv_freq value, which previously returned None because it was not calculated. (RotaryEmbedding)
We added the compute_default_rope_parameters function, which was missing in version 5.2. (RotaryEmbedding)
Compatible with both version 4.57.1 and version 5.2
Thank you.

Files changed (1) hide show
  1. _modeling_kormo.py +31 -3
_modeling_kormo.py CHANGED
@@ -94,7 +94,13 @@ def rotate_half(x):
94
  x1 = x[..., : x.shape[-1] // 2]
95
  x2 = x[..., x.shape[-1] // 2 :]
96
  return torch.cat((-x2, x1), dim=-1)
97
-
 
 
 
 
 
 
98
  class Attention(nn.Module):
99
  """Multi-headed attention from 'Attention Is All You Need' paper"""
100
 
@@ -237,11 +243,24 @@ class RotaryEmbedding(nn.Module):
237
  self.original_max_seq_len = config.max_position_embeddings
238
 
239
  self.config = config
240
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
241
 
242
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
243
  self.register_buffer("inv_freq", inv_freq, persistent=False)
244
- self.original_inv_freq = self.inv_freq
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  @torch.no_grad()
247
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -289,6 +308,15 @@ class KORMoPreTrainedModel(PreTrainedModel):
289
  module.weight.data[module.padding_idx].zero_()
290
  elif isinstance(module, RMSNorm):
291
  module.weight.data.fill_(1.0)
 
 
 
 
 
 
 
 
 
292
 
293
 
294
  class KORMoModel(KORMoPreTrainedModel):
 
94
  x1 = x[..., : x.shape[-1] // 2]
95
  x2 = x[..., x.shape[-1] // 2 :]
96
  return torch.cat((-x2, x1), dim=-1)
97
+
98
+ def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
99
+ if not getattr(tensor, "_is_hf_initialized", False):
100
+ with torch.no_grad():
101
+ return tensor.copy_(other)
102
+ return tensor
103
+
104
  class Attention(nn.Module):
105
  """Multi-headed attention from 'Attention Is All You Need' paper"""
106
 
 
243
  self.original_max_seq_len = config.max_position_embeddings
244
 
245
  self.config = config
246
+ rope_init_fn = self.compute_default_rope_parameters
247
+ rope_init_fn = ROPE_INIT_FUNCTIONS.get(self.rope_type, rope_init_fn)
248
+ self.rope_init_fn = rope_init_fn
249
 
250
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
251
  self.register_buffer("inv_freq", inv_freq, persistent=False)
252
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
253
+
254
+ @staticmethod
255
+ def compute_default_rope_parameters(config: KORMoConfig, device=None, seq_len =None):
256
+ base = config.rope_theta
257
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
258
+
259
+ attention_factor = 1.0
260
+ inv_freq = 1.0 / (
261
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
262
+ )
263
+ return inv_freq, attention_factor
264
 
265
  @torch.no_grad()
266
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
 
308
  module.weight.data[module.padding_idx].zero_()
309
  elif isinstance(module, RMSNorm):
310
  module.weight.data.fill_(1.0)
311
+ elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
312
+ rope_fn = (
313
+ ROPE_INIT_FUNCTIONS[module.rope_type]
314
+ if module.rope_type != "default"
315
+ else module.compute_default_rope_parameters
316
+ )
317
+ buffer_value, _ = rope_fn(module.config)
318
+ copy_(module.inv_freq, buffer_value)
319
+ copy_(module.original_inv_freq, buffer_value)
320
 
321
 
322
  class KORMoModel(KORMoPreTrainedModel):