|
|
import torch, math |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
def compute_theta(dim: int, base: float = 10000.0, device: torch.device = torch.device('cpu')) -> torch.Tensor: |
|
|
""" |
|
|
计算旋转位置编码中的 Theta 角度值。 |
|
|
|
|
|
参数: |
|
|
- d (int): 嵌入向量的维度(必须为偶数)。 |
|
|
- base (float): 基础频率参数, 默认为10000.0。 |
|
|
- device (torch.device): 计算设备, 默认为CPU。 |
|
|
|
|
|
返回: |
|
|
- torch.Tensor: 包含Theta值的1D张量, 形状为 [d/2]。 |
|
|
""" |
|
|
if dim % 2 != 0: |
|
|
print("嵌入维度 dim 必须为偶数") |
|
|
i = torch.arange(1, (dim//2) + 1, dtype=torch.float32, device=device) |
|
|
theta_i = base ** (-2*(i - 1) / dim) |
|
|
|
|
|
return theta_i |
|
|
|
|
|
def precompute_freqs_cis(dim: int, seq_len: int, base: float = 10000.0, device: torch.device = torch.device('cpu')): |
|
|
theta = compute_theta(dim, base, device) |
|
|
m = torch.arange(seq_len, device=device) |
|
|
m_theta = torch.outer(m, theta) |
|
|
freqs_cis = torch.polar(torch.ones_like(m_theta), m_theta) |
|
|
return freqs_cis |
|
|
|
|
|
def reshape_for_broadcast(freqs_cis, x): |
|
|
ndim = x.ndim |
|
|
assert ndim >= 2 |
|
|
assert freqs_cis.shape == (x.shape[1],x.shape[-1]), "the last two dimension of freqs_cis, x must match" |
|
|
shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)] |
|
|
return freqs_cis.view(*shape) |
|
|
|
|
|
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, device: torch.device = torch.device('cpu')): |
|
|
""" |
|
|
参数: |
|
|
- x_q(torch.Tensor): 实际上是权重 W_q * 词嵌入向量值, 来自上一个线性层的输出, 形状为 [batch_size, seq_len, n_heads, head_dim] |
|
|
- x_k(torch.Tensor): 实际上是权重 W_k * 词嵌入向量值, 来自上一个线性层的输出, 形状为 [batch_size, seq_len, n_heads, head_dim] |
|
|
- freqs_cis (torch.Tensor): 频率复数张量, 形状为 [max_seq_len, head_dim] |
|
|
返回: |
|
|
- Tuple[torch.Tensor, torch.Tensor]: 旋转编码后的查询和键张量 |
|
|
""" |
|
|
|
|
|
xq_reshape = xq.reshape(*xq.shape[:-1], -1, 2) |
|
|
xk_reshape = xk.reshape(*xk.shape[:-1], -1, 2) |
|
|
xq_complex = torch.view_as_complex(xq_reshape) |
|
|
xk_complex = torch.view_as_complex(xk_reshape) |
|
|
|
|
|
|
|
|
|
|
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex) |
|
|
|
|
|
|
|
|
xq_out = torch.view_as_real(xq_complex * freqs_cis).flatten(3) |
|
|
xk_out = torch.view_as_real(xk_complex * freqs_cis).flatten(3) |
|
|
|
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
class Attention(nn.Module): |
|
|
"""compute scale dot product attention |
|
|
|
|
|
Query : given sentence that we focused on (decoder) |
|
|
Key : every sentence to check relationship with Qeury(encoder) |
|
|
Value : every sentence same with Key (encoder) |
|
|
""" |
|
|
|
|
|
def __init__(self, dim, max_seq_len, n_heads): |
|
|
super(Attention, self).__init__() |
|
|
self.wq = nn.Linear(dim, dim) |
|
|
self.wk = nn.Linear(dim, dim) |
|
|
self.wv = nn.Linear(dim, dim) |
|
|
self.out = nn.Linear(dim, dim) |
|
|
|
|
|
self.softmax = nn.Softmax(dim = -1) |
|
|
self.dim = dim |
|
|
self.max_seq_len = max_seq_len |
|
|
self.n_heads = n_heads |
|
|
self.head_dim = dim // n_heads |
|
|
|
|
|
def forward(self, x: torch.Tensor, start_pos=0, inference=True, mask=None): |
|
|
|
|
|
bs, seq_len, dim = x.shape |
|
|
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) |
|
|
xq = xq.view(bs, seq_len, self.n_heads, self.head_dim) |
|
|
xk = xk.view(bs, seq_len, self.n_heads, self.head_dim) |
|
|
xv = xv.view(bs, seq_len, self.n_heads, self.head_dim) |
|
|
|
|
|
if inference: |
|
|
|
|
|
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_seq_len * 2) |
|
|
freqs_cis = freqs_cis[start_pos : start_pos + seq_len] |
|
|
|
|
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cis) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
querys = xq.transpose(1, 2) |
|
|
keys = xk.transpose(1, 2) |
|
|
values = xv.transpose(1, 2) |
|
|
|
|
|
|
|
|
scores = torch.matmul(querys, keys.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
if mask is not None: |
|
|
scores = scores.masked_fill(mask == 0, -1e9) |
|
|
attn_weights = F.softmax(scores.float(), dim = -1) |
|
|
atten_output = torch.matmul(attn_weights, values) |
|
|
|
|
|
|
|
|
output = atten_output.transpose(1,2).contiguous().view(bs, seq_len, -1) |
|
|
|
|
|
return output |
|
|
|
|
|
def test_compute_theta(): |
|
|
dim = 8 |
|
|
theta = compute_theta(dim) |
|
|
assert theta.shape == (dim // 2,), f"Expected shape {(dim // 2,)}, got {theta.shape}" |
|
|
print("test_compute_theta passed.") |
|
|
|
|
|
def test_precompute_freqs_cis(): |
|
|
dim = 8 |
|
|
seq_len = 10 |
|
|
freqs_cis = precompute_freqs_cis(dim, seq_len) |
|
|
freqs_cis_llama = llama.precompute_freqs_cis(dim, seq_len) |
|
|
|
|
|
assert freqs_cis.shape == (seq_len, dim // 2), f"Expected shape {(seq_len, dim // 2)}, got {freqs_cis.shape}" |
|
|
print("test_precompute_freqs_cis passed.") |
|
|
|
|
|
def test_apply_rotary_emb(): |
|
|
batch_size = 2 |
|
|
seq_len = 5 |
|
|
n_heads = 2 |
|
|
head_dim = 4 |
|
|
dim = n_heads * head_dim |
|
|
|
|
|
xq = torch.randn(batch_size, seq_len, n_heads, head_dim) |
|
|
xk = torch.randn(batch_size, seq_len, n_heads, head_dim) |
|
|
freqs_cis = precompute_freqs_cis(head_dim, seq_len) |
|
|
|
|
|
xq_out, xk_out = apply_rotary_emb(xq, xk, freqs_cis) |
|
|
|
|
|
xq, xk = llama.apply_rotary_emb(xq, xk, freqs_cis) |
|
|
|
|
|
assert torch.allclose(xq_out, xq, atol=1e-4) |
|
|
assert xq_out.shape == xq.shape, f"Expected xq_out shape {xq.shape}, got {xq_out.shape}" |
|
|
assert xk_out.shape == xk.shape, f"Expected xk_out shape {xk.shape}, got {xk_out.shape}" |
|
|
|
|
|
print(f"test_apply_rotary_emb passed, xq_out and xq [0][0][0][0]: {xq_out[0,0,0,0]} {xq[0,0,0,0]}.") |
|
|
|
|
|
def test_attention(): |
|
|
batch_size = 2 |
|
|
seq_len = 5 |
|
|
dim = 8 |
|
|
n_heads = 2 |
|
|
x = torch.randn(batch_size, seq_len, dim) |
|
|
attn = Attention(dim=dim, max_seq_len=seq_len, n_heads=n_heads) |
|
|
output = attn(x) |
|
|
assert output.shape == (batch_size, seq_len, dim), f"Expected output shape {(batch_size, seq_len, dim)}, got {output.shape}" |
|
|
print("test_attention passed.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_compute_theta() |
|
|
test_precompute_freqs_cis() |
|
|
test_apply_rotary_emb() |
|
|
test_attention() |