File size: 1,217 Bytes
3236af9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import random
import sys
import types
import unittest
from pathlib import Path

import numpy as np
import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

questionary = types.ModuleType("questionary")
questionary.Choice = type("Choice", (), {})
questionary.Style = lambda *args, **kwargs: None
questionary.select = lambda *args, **kwargs: None
questionary.text = lambda *args, **kwargs: None
questionary.path = lambda *args, **kwargs: None
questionary.password = lambda *args, **kwargs: None
sys.modules.setdefault("questionary", questionary)

optuna = types.ModuleType("optuna")
optuna.Trial = type("Trial", (), {})
sys.modules.setdefault("optuna", optuna)

from iconoclast.utils import set_random_seed


class UtilsTests(unittest.TestCase):
    def test_set_random_seed_is_reproducible(self):
        set_random_seed(1234)
        first = (
            random.random(),
            np.random.rand(),
            torch.rand(1).item(),
        )

        set_random_seed(1234)
        second = (
            random.random(),
            np.random.rand(),
            torch.rand(1).item(),
        )

        self.assertEqual(first, second)


if __name__ == "__main__":
    unittest.main()