Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm.models.vision_transformer import _load_weights | |
| from timm.models.layers import trunc_normal_ | |
| from typing import List | |
| from src.models.vit.utils import init_weights, resize_pos_embed | |
| from src.models.vit.blocks import Block | |
| from src.models.vit.decoder import DecoderLinear | |
| class PatchEmbedding(nn.Module): | |
| def __init__(self, image_size, patch_size, embed_dim, channels): | |
| super().__init__() | |
| self.image_size = image_size | |
| if image_size[0] % patch_size != 0 or image_size[1] % patch_size != 0: | |
| raise ValueError("image dimensions must be divisible by the patch size") | |
| self.grid_size = image_size[0] // patch_size, image_size[1] // patch_size | |
| self.num_patches = self.grid_size[0] * self.grid_size[1] | |
| self.patch_size = patch_size | |
| self.proj = nn.Conv2d(channels, embed_dim, kernel_size=patch_size, stride=patch_size) | |
| def forward(self, im): | |
| B, C, H, W = im.shape | |
| x = self.proj(im).flatten(2).transpose(1, 2) | |
| return x | |
| class VisionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| image_size, | |
| patch_size, | |
| n_layers, | |
| d_model, | |
| d_ff, | |
| n_heads, | |
| n_cls, | |
| dropout=0.1, | |
| drop_path_rate=0.0, | |
| distilled=False, | |
| channels=3, | |
| ): | |
| super().__init__() | |
| self.patch_embed = PatchEmbedding( | |
| image_size, | |
| patch_size, | |
| d_model, | |
| channels, | |
| ) | |
| self.patch_size = patch_size | |
| self.n_layers = n_layers | |
| self.d_model = d_model | |
| self.d_ff = d_ff | |
| self.n_heads = n_heads | |
| self.dropout = nn.Dropout(dropout) | |
| self.n_cls = n_cls | |
| # cls and pos tokens | |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) | |
| self.distilled = distilled | |
| if self.distilled: | |
| self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) | |
| self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 2, d_model)) | |
| self.head_dist = nn.Linear(d_model, n_cls) | |
| else: | |
| self.pos_embed = nn.Parameter(torch.randn(1, self.patch_embed.num_patches + 1, d_model)) | |
| # transformer blocks | |
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)] | |
| self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff, dropout, dpr[i]) for i in range(n_layers)]) | |
| # output head | |
| self.norm = nn.LayerNorm(d_model) | |
| self.head = nn.Linear(d_model, n_cls) | |
| trunc_normal_(self.pos_embed, std=0.02) | |
| trunc_normal_(self.cls_token, std=0.02) | |
| if self.distilled: | |
| trunc_normal_(self.dist_token, std=0.02) | |
| self.pre_logits = nn.Identity() | |
| self.apply(init_weights) | |
| def no_weight_decay(self): | |
| return {"pos_embed", "cls_token", "dist_token"} | |
| def load_pretrained(self, checkpoint_path, prefix=""): | |
| _load_weights(self, checkpoint_path, prefix) | |
| def forward(self, im, head_out_idx: List[int], n_dim_output=3, return_features=False): | |
| B, _, H, W = im.shape | |
| PS = self.patch_size | |
| assert n_dim_output == 3 or n_dim_output == 4, "n_dim_output must be 3 or 4" | |
| x = self.patch_embed(im) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| if self.distilled: | |
| dist_tokens = self.dist_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, dist_tokens, x), dim=1) | |
| else: | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| pos_embed = self.pos_embed | |
| num_extra_tokens = 1 + self.distilled | |
| if x.shape[1] != pos_embed.shape[1]: | |
| pos_embed = resize_pos_embed( | |
| pos_embed, | |
| self.patch_embed.grid_size, | |
| (H // PS, W // PS), | |
| num_extra_tokens, | |
| ) | |
| x = x + pos_embed | |
| x = self.dropout(x) | |
| device = x.device | |
| if n_dim_output == 3: | |
| heads_out = torch.zeros(size=(len(head_out_idx), B, (H // PS) ** 2 + 1, self.d_model)).to(device) | |
| else: | |
| heads_out = torch.zeros(size=(len(head_out_idx), B, self.d_model, H // PS, H // PS)).to(device) | |
| self.register_buffer("heads_out", heads_out) | |
| head_idx = 0 | |
| for idx_layer, blk in enumerate(self.blocks): | |
| x = blk(x) | |
| if idx_layer in head_out_idx: | |
| if n_dim_output == 3: | |
| heads_out[head_idx] = x | |
| else: | |
| heads_out[head_idx] = x[:, 1:, :].reshape((-1, 24, 24, self.d_model)).permute(0, 3, 1, 2) | |
| head_idx += 1 | |
| x = self.norm(x) | |
| if return_features: | |
| return heads_out | |
| if self.distilled: | |
| x, x_dist = x[:, 0], x[:, 1] | |
| x = self.head(x) | |
| x_dist = self.head_dist(x_dist) | |
| x = (x + x_dist) / 2 | |
| else: | |
| x = x[:, 0] | |
| x = self.head(x) | |
| return x | |
| def get_attention_map(self, im, layer_id): | |
| if layer_id >= self.n_layers or layer_id < 0: | |
| raise ValueError(f"Provided layer_id: {layer_id} is not valid. 0 <= {layer_id} < {self.n_layers}.") | |
| B, _, H, W = im.shape | |
| PS = self.patch_size | |
| x = self.patch_embed(im) | |
| cls_tokens = self.cls_token.expand(B, -1, -1) | |
| if self.distilled: | |
| dist_tokens = self.dist_token.expand(B, -1, -1) | |
| x = torch.cat((cls_tokens, dist_tokens, x), dim=1) | |
| else: | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| pos_embed = self.pos_embed | |
| num_extra_tokens = 1 + self.distilled | |
| if x.shape[1] != pos_embed.shape[1]: | |
| pos_embed = resize_pos_embed( | |
| pos_embed, | |
| self.patch_embed.grid_size, | |
| (H // PS, W // PS), | |
| num_extra_tokens, | |
| ) | |
| x = x + pos_embed | |
| for i, blk in enumerate(self.blocks): | |
| if i < layer_id: | |
| x = blk(x) | |
| else: | |
| return blk(x, return_attention=True) | |
| class FeatureTransform(nn.Module): | |
| def __init__(self, img_size, d_encoder, nls_list=[128, 256, 512, 512], scale_factor_list=[8, 4, 2, 1]): | |
| super(FeatureTransform, self).__init__() | |
| self.img_size = img_size | |
| self.decoder_0 = DecoderLinear(n_cls=nls_list[0], d_encoder=d_encoder, scale_factor=scale_factor_list[0]) | |
| self.decoder_1 = DecoderLinear(n_cls=nls_list[1], d_encoder=d_encoder, scale_factor=scale_factor_list[1]) | |
| self.decoder_2 = DecoderLinear(n_cls=nls_list[2], d_encoder=d_encoder, scale_factor=scale_factor_list[2]) | |
| self.decoder_3 = DecoderLinear(n_cls=nls_list[3], d_encoder=d_encoder, scale_factor=scale_factor_list[3]) | |
| def forward(self, x_list): | |
| feat_3 = self.decoder_3(x_list[3][:, 1:, :], self.img_size) # (2, 512, 24, 24) | |
| feat_2 = self.decoder_2(x_list[2][:, 1:, :], self.img_size) # (2, 512, 48, 48) | |
| feat_1 = self.decoder_1(x_list[1][:, 1:, :], self.img_size) # (2, 256, 96, 96) | |
| feat_0 = self.decoder_0(x_list[0][:, 1:, :], self.img_size) # (2, 128, 192, 192) | |
| return feat_0, feat_1, feat_2, feat_3 |