File size: 5,747 Bytes
d5cfa8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def resize(x_tensor, new_shape):
    return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0)


def resample(old: torch.Tensor, new_patch_len: int):
    assert old.dim() == 2, "the size of input tensor should be (d_model, patch_size)"
    if old.size(1) == new_patch_len:
        return old

    old = old.T
    old_shape = old.size(0)
    factor = new_patch_len / old_shape

    basis_vectors = torch.eye(old_shape, dtype=torch.get_default_dtype(), device=old.device)
    resize_mat = resize(basis_vectors, new_patch_len).T
    resize_mat_pinv = torch.linalg.pinv(resize_mat.T)

    resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)

    return resampled_kernels.T


def RoPE(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply Rotary Position Embedding (RoPE) to the query and key tensors.

    Args:
        query (torch.Tensor): Query tensor with shape (bs, head, max_len, output_dim).
        key (torch.Tensor): Key tensor with shape (bs, head, max_len, output_dim).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
    """
    # Get the shape information of the input tensors
    batch_size, num_heads, max_len, output_dim = query.shape
    # Generate sinusoidal position embeddings
    pos_emb = sinusoidal_position_embedding(batch_size, num_heads, max_len, output_dim, query.device, factor=1)

    # Extract cosine and sine position embeddings
    cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)

    # Apply RoPE to the query tensor
    query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
    query = query * cos_pos + query_rot * sin_pos

    # Apply RoPE to the key tensor
    key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
    key = key * cos_pos + key_rot * sin_pos

    return query, key


def RoPE_decoder(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply Rotary Position Embedding (RoPE) to the query and key tensors in the decoder.

    Args:
        query (torch.Tensor): Query tensor with shape (bs, head, q_max_len, output_dim).
        key (torch.Tensor): Key tensor with shape (bs, head, k_max_len, output_dim).

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE.
    """
    # Get the shape information of the input tensors
    batch_size, num_heads, q_max_len, output_dim = query.shape
    _, _, k_max_len, _ = key.shape
    # Generate sinusoidal position embeddings
    pos_emb = sinusoidal_position_embedding(batch_size, num_heads, k_max_len + q_max_len, output_dim, query.device,
                                            factor=1)

    # Extract cosine and sine position embeddings
    cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1)
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)

    # Apply RoPE to the query tensor
    query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape)
    query = query * cos_pos[:, :, -q_max_len:, :] + query_rot * sin_pos[:, :, -q_max_len:, :]

    # Apply RoPE to the key tensor
    key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape)
    key = key * cos_pos[:, :, :k_max_len, :] + key_rot * sin_pos[:, :, :k_max_len, :]

    return query, key


def sinusoidal_position_embedding(
        batch_size: int,
        num_heads: int,
        max_len: int,
        output_dim: int,
        device: torch.device,
        factor: float = 1.0
) -> torch.Tensor:
    """
    Generate sinusoidal position embeddings.

    Args:
        batch_size (int): Batch size.
        num_heads (int): Number of attention heads.
        max_len (int): Maximum sequence length.
        output_dim (int): Output dimension.
        device (torch.device): Device type.
        factor (float, optional): Scaling factor. Defaults to 1.0.

    Returns:
        torch.Tensor: Sinusoidal position embedding tensor with shape (bs, head, max_len, output_dim).
    """
    # Generate position indices
    position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1)
    # Generate frequency indices
    ids = torch.arange(0, output_dim // 2, dtype=torch.float)
    theta = torch.pow(10000, -2 * ids / output_dim)

    # Calculate position embeddings
    embeddings = position * theta
    embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)

    # Expand dimensions to match batch size and number of attention heads
    embeddings = embeddings.repeat((batch_size, num_heads, *([1] * len(embeddings.shape))))
    embeddings = torch.reshape(embeddings, (batch_size, num_heads, -1, output_dim))
    embeddings = embeddings.to(device)

    # If the factor is greater than 1, perform interpolation
    if factor > 1.0:
        interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long()
        embeddings = embeddings[:, :, interpolation_indices, :]

    return embeddings


def causal_attention_mask(seq_length):
    mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1)
    return mask.unsqueeze(0).unsqueeze(0)


class Transpose(nn.Module):
    def __init__(self, *dims, contiguous=False):
        super().__init__()
        self.dims, self.contiguous = dims, contiguous

    def forward(self, x):
        if self.contiguous:
            return x.transpose(*self.dims).contiguous()
        else:
            return x.transpose(*self.dims)