|
|
import torch
|
|
|
from einops import rearrange
|
|
|
import cv2
|
|
|
import numpy as np
|
|
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
|
|
|
|
|
RMSNorm = lambda x : x / x.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
|
|
|
|
|
|
def build_padding_mask(x, L, context, mask_value=float('inf')):
|
|
|
B, N, D = x.shape
|
|
|
padding_mask = torch.full((len(x), context), mask_value)
|
|
|
for i in range(B):
|
|
|
padding_mask[i, :L[i]] = 0
|
|
|
padding_mask[i, i:]
|
|
|
return padding_mask
|
|
|
|
|
|
|
|
|
def trunc_normal_(shape, mean=0, std=1, upper=2, lower=-2, device="cpu"):
|
|
|
x = torch.randn(shape, device=device)
|
|
|
x.clamp_(lower, upper)
|
|
|
x *= std / x.std(unbiased=False)
|
|
|
return x
|
|
|
|
|
|
def preprocess(img):
|
|
|
arr = np.array(img)
|
|
|
arr = cv2.resize(arr, (512, 512))
|
|
|
return torch.tensor(arr).permute(2, 0, 1).float() / 255.0
|
|
|
|
|
|
class ViT_Sequence(torch.nn.Module):
|
|
|
def __init__(self, depth, preprocess=False):
|
|
|
super().__init__()
|
|
|
self.preprocess = preprocess
|
|
|
self.depth = min(depth, 12)
|
|
|
self.processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224", use_fast=True)
|
|
|
self.model = AutoModelForImageClassification.from_pretrained("WinKawaks/vit-tiny-patch16-224").eval()
|
|
|
|
|
|
def forward(self, x):
|
|
|
with torch.no_grad():
|
|
|
x = self.processor(x, return_tensors="pt") if self.preprocess else x
|
|
|
hidden = self.model.vit.embeddings(**x) if self.preprocess else self.model.vit.embeddings(x)
|
|
|
for i in range(self.depth):
|
|
|
hidden = self.model.vit.encoder.layer[i](hidden)
|
|
|
return hidden
|
|
|
|
|
|
def generate_angles_2d(H,W,D, device='cpu', freq=None):
|
|
|
"""
|
|
|
Generates a 3D frequency field for 2D Rotary Positional Embeddings.
|
|
|
- H: Height of the feature map.
|
|
|
- W: Width of the feature map.
|
|
|
- D: Embedding Dimension (must be even).
|
|
|
- freq: Optional precomputed frequency tensor for the embedding dimension.
|
|
|
"""
|
|
|
assert D % 2 == 0, "Embedding Dimension must be even!"
|
|
|
freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))], device=device) if freq is None else freq
|
|
|
pos = torch.outer(torch.linspace(-1, 1, steps=H, device=device),torch.linspace(-1, 1, steps=W, device=device))
|
|
|
freq_tensor = torch.einsum("ij,k->ijk", pos, freq)
|
|
|
return freq_tensor
|
|
|
|
|
|
def generate_angles_1d(N, D, device='cpu', freq=None):
|
|
|
"""
|
|
|
1d variation of generate_angles_2d
|
|
|
"""
|
|
|
assert D % 2 == 0, "Embedding Dimension must be even!"
|
|
|
freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))], device=device) if freq is None else freq
|
|
|
pos = torch.linspace(-1, 1, steps=N, device=device)
|
|
|
freq_tensor = torch.einsum("i,j->ij", pos, freq)
|
|
|
return freq_tensor
|
|
|
|
|
|
def apply_angles_2d(x, f):
|
|
|
"""
|
|
|
Applies the 2D Rotary Positional Embeddings to the input tensor.
|
|
|
- x: Input tensor of shape (B, H, W, D)
|
|
|
- f: Frequency tensor of shape (H, W, D/2)
|
|
|
Rotates each pair of dimensions in the last dimension via orthogonal 2D matrix multiplication.
|
|
|
"""
|
|
|
x_reshaped = rearrange(x, "B H W (D p) -> B H W D p", p=2)
|
|
|
real = x_reshaped[..., 0]
|
|
|
imag = x_reshaped[..., 1]
|
|
|
cosines, sines = f.cos(), f.sin()
|
|
|
|
|
|
rot_real = real * cosines - imag * sines
|
|
|
rot_imag = real * sines + imag * cosines
|
|
|
rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
|
|
|
return rearrange(rot_full, "B H W D p -> B H W (D p)", p=2)
|
|
|
|
|
|
def apply_angles_1d(x, f):
|
|
|
"""
|
|
|
1d variation of apply_angles_2d
|
|
|
"""
|
|
|
x_reshaped = rearrange(x, "... (D p) -> ... D p", p=2)
|
|
|
real = x_reshaped[..., 0]
|
|
|
imag = x_reshaped[..., 1]
|
|
|
cosines, sines = f.cos(), f.sin()
|
|
|
|
|
|
rot_real = real * cosines[:real.shape[-2], :] - imag * sines[:real.shape[-2], :]
|
|
|
rot_imag = real * sines[:real.shape[-2], :] + imag * cosines[:real.shape[-2], :]
|
|
|
rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
|
|
|
return rearrange(rot_full, "... D p -> ... (D p)", p=2)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print(ViT_Sequence(12)(torch.randn(1,3,224,224)).shape)
|
|
|
print(apply_angles_1d(torch.randn(1,4,43,768), generate_angles_1d(64,768)).shape)
|
|
|
print(apply_angles_2d(torch.randn(1,64,64,768), generate_angles_2d(64,64,768)).shape) |