| from .braceexpand import braceexpand | |
| from .context import autocast_exclude_mps | |
| from .file import get_latest_checkpoint | |
| from .instantiators import instantiate_callbacks, instantiate_loggers | |
| from .logger import RankedLogger | |
| from .logging_utils import log_hyperparameters | |
| from .rich_utils import enforce_tags, print_config_tree | |
| from .utils import extras, get_metric_value, task_wrapper | |
| def set_seed(seed: int): | |
| if seed < 0: | |
| seed = -seed | |
| if seed > (1 << 31): | |
| seed = 1 << 31 | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| if torch.backends.cudnn.is_available(): | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| __all__ = [ | |
| "enforce_tags", | |
| "extras", | |
| "get_metric_value", | |
| "RankedLogger", | |
| "instantiate_callbacks", | |
| "instantiate_loggers", | |
| "log_hyperparameters", | |
| "print_config_tree", | |
| "task_wrapper", | |
| "braceexpand", | |
| "get_latest_checkpoint", | |
| "autocast_exclude_mps", | |
| ] | |