KitsuVp commited on
Commit
bfcc64d
·
verified ·
1 Parent(s): f4a1c0d

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +12 -2
modeling_neollm.py CHANGED
@@ -1406,18 +1406,28 @@ class NeoLLMRotaryEmbedding(nn.Module):
1406
  B = x.shape[0]
1407
  if position_ids.shape[0] != B:
1408
  position_ids = position_ids.expand(B, -1)
1409
-
1410
  device_type = (x.device.type
1411
  if isinstance(x.device.type, str) and x.device.type != "mps"
1412
  else "cpu")
 
 
 
 
 
 
 
 
1413
  inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32)
1414
-
1415
  with torch.autocast(device_type=device_type, enabled=False):
1416
  freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1)
1417
  * inv_freq.unsqueeze(0).unsqueeze(0))
1418
  emb = torch.cat((freqs, freqs), dim=-1)
1419
  cos = emb.cos() * self.attention_scaling
1420
  sin = emb.sin() * self.attention_scaling
 
 
1421
 
1422
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1423
 
 
1406
  B = x.shape[0]
1407
  if position_ids.shape[0] != B:
1408
  position_ids = position_ids.expand(B, -1)
1409
+
1410
  device_type = (x.device.type
1411
  if isinstance(x.device.type, str) and x.device.type != "mps"
1412
  else "cpu")
1413
+
1414
+ if self.inv_freq.device.type == "meta":
1415
+ inv_freq_data, _ = self.compute_default_rope_parameters(
1416
+ self.config, device=x.device
1417
+ )
1418
+ self.register_buffer("inv_freq", inv_freq_data, persistent=False)
1419
+ self.register_buffer("original_inv_freq", inv_freq_data.clone(), persistent=False)
1420
+
1421
  inv_freq = self.inv_freq.to(device=x.device, dtype=torch.float32)
1422
+
1423
  with torch.autocast(device_type=device_type, enabled=False):
1424
  freqs = (position_ids.to(dtype=torch.float32).unsqueeze(-1)
1425
  * inv_freq.unsqueeze(0).unsqueeze(0))
1426
  emb = torch.cat((freqs, freqs), dim=-1)
1427
  cos = emb.cos() * self.attention_scaling
1428
  sin = emb.sin() * self.attention_scaling
1429
+
1430
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1431
 
1432
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1433