File size: 4,462 Bytes
ce78e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from einops import rearrange
import cv2
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification

RMSNorm = lambda x : x / x.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()

def build_padding_mask(x, L, context, mask_value=float('inf')):
    B, N, D = x.shape
    padding_mask = torch.full((len(x), context), mask_value)
    for i in range(B):
        padding_mask[i, :L[i]] = 0
        padding_mask[i, i:]
    return padding_mask


def trunc_normal_(shape, mean=0, std=1, upper=2, lower=-2, device="cpu"):
    x = torch.randn(shape, device=device)
    x.clamp_(lower, upper)
    x *= std / x.std(unbiased=False)
    return x

def preprocess(img):
    arr = np.array(img)
    arr = cv2.resize(arr, (512, 512))
    return torch.tensor(arr).permute(2, 0, 1).float() / 255.0

class ViT_Sequence(torch.nn.Module):
    def __init__(self, depth, preprocess=False):
        super().__init__()
        self.preprocess = preprocess
        self.depth = min(depth, 12)
        self.processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224", use_fast=True)
        self.model = AutoModelForImageClassification.from_pretrained("WinKawaks/vit-tiny-patch16-224").eval()

    def forward(self, x):
        with torch.no_grad():
            x = self.processor(x, return_tensors="pt") if self.preprocess else x
            hidden = self.model.vit.embeddings(**x) if self.preprocess else self.model.vit.embeddings(x)
            for i in range(self.depth):
                hidden = self.model.vit.encoder.layer[i](hidden)
        return hidden

def generate_angles_2d(H,W,D, device='cpu', freq=None):
    """

    Generates a 3D frequency field for 2D Rotary Positional Embeddings.

    - H: Height of the feature map.

    - W: Width of the feature map.

    - D: Embedding Dimension (must be even).

    - freq: Optional precomputed frequency tensor for the embedding dimension.

    """
    assert D % 2 == 0, "Embedding Dimension must be even!"
    freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))], device=device) if freq is None else freq
    pos = torch.outer(torch.linspace(-1, 1, steps=H, device=device),torch.linspace(-1, 1, steps=W, device=device))
    freq_tensor = torch.einsum("ij,k->ijk", pos, freq) # outer product
    return freq_tensor

def generate_angles_1d(N, D, device='cpu', freq=None):
    """

    1d variation of generate_angles_2d

    """
    assert D % 2 == 0, "Embedding Dimension must be even!"
    freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))], device=device) if freq is None else freq
    pos = torch.linspace(-1, 1, steps=N, device=device)
    freq_tensor = torch.einsum("i,j->ij", pos, freq) # outer product
    return freq_tensor

def apply_angles_2d(x, f):
    """

    Applies the 2D Rotary Positional Embeddings to the input tensor.

    - x: Input tensor of shape (B, H, W, D)

    - f: Frequency tensor of shape (H, W, D/2)

    Rotates each pair of dimensions in the last dimension via orthogonal 2D matrix multiplication.

    """
    x_reshaped = rearrange(x, "B H W (D p) -> B H W D p", p=2)
    real = x_reshaped[..., 0]
    imag = x_reshaped[..., 1]
    cosines, sines = f.cos(), f.sin()
    # r , i -> rcos-isin , rsin icos
    rot_real = real * cosines - imag * sines
    rot_imag = real * sines + imag * cosines
    rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
    return rearrange(rot_full, "B H W D p -> B H W (D p)", p=2)

def apply_angles_1d(x, f):
    """

    1d variation of apply_angles_2d

    """
    x_reshaped = rearrange(x, "... (D p) -> ... D p", p=2)
    real = x_reshaped[..., 0]
    imag = x_reshaped[..., 1]
    cosines, sines = f.cos(), f.sin()
    # r , i -> rcos-isin , rsin icos
    rot_real = real * cosines[:real.shape[-2], :] - imag * sines[:real.shape[-2], :]
    rot_imag = real * sines[:real.shape[-2], :] + imag * cosines[:real.shape[-2], :]
    rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
    return rearrange(rot_full, "... D p -> ... (D p)", p=2)

# Sanity Check :)
if __name__ == "__main__":
    print(ViT_Sequence(12)(torch.randn(1,3,224,224)).shape)
    print(apply_angles_1d(torch.randn(1,4,43,768), generate_angles_1d(64,768)).shape)
    print(apply_angles_2d(torch.randn(1,64,64,768), generate_angles_2d(64,64,768)).shape)