ius / model /subnetwork_utils.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
import torch
import torch.nn as nn
from typing import Tuple, Optional
from model.module_mapping import layer_mapping
class BaseBlockConvBN(nn.Module):
def __init__(self,
in_ch: int,
out_ch: int,
kernel_size: Tuple[int, int] = (3, 3),
stride: Tuple[int, int] = (2, 2),
padding: Tuple[int, int] = (1, 1),
conv_layers: int = 2,
pool_layer: bool = True,
normalization: bool = True,
activation: Optional[str] = "linear",
# return_feats: bool = False,
):
super(BaseBlockConvBN, self).__init__()
self.conv_layers = torch.nn.ModuleList()
# Same params for all convolutional layers in ModuleList
# activation, out_channels, kh, kw, stride, padding
for _ in range(conv_layers):
self.conv_layers.append(
nn.Conv2d(in_channels=in_ch,
out_channels=out_ch,
kernel_size=kernel_size,
stride=(1, 1),
padding=padding,
bias=True)
)
in_ch = out_ch
self._activation = layer_mapping(activation.lower())()
self._batch_norm = nn.BatchNorm2d(out_ch) if normalization else None
self._max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=stride) if pool_layer else None
self._intermediate_feats = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv in self.conv_layers:
x = conv(x)
x = self._activation(x)
self._intermediate_feats = x
if self._max_pooling is not None:
x = self._max_pooling(x)
if self._batch_norm is not None:
x = self._batch_norm(x)
return x
def get_block_feats(self) -> torch.Tensor:
return self._intermediate_feats
class TopHead(nn.Module):
def __init__(self,
# in_feats: int,
fc_units: int = 64,
num_classes: int = 1,
hidden_layers: int = 1,
fc_activation: Optional[str] = "relu",
pred_activation: Optional[str] = "sigmoid",
dropout_rate: Optional[float] = None,
):
super(TopHead, self).__init__()
self._dense_layers = nn.ModuleList()
for i in range(hidden_layers + 1):
if i == 0:
self._dense_layers.append(
nn.LazyLinear(out_features=fc_units, bias=True)
)
else:
self._dense_layers.append(
nn.Linear(in_features=fc_units, out_features=fc_units, bias=True)
)
self._dense_layers.append(
layer_mapping(fc_activation.lower())()
)
if dropout_rate is not None:
self._dense_layers.append(
nn.Dropout(dropout_rate)
)
# prediction layer
self._dense_layers.append(
nn.Linear(in_features=fc_units, out_features=num_classes, bias=True),
)
self._dense_layers.append(
layer_mapping(pred_activation.lower())()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self._dense_layers:
x = layer(x)
return x
if __name__ == "__main__":
block = BaseBlockConvBN(in_ch=3, out_ch=64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), normalization=True,
activation="relu", conv_layers=2)
dummy_input = torch.randn((1, 3, 32, 32))
out = block(dummy_input)