BucketOfFish commited on
Commit
3c52426
·
1 Parent(s): 5e8c4af

Simplified rotary embedding

Browse files
Files changed (1) hide show
  1. modeling_phi.py +88 -206
modeling_phi.py CHANGED
@@ -72,211 +72,106 @@ class Embedding(nn.Module):
72
  return hidden_states
73
 
74
 
75
- def _apply_rotary_emb(
76
- x: torch.FloatTensor,
77
- cos: torch.FloatTensor,
78
- sin: torch.FloatTensor,
79
- ) -> torch.FloatTensor:
80
- _, seqlen, _, _ = x.shape
81
- _, rotary_dim = cos.shape
82
- rotary_dim *= 2
83
-
84
- x_rot = x[:, :, :, :rotary_dim]
85
- x_pass = x[:, :, :, rotary_dim:]
86
-
87
- x1, x2 = x_rot.chunk(2, dim=-1)
88
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
89
- x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]]
90
-
91
- x_rot = torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], axis=-1).to(x.dtype)
92
-
93
- return torch.cat([x_rot, x_pass], axis=-1)
94
-
95
-
96
- def _apply_rotary_emb_kv(
97
- kv: torch.FloatTensor,
98
- cos: torch.FloatTensor,
99
- sin: torch.FloatTensor,
100
- cos_k: Optional[torch.FloatTensor] = None,
101
- sin_k: Optional[torch.FloatTensor] = None,
102
- ) -> torch.FloatTensor:
103
- _, seqlen, _, _, _ = kv.shape
104
- _, rotary_dim = cos.shape
105
- rotary_dim *= 2
106
-
107
- k_rot = kv[:, :, 0, :, :rotary_dim]
108
- k_pass = kv[:, :, 0, :, rotary_dim:]
109
-
110
- k1, k2 = k_rot.chunk(2, dim=-1)
111
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
112
- k1, k2, c, s = [t.to(dtype=torch.float32) for t in [k1, k2, c, s]]
113
-
114
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(kv.dtype)
115
-
116
- return torch.cat(
117
- [
118
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
119
- kv[:, :, 1:2, :, :],
120
- ],
121
- axis=2,
122
- )
123
-
124
-
125
- def _apply_rotary_emb_qkv(
126
- qkv: torch.FloatTensor,
127
- cos: torch.FloatTensor,
128
- sin: torch.FloatTensor,
129
- cos_k: Optional[torch.FloatTensor] = None,
130
- sin_k: Optional[torch.FloatTensor] = None,
131
- ) -> torch.FloatTensor:
132
- _, seqlen, _, _, _ = qkv.shape
133
- _, rotary_dim = cos.shape
134
- rotary_dim *= 2
135
-
136
- q_rot = qkv[:, :, 0, :, :rotary_dim]
137
- q_pass = qkv[:, :, 0, :, rotary_dim:]
138
-
139
- k_rot = qkv[:, :, 1, :, :rotary_dim]
140
- k_pass = qkv[:, :, 1, :, rotary_dim:]
141
-
142
- q1, q2 = q_rot.chunk(2, dim=-1)
143
- k1, k2 = k_rot.chunk(2, dim=-1)
144
- c, s = rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d")
145
- q1, q2, k1, k2, c, s = [t.to(dtype=torch.float32) for t in [q1, q2, k1, k2, c, s]]
146
-
147
- q_rot = torch.cat([q1 * c - q2 * s, q1 * s + q2 * c], axis=-1).to(qkv.dtype)
148
- k_rot = torch.cat([k1 * c - k2 * s, k1 * s + k2 * c], axis=-1).to(qkv.dtype)
149
-
150
- return torch.cat(
151
- [
152
- torch.cat([q_rot, q_pass], axis=-1).unsqueeze(2),
153
- torch.cat([k_rot, k_pass], axis=-1).unsqueeze(2),
154
- qkv[:, :, 2:3, :, :],
155
- ],
156
- axis=2,
157
- )
158
-
159
-
160
  class RotaryEmbedding(nn.Module):
161
- """Rotary positional embedding (RoPE).
162
-
163
- Reference:
164
- RoFormer: Enhanced Transformer with Rotary Position Embedding.
165
- https://arxiv.org/pdf/2104.09864.pdf.
166
-
167
  """
168
 
