File size: 2,064 Bytes
f43af3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from enum import Enum


class ExplicitEnum(str, Enum):
    """
    Enum with more explicit error message for missing values.
    """

    def __str__(self):
        return str(self.value)

    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )


class PaddingStrategy(ExplicitEnum):
    """
    Possible values for the `padding` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in an
    IDE.
    """

    LONGEST = "longest"
    MAX_LENGTH = "max_length"
    DO_NOT_PAD = "do_not_pad"


class TensorType(ExplicitEnum):
    """
    Possible values for the `return_tensors` argument in [`EventTokenizerBase.__call__`]. Useful for
    tab-completion in an IDE.
    """

    PYTORCH = "pt"
    NUMPY = "np"


class RunnerPhase(ExplicitEnum):
    """Model runner phase enum.
    """
    TRAIN = 'train'
    VALIDATE = 'validate'
    PREDICT = 'predict'


class LossFunction(ExplicitEnum):
    """Loss function for neural TPP model.
    """
    LOGLIKE = 'loglike'
    PARTIAL_TIME_LOSS = 'rmse'
    PARTIAL_EVENT_LOSS = 'accuracy'


class LogConst:
    """Format for log handler.
    """
    DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s'
    DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s]' \
                          ' - %(levelname)s: %(message)s'


class PredOutputIndex:
    """Positional index for the output tuple in ModelRunner.
    """
    TimePredIndex = 0
    TypePredIndex = 1


class DefaultRunnerConfig:
    DEFAULT_DATASET_ID = 'conttime'


class TruncationStrategy(ExplicitEnum):
    """
    Possible values for the `truncation` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in
    an IDE.
    """

    LONGEST_FIRST = "longest_first"
    DO_NOT_TRUNCATE = "do_not_truncate"


class Backend(ExplicitEnum):
    """
    Possible values for the `backend` argument in configuration.
    """

    Torch = 'torch'