import torch from einops import rearrange def extract_patches(images, patch_size): B, C, H, W = images.shape patches = rearrange(images, 'b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_size, pw=patch_size) return patches def cosine_loss(pred, target): pred = torch.nn.functional.normalize(pred, dim=-1) target = torch.nn.functional.normalize(target, dim=-1) return 2 - 2 * (pred * target).sum(dim=-1).mean()