| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| from modules.basic_layers import GroupNorm | |
| class Extractor(nn.Module): | |
| def __init__(self, channels: list[int], num_groups: int = 32, use_residual: bool = True): | |
| super().__init__() | |
| self.use_residual = use_residual | |
| self.layers = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, stride=2, padding=1), | |
| GroupNorm(channels[i + 1], num_groups = num_groups), | |
| nn.SiLU(), | |
| nn.Conv2d(in_channels=channels[i + 1], out_channels=channels[i + 1], kernel_size=3, stride=1, padding=1), | |
| GroupNorm(channels[i + 1], num_groups = num_groups), | |
| nn.SiLU() | |
| ) for i in range(len(channels) - 1) | |
| ]) | |
| if self.use_residual: | |
| self.residual = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(in_channels=channels[i], out_channels=channels[i + 1], kernel_size=3, stride=2, padding=1), | |
| ) for i in range(len(channels) - 1) | |
| ]) | |
| def forward(self, x: torch.Tensor) -> list[torch.Tensor]: | |
| features = [] | |
| for residual, layer in zip(self.residual, self.layers): | |
| if self.use_residual: | |
| x = layer(x) + residual(x) | |
| else: | |
| x = layer(x) | |
| features.append(x) | |
| return features | |
| class ResNetExtractor(nn.Module): | |
| def __init__(self, pretrained: bool = True, layers_to_extract: list[str] = ["layer1", "layer2", "layer3"]): | |
| super(ResNetExtractor, self).__init__() | |
| resnet = models.resnet18(pretrained=pretrained) | |
| self.initial_layers = nn.Sequential( | |
| resnet.conv1, | |
| resnet.bn1, | |
| resnet.relu | |
| ) | |
| self.layers = nn.ModuleDict({ | |
| "layer1": resnet.layer1, | |
| "layer2": resnet.layer2, | |
| "layer3": resnet.layer3, | |
| }) | |
| self.layers_to_extract = layers_to_extract | |
| def forward(self, x: torch.Tensor) -> list[torch.Tensor]: | |
| features = [] | |
| x = self.initial_layers(x) | |
| for name, layer in self.layers.items(): | |
| x = layer(x) | |
| if name in self.layers_to_extract: | |
| features.append(x) | |
| return features | |
| class VGGExtractor(nn.Module): | |
| def __init__(self, layers_to_extract: list[int] = [8, 15, 22, 29]): | |
| super(VGGExtractor, self).__init__() | |
| self.vgg = models.vgg16(pretrained=True).features | |
| self.layers_to_extract = layers_to_extract | |
| self.selected_layers = [self.vgg[i] for i in layers_to_extract] | |
| def forward(self, x: torch.Tensor) -> list[torch.Tensor]: | |
| features = [] | |
| for i, layer in enumerate(self.vgg): | |
| x = layer(x) | |
| if i in self.layers_to_extract: | |
| features.append(x) | |
| return features | |