| import torch | |
| from einops import rearrange | |
| def extract_patches(images, patch_size): | |
| B, C, H, W = images.shape | |
| patches = rearrange(images, 'b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_size, pw=patch_size) | |
| return patches | |
| def cosine_loss(pred, target): | |
| pred = torch.nn.functional.normalize(pred, dim=-1) | |
| target = torch.nn.functional.normalize(target, dim=-1) | |
| return 2 - 2 * (pred * target).sum(dim=-1).mean() | |