169
  def __init__(
170
  self,
171
- dim: int,
172
- base: int = 10000,
173
- scale_base: Optional[float] = None,
174
- pos_idx_in_fp32: bool = True,
175
- max_position_embeddings: int = 2048,
176
- device: Optional[str] = None,
177
- **kwargs,
178
  ) -> None:
179
  super().__init__()
180
-
181
- if scale_base is not None:
182
- raise NotImplementedError
183
-
184
- self.dim = dim
185
- self.base = float(base)
186
- self.scale_base = scale_base
187
- self.pos_idx_in_fp32 = pos_idx_in_fp32
188
- self.max_position_embeddings = max_position_embeddings
189
  self.device = device
190
-
191
- # Generate and save the inverse frequency buffer (non-trainable)
192
- inv_freq = self._compute_inv_freq(device)
193
- self.register_buffer("inv_freq", inv_freq, persistent=False)
194
-
195
- # Generate and save the scale buffer (non-trainable)
196
- scale = (
197
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
198
- if scale_base is not None
199
- else None
 
 
200
  )
201
- self.register_buffer("scale", scale, persistent=False)
202
-
203
- # Initialize cached attributes since ONNX can't rely on dynamic initialization
204
- self._update_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.float32)
205
-
206
- def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
207
- return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
208
-
209
- def _update_cos_sin_cache(
210
- self,
211
- seqlen: int,
212
- device: Optional[str] = None,
213
- dtype: Optional[torch.dtype] = None,
214
- ) -> None:
215
- self._seq_len_cached = seqlen
216
-
217
- # fp32 is preferred since the output of `torch.arange` can be quite large
218
- # and bf16 would lose a lot of precision
219
- if self.pos_idx_in_fp32:
220
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
221
- if self.inv_freq.dtype != torch.float32:
222
- inv_freq = self._compute_inv_freq(device=device)
223
- else:
224
- inv_freq = self.inv_freq
225
- else:
226
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
227
- inv_freq = self.inv_freq
228
-
229
- # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
230
- freqs = torch.outer(t, inv_freq)
231
- if self.scale is None:
232
- self._cos_cached = torch.cos(freqs).to(dtype)
233
- self._sin_cached = torch.sin(freqs).to(dtype)
234
- else:
235
  power = (
236
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
237
- ) / self.scale_base
238
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
239
-
240
- # Force the scale multiplication to happen in fp32
241
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
242
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
243
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
244
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  def forward(
247
  self,
248
- qkv: torch.Tensor,
249
- kv: Optional[torch.Tensor] = None,
250
- seqlen_offset: int = 0,
251
- **kwargs,
252
- ) -> Tuple[torch.Tensor, torch.Tensor]:
253
  if (
254
- self._seq_len_cached < qkv.shape[1] + seqlen_offset
255
- or self._cos_cached.device != qkv.device
256
- or self._cos_cached.dtype != qkv.dtype
 
257
  or (self.training and self._cos_cached.is_inference())
258
  ):
259
- self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
260
-
261
- if kv is None:
262
- return _apply_rotary_emb_qkv(
263
- qkv,
264
- self._cos_cached[seqlen_offset:],
265
- self._sin_cached[seqlen_offset:],
266
- )
267
- else:
268
- q = _apply_rotary_emb(
269
- qkv,
270
- self._cos_cached[seqlen_offset:],
271
- self._sin_cached[seqlen_offset:],
272
- )
273
- kv = _apply_rotary_emb_kv(
274
- kv,
275
- self._cos_cached[seqlen_offset:],
276
- self._sin_cached[seqlen_offset:],
277
- )
278
-
279
- return q, kv
280
 
281
 
282
  class MLP(nn.Module):
@@ -519,23 +414,10 @@ class MHA(nn.Module):
519
  super().__init__()
520
 
521
  # Rotary embedding
522
- self.rotary_dim = rotary_dim if rotary_dim is not None else getattr(config, "rotary_dim", 0)
523
- if self.rotary_dim > 0:
524
- rotary_cls = RotaryEmbedding
525
- if rotary_cls is None:
526
- rotary_cls = RotaryEmbedding
527
-
528
- rotary_kwargs = {}
529
- if rotary_cls is RotaryEmbedding:
530
- rotary_kwargs["max_position_embeddings"] = config.n_positions
531
-
532
- self.rotary_emb = rotary_cls(
533
- self.rotary_dim,
534
- base=rotary_base,
535
- scale_base=rotary_scale_base,
536
- device=device,
537
- **rotary_kwargs,
538
- )
539
 
