Spaces:
Runtime error
Runtime error
| from models import register | |
| from einops import rearrange | |
| from mmcv.cnn import build_norm_layer | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint as cp | |
| class SwinTransformerBackbone(nn.Module): | |
| def __init__(self, in_channels=3, embed_dims=256, depth=4, drop_path_rate=0.1): | |
| super().__init__() | |
| self.patch_embed = PatchEmbed( | |
| in_channels=in_channels, | |
| embed_dims=embed_dims, | |
| conv_type='Conv2d', | |
| kernel_size=5, | |
| stride=1, | |
| padding='same', | |
| norm_cfg=dict(type='LN') | |
| ) | |
| dpr = [ | |
| x.item() for x in torch.linspace(0, drop_path_rate, depth) | |
| ] | |
| self.stage = SwinBlockSequence( | |
| embed_dims=embed_dims, | |
| num_heads=8, | |
| feedforward_channels=embed_dims * 2, | |
| depth=depth, | |
| window_size=4, | |
| drop_path_rate=dpr, | |
| downsample=None | |
| ) | |
| self.norm_layer = build_norm_layer(dict(type='LN'), embed_dims)[1] | |
| self.out_dim = embed_dims | |
| def forward(self, x): | |
| x, hw_shape = self.patch_embed(x) | |
| x, hw_shape, out, out_hw_shape = self.stage(x, hw_shape) | |
| out = self.norm_layer(out) | |
| x = out.view(-1, *out_hw_shape, self.out_dim).permute(0, 3, 1, 2).contiguous() | |
| return x | |