Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
""" models from the segmentation_models_pytorch library
install the library before using these models
"""
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
class SegModel(nn.Module):
def __init__(self, architecture='Unet',
encoder_name='mobilenet_v2',
encoder_weights='imagenet',
encoder_depth=4,
in_channels=3,
classes=4):
super().__init__()
"""
architecture = {
Unet, UnetPlusPlus, EfficientUnetPlusPlus, DeepLabV3, DeepLabV3+
}
encoder_name = {
resnet18 / 11M, resnet34 / 21M, resnet50 / 23M, resnet101 / 42M
densenet121 / 6M, densenet169 / 12M, densenet201 / 18M
efficientnet-b0 / 4M, -b1 / 6M, -b2 / 7M, -b3 / 10M, -b4 / 17M, -b5 / 28M
mobilenet_v2 / 2M
}
encoder_weights = {
"imagenet", None
}
in_channels = 3
classes = 4
"""
if encoder_depth == 4:
decoder_channels = (256, 128, 64, 32)
exec("""self.model = smp.%s(
encoder_name='%s',
encoder_weights='%s',
encoder_depth=%s,
decoder_channels=%s,
in_channels=%s,
classes=%s,
activation=None,
)
""" % (architecture, encoder_name, encoder_weights, encoder_depth, decoder_channels, in_channels, classes))
def forward(self, x):
return self.model(x)