ius / model /register_modules.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
import torch.nn as nn
REGISTERED_MODELS = {}
def register_model(name: str):
def decorator(model_class: nn.Module) -> nn.Module:
key = name.lower()
if key in REGISTERED_MODELS:
raise ValueError(
f'Model {name} already registered'
)
REGISTERED_MODELS[key] = model_class
return model_class
return decorator
def get_registered_model(name: str) -> nn.Module:
key = name.lower()
if key not in REGISTERED_MODELS:
raise ValueError(
f'Unknown model name: {name}. '
f'Available models: {list(REGISTERED_MODELS.keys())}'
)
return REGISTERED_MODELS[key]