import torch import torch.nn as nn from functools import partial from timm.models.vision_transformer import PatchEmbed, Block import torch.nn.functional as F import numpy as np import torch import numpy as np import skimage import numpy as np import cv2 as cv def sdss_rgb(imgs, bands, scales=None, m=0.02): """ Transformation from raw image data (nanomaggies) to the rgb values displayed at the legacy viewer https://www.legacysurvey.org/viewer Code copied from https://github.com/legacysurvey/imagine/blob/master/map/views.py """ rgbscales = {'u': (2, 1.5), # 1.0, 'g': (2, 2.5), 'r': (1, 1.5), 'i': (0, 1.0), 'z': (0, 0.4), # 0.3 } if scales is not None: rgbscales.update(scales) I = 0 for img, band in zip(imgs, bands): plane, scale = rgbscales[band] img = np.maximum(0, img * scale + m) I = I + img I /= len(bands) Q = 20 fI = np.arcsinh(Q * I) / np.sqrt(Q) I += (I == 0.) * 1e-6 H, W = I.shape rgb = np.zeros((H, W, 3), np.float32) for img, band in zip(imgs, bands): plane, scale = rgbscales[band] rgb[:, :, plane] = (img * scale + m) * fI / I rgb = np.clip(rgb, 0, 1) return rgb def dr2_rgb(rimgs, bands, **ignored): return sdss_rgb(rimgs, bands, scales=dict(g=(2, 6.0), r=(1, 3.4), z=(0, 2.2)), m=0.03) # -------------------------------------------------------- # 2D sine-cosine position embedding # References: # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py # MoCo v3: https://github.com/facebookresearch/moco-v3 # -------------------------------------------------------- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = model.patch_embed.num_patches num_extra_tokens = model.pos_embed.shape[-2] - num_patches # height (== width) for the checkpoint position embedding orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) # height (== width) for the new position embedding new_size = int(num_patches ** 0.5) # class_token and dist_token are kept unchanged if orig_size != new_size: print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) checkpoint_model['pos_embed'] = new_pos_embed class MaskedAutoEncoderViT(nn.Module): """ Masked Autoencoder with VisionTransformer backbone """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=1024, depth=24, num_heads=16, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False,lambda_consistency=1.0): super().__init__() # -------------------------------------------------------------------------- # MAE encoder specifics self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding self.blocks = nn.ModuleList([ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) # -------------------------------------------------------------------------- # -------------------------------------------------------------------------- # MAE decoder specifics self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding self.decoder_blocks = nn.ModuleList([ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(decoder_depth) ]) self.decoder_norm = norm_layer(decoder_embed_dim) self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch # -------------------------------------------------------------------------- self.norm_pix_loss = norm_pix_loss self.lambda_consistency = lambda_consistency self.initialize_weights() def initialize_weights(self): pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) w = self.patch_embed.proj.weight.data torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) torch.nn.init.normal_(self.cls_token, std=.02) torch.nn.init.normal_(self.mask_token, std=.02) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def random_masking(self, x, mask_ratio): """ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence """ N, L, D = x.shape # batch, length, dim len_keep = int(L * (1 - mask_ratio)) noise = torch.rand(N, L, device=x.device) # noise in [0, 1] ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove ids_restore = torch.argsort(ids_shuffle, dim=1) ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) mask = torch.ones([N, L], device=x.device) mask[:, :len_keep] = 0 mask = torch.gather(mask, dim=1, index=ids_restore) return x_masked, mask, ids_restore def patchify(self, imgs): """ imgs: (N, 3, H, W) x: (N, L, patch_size**2 *3) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(shape=(imgs.shape[0], h*w, p**2*3)) return x def unpatchify(self, x): """ x: (N, L, patch_size**2 *3) imgs: (N, 3, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1]**0.5) assert h *w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) return imgs def forward_encoder(self, x, mask_ratio): x = self.patch_embed(x) x = x + self.pos_embed[:, 1:, :] x, mask, ids_restore = self.random_masking(x, mask_ratio) cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) for blk in self.blocks: x = blk(x) x = self.norm(x) return x, mask, ids_restore def forward_decoder(self, x, ids_restore): x = self.decoder_embed(x) mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token x = x + self.decoder_pos_embed for blk in self.decoder_blocks: x = blk(x) x = self.decoder_norm(x) x = self.decoder_pred(x) x = x[:, 1:, :] return x def forward_loss(self, imgs, pred, mask): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is move. """ target = self.patchify(imgs) if self.norm_pix_loss: mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) target = (target - mean) / (var + 1.e-6)**0.5 loss = (pred - target) ** 2 loss = loss.mean(dim=-1) # [N, L], mean loss per patch loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches return loss def consistency_loss(self, latent1, latent2, use_cls=False): """ latent1, latent2: [N, L, D] use_cls: 是否使用 cls_token,如果 False 就用平均 patch 特征 """ if use_cls: z1 = latent1[:, 0] # [N, D] cls token z2 = latent2[:, 0] else: z1 = latent1[:, 1:].mean(dim=1) # [N, D] 平均所有 patch 特征 z2 = latent2[:, 1:].mean(dim=1) z1 = F.normalize(z1, dim=-1) z2 = F.normalize(z2, dim=-1) loss = 2 - 2 * (z1 * z2).sum(dim=-1).mean() return loss def forward(self, imgs, mask_ratio=0.75): # --- 第一次 mask --- latent1, mask1, ids_restore1 = self.forward_encoder(imgs, mask_ratio) pred1 = self.forward_decoder(latent1, ids_restore1) loss_recon1 = self.forward_loss(imgs, pred1, mask1) # --- 第二次 mask --- latent2, mask2, ids_restore2 = self.forward_encoder(imgs, mask_ratio) pred2 = self.forward_decoder(latent2, ids_restore2) loss_recon2 = self.forward_loss(imgs, pred2, mask2) # --- 一致性损失 --- loss_cons = self.consistency_loss(latent1, latent2) # --- 总 loss --- loss_total = (loss_recon1 + loss_recon2) / 2 + self.lambda_consistency * loss_cons return loss_total, pred1, mask1 def forward_encoder_with_given_mask(self, x, given_patch_mask): """ Forward encoder using a given patch-level mask. Args: x: (N, 3, H, W) given_patch_mask: (N, L), 1 for masked patches, 0 for kept Returns: x: encoded tokens with cls token (N, len_keep + 1, embed_dim) mask: (N, L), same as input ids_restore: (N, L), mapping for unshuffling """ x = self.patch_embed(x) # (N, L, D) x = x + self.pos_embed[:, 1:, :] # (N, L, D) N, L, D = x.shape noise = torch.rand(N, L, device=x.device) mask_float = given_patch_mask.float() ids_shuffle = torch.argsort(mask_float * (noise.max() + 1) + noise, dim=1) # (N, L) ids_restore = torch.argsort(ids_shuffle, dim=1) len_keep = L - given_patch_mask.sum(dim=1).max().int().item() ids_keep = ids_shuffle[:, :len_keep] x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) cls_token = self.cls_token + self.pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x_masked), dim=1) for blk in self.blocks: x = blk(x) x = self.norm(x) return x, given_patch_mask, ids_restore def forward_with_given_mask(self, imgs, given_patch_mask): latent, mask, ids_restore = self.forward_encoder_with_given_mask(imgs, given_patch_mask) pred = self.forward_decoder(latent, ids_restore) loss = self.forward_loss(imgs, pred, mask) return loss, pred, mask def mae_vit_base_patch16(**kwargs): model = MaskedAutoEncoderViT( patch_size=16, embed_dim=768, depth=12, num_heads=12, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model