ishanjmukherjee commited on
Commit
305b72a
·
1 Parent(s): 43539ed

Copy rotary from vortex; drop-in replace vortex.ops apply_rotary with flash_attn's apply_rotary

Browse files
Files changed (1) hide show
  1. rotary.py +547 -0
rotary.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied verbatim from vortex
2
+ # Copyright (c) 2023, Tri Dao.
3
+
4
+ from typing import Optional, Tuple, Union
5
+
6
+ import torch
7
+ from einops import rearrange, repeat
8
+
9
+ # Commenting this out from the original file in vortex; this is guaranteed to
10
+ # fail since we're not shipping ops
11
+ # from vortex.ops.embedding.rotary import apply_rotary
12
+ # Instead, we use flash_attn ops completely following https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
13
+ from flash_attn.ops.triton.rotary import apply_rotary
14
+
15
+ def rotate_half(x, interleaved=False):
16
+ if not interleaved:
17
+ x1, x2 = x.chunk(2, dim=-1)
18
+ return torch.cat((-x2, x1), dim=-1)
19
+ else:
20
+ x1, x2 = x[..., ::2], x[..., 1::2]
21
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
22
+
23
+
24
+ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
25
+ """
26
+ x: (batch_size, seqlen, nheads, headdim)
27
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
28
+ """
29
+ ro_dim = cos.shape[-1] * 2
30
+ assert ro_dim <= x.shape[-1]
31
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
32
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
33
+ return torch.cat(
34
+ [
35
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
36
+ x[..., ro_dim:],
37
+ ],
38
+ dim=-1,
39
+ )
40
+
41
+
42
+ class ApplyRotaryEmb(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(
45
+ ctx,
46
+ x,
47
+ cos,
48
+ sin,
49
+ interleaved=False,
50
+ inplace=False,
51
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
52
+ cu_seqlens: Optional[torch.Tensor] = None,
53
+ max_seqlen: Optional[int] = None,
54
+ ):
55
+ out = apply_rotary(
56
+ x,
57
+ cos,
58
+ sin,
59
+ seqlen_offsets=seqlen_offsets,
60
+ cu_seqlens=cu_seqlens,
61
+ max_seqlen=max_seqlen,
62
+ interleaved=interleaved,
63
+ inplace=inplace,
64
+ )
65
+ if isinstance(seqlen_offsets, int):
66
+ ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
67
+ ctx.seqlen_offsets = seqlen_offsets
68
+ else:
69
+ ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
70
+ ctx.seqlen_offsets = None
71
+ ctx.interleaved = interleaved
72
+ ctx.inplace = inplace
73
+ ctx.max_seqlen = max_seqlen
74
+ return out if not inplace else x
75
+
76
+ @staticmethod
77
+ def backward(ctx, do):
78
+ seqlen_offsets = ctx.seqlen_offsets
79
+ if seqlen_offsets is None:
80
+ cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
81
+ else:
82
+ cos, sin, cu_seqlens = ctx.saved_tensors
83
+ # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
84
+ # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
85
+ if not ctx.interleaved and not ctx.inplace:
86
+ do = do.clone()
87
+ dx = apply_rotary(
88
+ do,
89
+ cos,
90
+ sin,
91
+ seqlen_offsets=seqlen_offsets,
92
+ cu_seqlens=cu_seqlens,
93
+ max_seqlen=ctx.max_seqlen,
94
+ interleaved=ctx.interleaved,
95
+ inplace=ctx.inplace,
96
+ conjugate=True,
97
+ )
98
+ return dx, None, None, None, None, None, None, None
99
+
100
+
101
+ def apply_rotary_emb(
102
+ x,
103
+ cos,
104
+ sin,
105
+ interleaved=False,
106
+ inplace=False,
107
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
108
+ cu_seqlens: Optional[torch.Tensor] = None,
109
+ max_seqlen: Optional[int] = None,
110
+ ):
111
+ """
112
+ Arguments:
113
+ x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
114
+ else (total_seqlen, nheads, headdim)
115
+ cos, sin: (seqlen_rotary, rotary_dim / 2)
116
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
117
+ of 1st half and 2nd half (GPT-NeoX style).
118
+ inplace: if True, apply rotary embedding in-place.
119
+ seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
120
+ Most commonly used in inference when we have KV cache.
121
+ cu_seqlens: (batch + 1,) or None
122
+ max_seqlen: int
123
+ Return:
124
+ out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
125
+ else (total_seqlen, nheads, headdim)
126
+ rotary_dim must be <= headdim
127
+ Apply rotary embedding to the first rotary_dim of x.
128
+ """
129
+ return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen)
130
+
131
+
132
+ # For backward compatibility
133
+ apply_rotary_emb_func = apply_rotary_emb
134
+
135
+
136
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
137
+ @staticmethod
138
+ def forward(
139
+ ctx,
140
+ qkv,
141
+ cos,
142
+ sin,
143
+ cos_k=None,
144
+ sin_k=None,
145
+ interleaved=False,
146
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
147
+ num_heads_q: Union[int] = None,
148
+ ):
149
+ if cos_k is None and sin_k is None and qkv.is_contiguous():
150
+ # Call 1 kernel instead of 2 kernels
151
+ # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
152
+ # dimensions, we get the same tensor
153
+ if qkv.dim() == 5:
154
+ batch, seqlen, three, nheads, headdim = qkv.shape
155
+ assert three == 3
156
+ # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d")
157
+ qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim)
158
+ else:
159
+ assert qkv.dim() == 4
160
+ assert num_heads_q is not None
161
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
162
+ assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
163
+ qk = qkv[:, :, : num_heads_q + num_heads_k]
164
+ apply_rotary(
165
+ qk,
166
+ cos,
167
+ sin,
168
+ seqlen_offsets=seqlen_offsets,
169
+ interleaved=interleaved,
170
+ inplace=True,
171
+ )
172
+ else:
173
+ cos_k = cos if cos_k is None else cos_k
174
+ sin_k = sin if sin_k is None else sin_k
175
+ if qkv.dim() == 5:
176
+ q, k = qkv[:, :, 0], qkv[:, :, 1]
177
+ else:
178
+ assert qkv.dim() == 4
179
+ assert num_heads_q is not None
180
+ num_heads_k = (qkv.shape[2] - num_heads_q) // 2
181
+ assert qkv.shape[2] == num_heads_q + 2 * num_heads_k
182
+ q, k = (
183
+ qkv[:, :, :num_heads_q],
184
+ qkv[:, :, num_heads_q : num_heads_q + num_heads_k],
185
+ )
186
+ apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True)
187
+ apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True)
188
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
189
+ if isinstance(seqlen_offsets, int):
190
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
191
+ ctx.seqlen_offsets = seqlen_offsets
192
+ else:
193
+ ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets)
194
+ ctx.seqlen_offsets = None
195
+ ctx.interleaved = interleaved
196
+ ctx.num_heads_q = num_heads_q
197
+ return qkv
198
+
199
+ @staticmethod
200
+ def backward(ctx, dqkv):
201
+ seqlen_offsets = ctx.seqlen_offsets
202
+ if seqlen_offsets is None:
203
+ cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors
204
+ else:
205
+ cos, sin, cos_k, sin_k = ctx.saved_tensors
206
+ if cos_k is None and sin_k is None and dqkv.is_contiguous():
207
+ # Call 1 kernel instead of 2 kernels
208
+ # We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
209
+ # dimensions, we get the same tensor
210
+ if dqkv.dim() == 5:
211
+ dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d")
212
+ else:
213
+ assert dqkv.dim() == 4
214
+ assert ctx.num_heads_q is not None
215
+ num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
216
+ assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
217
+ dqk = dqkv[:, :, : ctx.num_heads_q + num_heads_k]
218
+ apply_rotary(
219
+ dqk,
220
+ cos,
221
+ sin,
222
+ seqlen_offsets=seqlen_offsets,
223
+ interleaved=ctx.interleaved,
224
+ inplace=True,
225
+ conjugate=True,
226
+ )
227
+ else:
228
+ cos_k = cos if cos_k is None else cos_k
229
+ sin_k = sin if sin_k is None else sin_k
230
+ if dqkv.dim() == 5:
231
+ dq, dk = dqkv[:, :, 0], dqkv[:, :, 1]
232
+ else:
233
+ assert dqkv.dim() == 4
234
+ assert ctx.num_heads_q is not None
235
+ num_heads_k = (dqkv.shape[2] - ctx.num_heads_q) // 2
236
+ assert dqkv.shape[2] == ctx.num_heads_q + 2 * num_heads_k
237
+ dq = dqkv[:, :, : ctx.num_heads_q]
238
+ dk = dqkv[:, :, ctx.num_heads_q : ctx.num_heads_q + num_heads_k]
239
+ apply_rotary(
240
+ dq,
241
+ cos,
242
+ sin,
243
+ seqlen_offsets,
244
+ interleaved=ctx.interleaved,
245
+ inplace=True,
246
+ conjugate=True,
247
+ )
248
+ apply_rotary(
249
+ dk,
250
+ cos_k,
251
+ sin_k,
252
+ seqlen_offsets,
253
+ interleaved=ctx.interleaved,
254
+ inplace=True,
255
+ conjugate=True,
256
+ )
257
+ return dqkv, None, None, None, None, None, None, None
258
+
259
+
260
+ def apply_rotary_emb_qkv_(
261
+ qkv,
262
+ cos,
263
+ sin,
264
+ cos_k=None,
265
+ sin_k=None,
266
+ interleaved=False,
267
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
268
+ num_heads_q: Optional[int] = None,
269
+ ):
270
+ """
271
+ Arguments:
272
+ qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim).
273
+ If qkv has shape (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
274
+ then num_heads_q must be provided.
275
+ cos, sin: (seqlen, rotary_dim / 2)
276
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
277
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
278
+ 1st half and 2nd half (GPT-NeoX style).
279
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
280
+ Most commonly used in inference when we have KV cache.
281
+ Return:
282
+ qkv: (batch_size, seqlen, 3, nheads, headdim) or (batch_size, seqlen, num_heads_q + 2 * num_heads_k, headdim)
283
+ rotary_dim must be <= headdim
284
+ Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
285
+ """
286
+ return ApplyRotaryEmbQKV_.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets, num_heads_q)
287
+
288
+
289
+ class ApplyRotaryEmbKV_(torch.autograd.Function):
290
+ @staticmethod
291
+ def forward(
292
+ ctx,
293
+ kv,
294
+ cos,
295
+ sin,
296
+ interleaved=False,
297
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
298
+ ):
299
+ batch, seqlen, two, nheads, headdim = kv.shape
300
+ assert two == 2
301
+ k = kv[:, :, 0]
302
+ apply_rotary(
303
+ k,
304
+ cos,
305
+ sin,
306
+ seqlen_offsets=seqlen_offsets,
307
+ interleaved=interleaved,
308
+ inplace=True,
309
+ )
310
+ if isinstance(seqlen_offsets, int):
311
+ ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward
312
+ ctx.seqlen_offsets = seqlen_offsets
313
+ else:
314
+ ctx.save_for_backward(cos, sin, seqlen_offsets)
315
+ ctx.seqlen_offsets = None
316
+ ctx.interleaved = interleaved
317
+ return kv
318
+
319
+ @staticmethod
320
+ def backward(ctx, dkv):
321
+ seqlen_offsets = ctx.seqlen_offsets
322
+ if seqlen_offsets is None:
323
+ cos, sin, seqlen_offsets = ctx.saved_tensors
324
+ else:
325
+ cos, sin = ctx.saved_tensors
326
+ apply_rotary(
327
+ dkv[:, :, 0],
328
+ cos,
329
+ sin,
330
+ seqlen_offsets=seqlen_offsets,
331
+ interleaved=ctx.interleaved,
332
+ inplace=True,
333
+ conjugate=True,
334
+ )
335
+ return dkv, None, None, None, None
336
+
337
+
338
+ apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply
339
+
340
+
341
+ def apply_rotary_emb_kv_(
342
+ kv,
343
+ cos,
344
+ sin,
345
+ interleaved=False,
346
+ seqlen_offsets: Union[int, torch.Tensor] = 0,
347
+ ):
348
+ """
349
+ Arguments:
350
+ kv: (batch_size, seqlen, 2, nheads, headdim)
351
+ cos, sin: (seqlen, rotary_dim / 2)
352
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
353
+ 1st half and 2nd half (GPT-NeoX style).
354
+ seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
355
+ Most commonly used in inference when we have KV cache.
356
+ Return:
357
+ kv: (batch_size, seqlen, 2, nheads, headdim)
358
+ rotary_dim must be <= headdim
359
+ Apply rotary embedding *inplace* to the first rotary_dim of K.
360
+ """
361
+ return ApplyRotaryEmbKV_.apply(kv, cos, sin, interleaved, seqlen_offsets)
362
+
363
+
364
+ class RotaryEmbedding(torch.nn.Module):
365
+ """
366
+ The rotary position embeddings from RoFormer_ (Su et. al).
367
+ A crucial insight from the method is that the query and keys are
368
+ transformed by rotation matrices which depend on the relative positions.
369
+
370
+ Other implementations are available in the Rotary Transformer repo_ and in
371
+ GPT-NeoX_, GPT-NeoX was an inspiration
372
+
373
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
374
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
375
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
376
+
377
+ If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
378
+ A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
379
+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
380
+ """
381
+
382
+ def __init__(
383
+ self,
384
+ dim: int,
385
+ base=10000.0,
386
+ interleaved=False,
387
+ scale_base=None,
388
+ pos_idx_in_fp32=True,
389
+ device=None,
390
+ ):
391
+ """
392
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
393
+ of 1st half and 2nd half (GPT-NeoX style).
394
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
395
+ otherwise they might be in lower precision.
396
+ This option was added because previously (before 2023-07-02), when we construct
397
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
398
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
399
+ self.inv_freq would be bf16, and the position indices are also in bf16.
400
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
401
+ embeddings for some positions will coincide.
402
+ To maintain compatibility with models previously trained in pure bf16,
403
+ we add this option.
404
+ """
405
+ super().__init__()
406
+ self.dim = dim
407
+ self.base = float(base)
408
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
409
+ # Generate and save the inverse frequency buffer (non trainable)
410
+ inv_freq = self._compute_inv_freq(device)
411
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
412
+ self.interleaved = interleaved
413
+ self.scale_base = scale_base
414
+ scale = (
415
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
416
+ if scale_base is not None
417
+ else None
418
+ )
419
+ self.register_buffer("scale", scale, persistent=False)
420
+
421
+ self._seq_len_cached = 0
422
+ self._cos_cached = None
423
+ self._sin_cached = None
424
+ self._cos_k_cached = None
425
+ self._sin_k_cached = None
426
+
427
+ def _compute_inv_freq(self, device=None):
428
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
429
+
430
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
431
+ # Reset the tables if the sequence length has changed,
432
+ # if we're on a new device (possibly due to tracing for instance),
433
+ # or if we're switching from inference mode to training
434
+ if (
435
+ seqlen > self._seq_len_cached
436
+ or self._cos_cached is None
437
+ or self._cos_cached.device != device
438
+ or self._cos_cached.dtype != dtype
439
+ or (self.training and self._cos_cached.is_inference())
440
+ ):
441
+ self._seq_len_cached = seqlen
442
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
443
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
444
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
445
+ if self.pos_idx_in_fp32:
446
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
447
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
448
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
449
+ # cos & sin output to change significantly.
450
+ # We want to recompute self.inv_freq if it was not loaded in fp32
451
+ if self.inv_freq.dtype != torch.float32:
452
+ inv_freq = self._compute_inv_freq(device=device)
453
+ else:
454
+ inv_freq = self.inv_freq
455
+ else:
456
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
457
+ inv_freq = self.inv_freq
458
+ # Don't do einsum, it converts fp32 to fp16 under AMP
459
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
460
+ freqs = torch.outer(t, inv_freq)
461
+ if self.scale is None:
462
+ self._cos_cached = torch.cos(freqs).to(dtype)
463
+ self._sin_cached = torch.sin(freqs).to(dtype)
464
+ else:
465
+ power = (
466
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
467
+ ) / self.scale_base
468
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
469
+ # We want the multiplication by scale to happen in fp32
470
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
471
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
472
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
473
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
474
+
475
+ def forward(
476
+ self,
477
+ qkv: torch.Tensor,
478
+ kv: Optional[torch.Tensor] = None,
479
+ seqlen_offset: Union[int, torch.Tensor] = 0,
480
+ max_seqlen: Optional[int] = None,
481
+ num_heads_q: Optional[int] = None,
482
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
483
+ """
484
+ qkv: (batch, seqlen, 3, nheads, headdim) or (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim)
485
+ if kv is none, else it's just q of shape (batch, seqlen, nheads, headdim).
486
+ If qkv has shape (batch, seqlen, num_heads_q + 2 * num_heads_k, headdim) (e.g. MQA / GQA),
487
+ then num_heads_q must be provided.
488
+ kv: (batch, seqlen, 2, nheads, headdim)
489
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
490
+ Most commonly used in inference when we have KV cache.
491
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
492
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
493
+ Apply rotary embedding *inplace* to qkv and / or kv.
494
+ """
495
+ seqlen = qkv.shape[1]
496
+ if max_seqlen is not None:
497
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
498
+ elif isinstance(seqlen_offset, int):
499
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
500
+ if kv is None:
501
+ if self.scale is None:
502
+ return apply_rotary_emb_qkv_(
503
+ qkv,
504
+ self._cos_cached,
505
+ self._sin_cached,
506
+ interleaved=self.interleaved,
507
+ seqlen_offsets=seqlen_offset,
508
+ num_heads_q=num_heads_q,
509
+ )
510
+ else:
511
+ return apply_rotary_emb_qkv_(
512
+ qkv,
513
+ self._cos_cached,
514
+ self._sin_cached,
515
+ self._cos_k_cached,
516
+ self._sin_k_cached,
517
+ interleaved=self.interleaved,
518
+ seqlen_offsets=seqlen_offset,
519
+ num_heads_q=num_heads_q,
520
+ )
521
+ else:
522
+ q = qkv
523
+ q = apply_rotary_emb_func(
524
+ q,
525
+ self._cos_cached,
526
+ self._sin_cached,
527
+ interleaved=self.interleaved,
528
+ inplace=True,
529
+ seqlen_offsets=seqlen_offset,
530
+ )
531
+ if self.scale is None:
532
+ kv = apply_rotary_emb_kv_(
533
+ kv,
534
+ self._cos_cached,
535
+ self._sin_cached,
536
+ interleaved=self.interleaved,
537
+ seqlen_offsets=seqlen_offset,
538
+ )
539
+ else:
540
+ kv = apply_rotary_emb_kv_(
541
+ kv,
542
+ self._cos_k_cached,
543
+ self._sin_k_cached,
544
+ interleaved=self.interleaved,
545
+ seqlen_offsets=seqlen_offset,
546
+ )
547
+ return q, kv