| | import torch |
| | import torch.nn as nn |
| | import os |
| | torch.hub.set_dir('./cache') |
| | os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
| |
|
| | class HybridEmbed(nn.Module): |
| | """ CNN Feature Map Embedding |
| | Extract feature map from CNN, flatten, project to embedding dim. |
| | """ |
| | def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): |
| | super().__init__() |
| | assert isinstance(backbone, nn.Module) |
| | img_size = (img_size, img_size) |
| | patch_size = (patch_size, patch_size) |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.backbone = backbone |
| | if feature_size is None: |
| | with torch.no_grad(): |
| | |
| | training = backbone.training |
| | if training: |
| | backbone.eval() |
| | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) |
| | if isinstance(o, (list, tuple)): |
| | o = o[-1] |
| | feature_size = o.shape[-2:] |
| | feature_dim = o.shape[1] |
| | backbone.train(training) |
| | else: |
| | feature_size = (feature_size, feature_size) |
| | if hasattr(self.backbone, 'feature_info'): |
| | feature_dim = self.backbone.feature_info.channels()[-1] |
| | else: |
| | feature_dim = self.backbone.num_features |
| | assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 |
| | self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1]) |
| | self.num_patches = self.grid_size[0] * self.grid_size[1] |
| | self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) |
| |
|
| | def forward(self, x): |
| | x = self.backbone(x) |
| | if isinstance(x, (list, tuple)): |
| | x = x[-1] |
| | x = self.proj(x).flatten(2).transpose(1, 2) |
| | return x |