ray-006's picture
Upload 43 files
fc605f9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import math
from typing import Tuple
import torch
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
] + [2, 2]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
seq_dim: int,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
freqs_cis, xq_, seq_dim
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RotaryEmbedding(torch.nn.Module):
"""
RotaryEmbedding Module
"""
def __init__(
self,
theta: float,
head_dim: int,
max_seqlen: int = 1024,
scale_factor: int = 1,
low_freq_factor: int = 1,
high_freq_factor: int = 32,
old_context_len: int = 8192,
):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.scale_factor = scale_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.old_context_len = old_context_len
if scale_factor != 1:
self.low_freq_wavelen = old_context_len / low_freq_factor
self.high_freq_wavelen = old_context_len / high_freq_factor
assert self.low_freq_wavelen >= self.high_freq_wavelen
def reset_parameters(self):
freqs_cis = self.precompute_freqs_cis(
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
)
S, D, _, _ = freqs_cis.shape
# S D 2 2 -> 1 S 1 D 2 2
freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2)
self.register_buffer(
"freqs_cis",
freqs_cis,
persistent=False,
)
def apply_scaling(self, freqs):
if self.scale_factor == 1:
return freqs
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < self.high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > self.low_freq_wavelen:
new_freqs.append(freq / self.scale_factor)
else:
assert self.low_freq_wavelen != self.high_freq_wavelen
smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
new_freqs.append(
(1 - smooth) * freq / self.scale_factor + smooth * freq
)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(
self,
dim: int,
end: int,
theta: float = 10000.0,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = self.apply_scaling(freqs)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs):
if bhle:
x = x.transpose(1, 2) # (B H L E) -> (B L H E)
seqlen = x.size(1)
x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2
x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3)
if bhle:
x_out = x_out.transpose(1, 2)
return x_out.type_as(x)