File size: 5,747 Bytes
d5cfa8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from typing import Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize(x_tensor, new_shape):
return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0)
def resample(old: torch.Tensor, new_patch_len: int):
assert old.dim() == 2, "the size of input tensor should be (d_model, patch_size)"
if old.size(1) == new_patch_len:
return old
old = old.T
old_shape = old.size(0)
factor = new_patch_len / old_shape
basis_vectors = torch.eye(old_shape, dtype=torch.get_default_dtype(), device=old.device)
resize_mat = resize(basis_vectors, new_patch_len).T
resize_mat_pinv = torch.linalg.pinv(resize_mat.T)
resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)
return resampled_kernels.T
def RoPE(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply Rotary Position Embedding (RoPE) to the query and key tensors.
Args:
query (torch.Tensor): Query tensor with shape (bs, head, max_len, output_dim).
key (torch.Tensor): Key tensor with shape (bs, head, max_len, output_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
"""
# Get the shape information of the input tensors
batch_size, num_heads, max_len, output_dim = query.shape
# Generate sinusoidal position embeddings
pos_emb = sinusoidal_position_embedding(batch_size, num_heads, max_len, output_dim, query.device, factor=1)
# Extract cosine and sine position embeddings
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
# Apply RoPE to the query tensor
query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
query = query * cos_pos + query_rot * sin_pos
# Apply RoPE to the key tensor
key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
key = key * cos_pos + key_rot * sin_pos
return query, key
def RoPE_decoder(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply Rotary Position Embedding (RoPE) to the query and key tensors in the decoder.
Args:
query (torch.Tensor): Query tensor with shape (bs, head, q_max_len, output_dim).
key (torch.Tensor): Key tensor with shape (bs, head, k_max_len, output_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
"""
# Get the shape information of the input tensors
batch_size, num_heads, q_max_len, output_dim = query.shape
_, _, k_max_len, _ = key.shape
# Generate sinusoidal position embeddings
pos_emb = sinusoidal_position_embedding(batch_size, num_heads, k_max_len + q_max_len, output_dim, query.device,
factor=1)
# Extract cosine and sine position embeddings
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
# Apply RoPE to the query tensor
query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
query = query * cos_pos[:, :, -q_max_len:, :] + query_rot * sin_pos[:, :, -q_max_len:, :]
# Apply RoPE to the key tensor
key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
key = key * cos_pos[:, :, :k_max_len, :] + key_rot * sin_pos[:, :, :k_max_len, :]
return query, key
def sinusoidal_position_embedding(
batch_size: int,
num_heads: int,
max_len: int,
output_dim: int,
device: torch.device,
factor: float = 1.0
) -> torch.Tensor:
"""
Generate sinusoidal position embeddings.
Args:
batch_size (int): Batch size.
num_heads (int): Number of attention heads.
max_len (int): Maximum sequence length.
output_dim (int): Output dimension.
device (torch.device): Device type.
factor (float, optional): Scaling factor. Defaults to 1.0.
Returns:
torch.Tensor: Sinusoidal position embedding tensor with shape (bs, head, max_len, output_dim).
"""
# Generate position indices
position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1)
# Generate frequency indices
ids = torch.arange(0, output_dim // 2, dtype=torch.float)
theta = torch.pow(10000, -2 * ids / output_dim)
# Calculate position embeddings
embeddings = position * theta
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
# Expand dimensions to match batch size and number of attention heads
embeddings = embeddings.repeat((batch_size, num_heads, *([1] * len(embeddings.shape))))
embeddings = torch.reshape(embeddings, (batch_size, num_heads, -1, output_dim))
embeddings = embeddings.to(device)
# If the factor is greater than 1, perform interpolation
if factor > 1.0:
interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long()
embeddings = embeddings[:, :, interpolation_indices, :]
return embeddings
def causal_attention_mask(seq_length):
mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1)
return mask.unsqueeze(0).unsqueeze(0)
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous:
return x.transpose(*self.dims).contiguous()
else:
return x.transpose(*self.dims)
|