| import torch |
| import torch.nn as nn |
| from functools import partial |
| from timm.models.vision_transformer import PatchEmbed, Block |
| import torch.nn.functional as F |
| import numpy as np |
| import torch |
|
|
|
|
| import numpy as np |
| import skimage |
| import numpy as np |
| import cv2 as cv |
|
|
| def sdss_rgb(imgs, bands, scales=None, |
| m=0.02): |
| """ |
| Transformation from raw image data (nanomaggies) to the rgb values displayed |
| at the legacy viewer https://www.legacysurvey.org/viewer |
| |
| Code copied from |
| https://github.com/legacysurvey/imagine/blob/master/map/views.py |
| """ |
|
|
| rgbscales = {'u': (2, 1.5), |
| 'g': (2, 2.5), |
| 'r': (1, 1.5), |
| 'i': (0, 1.0), |
| 'z': (0, 0.4), |
| } |
|
|
| if scales is not None: |
| rgbscales.update(scales) |
|
|
| I = 0 |
| for img, band in zip(imgs, bands): |
| plane, scale = rgbscales[band] |
| img = np.maximum(0, img * scale + m) |
| I = I + img |
| I /= len(bands) |
|
|
| Q = 20 |
| fI = np.arcsinh(Q * I) / np.sqrt(Q) |
| I += (I == 0.) * 1e-6 |
| H, W = I.shape |
| rgb = np.zeros((H, W, 3), np.float32) |
|
|
| for img, band in zip(imgs, bands): |
| plane, scale = rgbscales[band] |
| rgb[:, :, plane] = (img * scale + m) * fI / I |
|
|
| rgb = np.clip(rgb, 0, 1) |
|
|
| return rgb |
|
|
| def dr2_rgb(rimgs, bands, **ignored): |
| return sdss_rgb(rimgs, bands, scales=dict(g=(2, 6.0), r=(1, 3.4), z=(0, 2.2)), m=0.03) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
| """ |
| grid_size: int of the grid height and width |
| return: |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = np.arange(grid_size, dtype=np.float32) |
| grid_w = np.arange(grid_size, dtype=np.float32) |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
|
|
| grid = grid.reshape([2, 1, grid_size, grid_size]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token: |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| return pos_embed |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| assert embed_dim % 2 == 0 |
|
|
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
| emb = np.concatenate([emb_h, emb_w], axis=1) |
| return emb |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=float) |
| omega /= embed_dim / 2. |
| omega = 1. / 10000**omega |
|
|
| pos = pos.reshape(-1) |
| out = np.einsum('m,d->md', pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
|
|
| |
| |
| |
| |
| |
| def interpolate_pos_embed(model, checkpoint_model): |
| if 'pos_embed' in checkpoint_model: |
| pos_embed_checkpoint = checkpoint_model['pos_embed'] |
| embedding_size = pos_embed_checkpoint.shape[-1] |
| num_patches = model.patch_embed.num_patches |
| num_extra_tokens = model.pos_embed.shape[-2] - num_patches |
| |
| orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| |
| new_size = int(num_patches ** 0.5) |
| |
| if orig_size != new_size: |
| print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) |
| extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| |
| pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
| pos_tokens = torch.nn.functional.interpolate( |
| pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
| pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| checkpoint_model['pos_embed'] = new_pos_embed |
|
|
| class MaskedAutoEncoderViT(nn.Module): |
| """ Masked Autoencoder with VisionTransformer backbone |
| """ |
| def __init__(self, img_size=224, patch_size=16, in_chans=3, |
| embed_dim=1024, depth=24, num_heads=16, |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False,lambda_consistency=1.0): |
| super().__init__() |
| |
| |
| |
| self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
| num_patches = self.patch_embed.num_patches |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) |
|
|
| self.blocks = nn.ModuleList([ |
| Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| for i in range(depth)]) |
| self.norm = norm_layer(embed_dim) |
| |
|
|
| |
| |
| self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
|
|
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
| self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) |
|
|
| self.decoder_blocks = nn.ModuleList([ |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| for i in range(decoder_depth) |
| ]) |
|
|
| self.decoder_norm = norm_layer(decoder_embed_dim) |
| self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) |
| |
|
|
| self.norm_pix_loss = norm_pix_loss |
| self.lambda_consistency = lambda_consistency |
|
|
| self.initialize_weights() |
| |
| def initialize_weights(self): |
|
|
| pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
| decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) |
| self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
| w = self.patch_embed.proj.weight.data |
| torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
| torch.nn.init.normal_(self.cls_token, std=.02) |
| torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
| self.apply(self._init_weights) |
| |
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
|
|
| torch.nn.init.xavier_uniform_(m.weight) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
| |
| def random_masking(self, x, mask_ratio): |
| """ |
| Perform per-sample random masking by per-sample shuffling. |
| Per-sample shuffling is done by argsort random noise. |
| x: [N, L, D], sequence |
| """ |
| N, L, D = x.shape |
| len_keep = int(L * (1 - mask_ratio)) |
|
|
| noise = torch.rand(N, L, device=x.device) |
|
|
| ids_shuffle = torch.argsort(noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| ids_keep = ids_shuffle[:, :len_keep] |
| x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
| mask = torch.ones([N, L], device=x.device) |
| mask[:, :len_keep] = 0 |
|
|
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| return x_masked, mask, ids_restore |
| |
| def patchify(self, imgs): |
| """ |
| imgs: (N, 3, H, W) |
| x: (N, L, patch_size**2 *3) |
| """ |
| p = self.patch_embed.patch_size[0] |
| assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
| h = w = imgs.shape[2] // p |
| x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
| x = torch.einsum('nchpwq->nhwpqc', x) |
| x = x.reshape(shape=(imgs.shape[0], h*w, p**2*3)) |
| return x |
| |
| def unpatchify(self, x): |
| """ |
| x: (N, L, patch_size**2 *3) |
| imgs: (N, 3, H, W) |
| """ |
| p = self.patch_embed.patch_size[0] |
| h = w = int(x.shape[1]**0.5) |
| assert h *w == x.shape[1] |
|
|
| x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
| x = torch.einsum('nhwpqc->nchpwq', x) |
| imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
| return imgs |
| |
| def forward_encoder(self, x, mask_ratio): |
|
|
| x = self.patch_embed(x) |
|
|
| x = x + self.pos_embed[:, 1:, :] |
|
|
| x, mask, ids_restore = self.random_masking(x, mask_ratio) |
|
|
| cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| for blk in self.blocks: |
| x = blk(x) |
| x = self.norm(x) |
|
|
| return x, mask, ids_restore |
| |
| def forward_decoder(self, x, ids_restore): |
| x = self.decoder_embed(x) |
|
|
| mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
| x = x + self.decoder_pos_embed |
|
|
| for blk in self.decoder_blocks: |
| x = blk(x) |
| x = self.decoder_norm(x) |
| x = self.decoder_pred(x) |
|
|
| x = x[:, 1:, :] |
|
|
| return x |
| |
| def forward_loss(self, imgs, pred, mask): |
| """ |
| imgs: [N, 3, H, W] |
| pred: [N, L, p*p*3] |
| mask: [N, L], 0 is keep, 1 is move. |
| """ |
| target = self.patchify(imgs) |
| if self.norm_pix_loss: |
| mean = target.mean(dim=-1, keepdim=True) |
| var = target.var(dim=-1, keepdim=True) |
| target = (target - mean) / (var + 1.e-6)**0.5 |
| |
| loss = (pred - target) ** 2 |
| loss = loss.mean(dim=-1) |
|
|
| loss = (loss * mask).sum() / mask.sum() |
| return loss |
|
|
|
|
| def consistency_loss(self, latent1, latent2, use_cls=False): |
| """ |
| latent1, latent2: [N, L, D] |
| use_cls: 是否使用 cls_token,如果 False 就用平均 patch 特征 |
| """ |
| if use_cls: |
| z1 = latent1[:, 0] |
| z2 = latent2[:, 0] |
| else: |
| z1 = latent1[:, 1:].mean(dim=1) |
| z2 = latent2[:, 1:].mean(dim=1) |
|
|
| z1 = F.normalize(z1, dim=-1) |
| z2 = F.normalize(z2, dim=-1) |
|
|
| loss = 2 - 2 * (z1 * z2).sum(dim=-1).mean() |
| return loss |
|
|
| def forward(self, imgs, mask_ratio=0.75): |
| |
| latent1, mask1, ids_restore1 = self.forward_encoder(imgs, mask_ratio) |
| pred1 = self.forward_decoder(latent1, ids_restore1) |
| loss_recon1 = self.forward_loss(imgs, pred1, mask1) |
|
|
| |
| latent2, mask2, ids_restore2 = self.forward_encoder(imgs, mask_ratio) |
| pred2 = self.forward_decoder(latent2, ids_restore2) |
| loss_recon2 = self.forward_loss(imgs, pred2, mask2) |
|
|
| |
| loss_cons = self.consistency_loss(latent1, latent2) |
|
|
| |
| loss_total = (loss_recon1 + loss_recon2) / 2 + self.lambda_consistency * loss_cons |
|
|
| return loss_total, pred1, mask1 |
| |
|
|
| def forward_encoder_with_given_mask(self, x, given_patch_mask): |
| """ |
| Forward encoder using a given patch-level mask. |
| |
| Args: |
| x: (N, 3, H, W) |
| given_patch_mask: (N, L), 1 for masked patches, 0 for kept |
| |
| Returns: |
| x: encoded tokens with cls token (N, len_keep + 1, embed_dim) |
| mask: (N, L), same as input |
| ids_restore: (N, L), mapping for unshuffling |
| """ |
|
|
| x = self.patch_embed(x) |
| x = x + self.pos_embed[:, 1:, :] |
|
|
| N, L, D = x.shape |
|
|
| noise = torch.rand(N, L, device=x.device) |
|
|
| mask_float = given_patch_mask.float() |
| ids_shuffle = torch.argsort(mask_float * (noise.max() + 1) + noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| len_keep = L - given_patch_mask.sum(dim=1).max().int().item() |
| ids_keep = ids_shuffle[:, :len_keep] |
| x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
| cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| x = torch.cat((cls_tokens, x_masked), dim=1) |
|
|
| for blk in self.blocks: |
| x = blk(x) |
| x = self.norm(x) |
|
|
| return x, given_patch_mask, ids_restore |
| |
| def forward_with_given_mask(self, imgs, given_patch_mask): |
|
|
| latent, mask, ids_restore = self.forward_encoder_with_given_mask(imgs, given_patch_mask) |
| pred = self.forward_decoder(latent, ids_restore) |
| loss = self.forward_loss(imgs, pred, mask) |
| return loss, pred, mask |
| |
| |
|
|
|
|
| def mae_vit_base_patch16(**kwargs): |
| model = MaskedAutoEncoderViT( |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, |
| decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, |
| mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) |
| return model |
|
|
|
|