|
|
|
|
|
|
| try:
|
| from timm.layers import resample_abs_pos_embed
|
| except ImportError as err:
|
| print("ImportError: {0}".format(err))
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.checkpoint import checkpoint
|
|
|
|
|
| def make_vit_b16_backbone(
|
| model,
|
| encoder_feature_dims,
|
| encoder_feature_layer_ids,
|
| vit_features,
|
| start_index=1,
|
| use_grad_checkpointing=False,
|
| ) -> nn.Module:
|
| """Make a ViTb16 backbone for the DPT model."""
|
| if use_grad_checkpointing:
|
| model.set_grad_checkpointing()
|
|
|
| vit_model = nn.Module()
|
| vit_model.hooks = encoder_feature_layer_ids
|
| vit_model.model = model
|
| vit_model.features = encoder_feature_dims
|
| vit_model.vit_features = vit_features
|
| vit_model.model.start_index = start_index
|
| vit_model.model.patch_size = vit_model.model.patch_embed.patch_size
|
| vit_model.model.is_vit = True
|
| vit_model.model.forward = vit_model.model.forward_features
|
|
|
| return vit_model
|
|
|
|
|
| def forward_features_eva_fixed(self, x):
|
| """Encode features."""
|
| x = self.patch_embed(x)
|
| x, rot_pos_embed = self._pos_embed(x)
|
| for blk in self.blocks:
|
| if self.grad_checkpointing:
|
| x = checkpoint(blk, x, rot_pos_embed)
|
| else:
|
| x = blk(x, rot_pos_embed)
|
| x = self.norm(x)
|
| return x
|
|
|
|
|
| def resize_vit(model: nn.Module, img_size) -> nn.Module:
|
| """Resample the ViT module to the given size."""
|
| patch_size = model.patch_embed.patch_size
|
| model.patch_embed.img_size = img_size
|
| grid_size = tuple([s // p for s, p in zip(img_size, patch_size)])
|
| model.patch_embed.grid_size = grid_size
|
|
|
| pos_embed = resample_abs_pos_embed(
|
| model.pos_embed,
|
| grid_size,
|
| num_prefix_tokens=(
|
| 0 if getattr(model, "no_embed_class", False) else model.num_prefix_tokens
|
| ),
|
| )
|
| model.pos_embed = torch.nn.Parameter(pos_embed)
|
|
|
| return model
|
|
|
|
|
| def resize_patch_embed(model: nn.Module, new_patch_size=(16, 16)) -> nn.Module:
|
| """Resample the ViT patch size to the given one."""
|
|
|
| if hasattr(model, "patch_embed"):
|
| old_patch_size = model.patch_embed.patch_size
|
|
|
| if (
|
| new_patch_size[0] != old_patch_size[0]
|
| or new_patch_size[1] != old_patch_size[1]
|
| ):
|
| patch_embed_proj = model.patch_embed.proj.weight
|
| patch_embed_proj_bias = model.patch_embed.proj.bias
|
| use_bias = True if patch_embed_proj_bias is not None else False
|
| _, _, h, w = patch_embed_proj.shape
|
|
|
| new_patch_embed_proj = torch.nn.functional.interpolate(
|
| patch_embed_proj,
|
| size=[new_patch_size[0], new_patch_size[1]],
|
| mode="bicubic",
|
| align_corners=False,
|
| )
|
| new_patch_embed_proj = (
|
| new_patch_embed_proj * (h / new_patch_size[0]) * (w / new_patch_size[1])
|
| )
|
|
|
| model.patch_embed.proj = nn.Conv2d(
|
| in_channels=model.patch_embed.proj.in_channels,
|
| out_channels=model.patch_embed.proj.out_channels,
|
| kernel_size=new_patch_size,
|
| stride=new_patch_size,
|
| bias=use_bias,
|
| )
|
|
|
| if use_bias:
|
| model.patch_embed.proj.bias = patch_embed_proj_bias
|
|
|
| model.patch_embed.proj.weight = torch.nn.Parameter(new_patch_embed_proj)
|
|
|
| model.patch_size = new_patch_size
|
| model.patch_embed.patch_size = new_patch_size
|
| model.patch_embed.img_size = (
|
| int(
|
| model.patch_embed.img_size[0]
|
| * new_patch_size[0]
|
| / old_patch_size[0]
|
| ),
|
| int(
|
| model.patch_embed.img_size[1]
|
| * new_patch_size[1]
|
| / old_patch_size[1]
|
| ),
|
| )
|
|
|
| return model
|
|
|