Spaces:
Running
Running
| """ | |
| This is an implementation of efficientvit, with some modifications (decode head, etc). | |
| Original paper at https://arxiv.org/abs/2205.14756 | |
| Code adapted from timm, https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientvit_mit.py | |
| Original code (that timm adapted from) at https://github.com/mit-han-lab/efficientvit | |
| """ | |
| from typing import Optional, Union, Tuple | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import SemanticSegmenterOutput | |
| from surya.model.detection.config import EfficientViTConfig | |
| from surya.model.detection.processor import SegformerImageProcessor | |
| from surya.settings import settings | |
| def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_MODEL, dtype=settings.MODEL_DTYPE): | |
| config = EfficientViTConfig.from_pretrained(checkpoint) | |
| model = EfficientViTForSemanticSegmentation.from_pretrained(checkpoint, torch_dtype=dtype, config=config, ignore_mismatched_sizes=True) | |
| model = model.to(device) | |
| model = model.eval() | |
| print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}") | |
| return model | |
| def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT): | |
| processor = SegformerImageProcessor.from_pretrained(checkpoint) | |
| return processor | |
| def val2list(x: list or tuple or any, repeat_time=1): | |
| if isinstance(x, (list, tuple)): | |
| return list(x) | |
| return [x for _ in range(repeat_time)] | |
| def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1): | |
| # repeat elements if necessary | |
| x = val2list(x) | |
| if len(x) > 0: | |
| x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] | |
| return tuple(x) | |
| def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]: | |
| if isinstance(kernel_size, tuple): | |
| return tuple([get_same_padding(ks) for ks in kernel_size]) | |
| else: | |
| assert kernel_size % 2 > 0, "kernel size should be odd number" | |
| return kernel_size // 2 | |
| def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1) -> int: | |
| padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 | |
| return padding | |
| class ConvNormAct(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| bias=False, | |
| dropout=0., | |
| norm_layer=nn.BatchNorm2d, | |
| act_layer=nn.ReLU, | |
| ): | |
| super(ConvNormAct, self).__init__() | |
| self.dropout = nn.Dropout(dropout, inplace=False) | |
| padding = get_padding(kernel_size, stride, dilation) | |
| self.conv = nn.Conv2d( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| groups=groups, | |
| bias=bias, | |
| padding=padding, | |
| ) | |
| self.norm = norm_layer(num_features=out_channels) if norm_layer else nn.Identity() | |
| self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity() | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| x = self.act(x) | |
| return x | |
| class DSConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| use_bias=False, | |
| norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), | |
| act_layer=(nn.ReLU6, None), | |
| ): | |
| super(DSConv, self).__init__() | |
| use_bias = val2tuple(use_bias, 2) | |
| norm_layer = val2tuple(norm_layer, 2) | |
| act_layer = val2tuple(act_layer, 2) | |
| self.depth_conv = ConvNormAct( | |
| in_channels, | |
| in_channels, | |
| kernel_size, | |
| stride, | |
| groups=in_channels, | |
| norm_layer=norm_layer[0], | |
| act_layer=act_layer[0], | |
| bias=use_bias[0], | |
| ) | |
| self.point_conv = ConvNormAct( | |
| in_channels, | |
| out_channels, | |
| 1, | |
| norm_layer=norm_layer[1], | |
| act_layer=act_layer[1], | |
| bias=use_bias[1], | |
| ) | |
| def forward(self, x): | |
| x = self.depth_conv(x) | |
| x = self.point_conv(x) | |
| return x | |
| class ConvBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=1, | |
| use_bias=False, | |
| norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), | |
| act_layer=(nn.ReLU6, None), | |
| ): | |
| super(ConvBlock, self).__init__() | |
| use_bias = val2tuple(use_bias, 2) | |
| norm_layer = val2tuple(norm_layer, 2) | |
| act_layer = val2tuple(act_layer, 2) | |
| mid_channels = mid_channels or round(in_channels * expand_ratio) | |
| self.conv1 = ConvNormAct( | |
| in_channels, | |
| mid_channels, | |
| kernel_size, | |
| stride, | |
| norm_layer=norm_layer[0], | |
| act_layer=act_layer[0], | |
| bias=use_bias[0], | |
| ) | |
| self.conv2 = ConvNormAct( | |
| mid_channels, | |
| out_channels, | |
| kernel_size, | |
| 1, | |
| norm_layer=norm_layer[1], | |
| act_layer=act_layer[1], | |
| bias=use_bias[1], | |
| ) | |
| def forward(self, x): | |
| x = self.conv1(x) | |
| x = self.conv2(x) | |
| return x | |
| class MBConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=6, | |
| use_bias=False, | |
| norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d, nn.BatchNorm2d), | |
| act_layer=(nn.ReLU6, nn.ReLU6, None), | |
| ): | |
| super(MBConv, self).__init__() | |
| use_bias = val2tuple(use_bias, 3) | |
| norm_layer = val2tuple(norm_layer, 3) | |
| act_layer = val2tuple(act_layer, 3) | |
| mid_channels = mid_channels or round(in_channels * expand_ratio) | |
| self.inverted_conv = ConvNormAct( | |
| in_channels, | |
| mid_channels, | |
| 1, | |
| stride=1, | |
| norm_layer=norm_layer[0], | |
| act_layer=act_layer[0], | |
| bias=use_bias[0], | |
| ) | |
| self.depth_conv = ConvNormAct( | |
| mid_channels, | |
| mid_channels, | |
| kernel_size, | |
| stride=stride, | |
| groups=mid_channels, | |
| norm_layer=norm_layer[1], | |
| act_layer=act_layer[1], | |
| bias=use_bias[1], | |
| ) | |
| self.point_conv = ConvNormAct( | |
| mid_channels, | |
| out_channels, | |
| 1, | |
| norm_layer=norm_layer[2], | |
| act_layer=act_layer[2], | |
| bias=use_bias[2], | |
| ) | |
| def forward(self, x): | |
| x = self.inverted_conv(x) | |
| x = self.depth_conv(x) | |
| x = self.point_conv(x) | |
| return x | |
| class FusedMBConv(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size=3, | |
| stride=1, | |
| mid_channels=None, | |
| expand_ratio=6, | |
| groups=1, | |
| use_bias=False, | |
| norm_layer=(nn.BatchNorm2d, nn.BatchNorm2d), | |
| act_layer=(nn.ReLU6, None), | |
| ): | |
| super(FusedMBConv, self).__init__() | |
| use_bias = val2tuple(use_bias, 2) | |
| norm_layer = val2tuple(norm_layer, 2) | |
| act_layer = val2tuple(act_layer, 2) | |
| mid_channels = mid_channels or round(in_channels * expand_ratio) | |
| self.spatial_conv = ConvNormAct( | |
| in_channels, | |
| mid_channels, | |
| kernel_size, | |
| stride=stride, | |
| groups=groups, | |
| norm_layer=norm_layer[0], | |
| act_layer=act_layer[0], | |
| bias=use_bias[0], | |
| ) | |
| self.point_conv = ConvNormAct( | |
| mid_channels, | |
| out_channels, | |
| 1, | |
| norm_layer=norm_layer[1], | |
| act_layer=act_layer[1], | |
| bias=use_bias[1], | |
| ) | |
| def forward(self, x): | |
| x = self.spatial_conv(x) | |
| x = self.point_conv(x) | |
| return x | |
| class LiteMLA(nn.Module): | |
| """Lightweight multi-scale linear attention""" | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| heads: int or None = None, | |
| heads_ratio: float = 1.0, | |
| dim=8, | |
| use_bias=False, | |
| norm_layer=(None, nn.BatchNorm2d), | |
| act_layer=(None, None), | |
| kernel_func=nn.ReLU, | |
| scales=(5,), | |
| eps=1e-5, | |
| ): | |
| super(LiteMLA, self).__init__() | |
| self.eps = eps | |
| heads = heads or int(in_channels // dim * heads_ratio) | |
| total_dim = heads * dim | |
| use_bias = val2tuple(use_bias, 2) | |
| norm_layer = val2tuple(norm_layer, 2) | |
| act_layer = val2tuple(act_layer, 2) | |
| self.dim = dim | |
| self.qkv = ConvNormAct( | |
| in_channels, | |
| 3 * total_dim, | |
| 1, | |
| bias=use_bias[0], | |
| norm_layer=norm_layer[0], | |
| act_layer=act_layer[0], | |
| ) | |
| self.aggreg = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d( | |
| 3 * total_dim, | |
| 3 * total_dim, | |
| scale, | |
| padding=get_same_padding(scale), | |
| groups=3 * total_dim, | |
| bias=use_bias[0], | |
| ), | |
| nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), | |
| ) | |
| for scale in scales | |
| ]) | |
| self.kernel_func = kernel_func(inplace=False) | |
| self.proj = ConvNormAct( | |
| total_dim * (1 + len(scales)), | |
| out_channels, | |
| 1, | |
| bias=use_bias[1], | |
| norm_layer=norm_layer[1], | |
| act_layer=act_layer[1], | |
| ) | |
| def _attn(self, q, k, v): | |
| dtype = v.dtype | |
| q, k, v = q.float(), k.float(), v.float() | |
| kv = k.transpose(-1, -2) @ v | |
| out = q @ kv | |
| out = out[..., :-1] / (out[..., -1:] + self.eps) | |
| return out.to(dtype) | |
| def forward(self, x): | |
| # Shape is B, C, H, W | |
| B, _, H, W = x.shape | |
| # generate multi-scale q, k, v | |
| qkv = self.qkv(x) | |
| multi_scale_qkv = [qkv] | |
| for op in self.aggreg: | |
| multi_scale_qkv.append(op(qkv)) | |
| multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1) | |
| multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2) | |
| # Shape for each is B, C, HW, head_dim | |
| q, k, v = multi_scale_qkv.chunk(3, dim=-1) | |
| # lightweight global attention | |
| q = self.kernel_func(q) | |
| k = self.kernel_func(k) | |
| v = F.pad(v, (0, 1), mode="constant", value=1.) | |
| out = self._attn(q, k, v) | |
| # final projection | |
| out = out.transpose(-1, -2).reshape(B, -1, H, W) | |
| out = self.proj(out) | |
| return out | |
| class EfficientVitBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| heads_ratio=1.0, | |
| head_dim=32, | |
| expand_ratio=4, | |
| norm_layer=nn.BatchNorm2d, | |
| act_layer=nn.Hardswish, | |
| ): | |
| super(EfficientVitBlock, self).__init__() | |
| self.context_module = ResidualBlock( | |
| LiteMLA( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| heads_ratio=heads_ratio, | |
| dim=head_dim, | |
| norm_layer=(None, norm_layer), | |
| ), | |
| nn.Identity(), | |
| ) | |
| self.local_module = ResidualBlock( | |
| MBConv( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| expand_ratio=expand_ratio, | |
| use_bias=(True, True, False), | |
| norm_layer=(None, None, norm_layer), | |
| act_layer=(act_layer, act_layer, None), | |
| ), | |
| nn.Identity(), | |
| ) | |
| def forward(self, x): | |
| x = self.context_module(x) | |
| x = self.local_module(x) | |
| return x | |
| class ResidualBlock(nn.Module): | |
| def __init__( | |
| self, | |
| main: Optional[nn.Module], | |
| shortcut: Optional[nn.Module] = None, | |
| pre_norm: Optional[nn.Module] = None, | |
| ): | |
| super(ResidualBlock, self).__init__() | |
| self.pre_norm = pre_norm if pre_norm is not None else nn.Identity() | |
| self.main = main | |
| self.shortcut = shortcut | |
| def forward(self, x): | |
| res = self.main(self.pre_norm(x)) | |
| if self.shortcut is not None: | |
| res = res + self.shortcut(x) | |
| return res | |
| def build_local_block( | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int, | |
| kernel_size: int, | |
| expand_ratio: float, | |
| norm_layer: str, | |
| act_layer: str, | |
| fewer_norm: bool = False, | |
| block_type: str = "default", | |
| ): | |
| assert block_type in ["default", "large", "fused"] | |
| if expand_ratio == 1: | |
| if block_type == "default": | |
| block = DSConv( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| stride=stride, | |
| kernel_size=kernel_size, | |
| use_bias=(True, False) if fewer_norm else False, | |
| norm_layer=(None, norm_layer) if fewer_norm else norm_layer, | |
| act_layer=(act_layer, None), | |
| ) | |
| else: | |
| block = ConvBlock( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| stride=stride, | |
| kernel_size=kernel_size, | |
| use_bias=(True, False) if fewer_norm else False, | |
| norm_layer=(None, norm_layer) if fewer_norm else norm_layer, | |
| act_layer=(act_layer, None), | |
| ) | |
| else: | |
| if block_type == "default": | |
| block = MBConv( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| stride=stride, | |
| kernel_size=kernel_size, | |
| expand_ratio=expand_ratio, | |
| use_bias=(True, True, False) if fewer_norm else False, | |
| norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer, | |
| act_layer=(act_layer, act_layer, None), | |
| ) | |
| else: | |
| block = FusedMBConv( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| stride=stride, | |
| kernel_size=kernel_size, | |
| expand_ratio=expand_ratio, | |
| use_bias=(True, False) if fewer_norm else False, | |
| norm_layer=(None, norm_layer) if fewer_norm else norm_layer, | |
| act_layer=(act_layer, None), | |
| ) | |
| return block | |
| class Stem(nn.Sequential): | |
| def __init__(self, in_chs, out_chs, depth, stride, norm_layer, act_layer, block_type='default'): | |
| super().__init__() | |
| self.stride = stride | |
| self.add_module( | |
| 'in_conv', | |
| ConvNormAct( | |
| in_chs, out_chs, | |
| kernel_size=stride + 1, stride=stride, norm_layer=norm_layer, act_layer=act_layer, | |
| ) | |
| ) | |
| stem_block = 0 | |
| for _ in range(depth): | |
| self.add_module(f'res{stem_block}', ResidualBlock( | |
| build_local_block( | |
| in_channels=out_chs, | |
| out_channels=out_chs, | |
| stride=1, | |
| kernel_size=3, | |
| expand_ratio=1, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| block_type=block_type, | |
| ), | |
| nn.Identity(), | |
| )) | |
| stem_block += 1 | |
| class EfficientVitLargeStage(nn.Module): | |
| def __init__( | |
| self, | |
| in_chs, | |
| out_chs, | |
| depth, | |
| stride, | |
| norm_layer, | |
| act_layer, | |
| head_dim, | |
| vit_stage=False, | |
| fewer_norm=False, | |
| ): | |
| super(EfficientVitLargeStage, self).__init__() | |
| blocks = [ResidualBlock( | |
| build_local_block( | |
| in_channels=in_chs, | |
| out_channels=out_chs, | |
| stride=stride, | |
| kernel_size=stride + 1, | |
| expand_ratio=24 if vit_stage else 16, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| fewer_norm=vit_stage or fewer_norm, | |
| block_type='default' if fewer_norm else 'fused', | |
| ), | |
| None, | |
| )] | |
| in_chs = out_chs | |
| if vit_stage: | |
| # for stage 4 | |
| for _ in range(depth): | |
| blocks.append( | |
| EfficientVitBlock( | |
| in_channels=in_chs, | |
| head_dim=head_dim, | |
| expand_ratio=6, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| ) | |
| ) | |
| else: | |
| # for stage 1, 2, 3 | |
| for i in range(depth): | |
| blocks.append(ResidualBlock( | |
| build_local_block( | |
| in_channels=in_chs, | |
| out_channels=out_chs, | |
| stride=1, | |
| kernel_size=3, | |
| expand_ratio=4, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| fewer_norm=fewer_norm, | |
| block_type='default' if fewer_norm else 'fused', | |
| ), | |
| nn.Identity(), | |
| )) | |
| self.blocks = nn.Sequential(*blocks) | |
| def forward(self, x): | |
| return self.blocks(x) | |
| class EfficientVitLarge(nn.Module): | |
| def __init__( | |
| self, | |
| config: EfficientViTConfig, | |
| norm_layer=nn.BatchNorm2d, | |
| act_layer=nn.Hardswish, | |
| ): | |
| super(EfficientVitLarge, self).__init__() | |
| self.grad_checkpointing = False | |
| self.num_classes = config.num_classes | |
| self.norm_eps = config.layer_norm_eps | |
| norm_layer = partial(norm_layer, eps=self.norm_eps) | |
| # input stem | |
| self.stem = Stem(config.num_channels, config.widths[0], config.depths[0], config.strides[0], norm_layer, act_layer, block_type='large') | |
| stride = config.strides[0] | |
| # stages | |
| self.feature_info = [] | |
| self.stages = nn.Sequential() | |
| in_channels = config.widths[0] | |
| for i, (w, d, s) in enumerate(zip(config.widths[1:], config.depths[1:], config.strides[1:])): | |
| self.stages.append(EfficientVitLargeStage( | |
| in_channels, | |
| w, | |
| depth=d, | |
| stride=s, | |
| norm_layer=norm_layer, | |
| act_layer=act_layer, | |
| head_dim=config.head_dim, | |
| vit_stage=i >= 3, | |
| fewer_norm=i >= 2, | |
| )) | |
| stride *= s | |
| in_channels = w | |
| self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] | |
| self.num_features = in_channels | |
| def set_grad_checkpointing(self, enable=True): | |
| self.grad_checkpointing = enable | |
| def forward(self, x): | |
| x = self.stem(x) | |
| encoder_hidden_states = [] | |
| for i, module in enumerate(self.stages): | |
| x = module(x) | |
| encoder_hidden_states.append(x) | |
| return encoder_hidden_states | |
| class EfficientViTPreTrainedModel(PreTrainedModel): | |
| """ | |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
| models. | |
| """ | |
| config_class = EfficientViTConfig | |
| base_model_prefix = "efficientvit" | |
| main_input_name = "pixel_values" | |
| def _init_weights(self, module): | |
| """Initialize the weights""" | |
| if isinstance(module, (nn.Linear, nn.Conv2d)): | |
| # Slightly different from the TF version which uses truncated_normal for initialization | |
| # cf https://github.com/pytorch/pytorch/pull/5617 | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| elif isinstance(module, nn.LayerNorm): | |
| module.bias.data.zero_() | |
| module.weight.data.fill_(1.0) | |
| class DecodeMLP(nn.Module): | |
| def __init__(self, input_dim, output_dim): | |
| super().__init__() | |
| self.proj = nn.Linear(input_dim, output_dim) | |
| def forward(self, hidden_states: torch.Tensor): | |
| # Input is B, C, H, W | |
| hidden_states = hidden_states.flatten(2).transpose(1, 2) | |
| # Output is B, HW, C | |
| hidden_states = self.proj(hidden_states) | |
| return hidden_states | |
| class DecodeHead(EfficientViTPreTrainedModel): | |
| def __init__(self, config: EfficientViTConfig): | |
| super().__init__(config) | |
| # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size | |
| mlps = [] | |
| for width in config.widths[1:]: | |
| mlp = DecodeMLP(input_dim=width, output_dim=config.decoder_layer_hidden_size) | |
| mlps.append(mlp) | |
| self.linear_c = nn.ModuleList(mlps) | |
| # the following 3 layers implement the ConvModule of the original implementation | |
| self.linear_fuse = nn.Conv2d( | |
| in_channels=config.decoder_layer_hidden_size * config.num_stages, | |
| out_channels=config.decoder_hidden_size, | |
| kernel_size=1, | |
| bias=False, | |
| ) | |
| self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) | |
| self.activation = nn.ReLU() | |
| self.dropout = nn.Dropout(config.classifier_dropout_prob) | |
| self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) | |
| self.config = config | |
| def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: | |
| batch_size = encoder_hidden_states[-1].shape[0] | |
| all_hidden_states = () | |
| for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): | |
| height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] | |
| encoder_hidden_state = mlp(encoder_hidden_state) # Output is B, HW, C | |
| # Permute to B, C, HW | |
| encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) | |
| encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) | |
| # upsample | |
| encoder_hidden_state = nn.functional.interpolate( | |
| encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False | |
| ) | |
| all_hidden_states += (encoder_hidden_state,) | |
| hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) | |
| hidden_states = self.batch_norm(hidden_states) | |
| hidden_states = self.activation(hidden_states) | |
| # logits are of shape (batch_size, num_labels, height/4, width/4) | |
| logits = self.classifier(hidden_states) | |
| return logits | |
| class EfficientViTForSemanticSegmentation(EfficientViTPreTrainedModel): | |
| def __init__(self, config, **kwargs): | |
| super().__init__(config) | |
| self.vit = EfficientVitLarge(config) | |
| self.decode_head = DecodeHead(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| pixel_values: torch.FloatTensor | |
| ) -> Union[Tuple, SemanticSegmenterOutput]: | |
| # Pixel values should be B,C,H,W | |
| encoder_hidden_states = self.vit( | |
| pixel_values, | |
| ) | |
| logits = self.decode_head(encoder_hidden_states) | |
| # Apply sigmoid to get 0-1 output | |
| logits = torch.special.expit(logits) | |
| return SemanticSegmenterOutput( | |
| loss=None, | |
| logits=logits, | |
| hidden_states=encoder_hidden_states | |
| ) |