| |
|
|
| import torch |
|
|
|
|
| def precompute_freqs_cis( |
| dim: int, |
| end: int, |
| theta: float = 10000.0, |
| use_scaled: bool = False, |
| dtype: torch.dtype = torch.float32, |
| ) -> torch.Tensor: |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim)) |
| t = torch.arange(end, dtype=dtype).unsqueeze(1) |
| freqs = t * freqs.unsqueeze(0) |
| freqs = torch.exp(1j * freqs) |
| return torch.stack([freqs.real, freqs.imag], dim=-1) |
|
|
| |
| import torch |
|
|
| def apply_rotary_emb( |
| x: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| position_ids: torch.Tensor, |
| num_heads: int, |
| rot_dim: int = 32, |
| interleave: bool = False, |
| ) -> torch.Tensor: |
| """ |
| RoPE as used in the original moondream2 text stack: |
| x: (B, H, T, D) |
| freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin |
| position_ids: (T,) or (B,T) |
| returns x with first rot_dim dims rotated. |
| """ |
| assert rot_dim == freqs_cis.shape[-2] * 2 |
| assert num_heads == x.shape[1] |
|
|
| B, H, T, D = x.shape |
| rd = min(rot_dim, D) |
| x_rot, x_pass = x[..., :rd], x[..., rd:] |
|
|
| |
| if interleave: |
| xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0] |
| xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1] |
| else: |
| d = x_rot.shape[-1] // 2 |
| xr, xi = x_rot[..., :d], x_rot[..., d:] |
|
|
| |
| if position_ids.dim() == 2 and position_ids.size(0) == B: |
| freq = freqs_cis[position_ids] |
| else: |
| freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1) |
|
|
| rot_half = rd // 2 |
| cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype) |
| sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype) |
|
|
| |
| yr = xr * cos - xi * sin |
| yi = xr * sin + xi * cos |
| y = torch.stack((yr, yi), dim=-1).flatten(-2) |
|
|
| return torch.cat([y, x_pass], dim=-1) |
|
|
|
|
|
|