| | from . import datasets |
| | from . import encoders |
| | from . import decoders |
| | from . import losses |
| | from . import metrics |
| |
|
| | from .decoders.unet import Unet |
| | from .decoders.unetplusplus import UnetPlusPlus |
| | from .decoders.manet import MAnet |
| | from .decoders.linknet import Linknet |
| | from .decoders.fpn import FPN |
| | from .decoders.pspnet import PSPNet |
| | from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus |
| | from .decoders.pan import PAN |
| |
|
| | from .__version__ import __version__ |
| |
|
| | |
| | from typing import Optional as _Optional |
| | import torch as _torch |
| |
|
| |
|
| | def create_model( |
| | arch: str, |
| | encoder_name: str = "resnet34", |
| | encoder_weights: _Optional[str] = "imagenet", |
| | in_channels: int = 3, |
| | classes: int = 1, |
| | **kwargs, |
| | ) -> _torch.nn.Module: |
| | """Models entrypoint, allows to create any model architecture just with |
| | parameters, without using its class |
| | """ |
| |
|
| | archs = [ |
| | Unet, |
| | UnetPlusPlus, |
| | MAnet, |
| | Linknet, |
| | FPN, |
| | PSPNet, |
| | DeepLabV3, |
| | DeepLabV3Plus, |
| | PAN, |
| | ] |
| | archs_dict = {a.__name__.lower(): a for a in archs} |
| | try: |
| | model_class = archs_dict[arch.lower()] |
| | except KeyError: |
| | raise KeyError( |
| | "Wrong architecture type `{}`. Available options are: {}".format( |
| | arch, list(archs_dict.keys()), |
| | ) |
| | ) |
| | return model_class( |
| | encoder_name=encoder_name, |
| | encoder_weights=encoder_weights, |
| | in_channels=in_channels, |
| | classes=classes, |
| | **kwargs, |
| | ) |
| |
|