| | from torch import nn, Tensor |
| | from torch.hub import load_state_dict_from_url |
| | from typing import Optional |
| |
|
| | from .utils import make_vgg_layers, vgg_cfgs, vgg_urls |
| | from ..utils import _init_weights, _get_norm_layer, _get_activation |
| | from ..utils import ConvDownsample, ConvUpsample |
| |
|
| |
|
| | vgg_models = [ |
| | "vgg11", "vgg11_bn", |
| | "vgg13", "vgg13_bn", |
| | "vgg16", "vgg16_bn", |
| | "vgg19", "vgg19_bn", |
| | ] |
| |
|
| | decoder_cfg = [512, 256, 128] |
| |
|
| |
|
| | class VGGEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | model_name: str, |
| | block_size: Optional[int] = None, |
| | norm: str = "none", |
| | act: str = "none", |
| | ) -> None: |
| | super().__init__() |
| | assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}." |
| | assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." |
| | self.model_name = model_name |
| |
|
| | if model_name == "vgg11": |
| | self.encoder = vgg11() |
| | elif model_name == "vgg11_bn": |
| | self.encoder = vgg11_bn() |
| | elif model_name == "vgg13": |
| | self.encoder = vgg13() |
| | elif model_name == "vgg13_bn": |
| | self.encoder = vgg13_bn() |
| | elif model_name == "vgg16": |
| | self.encoder = vgg16() |
| | elif model_name == "vgg16_bn": |
| | self.encoder = vgg16_bn() |
| | elif model_name == "vgg19": |
| | self.encoder = vgg19() |
| | else: |
| | self.encoder = vgg19_bn() |
| | |
| | self.encoder_channels = 512 |
| | self.encoder_reduction = 16 |
| | self.block_size = block_size if block_size is not None else self.encoder_reduction |
| |
|
| | if norm == "bn": |
| | norm_layer = nn.BatchNorm2d |
| | elif norm == "ln": |
| | norm_layer = nn.LayerNorm |
| | else: |
| | norm_layer = _get_norm_layer(self.encoder) |
| | |
| | if act == "relu": |
| | activation = nn.ReLU(inplace=True) |
| | elif act == "gelu": |
| | activation = nn.GELU() |
| | else: |
| | activation = _get_activation(self.encoder) |
| | |
| | if self.encoder_reduction >= self.block_size: |
| | self.refiner = ConvUpsample( |
| | in_channels=self.encoder_channels, |
| | out_channels=self.encoder_channels, |
| | scale_factor=self.encoder_reduction // self.block_size, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | else: |
| | self.refiner = ConvDownsample( |
| | in_channels=self.encoder_channels, |
| | out_channels=self.encoder_channels, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | self.refiner_channels = self.encoder_channels |
| | self.refiner_reduction = self.block_size |
| |
|
| | self.decoder = nn.Identity() |
| | self.decoder_channels = self.encoder_channels |
| | self.decoder_reduction = self.refiner_reduction |
| | |
| | def encode(self, x: Tensor) -> Tensor: |
| | return self.encoder(x) |
| | |
| | def refine(self, x: Tensor) -> Tensor: |
| | return self.refiner(x) |
| | |
| | def decode(self, x: Tensor) -> Tensor: |
| | return self.decoder(x) |
| | |
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.encode(x) |
| | x = self.refine(x) |
| | x = self.decode(x) |
| | return x |
| |
|
| |
|
| | class VGGEncoderDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | model_name: str, |
| | block_size: Optional[int] = None, |
| | norm: str = "none", |
| | act: str = "none", |
| | ) -> None: |
| | super().__init__() |
| | assert model_name in vgg_models, f"Model name should be one of {vgg_models}, but got {model_name}." |
| | assert block_size is None or block_size in [8, 16, 32], f"Block size should be one of [8, 16, 32], but got {block_size}." |
| | self.model_name = model_name |
| |
|
| | if model_name == "vgg11": |
| | encoder = vgg11() |
| | elif model_name == "vgg11_bn": |
| | encoder = vgg11_bn() |
| | elif model_name == "vgg13": |
| | encoder = vgg13() |
| | elif model_name == "vgg13_bn": |
| | encoder = vgg13_bn() |
| | elif model_name == "vgg16": |
| | encoder = vgg16() |
| | elif model_name == "vgg16_bn": |
| | encoder = vgg16_bn() |
| | elif model_name == "vgg19": |
| | encoder = vgg19() |
| | else: |
| | encoder = vgg19_bn() |
| | |
| | encoder_channels = 512 |
| | encoder_reduction = 16 |
| | decoder = make_vgg_layers(decoder_cfg, in_channels=encoder_channels, batch_norm="bn" in model_name, dilation=1) |
| | decoder.apply(_init_weights) |
| |
|
| | if norm == "bn": |
| | norm_layer = nn.BatchNorm2d |
| | elif norm == "ln": |
| | norm_layer = nn.LayerNorm |
| | else: |
| | norm_layer = _get_norm_layer(encoder) |
| | |
| | if act == "relu": |
| | activation = nn.ReLU(inplace=True) |
| | elif act == "gelu": |
| | activation = nn.GELU() |
| | else: |
| | activation = _get_activation(encoder) |
| |
|
| | self.encoder = nn.Sequential(encoder, decoder) |
| | self.encoder_channels = decoder_cfg[-1] |
| | self.encoder_reduction = encoder_reduction |
| | self.block_size = block_size if block_size is not None else self.encoder_reduction |
| | |
| | if self.encoder_reduction >= self.block_size: |
| | self.refiner = ConvUpsample( |
| | in_channels=self.encoder_channels, |
| | out_channels=self.encoder_channels, |
| | scale_factor=self.encoder_reduction // self.block_size, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | else: |
| | self.refiner = ConvDownsample( |
| | in_channels=self.encoder_channels, |
| | out_channels=self.encoder_channels, |
| | norm_layer=norm_layer, |
| | activation=activation, |
| | ) |
| | self.refiner_channels = self.encoder_channels |
| | self.refiner_reduction = self.block_size |
| |
|
| | self.decoder = nn.Identity() |
| | self.decoder_channels = self.refiner_channels |
| | self.decoder_reduction = self.refiner_reduction |
| |
|
| | def encode(self, x: Tensor) -> Tensor: |
| | return self.encoder(x) |
| | |
| | def refine(self, x: Tensor) -> Tensor: |
| | return self.refiner(x) |
| | |
| | def decode(self, x: Tensor) -> Tensor: |
| | return self.decoder(x) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.encode(x) |
| | x = self.refine(x) |
| | x = self.decode(x) |
| | return x |
| |
|
| |
|
| | class VGG(nn.Module): |
| | def __init__( |
| | self, |
| | features: nn.Module, |
| | ) -> None: |
| | super().__init__() |
| | self.features = features |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.features(x) |
| | return x |
| |
|
| |
|
| | def vgg11() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["A"])) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11"]), strict=False) |
| | return model |
| |
|
| | def vgg11_bn() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True)) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg11_bn"]), strict=False) |
| | return model |
| |
|
| | def vgg13() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["B"])) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13"]), strict=False) |
| | return model |
| |
|
| | def vgg13_bn() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True)) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg13_bn"]), strict=False) |
| | return model |
| |
|
| | def vgg16() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["D"])) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16"]), strict=False) |
| | return model |
| |
|
| | def vgg16_bn() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True)) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg16_bn"]), strict=False) |
| | return model |
| |
|
| | def vgg19() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["E"])) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19"]), strict=False) |
| | return model |
| |
|
| | def vgg19_bn() -> VGG: |
| | model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True)) |
| | model.load_state_dict(state_dict=load_state_dict_from_url(vgg_urls["vgg19_bn"]), strict=False) |
| | return model |
| |
|
| | def _vgg_encoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoder: |
| | return VGGEncoder(model_name, block_size, norm=norm, act=act) |
| |
|
| | def _vgg_encoder_decoder(model_name: str, block_size: Optional[int] = None, norm: str = "none", act: str = "none") -> VGGEncoderDecoder: |
| | return VGGEncoderDecoder(model_name, block_size, norm=norm, act=act) |
| |
|