Aurora / util_functions.py
ccloud0525
feat: "first commit"
b11fb36
from typing import Tuple
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize(x_tensor, new_shape):
return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0)
def resample(old: torch.Tensor, new_patch_len: int):
assert old.dim() == 2, "the size of input tensor should be (d_model, patch_size)"
if old.size(1) == new_patch_len:
return old
old = old.T
old_shape = old.size(0)
factor = new_patch_len / old_shape
basis_vectors = torch.eye(old_shape, dtype=torch.get_default_dtype(), device=old.device)
resize_mat = resize(basis_vectors, new_patch_len).T
resize_mat_pinv = torch.linalg.pinv(resize_mat.T)
resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)
return resampled_kernels.T
def RoPE(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply Rotary Position Embedding (RoPE) to the query and key tensors.
Args:
query (torch.Tensor): Query tensor with shape (bs, head, max_len, output_dim).
key (torch.Tensor): Key tensor with shape (bs, head, max_len, output_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
"""
# Get the shape information of the input tensors
batch_size, num_heads, max_len, output_dim = query.shape
# Generate sinusoidal position embeddings
pos_emb = sinusoidal_position_embedding(batch_size, num_heads, max_len, output_dim, query.device, factor=1)
# Extract cosine and sine position embeddings
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
# Apply RoPE to the query tensor
query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
query = query * cos_pos + query_rot * sin_pos
# Apply RoPE to the key tensor
key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
key = key * cos_pos + key_rot * sin_pos
return query, key
def RoPE_decoder(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply Rotary Position Embedding (RoPE) to the query and key tensors in the decoder.
Args:
query (torch.Tensor): Query tensor with shape (bs, head, q_max_len, output_dim).
key (torch.Tensor): Key tensor with shape (bs, head, k_max_len, output_dim).
Returns:
Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
"""
# Get the shape information of the input tensors
batch_size, num_heads, q_max_len, output_dim = query.shape
_, _, k_max_len, _ = key.shape
# Generate sinusoidal position embeddings
pos_emb = sinusoidal_position_embedding(batch_size, num_heads, k_max_len + q_max_len, output_dim, query.device,
factor=1)
# Extract cosine and sine position embeddings
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)
# Apply RoPE to the query tensor
query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
query = query * cos_pos[:, :, -q_max_len:, :] + query_rot * sin_pos[:, :, -q_max_len:, :]
# Apply RoPE to the key tensor
key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
key = key * cos_pos[:, :, :k_max_len, :] + key_rot * sin_pos[:, :, :k_max_len, :]
return query, key
def sinusoidal_position_embedding(
batch_size: int,
num_heads: int,
max_len: int,
output_dim: int,
device: torch.device,
factor: float = 1.0
) -> torch.Tensor:
"""
Generate sinusoidal position embeddings.
Args:
batch_size (int): Batch size.
num_heads (int): Number of attention heads.
max_len (int): Maximum sequence length.
output_dim (int): Output dimension.
device (torch.device): Device type.
factor (float, optional): Scaling factor. Defaults to 1.0.
Returns:
torch.Tensor: Sinusoidal position embedding tensor with shape (bs, head, max_len, output_dim).
"""
# Generate position indices
position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1)
# Generate frequency indices
ids = torch.arange(0, output_dim // 2, dtype=torch.float)
theta = torch.pow(10000, -2 * ids / output_dim)
# Calculate position embeddings
embeddings = position * theta
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
# Expand dimensions to match batch size and number of attention heads
embeddings = embeddings.repeat((batch_size, num_heads, *([1] * len(embeddings.shape))))
embeddings = torch.reshape(embeddings, (batch_size, num_heads, -1, output_dim))
embeddings = embeddings.to(device)
# If the factor is greater than 1, perform interpolation
if factor > 1.0:
interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long()
embeddings = embeddings[:, :, interpolation_indices, :]
return embeddings
def causal_attention_mask(seq_length):
mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1)
return mask.unsqueeze(0).unsqueeze(0)
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous:
return x.transpose(*self.dims).contiguous()
else:
return x.transpose(*self.dims)