File size: 679 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn

REGISTERED_MODELS = {}


def register_model(name: str):
    def decorator(model_class: nn.Module) -> nn.Module:
        key = name.lower()
        if key in REGISTERED_MODELS:
            raise ValueError(
                f'Model {name} already registered'
            )
        REGISTERED_MODELS[key] = model_class
        return model_class
    return decorator


def get_registered_model(name: str) -> nn.Module:
    key = name.lower()
    if key not in REGISTERED_MODELS:
        raise ValueError(
            f'Unknown model name: {name}. '
            f'Available models: {list(REGISTERED_MODELS.keys())}'
        )
    return REGISTERED_MODELS[key]