Resolving inference compatibility issues in the Kormo model’s Transformer 5.2

#3
Files changed (1) hide show
  1. _modeling_kormo.py +30 -7
_modeling_kormo.py CHANGED
@@ -95,6 +95,12 @@ def rotate_half(x):
95
  x2 = x[..., x.shape[-1] // 2 :]
96
  return torch.cat((-x2, x1), dim=-1)
97
 
 
 
 
 
 
 
98
  class Attention(nn.Module):
99
  """Multi-headed attention from 'Attention Is All You Need' paper"""
100
 
@@ -237,12 +243,25 @@ class RotaryEmbedding(nn.Module):
237
  self.original_max_seq_len = config.max_position_embeddings
238
 
239
  self.config = config
240
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
 
241
 
242
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
243
  self.register_buffer("inv_freq", inv_freq, persistent=False)
244
- self.original_inv_freq = self.inv_freq
 
 
 
 
 
245
 
 
 
 
 
 
 
246
  @torch.no_grad()
247
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
248
  def forward(self, x, position_ids):
@@ -258,10 +277,6 @@ class RotaryEmbedding(nn.Module):
258
  return cos, sin
259
 
260
 
261
-
262
-
263
-
264
-
265
  class KORMoPreTrainedModel(PreTrainedModel):
266
  config_class = KORMoConfig
267
  base_model_prefix = "model"
@@ -289,7 +304,15 @@ class KORMoPreTrainedModel(PreTrainedModel):
289
  module.weight.data[module.padding_idx].zero_()
290
  elif isinstance(module, RMSNorm):
291
  module.weight.data.fill_(1.0)
292
-
 
 
 
 
 
 
 
 
293
 
294
  class KORMoModel(KORMoPreTrainedModel):
295
  def __init__(self, config: KORMoConfig):
 
95
  x2 = x[..., x.shape[-1] // 2 :]
96
  return torch.cat((-x2, x1), dim=-1)
97
 
98
+ def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
99
+ if not getattr(tensor, "_is_hf_initialized", False):
100
+ with torch.no_grad():
101
+ return tensor.copy_(other)
102
+ return tensor
103
+
104
  class Attention(nn.Module):
105
  """Multi-headed attention from 'Attention Is All You Need' paper"""
106
 
 
243
  self.original_max_seq_len = config.max_position_embeddings
244
 
245
  self.config = config
246
+ rope_init_fn = self.compute_default_rope_parameters
247
+ rope_init_fn = ROPE_INIT_FUNCTIONS.get(self.rope_type, rope_init_fn)
248
+ self.rope_init_fn = rope_init_fn
249
 
250
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
251
  self.register_buffer("inv_freq", inv_freq, persistent=False)
252
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
253
+
254
+ @staticmethod
255
+ def compute_default_rope_parameters(config: KORMoConfig, device=None, seq_len =None):
256
+ base = config.rope_theta
257
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
258
 
259
+ attention_factor = 1.0
260
+ inv_freq = 1.0 / (
261
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
262
+ )
263
+ return inv_freq, attention_factor
264
+
265
  @torch.no_grad()
266
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
267
  def forward(self, x, position_ids):
 
277
  return cos, sin
278
 
279
 
 
 
 
 
280
  class KORMoPreTrainedModel(PreTrainedModel):
281
  config_class = KORMoConfig
282
  base_model_prefix = "model"
 
304
  module.weight.data[module.padding_idx].zero_()
305
  elif isinstance(module, RMSNorm):
306
  module.weight.data.fill_(1.0)
307
+ elif "RotaryEmbedding" in module.__class__.__name__ and hasattr(module, "original_inv_freq"):
308
+ rope_fn = (
309
+ ROPE_INIT_FUNCTIONS[module.rope_type]
310
+ if module.rope_type != "default"
311
+ else module.compute_default_rope_parameters
312
+ )
313
+ buffer_value, _ = rope_fn(module.config)
314
+ copy_(module.inv_freq, buffer_value)
315
+ copy_(module.original_inv_freq, buffer_value)
316
 
317
  class KORMoModel(KORMoPreTrainedModel):
318
  def __init__(self, config: KORMoConfig):