| # %% | |
| import os | |
| import torch | |
| from PIL import Image | |
| from einops import rearrange, repeat | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| # align_weights = torch.load("align_weights.pth") | |
| from torch import nn | |
| from alignedthreeattn_backbone import CLIPAttnNode, DiNOv2AttnNode, MAEAttnNode | |
| class ThreeAttnNodes(nn.Module): | |
| def __init__(self, align_weights): | |
| super().__init__() | |
| self.backbone1 = CLIPAttnNode() | |
| self.backbone2 = DiNOv2AttnNode() | |
| self.backbone3 = MAEAttnNode() | |
| for backbone in [self.backbone1, self.backbone2, self.backbone3]: | |
| backbone.requires_grad_(False) | |
| backbone.eval() | |
| # def resample_position_embeddings(embeddings, h, w): | |
| # cls_embeddings = embeddings[0] | |
| # patch_embeddings = embeddings[1:] # [14*14, 768] | |
| # hw = np.sqrt(patch_embeddings.shape[0]).astype(int) | |
| # patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw) | |
| # patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0) | |
| # patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c") | |
| # embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0) | |
| # return embeddings | |
| # pos_embd = self.backbone1.model.visual.positional_embedding | |
| # pos_embd = resample_position_embeddings(pos_embd, 42, 42) | |
| # self.backbone1.model.visual.positional_embedding = nn.Parameter(pos_embd) | |
| # pos_embed = self.backbone3.model.pos_embed[0] | |
| # pos_embed = resample_position_embeddings(pos_embed, 42, 42) | |
| # self.backbone3.model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0)) | |
| # self.backbone3.model.img_size = (672, 672) | |
| # self.backbone3.model.patch_embed.img_size = (672, 672) | |
| self.align_weights = nn.Parameter(align_weights) | |
| def forward(self, x): | |
| # resize x to 672x672 | |
| # x = F.interpolate(x, size=(672, 672), mode="bilinear") | |
| x = F.interpolate(x, size=(224, 224), mode="bilinear") | |
| feat1 = self.backbone1(x) | |
| feat3 = self.backbone3(x) | |
| # resize x to 588x588 | |
| # x = F.interpolate(x, size=(588, 588), mode="bilinear") | |
| x = F.interpolate(x, size=(196, 196), mode="bilinear") | |
| feat2 = self.backbone2(x) | |
| feats = torch.cat([feat1, feat2, feat3], dim=1) | |
| # out = torch.einsum("b l p i, l o i -> b l p o", feats, self.align_weights) | |
| outs = [] | |
| for i_layer in range(36): | |
| out = torch.einsum("b p i, o i -> b p o", feats[:, i_layer], self.align_weights[i_layer]) | |
| outs.append(out) | |
| out = torch.stack(outs, dim=1) | |
| hw = np.sqrt(out.shape[2]-1).astype(int) | |
| out = rearrange(out[:, :, 1:], "b l (h w) o -> b l h w o", h=hw, w=hw) | |
| return out | |