dofa-dinov3 / dofa_dinov3.py
earthflow's picture
Upload dofa_dinov3.py with huggingface_hub
d64d4bd verified
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
from functools import partial
import math
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # Add this import for F.pad
from timm.models.vision_transformer import VisionTransformer
from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid_torch
import pdb
import timm
class TransformerWeightGenerator(nn.Module):
def __init__(self, input_dim, output_dim, embed_dim, num_heads=4, num_layers=1):
super(TransformerWeightGenerator, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads, activation='gelu', norm_first=False, batch_first=False, dropout=False)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers,enable_nested_tensor=False)
# Linear layer to map transformer output to desired weight shape
self.fc_weight = nn.Linear(input_dim, output_dim)
self.fc_bias = nn.Linear(input_dim, embed_dim)
self.wt_num = 128
self.weight_tokens = nn.Parameter(torch.empty([self.wt_num,input_dim]))
self.bias_token = nn.Parameter(torch.empty([1,input_dim]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.weight_tokens, std=.02)
torch.nn.init.normal_(self.bias_token, std=.02)
def forward(self, x):
# x should have shape [seq_len, batch, input_dim]
pos_wave = x
x = torch.cat([self.weight_tokens, pos_wave],dim=0)
x = torch.cat([x,self.bias_token], dim=0)
transformer_output = self.transformer_encoder(x)
weights = self.fc_weight(transformer_output[self.wt_num:-1]+pos_wave)
bias = self.fc_bias(transformer_output[-1]) # Using the last output to generate bias
return weights, bias
class Basic1d(nn.Module):
def __init__(self, in_channels, out_channels, bias=True):
super().__init__()
conv = nn.Linear(in_channels, out_channels, bias)
self.conv = nn.Sequential(conv, )
if not bias:
self.conv.add_module('ln', nn.LayerNorm(out_channels))
self.conv.add_module('relu', nn.ReLU(inplace=True))
def forward(self, x):
out = self.conv(x)
return out
class FCResLayer(nn.Module):
def __init__(self, linear_size=128):
super(FCResLayer, self).__init__()
self.l_size = linear_size
self.nonlin1 = nn.ReLU(inplace=True)
self.nonlin2 = nn.ReLU(inplace=True)
#self.dropout1 = nn.Dropout()
self.w1 = nn.Linear(self.l_size, self.l_size)
self.w2 = nn.Linear(self.l_size, self.l_size)
def forward(self, x):
y = self.w1(x)
y = self.nonlin1(y)
#y = self.dropout1(y)
y = self.w2(y)
y = self.nonlin2(y)
out = x + y
return out
class Dynamic_MLP_OFA(nn.Module):
"""
Input: channels of wavelength (normalized): List -> List
kernel size of the depth-wise convolution: kernel_size, default 3x3
wv_planes
inplanes
"""
def __init__(self, wv_planes, inter_dim = 128, kernel_size=3, embed_dim=1024):
super().__init__()
self.kernel_size = kernel_size
self.wv_planes = wv_planes
self.embed_dim = embed_dim
self.kernel_size = kernel_size
self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
self.inter_dim = inter_dim
self.patch_size = (kernel_size, kernel_size)
self.weight_generator = TransformerWeightGenerator(wv_planes, self._num_kernel, embed_dim)
self.scaler = 0.1
self.fclayer = FCResLayer(wv_planes)
self._init_weights()
def _get_weights(self, waves):
dweights = []
dynamic_weights = self.weight_generator(waves)
return dynamic_weights
def weight_init(self, m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def _init_weights(self):
"""
initialize the base weights and dynamic mlp weights
"""
self.weight_generator.apply(self.weight_init)
self.fclayer.apply(self.weight_init)
def forward(self, img_feat, wvs):
inplanes = wvs.size(0)
#wv_feats: 9,128 -> 9, 3x3x3
waves = get_1d_sincos_pos_embed_from_grid_torch(self.wv_planes, wvs*1000)
waves = self.fclayer(waves)
weight,bias = self._get_weights(waves) #3x3x3
#bias = None
dynamic_weight = weight.view(inplanes, self.kernel_size, self.kernel_size, self.embed_dim)
dynamic_weight = dynamic_weight.permute([3,0,1,2])
if bias is not None:
bias = bias.view([self.embed_dim]) * self.scaler
weights = dynamic_weight * self.scaler
#pdb.set_trace()
dynamic_out = F.conv2d(img_feat, weights, bias=bias, stride=self.kernel_size)
x = dynamic_out
#x = x.flatten(2).transpose(1, 2)
return x, waves
class DOFAViT(nn.Module):
"""Masked Autoencoder with VisionTransformer backbone"""
def __init__(
self,
img_size=224,
patch_size=16,
drop_rate=0.0,
out_indices=None,
drop_path_rate=0.0,
embed_dim=1024,
depth=24,
num_heads=16,
wv_planes=128,
num_classes=45,
global_pool=True,
mlp_ratio=4.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.wv_planes = wv_planes
self.out_indices = out_indices
self.global_pool = True
if self.global_pool:
norm_layer = norm_layer
embed_dim = embed_dim
self.fc_norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# MAE encoder specifics
self.img_size = img_size
if isinstance(img_size, tuple):
self.img_size = self.img_size[0]
self.num_patches = (self.img_size // patch_size) ** 2
self.patch_embed = Dynamic_MLP_OFA(wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim)
self.model = timm.create_model('vit_large_patch16_dinov3.lvd1689m', pretrained=False)
self.dynamic_img_size = True
self.waves = None
self.norm = norm_layer(embed_dim)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, wave_list):
with torch.autocast("cuda", enabled=False):
waves = torch.tensor(wave_list, device=x.device).float()
x, _ = self.patch_embed(x, waves)
x = einops.rearrange(x, 'b c h w -> b h w c', h=14, w=14)
x, rot_pos_embed = self.model._pos_embed(x)
x = self.model.norm_pre(x)
for i,blk in enumerate(self.model.blocks[:-1]):
x = blk(x, rope=rot_pos_embed)
if i == len(self.model.blocks)-2:
outx = x
if self.global_pool:
x = self.model.norm(outx)
x = x[:, self.model.num_prefix_tokens:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x)
else:
x = self.model.norm(x)
outcome = x[:, 0]
return outcome
def forward_head(self, x, pre_logits=False):
x = self.model.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x, wave_list):
x = self.forward_features(x, wave_list)
x = self.forward_head(x)
return x
def vit_base_patch16(**kwargs):
model = DOFAViT(
out_indices=[4, 6, 10, 11],
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
def vit_large_patch16(**kwargs):
model = DOFAViT(
out_indices=[5, 11, 17, 23],
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model