| | import logging
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import torch.nn as nn
|
| | import torchvision
|
| | from torchvision.models.feature_extraction import create_feature_extractor
|
| |
|
| | from .base import BaseModel
|
| | from .schema import ResNetConfiguration
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | class DecoderBlock(nn.Module):
|
| | def __init__(
|
| | self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
|
| | ):
|
| | super().__init__()
|
| | layers = []
|
| | for i in range(num_convs):
|
| | conv = nn.Conv2d(
|
| | previous if i == 0 else out,
|
| | out,
|
| | kernel_size=ksize,
|
| | padding=ksize // 2,
|
| | bias=norm is None,
|
| | padding_mode=padding,
|
| | )
|
| | layers.append(conv)
|
| | if norm is not None:
|
| | layers.append(norm(out))
|
| | layers.append(nn.ReLU(inplace=True))
|
| | self.layers = nn.Sequential(*layers)
|
| |
|
| | def forward(self, previous, skip):
|
| | _, _, hp, wp = previous.shape
|
| | _, _, hs, ws = skip.shape
|
| | scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
|
| | upsampled = nn.functional.interpolate(
|
| | previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| | _, _, hu, wu = upsampled.shape
|
| | _, _, hs, ws = skip.shape
|
| | if (hu <= hs) and (wu <= ws):
|
| | skip = skip[:, :, :hu, :wu]
|
| | elif (hu >= hs) and (wu >= ws):
|
| | skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
|
| | else:
|
| | raise ValueError(
|
| | f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
|
| | )
|
| |
|
| | return self.layers(skip) + upsampled
|
| |
|
| |
|
| | class FPN(nn.Module):
|
| | def __init__(self, in_channels_list, out_channels, **kw):
|
| | super().__init__()
|
| | self.first = nn.Conv2d(
|
| | in_channels_list[-1], out_channels, 1, padding=0, bias=True
|
| | )
|
| | self.blocks = nn.ModuleList(
|
| | [
|
| | DecoderBlock(c, out_channels, ksize=1, **kw)
|
| | for c in in_channels_list[::-1][1:]
|
| | ]
|
| | )
|
| | self.out = nn.Sequential(
|
| | nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
|
| | nn.BatchNorm2d(out_channels),
|
| | nn.ReLU(inplace=True),
|
| | )
|
| |
|
| | def forward(self, layers):
|
| | feats = None
|
| | for idx, x in enumerate(reversed(layers.values())):
|
| | if feats is None:
|
| | feats = self.first(x)
|
| | else:
|
| | feats = self.blocks[idx - 1](feats, x)
|
| | out = self.out(feats)
|
| | return out
|
| |
|
| |
|
| | def remove_conv_stride(conv):
|
| | conv_new = nn.Conv2d(
|
| | conv.in_channels,
|
| | conv.out_channels,
|
| | conv.kernel_size,
|
| | bias=conv.bias is not None,
|
| | stride=1,
|
| | padding=conv.padding,
|
| | )
|
| | conv_new.weight = conv.weight
|
| | conv_new.bias = conv.bias
|
| | return conv_new
|
| |
|
| |
|
| | class FeatureExtractor(BaseModel):
|
| | default_conf = {
|
| | "pretrained": True,
|
| | "input_dim": 3,
|
| | "output_dim": 128,
|
| | "encoder": "resnet50",
|
| | "remove_stride_from_first_conv": False,
|
| | "num_downsample": None,
|
| | "decoder_norm": "nn.BatchNorm2d",
|
| | "do_average_pooling": False,
|
| | "checkpointed": False,
|
| | }
|
| | mean = [0.485, 0.456, 0.406]
|
| | std = [0.229, 0.224, 0.225]
|
| |
|
| | def freeze_encoder(self):
|
| | """
|
| | Freeze the encoder part of the model, i.e., set requires_grad = False
|
| | for all parameters in the encoder.
|
| | """
|
| | for param in self.encoder.parameters():
|
| | param.requires_grad = False
|
| | logger.debug("Encoder has been frozen.")
|
| |
|
| | def unfreeze_encoder(self):
|
| | """
|
| | Unfreeze the encoder part of the model, i.e., set requires_grad = True
|
| | for all parameters in the encoder.
|
| | """
|
| | for param in self.encoder.parameters():
|
| | param.requires_grad = True
|
| | logger.debug("Encoder has been unfrozen.")
|
| |
|
| | def build_encoder(self, conf: ResNetConfiguration):
|
| | assert isinstance(conf.encoder, str)
|
| | if conf.pretrained:
|
| | assert conf.input_dim == 3
|
| | Encoder = getattr(torchvision.models, conf.encoder)
|
| |
|
| | kw = {}
|
| | if conf.encoder.startswith("resnet"):
|
| | layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
|
| | kw["replace_stride_with_dilation"] = [False, False, False]
|
| | elif conf.encoder == "vgg13":
|
| | layers = [
|
| | "features.3",
|
| | "features.8",
|
| | "features.13",
|
| | "features.18",
|
| | "features.23",
|
| | ]
|
| | elif conf.encoder == "vgg16":
|
| | layers = [
|
| | "features.3",
|
| | "features.8",
|
| | "features.15",
|
| | "features.22",
|
| | "features.29",
|
| | ]
|
| | else:
|
| | raise NotImplementedError(conf.encoder)
|
| |
|
| | if conf.num_downsample is not None:
|
| | layers = layers[: conf.num_downsample]
|
| | encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
|
| | encoder = create_feature_extractor(encoder, return_nodes=layers)
|
| | if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
|
| | encoder.conv1 = remove_conv_stride(encoder.conv1)
|
| |
|
| | if conf.do_average_pooling:
|
| | raise NotImplementedError
|
| | if conf.checkpointed:
|
| | raise NotImplementedError
|
| |
|
| | return encoder, layers
|
| |
|
| | def _init(self, conf):
|
| |
|
| | self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
|
| | self.register_buffer("std_", torch.tensor(self.std), persistent=False)
|
| |
|
| |
|
| | self.encoder, self.layers = self.build_encoder(conf)
|
| | s = 128
|
| | inp = torch.zeros(1, 3, s, s)
|
| | features = list(self.encoder(inp).values())
|
| | self.skip_dims = [x.shape[1] for x in features]
|
| | self.layer_strides = [s / f.shape[-1] for f in features]
|
| | self.scales = [self.layer_strides[0]]
|
| |
|
| |
|
| | norm = eval(conf.decoder_norm) if conf.decoder_norm else None
|
| | self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
|
| |
|
| | logger.debug(
|
| | "Built feature extractor with layers {name:dim:stride}:\n"
|
| | f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
|
| | f"and output scales {self.scales}."
|
| | )
|
| |
|
| | def _forward(self, data):
|
| | image = data["image"]
|
| | image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
|
| |
|
| | skip_features = self.encoder(image)
|
| | output = self.decoder(skip_features)
|
| | return output, data['camera']
|
| |
|