Antreas commited on
Commit
08ab6b6
·
verified ·
1 Parent(s): 6422108

Enable AutoModel loading

Browse files
Files changed (1) hide show
  1. embeddings.py +2 -2
embeddings.py CHANGED
@@ -105,8 +105,8 @@ class RotaryPositionalEncoding(nn.Module):
105
  seq_len = x.shape[1]
106
  cos: torch.Tensor = self.cos_cached # type: ignore[assignment]
107
  sin: torch.Tensor = self.sin_cached # type: ignore[assignment]
108
- if seq_len > cos.shape[0]:
109
- self._build_cache(seq_len)
110
  cos = self.cos_cached # type: ignore[assignment]
111
  sin = self.sin_cached # type: ignore[assignment]
112
  return cos[:seq_len], sin[:seq_len]
 
105
  seq_len = x.shape[1]
106
  cos: torch.Tensor = self.cos_cached # type: ignore[assignment]
107
  sin: torch.Tensor = self.sin_cached # type: ignore[assignment]
108
+ if seq_len > cos.shape[0] or not torch.isfinite(cos[:seq_len]).all():
109
+ self._build_cache(max(seq_len, cos.shape[0]))
110
  cos = self.cos_cached # type: ignore[assignment]
111
  sin = self.sin_cached # type: ignore[assignment]
112
  return cos[:seq_len], sin[:seq_len]