| |
| |
| |
| |
| |
| |
|
|
| from functools import partial |
| import math |
| import einops |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.models.vision_transformer import VisionTransformer |
| from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid_torch |
| import pdb |
|
|
| import timm |
|
|
| class TransformerWeightGenerator(nn.Module): |
| def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1): |
| super(TransformerWeightGenerator, self).__init__() |
| encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, activation='gelu', norm_first=False, batch_first=False, dropout=False) |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers,enable_nested_tensor=False) |
|
|
| |
| self.fc_weight = nn.Linear(input_dim, output_dim) |
| self.fc_bias = nn.Linear(input_dim, embed_dim) |
| self.wt_num = 128 |
| self.weight_tokens = nn.Parameter(torch.empty([self.wt_num,input_dim])) |
| self.bias_token = nn.Parameter(torch.empty([1,input_dim])) |
|
|
| |
| torch.nn.init.normal_(self.weight_tokens, std=.02) |
| torch.nn.init.normal_(self.bias_token, std=.02) |
|
|
| def forward(self, x): |
| |
| pos_wave = x |
| x = torch.cat([self.weight_tokens, pos_wave],dim=0) |
| x = torch.cat([x,self.bias_token], dim=0) |
| transformer_output = self.transformer_encoder(x) |
| weights = self.fc_weight(transformer_output[self.wt_num:-1]+pos_wave) |
| bias = self.fc_bias(transformer_output[-1]) |
| return weights, bias |
|
|
| class Basic1d(nn.Module): |
| def __init__(self, in_channels, out_channels, bias=True): |
| super().__init__() |
| conv = nn.Linear(in_channels, out_channels, bias) |
| self.conv = nn.Sequential(conv, ) |
| if not bias: |
| self.conv.add_module('ln', nn.LayerNorm(out_channels)) |
| self.conv.add_module('relu', nn.ReLU(inplace=True)) |
|
|
| def forward(self, x): |
| out = self.conv(x) |
| return out |
|
|
| class FCResLayer(nn.Module): |
| def __init__(self, linear_size=128): |
| super(FCResLayer, self).__init__() |
| self.l_size = linear_size |
| self.nonlin1 = nn.ReLU(inplace=True) |
| self.nonlin2 = nn.ReLU(inplace=True) |
| |
| self.w1 = nn.Linear(self.l_size, self.l_size) |
| self.w2 = nn.Linear(self.l_size, self.l_size) |
|
|
| def forward(self, x): |
| y = self.w1(x) |
| y = self.nonlin1(y) |
| |
| y = self.w2(y) |
| y = self.nonlin2(y) |
| out = x + y |
| return out |
|
|
|
|
| class Dynamic_MLP_OFA(nn.Module): |
| """ |
| Input: channels of wavelength (normalized): List -> List |
| kernel size of the depth-wise convolution: kernel_size, default 3x3 |
| wv_planes |
| inplanes |
| """ |
|
|
| def __init__(self, wv_planes, inter_dim = 128, kernel_size=3, embed_dim=1024): |
| super().__init__() |
| self.kernel_size = kernel_size |
| self.wv_planes = wv_planes |
| self.embed_dim = embed_dim |
| self.kernel_size = kernel_size |
| self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim |
| self.inter_dim = inter_dim |
| self.patch_size = (kernel_size, kernel_size) |
|
|
| self.weight_generator = TransformerWeightGenerator(wv_planes, self._num_kernel, embed_dim) |
| self.scaler = 0.1 |
|
|
| self.fclayer = FCResLayer(wv_planes) |
|
|
| self._init_weights() |
|
|
| def _get_weights(self, waves): |
| dweights = [] |
| dynamic_weights = self.weight_generator(waves) |
|
|
| return dynamic_weights |
|
|
| def weight_init(self, m): |
| if type(m) == nn.Linear: |
| torch.nn.init.xavier_uniform_(m.weight) |
| m.bias.data.fill_(0.01) |
|
|
| def _init_weights(self): |
| """ |
| initialize the base weights and dynamic mlp weights |
| """ |
| self.weight_generator.apply(self.weight_init) |
| self.fclayer.apply(self.weight_init) |
|
|
|
|
| def forward(self, img_feat, wvs): |
| inplanes = wvs.size(0) |
| |
| waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs*1000) |
| waves = self.fclayer(waves) |
| weight,bias = self._get_weights(waves) |
| |
|
|
| dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim) |
| dynamic_weight = dynamic_weight.permute([3,0,1,2]) |
|
|
| if bias is not None: |
| bias = bias.view([self.embed_dim]) * self.scaler |
|
|
| weights = dynamic_weight * self.scaler |
| |
|
|
| dynamic_out = F.conv2d(img_feat, weights, bias=bias, stride=self.kernel_size) |
|
|
| x = dynamic_out |
| |
|
|
| return x, waves |
|
|
| class DOFAViT(nn.Module): |
| """Masked Autoencoder with VisionTransformer backbone""" |
|
|
| def __init__( |
| self, |
| img_size=224, |
| patch_size=16, |
| drop_rate=0.0, |
| out_indices=None, |
| drop_path_rate=0.0, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| wv_planes=128, |
| num_classes=45, |
| global_pool=True, |
| mlp_ratio=4.0, |
| norm_layer=nn.LayerNorm, |
| ): |
| super().__init__() |
|
|
| self.wv_planes = wv_planes |
| self.out_indices = out_indices |
| self.global_pool = True |
| if self.global_pool: |
| norm_layer = norm_layer |
| embed_dim = embed_dim |
| self.fc_norm = norm_layer(embed_dim) |
|
|
| |
| |
| self.img_size = img_size |
| if isinstance(img_size, tuple): |
| self.img_size = self.img_size[0] |
|
|
| self.num_patches = (self.img_size // patch_size) ** 2 |
| self.patch_embed = Dynamic_MLP_OFA(wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim) |
| self.model = timm.create_model('vit_large_patch16_dinov3.lvd1689m', pretrained=False) |
|
|
| self.dynamic_img_size = True |
| self.waves = None |
| self.norm = norm_layer(embed_dim) |
|
|
| self.head_drop = nn.Dropout(drop_rate) |
| self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
| def forward_features(self, x, wave_list): |
| with torch.autocast("cuda", enabled=False): |
| waves = torch.tensor(wave_list, device=x.device).float() |
| x, _ = self.patch_embed(x, waves) |
| x = einops.rearrange(x, 'b c h w -> b h w c', h=14, w=14) |
| x, rot_pos_embed = self.model._pos_embed(x) |
|
|
| x = self.model.norm_pre(x) |
| for i,blk in enumerate(self.model.blocks[:-1]): |
| x = blk(x, rope=rot_pos_embed) |
| if i == len(self.model.blocks)-2: |
| outx = x |
|
|
| if self.global_pool: |
| x = self.model.norm(outx) |
| x = x[:, self.model.num_prefix_tokens:, :].mean(dim=1) |
| outcome = self.fc_norm(x) |
| else: |
| x = self.model.norm(x) |
| outcome = x[:, 0] |
| return outcome |
|
|
| def forward_head(self, x, pre_logits=False): |
| x = self.model.head_drop(x) |
| return x if pre_logits else self.head(x) |
|
|
| def forward(self, x, wave_list): |
| x = self.forward_features(x, wave_list) |
| x = self.forward_head(x) |
| return x |
|
|
|
|
| def vit_base_patch16(**kwargs): |
| model = DOFAViT( |
| out_indices=[4, 6, 10, 11], |
| patch_size=16, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| mlp_ratio=4, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| **kwargs, |
| ) |
| return model |
|
|
|
|
| def vit_large_patch16(**kwargs): |
| model = DOFAViT( |
| out_indices=[5, 11, 17, 23], |
| patch_size=16, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| mlp_ratio=4, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| **kwargs, |
| ) |
| return model |
|
|