ius / model /subnetwork.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
import torch
import torch.nn as nn
from .subnetwork_utils import BaseBlockConvBN, TopHead
from .register_modules import register_model
class BaseSubNetwork(nn.Module):
def __init__(self,
input_channels: int,
base_channels: int,
fc_hidden_units: int,
fc_pred_units: int,
pred_activation: str,
):
super(BaseSubNetwork, self).__init__()
self.input_channels = input_channels
self.base_channels = base_channels
self.fc_hidden_units = fc_hidden_units
self.fc_pred_units = fc_pred_units
self.pred_activation = pred_activation
self.intermediate_features = None
def get_intermediate_features(self) -> torch.Tensor:
return self.intermediate_features
@register_model("base_one")
class Subnet(BaseSubNetwork):
def __init__(self, input_channels=3, base_channels=32, fc_hidden_units=64, fc_pred_units=1, pred_activation="sigmoid"):
super(Subnet, self).__init__(
input_channels=input_channels,
base_channels=base_channels,
fc_hidden_units=fc_hidden_units,
fc_pred_units=fc_pred_units,
pred_activation=pred_activation,
)
self.block_one = BaseBlockConvBN(in_ch=input_channels,
out_ch=base_channels,
conv_layers=2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
activation="relu",
normalization=True,)
self.block_two = BaseBlockConvBN(in_ch=base_channels,
out_ch=base_channels*2,
conv_layers=2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
activation="relu",
normalization=True,)
self.block_three = BaseBlockConvBN(in_ch=base_channels*2,
out_ch=base_channels*4,
conv_layers=3,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
activation="relu",
normalization=True,)
self.flatten = nn.Flatten()
self.head = TopHead(fc_units=fc_hidden_units,
num_classes=fc_pred_units,
hidden_layers=1,
dropout_rate=0.6,
fc_activation="relu",
pred_activation=pred_activation)
self.intermediate_features = None
def forward(self, x):
x = self.block_one(x)
x = self.block_two(x)
self.intermediate_features = self.block_two.get_block_feats()
x = self.block_three(x)
x = self.flatten(x)
x = self.head(x)
return x