import torch import torch.nn as nn from .subnetwork_utils import BaseBlockConvBN, TopHead from .register_modules import register_model class BaseSubNetwork(nn.Module): def __init__(self, input_channels: int, base_channels: int, fc_hidden_units: int, fc_pred_units: int, pred_activation: str, ): super(BaseSubNetwork, self).__init__() self.input_channels = input_channels self.base_channels = base_channels self.fc_hidden_units = fc_hidden_units self.fc_pred_units = fc_pred_units self.pred_activation = pred_activation self.intermediate_features = None def get_intermediate_features(self) -> torch.Tensor: return self.intermediate_features @register_model("base_one") class Subnet(BaseSubNetwork): def __init__(self, input_channels=3, base_channels=32, fc_hidden_units=64, fc_pred_units=1, pred_activation="sigmoid"): super(Subnet, self).__init__( input_channels=input_channels, base_channels=base_channels, fc_hidden_units=fc_hidden_units, fc_pred_units=fc_pred_units, pred_activation=pred_activation, ) self.block_one = BaseBlockConvBN(in_ch=input_channels, out_ch=base_channels, conv_layers=2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), activation="relu", normalization=True,) self.block_two = BaseBlockConvBN(in_ch=base_channels, out_ch=base_channels*2, conv_layers=2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), activation="relu", normalization=True,) self.block_three = BaseBlockConvBN(in_ch=base_channels*2, out_ch=base_channels*4, conv_layers=3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), activation="relu", normalization=True,) self.flatten = nn.Flatten() self.head = TopHead(fc_units=fc_hidden_units, num_classes=fc_pred_units, hidden_layers=1, dropout_rate=0.6, fc_activation="relu", pred_activation=pred_activation) self.intermediate_features = None def forward(self, x): x = self.block_one(x) x = self.block_two(x) self.intermediate_features = self.block_two.get_block_feats() x = self.block_three(x) x = self.flatten(x) x = self.head(x) return x