Update rope.py
Browse filesfix: apply rotatory embed.
rope.py
CHANGED
|
@@ -16,37 +16,54 @@ def precompute_freqs_cis(
|
|
| 16 |
freqs = torch.exp(1j * freqs)
|
| 17 |
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
| 27 |
"""
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# cos/sin expected as (L, D) or broadcastable to x
|
| 40 |
-
cos_ = cos.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
|
| 41 |
-
sin_ = sin.index_select(0, position_ids.view(-1)).view(*position_ids.shape, -1)
|
| 42 |
-
# reshape to broadcast over heads
|
| 43 |
-
while cos_.dim() < x_q.dim():
|
| 44 |
-
cos_ = cos_.unsqueeze(1)
|
| 45 |
-
sin_ = sin_.unsqueeze(1)
|
| 46 |
else:
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
q = (x_q * cos_) + (_rotate_half(x_q) * sin_)
|
| 50 |
-
k = (x_k * cos_) + (_rotate_half(x_k) * sin_)
|
| 51 |
-
return q, k
|
| 52 |
|
|
|
|
| 16 |
freqs = torch.exp(1j * freqs)
|
| 17 |
return torch.stack([freqs.real, freqs.imag], dim=-1)
|
| 18 |
|
| 19 |
+
# rope.py
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
def apply_rotary_emb(
|
| 23 |
+
x: torch.Tensor,
|
| 24 |
+
freqs_cis: torch.Tensor,
|
| 25 |
+
position_ids: torch.Tensor,
|
| 26 |
+
num_heads: int,
|
| 27 |
+
rot_dim: int = 32,
|
| 28 |
+
interleave: bool = False,
|
| 29 |
+
) -> torch.Tensor:
|
| 30 |
"""
|
| 31 |
+
RoPE as used in the original moondream2 text stack:
|
| 32 |
+
x: (B, H, T, D)
|
| 33 |
+
freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin
|
| 34 |
+
position_ids: (T,) or (B,T)
|
| 35 |
+
returns x with first rot_dim dims rotated.
|
| 36 |
"""
|
| 37 |
+
assert rot_dim == freqs_cis.shape[-2] * 2
|
| 38 |
+
assert num_heads == x.shape[1]
|
| 39 |
+
|
| 40 |
+
B, H, T, D = x.shape
|
| 41 |
+
rd = min(rot_dim, D)
|
| 42 |
+
x_rot, x_pass = x[..., :rd], x[..., rd:]
|
| 43 |
+
|
| 44 |
+
# split real/imag parts depending on layout
|
| 45 |
+
if interleave:
|
| 46 |
+
xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
|
| 47 |
+
xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
else:
|
| 49 |
+
d = x_rot.shape[-1] // 2
|
| 50 |
+
xr, xi = x_rot[..., :d], x_rot[..., d:]
|
| 51 |
+
|
| 52 |
+
# gather cos/sin for these positions
|
| 53 |
+
if position_ids.dim() == 2 and position_ids.size(0) == B:
|
| 54 |
+
freq = freqs_cis[position_ids] # (B, T, rd//2, 2)
|
| 55 |
+
else: # (T,) or scalar
|
| 56 |
+
freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)
|
| 57 |
+
|
| 58 |
+
rot_half = rd // 2
|
| 59 |
+
cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype) # (B,1,T,rot_half)
|
| 60 |
+
sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype)
|
| 61 |
+
|
| 62 |
+
# complex multiply
|
| 63 |
+
yr = xr * cos - xi * sin
|
| 64 |
+
yi = xr * sin + xi * cos
|
| 65 |
+
y = torch.stack((yr, yi), dim=-1).flatten(-2) # (B,H,T,rd)
|
| 66 |
+
|
| 67 |
+
return torch.cat([y, x_pass], dim=-1)
|
| 68 |
|
|
|
|
|
|
|
|
|
|
| 69 |
|