Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,412 Bytes
fc605f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import math
from typing import Tuple
import torch
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor, seq_dim: int):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
seq_dim (int): Sequence dimension index.
Returns:
torch.Tensor: Reshaped frequency tensor.
"""
ndim = x.ndim
assert 0 <= seq_dim < ndim
assert freqs_cis.shape == (
x.shape[seq_dim],
x.shape[-3],
2,
2,
), f"freqs_cis vs x: {(freqs_cis.shape, x.shape)}"
shape = [
d if i == seq_dim or i == ndim - 3 else 1 for i, d in enumerate(x.shape[:-2])
] + [2, 2]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
seq_dim: int,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2) # B S H D -> B S H D/2 1 2
freqs_cis = reshape_for_broadcast(
freqs_cis, xq_, seq_dim
).float() # S D/2 2 2 -> 1 S 1 D/2 2 2
xq_out = (xq_ * freqs_cis).sum(5).flatten(3)
xk_out = (xk_ * freqs_cis).sum(5).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class RotaryEmbedding(torch.nn.Module):
"""
RotaryEmbedding Module
"""
def __init__(
self,
theta: float,
head_dim: int,
max_seqlen: int = 1024,
scale_factor: int = 1,
low_freq_factor: int = 1,
high_freq_factor: int = 32,
old_context_len: int = 8192,
):
super().__init__()
self.theta = theta
self.head_dim = head_dim
self.max_seqlen = max_seqlen
self.scale_factor = scale_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.old_context_len = old_context_len
if scale_factor != 1:
self.low_freq_wavelen = old_context_len / low_freq_factor
self.high_freq_wavelen = old_context_len / high_freq_factor
assert self.low_freq_wavelen >= self.high_freq_wavelen
def reset_parameters(self):
freqs_cis = self.precompute_freqs_cis(
dim=self.head_dim, end=self.max_seqlen, theta=self.theta
)
S, D, _, _ = freqs_cis.shape
# S D 2 2 -> 1 S 1 D 2 2
freqs_cis = freqs_cis.view(1, S, 1, D, 2, 2)
self.register_buffer(
"freqs_cis",
freqs_cis,
persistent=False,
)
def apply_scaling(self, freqs):
if self.scale_factor == 1:
return freqs
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < self.high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > self.low_freq_wavelen:
new_freqs.append(freq / self.scale_factor)
else:
assert self.low_freq_wavelen != self.high_freq_wavelen
smooth = (self.old_context_len / wavelen - self.low_freq_factor) / (
self.high_freq_factor - self.low_freq_factor
)
new_freqs.append(
(1 - smooth) * freq / self.scale_factor + smooth * freq
)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(
self,
dim: int,
end: int,
theta: float = 10000.0,
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
freqs = self.apply_scaling(freqs)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
cos, sin = freqs.cos(), freqs.sin()
return torch.stack((cos, -sin, sin, cos), dim=-1).view(*freqs.size(), 2, 2)
def forward(self, x: torch.Tensor, bhle: bool = False, **kwargs):
if bhle:
x = x.transpose(1, 2) # (B H L E) -> (B L H E)
seqlen = x.size(1)
x_ = x.reshape(*x.shape[:-1], -1, 1, 2) # B L H E -> B L H E/2 1 2
x_out = (x_ * self.freqs_cis[:, :seqlen]).sum(5).flatten(3)
if bhle:
x_out = x_out.transpose(1, 2)
return x_out.type_as(x)
|