| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import contextlib |
| | import importlib.util |
| | import inspect |
| | import logging |
| | import numpy as np |
| | import os |
| | import random |
| | import re |
| | import shutil |
| | import sys |
| | import tempfile |
| | import unittest |
| |
|
| | from distutils.util import strtobool |
| | from io import StringIO |
| | from packaging import version |
| | from pathlib import Path |
| | from typing import Iterator, Union |
| | from unittest import mock |
| | from unittest.case import SkipTest |
| | from deepspeed.accelerator import get_accelerator |
| |
|
| |
|
| | try: |
| | import torch |
| | _torch_available = True |
| | except: |
| | _torch_available = False |
| |
|
| |
|
| | try: |
| | import datasets |
| | _datasets_available = True |
| | except: |
| | _datasets_available = False |
| |
|
| |
|
| | try: |
| | import tensorflow |
| | _tf_available = True |
| | except: |
| | _tf_available = False |
| |
|
| | def is_tf_available(): |
| | return _tf_available |
| |
|
| |
|
| | def is_datasets_available(): |
| | return _datasets_available |
| |
|
| |
|
| | def is_torch_available(): |
| | return _torch_available |
| |
|
| |
|
| | def parse_flag_from_env(key, default=False): |
| | try: |
| | value = os.environ[key] |
| | except KeyError: |
| | |
| | _value = default |
| | else: |
| | |
| | try: |
| | _value = strtobool(value) |
| | except ValueError: |
| | |
| | raise ValueError(f"If set, {key} must be yes or no.") |
| | return _value |
| |
|
| |
|
| | def parse_int_from_env(key, default=None): |
| | try: |
| | value = os.environ[key] |
| | except KeyError: |
| | _value = default |
| | else: |
| | try: |
| | _value = int(value) |
| | except ValueError: |
| | raise ValueError(f"If set, {key} must be a int.") |
| | return _value |
| |
|
| |
|
| | def require_torch(test_case): |
| | """ |
| | Decorator marking a test that requires PyTorch. |
| | |
| | These tests are skipped when PyTorch isn't installed. |
| | |
| | """ |
| | if not is_torch_available(): |
| | return unittest.skip("test requires PyTorch")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def require_torch_multi_accelerator(test_case): |
| | """ |
| | Decorator marking a test that requires a multi-accelerators setup (in PyTorch). These tests are skipped on a machine without |
| | multiple Accelerators. |
| | |
| | To run *only* the multi_accelerator tests, assuming all test names contain multi_accelerator: $ pytest -sv ./tests/ -k "multi_accelerator" |
| | """ |
| | if not is_torch_available(): |
| | return unittest.skip("test requires PyTorch")(test_case) |
| |
|
| | if get_accelerator().device_count() < 2: |
| | return unittest.skip("test requires multiple Accelerators")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def require_torch_non_multi_accelerator(test_case): |
| | """ |
| | Decorator marking a test that requires 0 or 1 Accelerator setup (in PyTorch). |
| | """ |
| | if not is_torch_available(): |
| | return unittest.skip("test requires PyTorch")(test_case) |
| |
|
| | if get_accelerator().device_count() > 1: |
| | return unittest.skip("test requires 0 or 1 Accelerator")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def require_torch_up_to_2_accelerators(test_case): |
| | """ |
| | Decorator marking a test that requires 0 or 1 or 2 Accelerator setup (in PyTorch). |
| | """ |
| | if not is_torch_available(): |
| | return unittest.skip("test requires PyTorch")(test_case) |
| |
|
| | if get_accelerator().device_count() > 2: |
| | return unittest.skip("test requires 0 or 1 or 2 Accelerators")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | if is_torch_available(): |
| | |
| | torch_device = get_accelerator().device_name() |
| | else: |
| | torch_device = None |
| |
|
| |
|
| | def require_torch_accelerator(test_case): |
| | """Decorator marking a test that requires Accelerator and PyTorch.""" |
| | if torch_device == "cpu": |
| | return unittest.skip("test requires Accelerator")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def require_datasets(test_case): |
| | """Decorator marking a test that requires datasets.""" |
| |
|
| | if not is_datasets_available(): |
| | return unittest.skip("test requires `datasets`")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def is_deepspeed_available(): |
| | return importlib.util.find_spec("deepspeed") is not None |
| |
|
| |
|
| | def require_deepspeed(test_case): |
| | """ |
| | Decorator marking a test that requires deepspeed |
| | """ |
| | if not is_deepspeed_available(): |
| | return unittest.skip("test requires deepspeed")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def is_bnb_available(): |
| | return importlib.util.find_spec("bitsandbytes") is not None |
| |
|
| |
|
| | def require_bnb(test_case): |
| | """ |
| | Decorator marking a test that requires bitsandbytes |
| | """ |
| | if not is_bnb_available(): |
| | return unittest.skip("test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def require_bnb_non_decorator(): |
| | """ |
| | Non-Decorator function that would skip a test if bitsandbytes is missing |
| | """ |
| | if not is_bnb_available(): |
| | raise SkipTest("Test requires bitsandbytes from https://github.com/facebookresearch/bitsandbytes") |
| |
|
| |
|
| | def set_seed(seed: int=42): |
| | """ |
| | Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` |
| | |
| | Args: |
| | seed (:obj:`int`): The seed to set. |
| | """ |
| | random.seed(seed) |
| | np.random.seed(seed) |
| | if is_torch_available(): |
| | torch.manual_seed(seed) |
| | get_accelerator().manual_seed_all(seed) |
| |
|
| |
|
| | def get_accelerator_count(): |
| | """ |
| | Return the number of available accelerators (regardless of whether torch or tf is used) |
| | """ |
| | if is_torch_available(): |
| | return get_accelerator().device_count() |
| | elif is_tf_available(): |
| | import tensorflow as tf |
| | return len(tf.config.list_physical_devices("GPU")) |
| | else: |
| | return 0 |
| |
|
| | def torch_assert_equal(actual, expected, **kwargs): |
| | |
| | if hasattr(torch.testing, "assert_close"): |
| | return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs) |
| | else: |
| | return torch.allclose(actual, expected, rtol=0.0, atol=0.0) |
| |
|
| |
|
| | def torch_assert_close(actual, expected, **kwargs): |
| | |
| | if hasattr(torch.testing, "assert_close"): |
| | return torch.testing.assert_close(actual, expected, **kwargs) |
| | else: |
| | kwargs.pop("msg", None) |
| | return torch.allclose(actual, expected, **kwargs) |
| |
|
| |
|
| | def is_torch_bf16_available(): |
| | |
| | if is_torch_available(): |
| | return get_accelerator().is_bf16_supported() |
| | else: |
| | return False |
| |
|
| |
|
| | def require_torch_bf16(test_case): |
| | """Decorator marking a test that requires Accelerator hardware supporting bf16 and PyTorch >= 1.9.""" |
| | if not is_torch_bf16_available(): |
| | return unittest.skip("test requires Accelerator hardware supporting bf16 and PyTorch >= 1.9")(test_case) |
| | else: |
| | return test_case |
| |
|
| |
|
| | def get_tests_dir(append_path=None): |
| | """ |
| | Args: |
| | append_path: optional path to append to the tests dir path |
| | |
| | Return: |
| | The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is |
| | joined after the `tests` dir the former is provided. |
| | |
| | """ |
| | |
| | caller__file__ = inspect.stack()[1][1] |
| | tests_dir = os.path.abspath(os.path.dirname(caller__file__)) |
| | if append_path: |
| | return os.path.join(tests_dir, append_path) |
| | else: |
| | return tests_dir |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def apply_print_resets(buf): |
| | return re.sub(r"^.*\r", "", buf, 0, re.M) |
| |
|
| |
|
| | def assert_screenout(out, what): |
| | out_pr = apply_print_resets(out).lower() |
| | match_str = out_pr.find(what.lower()) |
| | assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" |
| |
|
| |
|
| | class CaptureStd: |
| | """ |
| | Context manager to capture: |
| | |
| | - stdout: replay it, clean it up and make it available via ``obj.out`` |
| | - stderr: replay it and make it available via ``obj.err`` |
| | |
| | init arguments: |
| | |
| | - out - capture stdout:`` True``/``False``, default ``True`` |
| | - err - capture stdout: ``True``/``False``, default ``True`` |
| | - replay - whether to replay or not: ``True``/``False``, default ``True``. By default each |
| | captured stream gets replayed back on context's exit, so that one can see what the test was |
| | doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass |
| | ``replay=False`` to disable this feature. |
| | |
| | Examples:: |
| | |
| | # to capture stdout only with auto-replay |
| | with CaptureStdout() as cs: |
| | print("Secret message") |
| | assert "message" in cs.out |
| | |
| | # to capture stderr only with auto-replay |
| | import sys |
| | with CaptureStderr() as cs: |
| | print("Warning: ", file=sys.stderr) |
| | assert "Warning" in cs.err |
| | |
| | # to capture both streams with auto-replay |
| | with CaptureStd() as cs: |
| | print("Secret message") |
| | print("Warning: ", file=sys.stderr) |
| | assert "message" in cs.out |
| | assert "Warning" in cs.err |
| | |
| | # to capture just one of the streams, and not the other, with auto-replay |
| | with CaptureStd(err=False) as cs: |
| | print("Secret message") |
| | assert "message" in cs.out |
| | # but best use the stream-specific subclasses |
| | |
| | # to capture without auto-replay |
| | with CaptureStd(replay=False) as cs: |
| | print("Secret message") |
| | assert "message" in cs.out |
| | |
| | """ |
| |
|
| | def __init__(self, out=True, err=True, replay=True): |
| |
|
| | self.replay = replay |
| |
|
| | if out: |
| | self.out_buf = StringIO() |
| | self.out = "error: CaptureStd context is unfinished yet, called too early" |
| | else: |
| | self.out_buf = None |
| | self.out = "not capturing stdout" |
| |
|
| | if err: |
| | self.err_buf = StringIO() |
| | self.err = "error: CaptureStd context is unfinished yet, called too early" |
| | else: |
| | self.err_buf = None |
| | self.err = "not capturing stderr" |
| |
|
| | def __enter__(self): |
| | if self.out_buf: |
| | self.out_old = sys.stdout |
| | sys.stdout = self.out_buf |
| |
|
| | if self.err_buf: |
| | self.err_old = sys.stderr |
| | sys.stderr = self.err_buf |
| |
|
| | return self |
| |
|
| | def __exit__(self, *exc): |
| | if self.out_buf: |
| | sys.stdout = self.out_old |
| | captured = self.out_buf.getvalue() |
| | if self.replay: |
| | sys.stdout.write(captured) |
| | self.out = apply_print_resets(captured) |
| |
|
| | if self.err_buf: |
| | sys.stderr = self.err_old |
| | captured = self.err_buf.getvalue() |
| | if self.replay: |
| | sys.stderr.write(captured) |
| | self.err = captured |
| |
|
| | def __repr__(self): |
| | msg = "" |
| | if self.out_buf: |
| | msg += f"stdout: {self.out}\n" |
| | if self.err_buf: |
| | msg += f"stderr: {self.err}\n" |
| | return msg |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| |
|
| | class CaptureStdout(CaptureStd): |
| | """Same as CaptureStd but captures only stdout""" |
| |
|
| | def __init__(self, replay=True): |
| | super().__init__(err=False, replay=replay) |
| |
|
| |
|
| | class CaptureStderr(CaptureStd): |
| | """Same as CaptureStd but captures only stderr""" |
| |
|
| | def __init__(self, replay=True): |
| | super().__init__(out=False, replay=replay) |
| |
|
| |
|
| | class CaptureLogger: |
| | """ |
| | Context manager to capture `logging` streams |
| | |
| | Args: |
| | |
| | - logger: 'logging` logger object |
| | |
| | Results: |
| | The captured output is available via `self.out` |
| | |
| | Example:: |
| | |
| | >>> from transformers import logging |
| | >>> from transformers.testing_utils import CaptureLogger |
| | |
| | >>> msg = "Testing 1, 2, 3" |
| | >>> logging.set_verbosity_info() |
| | >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") |
| | >>> with CaptureLogger(logger) as cl: |
| | ... logger.info(msg) |
| | >>> assert cl.out, msg+"\n" |
| | """ |
| |
|
| | def __init__(self, logger): |
| | self.logger = logger |
| | self.io = StringIO() |
| | self.sh = logging.StreamHandler(self.io) |
| | self.out = "" |
| |
|
| | def __enter__(self): |
| | self.logger.addHandler(self.sh) |
| | return self |
| |
|
| | def __exit__(self, *exc): |
| | self.logger.removeHandler(self.sh) |
| | self.out = self.io.getvalue() |
| |
|
| | def __repr__(self): |
| | return f"captured: {self.out}\n" |
| |
|
| |
|
| |
|
| | @contextlib.contextmanager |
| | |
| | def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: |
| | """ |
| | Temporary add given path to `sys.path`. |
| | |
| | Usage :: |
| | |
| | with ExtendSysPath('/path/to/dir'): |
| | mymodule = importlib.import_module('mymodule') |
| | |
| | """ |
| |
|
| | path = os.fspath(path) |
| | try: |
| | sys.path.insert(0, path) |
| | yield |
| | finally: |
| | sys.path.remove(path) |
| |
|
| |
|
| | class TestCasePlus(unittest.TestCase): |
| | """ |
| | This class extends `unittest.TestCase` with additional features. |
| | |
| | Feature 1: A set of fully resolved important file and dir path accessors. |
| | |
| | In tests often we need to know where things are relative to the current test file, and it's not trivial since the |
| | test could be invoked from more than one directory or could reside in sub-directories with different depths. This |
| | class solves this problem by sorting out all the basic paths and provides easy accessors to them: |
| | |
| | * ``pathlib`` objects (all fully resolved): |
| | |
| | - ``test_file_path`` - the current test file path (=``__file__``) |
| | - ``test_file_dir`` - the directory containing the current test file |
| | - ``tests_dir`` - the directory of the ``tests`` test suite |
| | - ``data_dir`` - the directory of the ``tests/data`` test suite |
| | - ``repo_root_dir`` - the directory of the repository |
| | - ``src_dir`` - the directory of ``src`` (i.e. where the ``transformers`` sub-dir resides) |
| | |
| | * stringified paths---same as above but these return paths as strings, rather than ``pathlib`` objects: |
| | |
| | - ``test_file_path_str`` |
| | - ``test_file_dir_str`` |
| | - ``tests_dir_str`` |
| | - ``data_dir_str`` |
| | - ``repo_root_dir_str`` |
| | - ``src_dir_str`` |
| | |
| | Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test. |
| | |
| | 1. Create a unique temporary dir: |
| | |
| | :: |
| | |
| | def test_whatever(self): |
| | tmp_dir = self.get_auto_remove_tmp_dir() |
| | |
| | ``tmp_dir`` will contain the path to the created temporary dir. It will be automatically removed at the end of the |
| | test. |
| | |
| | |
| | 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't |
| | empty it after the test. |
| | |
| | :: |
| | |
| | def test_whatever(self): |
| | tmp_dir = self.get_auto_remove_tmp_dir("./xxx") |
| | |
| | This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests |
| | didn't leave any data in there. |
| | |
| | 3. You can override the first two options by directly overriding the ``before`` and ``after`` args, leading to the |
| | following behavior: |
| | |
| | ``before=True``: the temporary dir will always be cleared at the beginning of the test. |
| | |
| | ``before=False``: if the temporary dir already existed, any existing files will remain there. |
| | |
| | ``after=True``: the temporary dir will always be deleted at the end of the test. |
| | |
| | ``after=False``: the temporary dir will always be left intact at the end of the test. |
| | |
| | Note 1: In order to run the equivalent of ``rm -r`` safely, only subdirs of the project repository checkout are |
| | allowed if an explicit ``tmp_dir`` is used, so that by mistake no ``/tmp`` or similar important part of the |
| | filesystem will get nuked. i.e. please always pass paths that start with ``./`` |
| | |
| | Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested |
| | otherwise. |
| | |
| | Feature 3: Get a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` specific to the current test suite. |
| | This is useful for invoking external programs from the test suite - e.g. distributed training. |
| | |
| | |
| | :: |
| | def test_whatever(self): |
| | env = self.get_env() |
| | |
| | """ |
| |
|
| | def setUp(self): |
| | |
| | self.teardown_tmp_dirs = [] |
| |
|
| | |
| | self._test_file_path = inspect.getfile(self.__class__) |
| | path = Path(self._test_file_path).resolve() |
| | self._test_file_dir = path.parents[0] |
| | for up in [1, 2, 3]: |
| | tmp_dir = path.parents[up] |
| | if (tmp_dir / "megatron").is_dir() and (tmp_dir / "tests").is_dir(): |
| | break |
| | if tmp_dir: |
| | self._repo_root_dir = tmp_dir |
| | else: |
| | raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}") |
| | self._tests_dir = self._repo_root_dir / "tests" |
| | self._data_dir = self._repo_root_dir / "tests" / "data" |
| | self._src_dir = self._repo_root_dir |
| |
|
| | @property |
| | def test_file_path(self): |
| | return self._test_file_path |
| |
|
| | @property |
| | def test_file_path_str(self): |
| | return str(self._test_file_path) |
| |
|
| | @property |
| | def test_file_dir(self): |
| | return self._test_file_dir |
| |
|
| | @property |
| | def test_file_dir_str(self): |
| | return str(self._test_file_dir) |
| |
|
| | @property |
| | def tests_dir(self): |
| | return self._tests_dir |
| |
|
| | @property |
| | def tests_dir_str(self): |
| | return str(self._tests_dir) |
| |
|
| | @property |
| | def data_dir(self): |
| | return self._data_dir |
| |
|
| | @property |
| | def data_dir_str(self): |
| | return str(self._data_dir) |
| |
|
| | @property |
| | def repo_root_dir(self): |
| | return self._repo_root_dir |
| |
|
| | @property |
| | def repo_root_dir_str(self): |
| | return str(self._repo_root_dir) |
| |
|
| | @property |
| | def src_dir(self): |
| | return self._src_dir |
| |
|
| | @property |
| | def src_dir_str(self): |
| | return str(self._src_dir) |
| |
|
| | def get_env(self): |
| | """ |
| | Return a copy of the ``os.environ`` object that sets up ``PYTHONPATH`` correctly. This is useful |
| | for invoking external programs from the test suite - e.g. distributed training. |
| | |
| | It always inserts ``.`` first, then ``./tests`` depending on the test suite type and |
| | finally the preset ``PYTHONPATH`` if any (all full resolved paths). |
| | |
| | """ |
| | env = os.environ.copy() |
| | paths = [self.src_dir_str] |
| | paths.append(self.tests_dir_str) |
| | paths.append(env.get("PYTHONPATH", "")) |
| |
|
| | env["PYTHONPATH"] = ":".join(paths) |
| | return env |
| |
|
| | def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): |
| | """ |
| | Args: |
| | tmp_dir (:obj:`string`, `optional`): |
| | if :obj:`None`: |
| | |
| | - a unique temporary path will be created |
| | - sets ``before=True`` if ``before`` is :obj:`None` |
| | - sets ``after=True`` if ``after`` is :obj:`None` |
| | else: |
| | |
| | - :obj:`tmp_dir` will be created |
| | - sets ``before=True`` if ``before`` is :obj:`None` |
| | - sets ``after=False`` if ``after`` is :obj:`None` |
| | before (:obj:`bool`, `optional`): |
| | If :obj:`True` and the :obj:`tmp_dir` already exists, make sure to empty it right away if :obj:`False` |
| | and the :obj:`tmp_dir` already exists, any existing files will remain there. |
| | after (:obj:`bool`, `optional`): |
| | If :obj:`True`, delete the :obj:`tmp_dir` at the end of the test if :obj:`False`, leave the |
| | :obj:`tmp_dir` and its contents intact at the end of the test. |
| | |
| | Returns: |
| | tmp_dir(:obj:`string`): either the same value as passed via `tmp_dir` or the path to the auto-selected tmp |
| | dir |
| | """ |
| | if tmp_dir is not None: |
| |
|
| | |
| | |
| | |
| | |
| | if before is None: |
| | before = True |
| | if after is None: |
| | after = False |
| |
|
| | |
| | path = Path(tmp_dir).resolve() |
| |
|
| | |
| | if not tmp_dir.startswith("./"): |
| | raise ValueError( |
| | f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`" |
| | ) |
| |
|
| | |
| | if before is True and path.exists(): |
| | shutil.rmtree(tmp_dir, ignore_errors=True) |
| |
|
| | path.mkdir(parents=True, exist_ok=True) |
| |
|
| | else: |
| | |
| | |
| | |
| | |
| | if before is None: |
| | before = True |
| | if after is None: |
| | after = True |
| |
|
| | |
| | tmp_dir = tempfile.mkdtemp() |
| |
|
| | if after is True: |
| | |
| | self.teardown_tmp_dirs.append(tmp_dir) |
| |
|
| | return tmp_dir |
| |
|
| | def tearDown(self): |
| |
|
| | |
| | for path in self.teardown_tmp_dirs: |
| | shutil.rmtree(path, ignore_errors=True) |
| | self.teardown_tmp_dirs = [] |
| |
|
| |
|
| | def mockenv(**kwargs): |
| | """ |
| | this is a convenience wrapper, that allows this :: |
| | |
| | @mockenv(RUN_SLOW=True, USE_TF=False) |
| | def test_something(): |
| | run_slow = os.getenv("RUN_SLOW", False) |
| | use_tf = os.getenv("USE_TF", False) |
| | |
| | """ |
| | return mock.patch.dict(os.environ, kwargs) |
| |
|
| |
|
| | |
| | @contextlib.contextmanager |
| | def mockenv_context(*remove, **update): |
| | """ |
| | Temporarily updates the ``os.environ`` dictionary in-place. Similar to mockenv |
| | |
| | The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations. |
| | |
| | Args: |
| | remove: Environment variables to remove. |
| | update: Dictionary of environment variables and values to add/update. |
| | """ |
| | env = os.environ |
| | update = update or {} |
| | remove = remove or [] |
| |
|
| | |
| | stomped = (set(update.keys()) | set(remove)) & set(env.keys()) |
| | |
| | update_after = {k: env[k] for k in stomped} |
| | |
| | remove_after = frozenset(k for k in update if k not in env) |
| |
|
| | try: |
| | env.update(update) |
| | [env.pop(k, None) for k in remove] |
| | yield |
| | finally: |
| | env.update(update_after) |
| | [env.pop(k) for k in remove_after] |
| |
|
| |
|
| |
|
| | |
| |
|
| | |
| | import asyncio |
| |
|
| |
|
| | class _RunOutput: |
| | def __init__(self, returncode, stdout, stderr): |
| | self.returncode = returncode |
| | self.stdout = stdout |
| | self.stderr = stderr |
| |
|
| |
|
| | async def _read_stream(stream, callback): |
| | while True: |
| | line = await stream.readline() |
| | if line: |
| | callback(line) |
| | else: |
| | break |
| |
|
| |
|
| | async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: |
| | if echo: |
| | print("\nRunning: ", " ".join(cmd)) |
| |
|
| | p = await asyncio.create_subprocess_exec( |
| | cmd[0], |
| | *cmd[1:], |
| | stdin=stdin, |
| | stdout=asyncio.subprocess.PIPE, |
| | stderr=asyncio.subprocess.PIPE, |
| | env=env, |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | out = [] |
| | err = [] |
| |
|
| | def tee(line, sink, pipe, label=""): |
| | line = line.decode("utf-8").rstrip() |
| | sink.append(line) |
| | if not quiet: |
| | print(label, line, file=pipe) |
| |
|
| | |
| | await asyncio.wait( |
| | [ |
| | _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")), |
| | _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")), |
| | ], |
| | timeout=timeout, |
| | ) |
| | return _RunOutput(await p.wait(), out, err) |
| |
|
| |
|
| | def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: |
| |
|
| | loop = asyncio.get_event_loop() |
| | result = loop.run_until_complete( |
| | _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) |
| | ) |
| |
|
| | cmd_str = " ".join(cmd) |
| | if result.returncode > 0: |
| | stderr = "\n".join(result.stderr) |
| | raise RuntimeError( |
| | f"'{cmd_str}' failed with returncode {result.returncode}\n\n" |
| | f"The combined stderr from workers follows:\n{stderr}" |
| | ) |
| |
|
| | |
| | |
| | if not result.stdout and not result.stderr: |
| | raise RuntimeError(f"'{cmd_str}' produced no output.") |
| |
|
| | return result |
| |
|
| |
|
| | |
| |
|
| | def flatten_arguments(args): |
| | """ |
| | Converts dictionary argument to a list. |
| | |
| | Note: we add "IGNORED" at the beginning as this value is ignored by the argparser |
| | |
| | Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"] |
| | """ |
| | return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""] |
| |
|