import torch.nn as nn REGISTERED_MODELS = {} def register_model(name: str): def decorator(model_class: nn.Module) -> nn.Module: key = name.lower() if key in REGISTERED_MODELS: raise ValueError( f'Model {name} already registered' ) REGISTERED_MODELS[key] = model_class return model_class return decorator def get_registered_model(name: str) -> nn.Module: key = name.lower() if key not in REGISTERED_MODELS: raise ValueError( f'Unknown model name: {name}. ' f'Available models: {list(REGISTERED_MODELS.keys())}' ) return REGISTERED_MODELS[key]