Spaces:
Running
on
Zero
Running
on
Zero
| # MIT License | |
| # | |
| # Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab) | |
| # | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # | |
| # Based on code from https://github.com/isl-org/DPT | |
| """Flexible configuration and feature extraction of timm VisionTransformers.""" | |
| import types | |
| import math | |
| from typing import Callable | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AddReadout(nn.Module): | |
| def __init__(self, start_index: bool = 1): | |
| super(AddReadout, self).__init__() | |
| self.start_index = start_index | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.start_index == 2: | |
| readout = (x[:, 0] + x[:, 1]) / 2 | |
| else: | |
| readout = x[:, 0] | |
| return x[:, self.start_index:] + readout.unsqueeze(1) | |
| class Transpose(nn.Module): | |
| def __init__(self, dim0: int, dim1: int): | |
| super(Transpose, self).__init__() | |
| self.dim0 = dim0 | |
| self.dim1 = dim1 | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.transpose(self.dim0, self.dim1) | |
| return x.contiguous() | |
| def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict: | |
| _, _, H, W = x.size() | |
| _ = pretrained.model.forward_flex(x) | |
| return {k: pretrained.rearrange(v) for k, v in activations.items()} | |
| def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: | |
| posemb_tok, posemb_grid = ( | |
| posemb[:, : self.start_index], | |
| posemb[0, self.start_index :], | |
| ) | |
| gs_old = int(math.sqrt(len(posemb_grid))) | |
| posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) | |
| posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) | |
| posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) | |
| posemb = torch.cat([posemb_tok, posemb_grid], dim=1) | |
| return posemb | |
| def forward_flex(self, x: torch.Tensor) -> torch.Tensor: | |
| # patch proj and dynamically resize | |
| B, C, H, W = x.size() | |
| x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) | |
| pos_embed = self._resize_pos_embed( | |
| self.pos_embed, H // self.patch_size[1], W // self.patch_size[0] | |
| ) | |
| # add cls token | |
| cls_tokens = self.cls_token.expand( | |
| x.size(0), -1, -1 | |
| ) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| # forward pass | |
| x = x + pos_embed | |
| x = self.pos_drop(x) | |
| for blk in self.blocks: | |
| x = blk(x) | |
| x = self.norm(x) | |
| return x | |
| activations = {} | |
| def get_activation(name: str) -> Callable: | |
| def hook(model, input, output): | |
| activations[name] = output | |
| return hook | |
| def make_sd_backbone( | |
| model: nn.Module, | |
| hooks: list[int] = [2, 5, 8, 11], | |
| hook_patch: bool = True, | |
| start_index: list[int] = 1, | |
| ): | |
| assert len(hooks) == 4 | |
| pretrained = nn.Module() | |
| pretrained.model = model | |
| # add hooks | |
| pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) | |
| pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) | |
| pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) | |
| pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) | |
| if hook_patch: | |
| pretrained.model.pos_drop.register_forward_hook(get_activation('4')) | |
| # configure readout | |
| pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) | |
| pretrained.model.start_index = start_index | |
| pretrained.model.patch_size = patch_size | |
| # We inject this function into the VisionTransformer instances so that | |
| # we can use it with interpolated position embeddings without modifying the library source. | |
| pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) | |
| pretrained.model._resize_pos_embed = types.MethodType( | |
| _resize_pos_embed, pretrained.model | |
| ) | |
| return pretrained | |
| def make_vit_backbone( | |
| model: nn.Module, | |
| patch_size: list[int] = [16, 16], | |
| hooks: list[int] = [2, 5, 8, 11], | |
| hook_patch: bool = True, | |
| start_index: list[int] = 1, | |
| ): | |
| assert len(hooks) == 4 | |
| pretrained = nn.Module() | |
| pretrained.model = model | |
| # add hooks | |
| pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) | |
| pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) | |
| pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) | |
| pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) | |
| if hook_patch: | |
| pretrained.model.pos_drop.register_forward_hook(get_activation('4')) | |
| # configure readout | |
| pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) | |
| pretrained.model.start_index = start_index | |
| pretrained.model.patch_size = patch_size | |
| # We inject this function into the VisionTransformer instances so that | |
| # we can use it with interpolated position embeddings without modifying the library source. | |
| pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) | |
| pretrained.model._resize_pos_embed = types.MethodType( | |
| _resize_pos_embed, pretrained.model | |
| ) | |
| return pretrained | |