Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| from einops import rearrange | |
| from src.models.vit.utils import init_weights | |
| class DecoderLinear(nn.Module): | |
| def __init__( | |
| self, | |
| n_cls, | |
| d_encoder, | |
| scale_factor, | |
| dropout_rate=0.3, | |
| ): | |
| super().__init__() | |
| self.scale_factor = scale_factor | |
| self.head = nn.Linear(d_encoder, n_cls) | |
| self.upsampling = nn.Upsample(scale_factor=scale_factor**2, mode="linear") | |
| self.norm = nn.LayerNorm((n_cls, 24 * scale_factor, 24 * scale_factor)) | |
| self.dropout = nn.Dropout(dropout_rate) | |
| self.gelu = nn.GELU() | |
| self.apply(init_weights) | |
| def forward(self, x, img_size): | |
| H, _ = img_size | |
| x = self.head(x) ####### (2, 577, 64) | |
| x = x.transpose(2, 1) ## (2, 64, 576) | |
| x = self.upsampling(x) # (2, 64, 576*scale_factor*scale_factor) | |
| x = x.transpose(2, 1) ## (2, 576*scale_factor*scale_factor, 64) | |
| x = rearrange(x, "b (h w) c -> b c h w", h=H // (16 // self.scale_factor)) # (2, 64, 24*scale_factor, 24*scale_factor) | |
| x = self.norm(x) | |
| x = self.dropout(x) | |
| x = self.gelu(x) | |
| return x # (2, 64, a, a) | |