AstroM3 / ExampleCode /example1 /model /DesiEncoder.py
lvjiameng's picture
Upload 21 files
d24fe95 verified
import torch
import torch.nn as nn
from functools import partial
from timm.models.vision_transformer import PatchEmbed, Block
import torch.nn.functional as F
import numpy as np
import torch
import numpy as np
import skimage
import numpy as np
import cv2 as cv
def sdss_rgb(imgs, bands, scales=None,
m=0.02):
"""
Transformation from raw image data (nanomaggies) to the rgb values displayed
at the legacy viewer https://www.legacysurvey.org/viewer
Code copied from
https://github.com/legacysurvey/imagine/blob/master/map/views.py
"""
rgbscales = {'u': (2, 1.5), # 1.0,
'g': (2, 2.5),
'r': (1, 1.5),
'i': (0, 1.0),
'z': (0, 0.4), # 0.3
}
if scales is not None:
rgbscales.update(scales)
I = 0
for img, band in zip(imgs, bands):
plane, scale = rgbscales[band]
img = np.maximum(0, img * scale + m)
I = I + img
I /= len(bands)
Q = 20
fI = np.arcsinh(Q * I) / np.sqrt(Q)
I += (I == 0.) * 1e-6
H, W = I.shape
rgb = np.zeros((H, W, 3), np.float32)
for img, band in zip(imgs, bands):
plane, scale = rgbscales[band]
rgb[:, :, plane] = (img * scale + m) * fI / I
rgb = np.clip(rgb, 0, 1)
return rgb
def dr2_rgb(rimgs, bands, **ignored):
return sdss_rgb(rimgs, bands, scales=dict(g=(2, 6.0), r=(1, 3.4), z=(0, 2.2)), m=0.03)
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed']
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
class MaskedAutoEncoderViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
embed_dim=1024, depth=24, num_heads=16,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False,lambda_consistency=1.0):
super().__init__()
# --------------------------------------------------------------------------
# MAE encoder specifics
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
self.decoder_blocks = nn.ModuleList([
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(decoder_depth)
])
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
self.lambda_consistency = lambda_consistency
self.initialize_weights()
def initialize_weights(self):
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
torch.nn.init.normal_(self.cls_token, std=.02)
torch.nn.init.normal_(self.mask_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p = self.patch_embed.patch_size[0]
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(shape=(imgs.shape[0], h*w, p**2*3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p = self.patch_embed.patch_size[0]
h = w = int(x.shape[1]**0.5)
assert h *w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
return imgs
def forward_encoder(self, x, mask_ratio):
x = self.patch_embed(x)
x = x + self.pos_embed[:, 1:, :]
x, mask, ids_restore = self.random_masking(x, mask_ratio)
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore
def forward_decoder(self, x, ids_restore):
x = self.decoder_embed(x)
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
x = x + self.decoder_pos_embed
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
x = self.decoder_pred(x)
x = x[:, 1:, :]
return x
def forward_loss(self, imgs, pred, mask):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is move.
"""
target = self.patchify(imgs)
if self.norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.e-6)**0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def consistency_loss(self, latent1, latent2, use_cls=False):
"""
latent1, latent2: [N, L, D]
use_cls: 是否使用 cls_token,如果 False 就用平均 patch 特征
"""
if use_cls:
z1 = latent1[:, 0] # [N, D] cls token
z2 = latent2[:, 0]
else:
z1 = latent1[:, 1:].mean(dim=1) # [N, D] 平均所有 patch 特征
z2 = latent2[:, 1:].mean(dim=1)
z1 = F.normalize(z1, dim=-1)
z2 = F.normalize(z2, dim=-1)
loss = 2 - 2 * (z1 * z2).sum(dim=-1).mean()
return loss
def forward(self, imgs, mask_ratio=0.75):
# --- 第一次 mask ---
latent1, mask1, ids_restore1 = self.forward_encoder(imgs, mask_ratio)
pred1 = self.forward_decoder(latent1, ids_restore1)
loss_recon1 = self.forward_loss(imgs, pred1, mask1)
# --- 第二次 mask ---
latent2, mask2, ids_restore2 = self.forward_encoder(imgs, mask_ratio)
pred2 = self.forward_decoder(latent2, ids_restore2)
loss_recon2 = self.forward_loss(imgs, pred2, mask2)
# --- 一致性损失 ---
loss_cons = self.consistency_loss(latent1, latent2)
# --- 总 loss ---
loss_total = (loss_recon1 + loss_recon2) / 2 + self.lambda_consistency * loss_cons
return loss_total, pred1, mask1
def forward_encoder_with_given_mask(self, x, given_patch_mask):
"""
Forward encoder using a given patch-level mask.
Args:
x: (N, 3, H, W)
given_patch_mask: (N, L), 1 for masked patches, 0 for kept
Returns:
x: encoded tokens with cls token (N, len_keep + 1, embed_dim)
mask: (N, L), same as input
ids_restore: (N, L), mapping for unshuffling
"""
x = self.patch_embed(x) # (N, L, D)
x = x + self.pos_embed[:, 1:, :] # (N, L, D)
N, L, D = x.shape
noise = torch.rand(N, L, device=x.device)
mask_float = given_patch_mask.float()
ids_shuffle = torch.argsort(mask_float * (noise.max() + 1) + noise, dim=1) # (N, L)
ids_restore = torch.argsort(ids_shuffle, dim=1)
len_keep = L - given_patch_mask.sum(dim=1).max().int().item()
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x_masked), dim=1)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, given_patch_mask, ids_restore
def forward_with_given_mask(self, imgs, given_patch_mask):
latent, mask, ids_restore = self.forward_encoder_with_given_mask(imgs, given_patch_mask)
pred = self.forward_decoder(latent, ids_restore)
loss = self.forward_loss(imgs, pred, mask)
return loss, pred, mask
def mae_vit_base_patch16(**kwargs):
model = MaskedAutoEncoderViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12,
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model