trm-vit / utils.py
detectivejoewest's picture
Upload 4 files
ce78e68 verified
import torch
from einops import rearrange
import cv2
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification
RMSNorm = lambda x : x / x.pow(2).mean(dim=-1, keepdim=True).add(1e-6).sqrt()
def build_padding_mask(x, L, context, mask_value=float('inf')):
B, N, D = x.shape
padding_mask = torch.full((len(x), context), mask_value)
for i in range(B):
padding_mask[i, :L[i]] = 0
padding_mask[i, i:]
return padding_mask
def trunc_normal_(shape, mean=0, std=1, upper=2, lower=-2, device="cpu"):
x = torch.randn(shape, device=device)
x.clamp_(lower, upper)
x *= std / x.std(unbiased=False)
return x
def preprocess(img):
arr = np.array(img)
arr = cv2.resize(arr, (512, 512))
return torch.tensor(arr).permute(2, 0, 1).float() / 255.0
class ViT_Sequence(torch.nn.Module):
def __init__(self, depth, preprocess=False):
super().__init__()
self.preprocess = preprocess
self.depth = min(depth, 12)
self.processor = AutoImageProcessor.from_pretrained("WinKawaks/vit-tiny-patch16-224", use_fast=True)
self.model = AutoModelForImageClassification.from_pretrained("WinKawaks/vit-tiny-patch16-224").eval()
def forward(self, x):
with torch.no_grad():
x = self.processor(x, return_tensors="pt") if self.preprocess else x
hidden = self.model.vit.embeddings(**x) if self.preprocess else self.model.vit.embeddings(x)
for i in range(self.depth):
hidden = self.model.vit.encoder.layer[i](hidden)
return hidden
def generate_angles_2d(H,W,D, device='cpu', freq=None):
"""
Generates a 3D frequency field for 2D Rotary Positional Embeddings.
- H: Height of the feature map.
- W: Width of the feature map.
- D: Embedding Dimension (must be even).
- freq: Optional precomputed frequency tensor for the embedding dimension.
"""
assert D % 2 == 0, "Embedding Dimension must be even!"
freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))], device=device) if freq is None else freq
pos = torch.outer(torch.linspace(-1, 1, steps=H, device=device),torch.linspace(-1, 1, steps=W, device=device))
freq_tensor = torch.einsum("ij,k->ijk", pos, freq) # outer product
return freq_tensor
def generate_angles_1d(N, D, device='cpu', freq=None):
"""
1d variation of generate_angles_2d
"""
assert D % 2 == 0, "Embedding Dimension must be even!"
freq = torch.tensor([10000**(-2*i/D) for i in range(int(D/2))], device=device) if freq is None else freq
pos = torch.linspace(-1, 1, steps=N, device=device)
freq_tensor = torch.einsum("i,j->ij", pos, freq) # outer product
return freq_tensor
def apply_angles_2d(x, f):
"""
Applies the 2D Rotary Positional Embeddings to the input tensor.
- x: Input tensor of shape (B, H, W, D)
- f: Frequency tensor of shape (H, W, D/2)
Rotates each pair of dimensions in the last dimension via orthogonal 2D matrix multiplication.
"""
x_reshaped = rearrange(x, "B H W (D p) -> B H W D p", p=2)
real = x_reshaped[..., 0]
imag = x_reshaped[..., 1]
cosines, sines = f.cos(), f.sin()
# r , i -> rcos-isin , rsin icos
rot_real = real * cosines - imag * sines
rot_imag = real * sines + imag * cosines
rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
return rearrange(rot_full, "B H W D p -> B H W (D p)", p=2)
def apply_angles_1d(x, f):
"""
1d variation of apply_angles_2d
"""
x_reshaped = rearrange(x, "... (D p) -> ... D p", p=2)
real = x_reshaped[..., 0]
imag = x_reshaped[..., 1]
cosines, sines = f.cos(), f.sin()
# r , i -> rcos-isin , rsin icos
rot_real = real * cosines[:real.shape[-2], :] - imag * sines[:real.shape[-2], :]
rot_imag = real * sines[:real.shape[-2], :] + imag * cosines[:real.shape[-2], :]
rot_full = torch.concat((rot_real.unsqueeze(-1), rot_imag.unsqueeze(-1)), dim=-1)
return rearrange(rot_full, "... D p -> ... (D p)", p=2)
# Sanity Check :)
if __name__ == "__main__":
print(ViT_Sequence(12)(torch.randn(1,3,224,224)).shape)
print(apply_angles_1d(torch.randn(1,4,43,768), generate_angles_1d(64,768)).shape)
print(apply_angles_2d(torch.randn(1,64,64,768), generate_angles_2d(64,64,768)).shape)