| | import os |
| |
|
| | |
| | if os.getenv("TEST_ENFORCE_NUMPY_FLOAT32"): |
| | |
| | |
| | del os.environ["TEST_ENFORCE_NUMPY_FLOAT32"] |
| | import numpy as np |
| | import traceback |
| |
|
| | __old_np_array = np.array |
| | __old_np_zeros = np.zeros |
| | __old_np_ones = np.ones |
| |
|
| | def _check_no_float64(arr, kwargs_dtype): |
| | if arr.dtype == np.float64: |
| | tb = traceback.extract_stack() |
| | |
| | |
| | |
| | filename = tb[-3].filename |
| | |
| | if ( |
| | "ml-agents/mlagents" in filename |
| | or "ml-agents-envs/mlagents" in filename |
| | ): |
| | raise ValueError( |
| | f"float64 array created. Set dtype=np.float32 instead of current dtype={kwargs_dtype}. " |
| | f"Run pytest with TEST_ENFORCE_NUMPY_FLOAT32=1 to confirm fix." |
| | ) |
| |
|
| | def np_array_no_float64(*args, **kwargs): |
| | res = __old_np_array(*args, **kwargs) |
| | _check_no_float64(res, kwargs.get("dtype")) |
| | return res |
| |
|
| | def np_zeros_no_float64(*args, **kwargs): |
| | res = __old_np_zeros(*args, **kwargs) |
| | _check_no_float64(res, kwargs.get("dtype")) |
| | return res |
| |
|
| | def np_ones_no_float64(*args, **kwargs): |
| | res = __old_np_ones(*args, **kwargs) |
| | _check_no_float64(res, kwargs.get("dtype")) |
| | return res |
| |
|
| | np.array = np_array_no_float64 |
| | np.zeros = np_zeros_no_float64 |
| | np.ones = np_ones_no_float64 |
| |
|
| |
|
| | if os.getenv("TEST_ENFORCE_BUFFER_KEY_TYPES"): |
| | from mlagents.trainers.buffer import AgentBuffer |
| |
|
| | AgentBuffer.CHECK_KEY_TYPES_AT_RUNTIME = True |
| |
|