|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from lerobot.common.utils.random_utils import ( |
|
|
deserialize_numpy_rng_state, |
|
|
deserialize_python_rng_state, |
|
|
deserialize_rng_state, |
|
|
deserialize_torch_rng_state, |
|
|
get_rng_state, |
|
|
seeded_context, |
|
|
serialize_numpy_rng_state, |
|
|
serialize_python_rng_state, |
|
|
serialize_rng_state, |
|
|
serialize_torch_rng_state, |
|
|
set_rng_state, |
|
|
set_seed, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def fixed_seed(): |
|
|
"""Fixture to set a consistent initial seed for each test.""" |
|
|
set_seed(12345) |
|
|
yield |
|
|
|
|
|
|
|
|
def test_serialize_deserialize_python_rng(fixed_seed): |
|
|
|
|
|
_ = random.random() |
|
|
st = serialize_python_rng_state() |
|
|
|
|
|
val2 = random.random() |
|
|
|
|
|
deserialize_python_rng_state(st) |
|
|
val3 = random.random() |
|
|
assert val2 == val3 |
|
|
|
|
|
|
|
|
def test_serialize_deserialize_numpy_rng(fixed_seed): |
|
|
_ = np.random.rand() |
|
|
st = serialize_numpy_rng_state() |
|
|
val2 = np.random.rand() |
|
|
deserialize_numpy_rng_state(st) |
|
|
val3 = np.random.rand() |
|
|
assert val2 == val3 |
|
|
|
|
|
|
|
|
def test_serialize_deserialize_torch_rng(fixed_seed): |
|
|
_ = torch.rand(1).item() |
|
|
st = serialize_torch_rng_state() |
|
|
val2 = torch.rand(1).item() |
|
|
deserialize_torch_rng_state(st) |
|
|
val3 = torch.rand(1).item() |
|
|
assert val2 == val3 |
|
|
|
|
|
|
|
|
def test_serialize_deserialize_rng(fixed_seed): |
|
|
|
|
|
_ = random.random() |
|
|
_ = np.random.rand() |
|
|
_ = torch.rand(1).item() |
|
|
|
|
|
st = serialize_rng_state() |
|
|
|
|
|
val_py2 = random.random() |
|
|
val_np2 = np.random.rand() |
|
|
val_th2 = torch.rand(1).item() |
|
|
|
|
|
deserialize_rng_state(st) |
|
|
assert random.random() == val_py2 |
|
|
assert np.random.rand() == val_np2 |
|
|
assert torch.rand(1).item() == val_th2 |
|
|
|
|
|
|
|
|
def test_get_set_rng_state(fixed_seed): |
|
|
st = get_rng_state() |
|
|
val1 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
|
|
|
random.random() |
|
|
np.random.rand() |
|
|
torch.rand(1) |
|
|
|
|
|
set_rng_state(st) |
|
|
val2 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
assert val1 == val2 |
|
|
|
|
|
|
|
|
def test_set_seed(): |
|
|
set_seed(1337) |
|
|
val1 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
set_seed(1337) |
|
|
val2 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
assert val1 == val2 |
|
|
|
|
|
|
|
|
def test_seeded_context(fixed_seed): |
|
|
val1 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
with seeded_context(1337): |
|
|
seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
val2 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
with seeded_context(1337): |
|
|
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item()) |
|
|
|
|
|
assert seeded_val1 == seeded_val2 |
|
|
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) |
|
|
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) |
|
|
|