Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import einsum, nn | |
| from einops import rearrange, repeat | |
| from einops.layers.torch import Rearrange | |
| from efficientnet_pytorch import EfficientNet | |
| CROSS_EFFICIENT_VIT_CONFIG = { | |
| "model": { | |
| "image-size": 224, | |
| "num-classes": 1, | |
| "depth": 4, | |
| "sm-dim": 192, | |
| "sm-patch-size": 7, | |
| "sm-enc-depth": 2, | |
| "sm-enc-dim-head": 64, | |
| "sm-enc-heads": 8, | |
| "sm-enc-mlp-dim": 2048, | |
| "lg-dim": 384, | |
| "lg-patch-size": 56, | |
| "lg-enc-depth": 3, | |
| "lg-enc-dim-head": 64, | |
| "lg-enc-heads": 8, | |
| "lg-enc-mlp-dim": 2048, | |
| "cross-attn-depth": 2, | |
| "cross-attn-dim-head": 64, | |
| "cross-attn-heads": 8, | |
| "lg-channels": 24, | |
| "sm-channels": 1280, | |
| "dropout": 0.15, | |
| "emb-dropout": 0.15, | |
| } | |
| } | |
| def default(value, fallback): | |
| return value if value is not None else fallback | |
| class EfficientNetBackbone(EfficientNet): | |
| def delete_blocks(self, limit): | |
| pass | |
| def extract_features_at_block(self, inputs, selected_block): | |
| x = self._swish(self._bn0(self._conv_stem(inputs))) | |
| for idx, block in enumerate(self._blocks): | |
| drop_connect_rate = self._global_params.drop_connect_rate | |
| if drop_connect_rate: | |
| drop_connect_rate *= float(idx) / len(self._blocks) | |
| x = block(x, drop_connect_rate=drop_connect_rate) | |
| if idx > selected_block: | |
| break | |
| if selected_block >= len(self._blocks): | |
| x = self._swish(self._bn1(self._conv_head(x))) | |
| return x | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x, **kwargs): | |
| return self.fn(self.norm(x), **kwargs) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout=0.0): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(dropout), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class Attention(nn.Module): | |
| def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.heads = heads | |
| self.scale = dim_head**-0.5 | |
| self.attend = nn.Softmax(dim=-1) | |
| self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) | |
| self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) | |
| def forward(self, x, context=None, kv_include_self=False): | |
| _, _, _, heads = *x.shape, self.heads | |
| context = default(context, x) | |
| if kv_include_self: | |
| context = torch.cat((x, context), dim=1) | |
| qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) | |
| q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=heads), qkv) | |
| dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale | |
| attn = self.attend(dots) | |
| out = einsum("b h i j, b h j d -> b h i d", attn, v) | |
| out = rearrange(out, "b h n d -> b n (h d)") | |
| return self.to_out(out) | |
| class Transformer(nn.Module): | |
| def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| self.norm = nn.LayerNorm(dim) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| PreNorm( | |
| dim, | |
| Attention( | |
| dim, heads=heads, dim_head=dim_head, dropout=dropout | |
| ), | |
| ), | |
| PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), | |
| ] | |
| ) | |
| ) | |
| def forward(self, x): | |
| for attn, ff in self.layers: | |
| x = attn(x) + x | |
| x = ff(x) + x | |
| return self.norm(x) | |
| class ProjectInOut(nn.Module): | |
| def __init__(self, dim_in, dim_out, fn): | |
| super().__init__() | |
| self.fn = fn | |
| need_projection = dim_in != dim_out | |
| self.project_in = nn.Linear(dim_in, dim_out) if need_projection else nn.Identity() | |
| self.project_out = nn.Linear(dim_out, dim_in) if need_projection else nn.Identity() | |
| def forward(self, x, *args, **kwargs): | |
| x = self.project_in(x) | |
| x = self.fn(x, *args, **kwargs) | |
| return self.project_out(x) | |
| class CrossTransformer(nn.Module): | |
| def __init__(self, sm_dim, lg_dim, depth, heads, dim_head, dropout): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| ProjectInOut( | |
| sm_dim, | |
| lg_dim, | |
| PreNorm( | |
| lg_dim, | |
| Attention( | |
| lg_dim, | |
| heads=heads, | |
| dim_head=dim_head, | |
| dropout=dropout, | |
| ), | |
| ), | |
| ), | |
| ProjectInOut( | |
| lg_dim, | |
| sm_dim, | |
| PreNorm( | |
| sm_dim, | |
| Attention( | |
| sm_dim, | |
| heads=heads, | |
| dim_head=dim_head, | |
| dropout=dropout, | |
| ), | |
| ), | |
| ), | |
| ] | |
| ) | |
| ) | |
| def forward(self, sm_tokens, lg_tokens): | |
| (sm_cls, sm_patch_tokens), (lg_cls, lg_patch_tokens) = map( | |
| lambda t: (t[:, :1], t[:, 1:]), (sm_tokens, lg_tokens) | |
| ) | |
| for sm_attend_lg, lg_attend_sm in self.layers: | |
| sm_cls = ( | |
| sm_attend_lg(sm_cls, context=lg_patch_tokens, kv_include_self=True) | |
| + sm_cls | |
| ) | |
| lg_cls = ( | |
| lg_attend_sm(lg_cls, context=sm_patch_tokens, kv_include_self=True) | |
| + lg_cls | |
| ) | |
| return torch.cat((sm_cls, sm_patch_tokens), dim=1), torch.cat( | |
| (lg_cls, lg_patch_tokens), dim=1 | |
| ) | |
| class MultiScaleEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| depth, | |
| sm_dim, | |
| lg_dim, | |
| sm_enc_params, | |
| lg_enc_params, | |
| cross_attn_heads, | |
| cross_attn_depth, | |
| cross_attn_dim_head=64, | |
| dropout=0.0 | |
| ): | |
| super().__init__() | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(depth): | |
| self.layers.append( | |
| nn.ModuleList( | |
| [ | |
| Transformer(dim=sm_dim, dropout=dropout, **sm_enc_params), | |
| Transformer(dim=lg_dim, dropout=dropout, **lg_enc_params), | |
| CrossTransformer( | |
| sm_dim=sm_dim, | |
| lg_dim=lg_dim, | |
| depth=cross_attn_depth, | |
| heads=cross_attn_heads, | |
| dim_head=cross_attn_dim_head, | |
| dropout=dropout, | |
| ), | |
| ] | |
| ) | |
| ) | |
| def forward(self, sm_tokens, lg_tokens): | |
| for sm_enc, lg_enc, cross_attend in self.layers: | |
| sm_tokens, lg_tokens = sm_enc(sm_tokens), lg_enc(lg_tokens) | |
| sm_tokens, lg_tokens = cross_attend(sm_tokens, lg_tokens) | |
| return sm_tokens, lg_tokens | |
| class ImageEmbedder(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| image_size, | |
| patch_size, | |
| dropout=0.0, | |
| efficient_block=8, | |
| channels | |
| ): | |
| super().__init__() | |
| self.efficient_net = EfficientNetBackbone.from_name("efficientnet-b0") | |
| self.efficient_net.delete_blocks(efficient_block) | |
| self.efficient_block = efficient_block | |
| num_patches = (image_size // patch_size) ** 2 | |
| patch_dim = channels * patch_size**2 | |
| self.to_patch_embedding = nn.Sequential( | |
| Rearrange( | |
| "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", | |
| p1=patch_size, | |
| p2=patch_size, | |
| ), | |
| nn.Linear(patch_dim, dim), | |
| ) | |
| self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) | |
| self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, img): | |
| x = self.efficient_net.extract_features_at_block(img, self.efficient_block) | |
| x = self.to_patch_embedding(x) | |
| batch_size, num_tokens, _ = x.shape | |
| cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=batch_size) | |
| x = torch.cat((cls_tokens, x), dim=1) | |
| x += self.pos_embedding[:, : num_tokens + 1] | |
| return self.dropout(x) | |
| class CrossEfficientViT(nn.Module): | |
| def __init__(self, *, config=CROSS_EFFICIENT_VIT_CONFIG): | |
| super().__init__() | |
| model_config = config["model"] | |
| image_size = model_config["image-size"] | |
| num_classes = model_config["num-classes"] | |
| sm_dim = model_config["sm-dim"] | |
| sm_channels = model_config["sm-channels"] | |
| lg_dim = model_config["lg-dim"] | |
| lg_channels = model_config["lg-channels"] | |
| sm_patch_size = model_config["sm-patch-size"] | |
| lg_patch_size = model_config["lg-patch-size"] | |
| self.sm_image_embedder = ImageEmbedder( | |
| dim=sm_dim, | |
| image_size=image_size, | |
| patch_size=sm_patch_size, | |
| dropout=model_config["emb-dropout"], | |
| efficient_block=16, | |
| channels=sm_channels, | |
| ) | |
| self.lg_image_embedder = ImageEmbedder( | |
| dim=lg_dim, | |
| image_size=image_size, | |
| patch_size=lg_patch_size, | |
| dropout=model_config["emb-dropout"], | |
| efficient_block=1, | |
| channels=lg_channels, | |
| ) | |
| self.multi_scale_encoder = MultiScaleEncoder( | |
| depth=model_config["depth"], | |
| sm_dim=sm_dim, | |
| lg_dim=lg_dim, | |
| cross_attn_heads=model_config["cross-attn-heads"], | |
| cross_attn_dim_head=model_config["cross-attn-dim-head"], | |
| cross_attn_depth=model_config["cross-attn-depth"], | |
| sm_enc_params={ | |
| "depth": model_config["sm-enc-depth"], | |
| "heads": model_config["sm-enc-heads"], | |
| "mlp_dim": model_config["sm-enc-mlp-dim"], | |
| "dim_head": model_config["sm-enc-dim-head"], | |
| }, | |
| lg_enc_params={ | |
| "depth": model_config["lg-enc-depth"], | |
| "heads": model_config["lg-enc-heads"], | |
| "mlp_dim": model_config["lg-enc-mlp-dim"], | |
| "dim_head": model_config["lg-enc-dim-head"], | |
| }, | |
| dropout=model_config["dropout"], | |
| ) | |
| self.sm_mlp_head = nn.Sequential(nn.LayerNorm(sm_dim), nn.Linear(sm_dim, num_classes)) | |
| self.lg_mlp_head = nn.Sequential(nn.LayerNorm(lg_dim), nn.Linear(lg_dim, num_classes)) | |
| def forward(self, img): | |
| sm_tokens = self.sm_image_embedder(img) | |
| lg_tokens = self.lg_image_embedder(img) | |
| sm_tokens, lg_tokens = self.multi_scale_encoder(sm_tokens, lg_tokens) | |
| sm_cls, lg_cls = map(lambda t: t[:, 0], (sm_tokens, lg_tokens)) | |
| return self.sm_mlp_head(sm_cls) + self.lg_mlp_head(lg_cls) | |