File size: 679 Bytes
99ec8a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import torch.nn as nn
REGISTERED_MODELS = {}
def register_model(name: str):
def decorator(model_class: nn.Module) -> nn.Module:
key = name.lower()
if key in REGISTERED_MODELS:
raise ValueError(
f'Model {name} already registered'
)
REGISTERED_MODELS[key] = model_class
return model_class
return decorator
def get_registered_model(name: str) -> nn.Module:
key = name.lower()
if key not in REGISTERED_MODELS:
raise ValueError(
f'Unknown model name: {name}. '
f'Available models: {list(REGISTERED_MODELS.keys())}'
)
return REGISTERED_MODELS[key]
|