Jepa / utils.py
Ananthusajeev190's picture
Upload 5 files
046e256 verified
raw
history blame contribute delete
438 Bytes
import torch
from einops import rearrange
def extract_patches(images, patch_size):
B, C, H, W = images.shape
patches = rearrange(images, 'b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_size, pw=patch_size)
return patches
def cosine_loss(pred, target):
pred = torch.nn.functional.normalize(pred, dim=-1)
target = torch.nn.functional.normalize(target, dim=-1)
return 2 - 2 * (pred * target).sum(dim=-1).mean()