|
|
from enum import Enum |
|
|
|
|
|
|
|
|
class ExplicitEnum(str, Enum): |
|
|
""" |
|
|
Enum with more explicit error message for missing values. |
|
|
""" |
|
|
|
|
|
def __str__(self): |
|
|
return str(self.value) |
|
|
|
|
|
@classmethod |
|
|
def _missing_(cls, value): |
|
|
raise ValueError( |
|
|
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" |
|
|
) |
|
|
|
|
|
|
|
|
class PaddingStrategy(ExplicitEnum): |
|
|
""" |
|
|
Possible values for the `padding` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in an |
|
|
IDE. |
|
|
""" |
|
|
|
|
|
LONGEST = "longest" |
|
|
MAX_LENGTH = "max_length" |
|
|
DO_NOT_PAD = "do_not_pad" |
|
|
|
|
|
|
|
|
class TensorType(ExplicitEnum): |
|
|
""" |
|
|
Possible values for the `return_tensors` argument in [`EventTokenizerBase.__call__`]. Useful for |
|
|
tab-completion in an IDE. |
|
|
""" |
|
|
|
|
|
PYTORCH = "pt" |
|
|
NUMPY = "np" |
|
|
|
|
|
|
|
|
class RunnerPhase(ExplicitEnum): |
|
|
"""Model runner phase enum. |
|
|
""" |
|
|
TRAIN = 'train' |
|
|
VALIDATE = 'validate' |
|
|
PREDICT = 'predict' |
|
|
|
|
|
|
|
|
class LossFunction(ExplicitEnum): |
|
|
"""Loss function for neural TPP model. |
|
|
""" |
|
|
LOGLIKE = 'loglike' |
|
|
PARTIAL_TIME_LOSS = 'rmse' |
|
|
PARTIAL_EVENT_LOSS = 'accuracy' |
|
|
|
|
|
|
|
|
class LogConst: |
|
|
"""Format for log handler. |
|
|
""" |
|
|
DEFAULT_FORMAT = '[%(asctime)s] [%(levelname)s] %(message)s' |
|
|
DEFAULT_FORMAT_LONG = '%(asctime)s - %(filename)s[pid:%(process)d;line:%(lineno)d:%(funcName)s]' \ |
|
|
' - %(levelname)s: %(message)s' |
|
|
|
|
|
|
|
|
class PredOutputIndex: |
|
|
"""Positional index for the output tuple in ModelRunner. |
|
|
""" |
|
|
TimePredIndex = 0 |
|
|
TypePredIndex = 1 |
|
|
|
|
|
|
|
|
class DefaultRunnerConfig: |
|
|
DEFAULT_DATASET_ID = 'conttime' |
|
|
|
|
|
|
|
|
class TruncationStrategy(ExplicitEnum): |
|
|
""" |
|
|
Possible values for the `truncation` argument in [`EventTokenizer.__call__`]. Useful for tab-completion in |
|
|
an IDE. |
|
|
""" |
|
|
|
|
|
LONGEST_FIRST = "longest_first" |
|
|
DO_NOT_TRUNCATE = "do_not_truncate" |
|
|
|
|
|
|
|
|
class Backend(ExplicitEnum): |
|
|
""" |
|
|
Possible values for the `backend` argument in configuration. |
|
|
""" |
|
|
|
|
|
Torch = 'torch' |
|
|
|