# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n import math from typing import Tuple import torch def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int): """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. seq_dim (int): Sequence dimension index. Returns: torch.Tensor: Reshaped frequency tensor. """ ndim = x.ndim assert 0 <= seq_dim < ndim assert freqs_cis.shape == ( x.shape[seq_dim], x.shape[-3], 2, 2, ), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}" shape = [ d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2]) ] + [2, 2] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, seq_dim: int, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2 freqs_cis = reshape_for_broadcast( freqs_cis, xq_, seq_dim ).float() # S D/2 2 2 -> 1 S 1 D/2 2 2 xq_out = (xq_ * freqs_cis).sum(5).flatten(3) xk_out = (xk_ * freqs_cis).sum(5).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class RotaryEmbedding(torch.nn.Module): """ RotaryEmbedding Module """ def __init__( self, theta: float, head_dim: int, max_seqlen: int = 1024, scale_factor: int = 1, low_freq_factor: int = 1, high_freq_factor: int = 32, old_context_len: int = 8192, ): super().__init__() self.theta = theta self.head_dim = head_dim self.max_seqlen = max_seqlen self.scale_factor = scale_factor self.low_freq_factor = low_freq_factor self.high_freq_factor = high_freq_factor self.old_context_len = old_context_len if scale_factor != 1: self.low_freq_wavelen = old_context_len / low_freq_factor self.high_freq_wavelen = old_context_len / high_freq_factor assert self.low_freq_wavelen >= self.high_freq_wavelen def reset_parameters(self): freqs_cis = self.precompute_freqs_cis( dim=self.head_dim, end=self.max_seqlen, theta=self.theta ) S, D, _, _ = freqs_cis.shape # S D 2 2 -> 1 S 1 D 2 2 freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2) self.register_buffer( "freqs_cis", freqs_cis, persistent=False, ) def apply_scaling(self, freqs): if self.scale_factor == 1: return freqs new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq if wavelen < self.high_freq_wavelen: new_freqs.append(freq) elif wavelen > self.low_freq_wavelen: new_freqs.append(freq / self.scale_factor) else: assert self.low_freq_wavelen != self.high_freq_wavelen smooth = (self.old_context_len / wavelen - self.low_freq_factor) / ( self.high_freq_factor - self.low_freq_factor ) new_freqs.append( (1 - smooth) * freq / self.scale_factor + smooth * freq ) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) def precompute_freqs_cis( self, dim: int, end: int, theta: float = 10000.0, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = self.apply_scaling(freqs) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() cos, sin = freqs.cos(), freqs.sin() return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2) def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs): if bhle: x = x.transpose(1, 2) # (B H L E) -> (B L H E) seqlen = x.size(1) x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2 x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3) if bhle: x_out = x_out.transpose(1, 2) return x_out.type_as(x)