File size: 1,217 Bytes
3236af9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | import random
import sys
import types
import unittest
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))
questionary = types.ModuleType("questionary")
questionary.Choice = type("Choice", (), {})
questionary.Style = lambda *args, **kwargs: None
questionary.select = lambda *args, **kwargs: None
questionary.text = lambda *args, **kwargs: None
questionary.path = lambda *args, **kwargs: None
questionary.password = lambda *args, **kwargs: None
sys.modules.setdefault("questionary", questionary)
optuna = types.ModuleType("optuna")
optuna.Trial = type("Trial", (), {})
sys.modules.setdefault("optuna", optuna)
from iconoclast.utils import set_random_seed
class UtilsTests(unittest.TestCase):
def test_set_random_seed_is_reproducible(self):
set_random_seed(1234)
first = (
random.random(),
np.random.rand(),
torch.rand(1).item(),
)
set_random_seed(1234)
second = (
random.random(),
np.random.rand(),
torch.rand(1).item(),
)
self.assertEqual(first, second)
if __name__ == "__main__":
unittest.main()
|