File size: 3,383 Bytes
99ec8a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | import torch
import torch.nn as nn
from .subnetwork_utils import BaseBlockConvBN, TopHead
from .register_modules import register_model
class BaseSubNetwork(nn.Module):
def __init__(self,
input_channels: int,
base_channels: int,
fc_hidden_units: int,
fc_pred_units: int,
pred_activation: str,
):
super(BaseSubNetwork, self).__init__()
self.input_channels = input_channels
self.base_channels = base_channels
self.fc_hidden_units = fc_hidden_units
self.fc_pred_units = fc_pred_units
self.pred_activation = pred_activation
self.intermediate_features = None
def get_intermediate_features(self) -> torch.Tensor:
return self.intermediate_features
@register_model("base_one")
class Subnet(BaseSubNetwork):
def __init__(self, input_channels=3, base_channels=32, fc_hidden_units=64, fc_pred_units=1, pred_activation="sigmoid"):
super(Subnet, self).__init__(
input_channels=input_channels,
base_channels=base_channels,
fc_hidden_units=fc_hidden_units,
fc_pred_units=fc_pred_units,
pred_activation=pred_activation,
)
self.block_one = BaseBlockConvBN(in_ch=input_channels,
out_ch=base_channels,
conv_layers=2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
activation="relu",
normalization=True,)
self.block_two = BaseBlockConvBN(in_ch=base_channels,
out_ch=base_channels*2,
conv_layers=2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
activation="relu",
normalization=True,)
self.block_three = BaseBlockConvBN(in_ch=base_channels*2,
out_ch=base_channels*4,
conv_layers=3,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
activation="relu",
normalization=True,)
self.flatten = nn.Flatten()
self.head = TopHead(fc_units=fc_hidden_units,
num_classes=fc_pred_units,
hidden_layers=1,
dropout_rate=0.6,
fc_activation="relu",
pred_activation=pred_activation)
self.intermediate_features = None
def forward(self, x):
x = self.block_one(x)
x = self.block_two(x)
self.intermediate_features = self.block_two.get_block_feats()
x = self.block_three(x)
x = self.flatten(x)
x = self.head(x)
return x
|