|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Registry for the available models we can train.""" |
|
|
|
|
|
from typing import Type |
|
|
|
|
|
from scenic.model_lib.base_models import base_model |
|
|
from scenic.projects.baselines import axial_resnet |
|
|
from scenic.projects.baselines import bit_resnet |
|
|
from scenic.projects.baselines import fully_connected |
|
|
from scenic.projects.baselines import hybrid_vit |
|
|
from scenic.projects.baselines import mixer |
|
|
from scenic.projects.baselines import resnet |
|
|
from scenic.projects.baselines import simple_cnn |
|
|
from scenic.projects.baselines import unet |
|
|
from scenic.projects.baselines import vit |
|
|
|
|
|
ALL_MODELS = {} |
|
|
|
|
|
CLASSIFICATION_MODELS = { |
|
|
'fully_connected_classification': |
|
|
fully_connected.FullyConnectedClassificationModel, |
|
|
'simple_cnn_classification': |
|
|
simple_cnn.SimpleCNNClassificationModel, |
|
|
'axial_resnet_multilabel_classification': |
|
|
axial_resnet.AxialResNetMultiLabelClassificationModel, |
|
|
'resnet_classification': |
|
|
resnet.ResNetClassificationModel, |
|
|
'resnet_multilabel_classification': |
|
|
resnet.ResNetMultiLabelClassificationModel, |
|
|
'bit_resnet_classification': |
|
|
bit_resnet.BitResNetClassificationModel, |
|
|
'bit_resnet_multilabel_classification': |
|
|
bit_resnet.BitResNetMultiLabelClassificationModel, |
|
|
'vit_multilabel_classification': |
|
|
vit.ViTMultiLabelClassificationModel, |
|
|
'hybrid_vit_multilabel_classification': |
|
|
hybrid_vit.HybridViTMultiLabelClassificationModel, |
|
|
'mixer_multilabel_classification': |
|
|
mixer.MixerMultiLabelClassificationModel, |
|
|
} |
|
|
|
|
|
SEGMENTATION_MODELS = { |
|
|
'simple_cnn_segmentation': simple_cnn.SimpleCNNSegmentationModel, |
|
|
'unet_segmentation': unet.UNetSegmentationModel, |
|
|
} |
|
|
|
|
|
|
|
|
ALL_MODELS.update(CLASSIFICATION_MODELS) |
|
|
ALL_MODELS.update(SEGMENTATION_MODELS) |
|
|
|
|
|
|
|
|
def get_model_cls(model_name: str) -> Type[base_model.BaseModel]: |
|
|
"""Get the corresponding model class based on the model string. |
|
|
|
|
|
API: |
|
|
``` |
|
|
model_builder= get_model_cls('fully_connected') |
|
|
model = model_builder(config, ...) |
|
|
``` |
|
|
|
|
|
Args: |
|
|
model_name: str; Name of the model, e.g. 'fully_connected'. |
|
|
|
|
|
Returns: |
|
|
The model architecture (a flax Model) along with its default config. |
|
|
Raises: |
|
|
ValueError if model_name is unrecognized. |
|
|
""" |
|
|
if model_name not in ALL_MODELS.keys(): |
|
|
raise ValueError('Unrecognized model: {}'.format(model_name)) |
|
|
return ALL_MODELS[model_name] |
|
|
|