540
  # MLP
541
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(
 
72
  return hidden_states
73
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  class RotaryEmbedding(nn.Module):
76
+ """Rotary positional embedding (RoPE) from Phi2.
77
+ See https://www.youtube.com/watch?v=C6rV8BsrrCc
 
 
 
 
78
  """
79
 
80
  def __init__(
81
  self,
82
+ d_rotary: int,
83
+ rotary_base: float = 10000.0,
84
+ initial_cos_sin_cache_len: int = 2048,
85
+ device: torch.device | None = None,
 
 
 
86
  ) -> None:
87
  super().__init__()
88
+ self.d_rotary = d_rotary
89
+ self.rotary_base = rotary_base
 
 
 
 
 
 
 
90
  self.device = device
91
+ self.dtype = torch.float32
92
+ self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
93
+
94
+ def _update_cos_sin_cache(self, seqlen: int) -> None:
95
+ # only call this function when seqlen is larger than _max_seqlen
96
+ self._max_seqlen = seqlen
97
+
98
+ # m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
99
+ m = torch.arange(
100
+ seqlen,
101
+ device=self.device,
102
+ dtype=self.dtype,
103
  )
104
+ theta_i = 1.0 / (
105
+ self.rotary_base ** (
106
+ torch.arange(
107
+ start=0,
108
+ end=self.d_rotary,
109
+ step=2,
110
+ device=self.device,
111
+ dtype=self.dtype,
112
+ ) / self.d_rotary
113
+ )
114
+ )
115
+ # torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
116
+ # TODO: does this matter if I'm disabling torch.autocast?
117
+ m_theta_i = torch.outer(m, theta_i)
118
+ self._cos_cached = torch.cos(m_theta_i).to(self.dtype)
119
+ self._sin_cached = torch.sin(m_theta_i).to(self.dtype)
120
+
121
+ # TODO: scale_base caching is labelled as not yet done in Phi2
122
+ """
123
+ if scale_base is not None:
124
+ scale = (
125
+ torch.arange(
126
+ start=0,
127
+ end=self.d_rotary,
128
+ step=2,
129
+ device=self.device,
130
+ dtype=torch.float32,
131
+ ) + 0.4 * self.d_rotary
132
+ ) / (1.4 * self.d_rotary)
 
 
 
 
 
133
  power = (
134
+ torch.arange(seqlen, dtype=scale.dtype, device=scale.device) - seqlen // 2
135
+ ) / scale_base
136
+ scale = scale.to(device=power.device) ** rearrange(power, "s -> s 1")
137
+ self._cos_cached = (torch.cos(m_theta_i) * scale).to(dtype)
138
+ self._sin_cached = (torch.sin(m_theta_i) * scale).to(dtype)
139
+ """
140
+
141
+ def _apply_rotary_emb_qkv(
142
+ self,
143
+ x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
144
+ cos: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
145
+ sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
146
+ ) -> torch.FloatTensor:
147
+ seqlen = x.shape[1]
148
+ x1, x2 = x.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head/2)
149
+ broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
150
+ c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
151
+ x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
152
+ return cast(
153
+ torch.FloatTensor,
154
+ torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
155
+ )
156
 
157
  def forward(
158
  self,
159
+ x: torch.FloatTensor, # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_head)
160
+ seqlen_offset: int = 0, # each sequence is shifted by this amount - used in inference with KV cache
161
+ ) -> torch.FloatTensor:
 
 
162
  if (
163
+ not self._max_seqlen
164
+ or self._max_seqlen < x.shape[1] + seqlen_offset
165
+ or self._cos_cached.device != x.device
166
+ or self._cos_cached.dtype != x.dtype
167
  or (self.training and self._cos_cached.is_inference())
168
  ):
169
+ self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
170
+ return self._apply_rotary_emb_qkv(
171
+ x,
172
+ cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
173
+ cast(torch.FloatTensor, self._sin_cached[seqlen_offset:]),
174
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
 
177
  class MLP(nn.Module):
 
414
  super().__init__()
415
 
416
  # Rotary embedding
417
+ self.rotary_emb = RotaryEmbedding(
418
+ d_rotary=math.ceil((rotary_dim // n_head) / 2), # d_rotary is half of d_head
419
+ initial_cos_sin_cache_len=config.n_positions,
420
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
  # MLP
423
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(