Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os | |
| import time | |
| from io import StringIO | |
| from unittest.mock import patch | |
| import mmcv | |
| def reset_string_io(io): | |
| io.truncate(0) | |
| io.seek(0) | |
| class TestProgressBar: | |
| def test_start(self): | |
| out = StringIO() | |
| bar_width = 20 | |
| # without total task num | |
| prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) | |
| assert out.getvalue() == 'completed: 0, elapsed: 0s' | |
| reset_string_io(out) | |
| prog_bar = mmcv.ProgressBar(bar_width=bar_width, start=False, file=out) | |
| assert out.getvalue() == '' | |
| reset_string_io(out) | |
| prog_bar.start() | |
| assert out.getvalue() == 'completed: 0, elapsed: 0s' | |
| # with total task num | |
| reset_string_io(out) | |
| prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) | |
| assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' | |
| reset_string_io(out) | |
| prog_bar = mmcv.ProgressBar( | |
| 10, bar_width=bar_width, start=False, file=out) | |
| assert out.getvalue() == '' | |
| reset_string_io(out) | |
| prog_bar.start() | |
| assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' | |
| def test_update(self): | |
| out = StringIO() | |
| bar_width = 20 | |
| # without total task num | |
| prog_bar = mmcv.ProgressBar(bar_width=bar_width, file=out) | |
| time.sleep(1) | |
| reset_string_io(out) | |
| prog_bar.update() | |
| assert out.getvalue() == 'completed: 1, elapsed: 1s, 1.0 tasks/s' | |
| reset_string_io(out) | |
| # with total task num | |
| prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) | |
| time.sleep(1) | |
| reset_string_io(out) | |
| prog_bar.update() | |
| assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ | |
| 'task/s, elapsed: 1s, ETA: 9s' | |
| def test_adaptive_length(self): | |
| with patch.dict('os.environ', {'COLUMNS': '80'}): | |
| out = StringIO() | |
| bar_width = 20 | |
| prog_bar = mmcv.ProgressBar(10, bar_width=bar_width, file=out) | |
| time.sleep(1) | |
| reset_string_io(out) | |
| prog_bar.update() | |
| assert len(out.getvalue()) == 66 | |
| os.environ['COLUMNS'] = '30' | |
| reset_string_io(out) | |
| prog_bar.update() | |
| assert len(out.getvalue()) == 48 | |
| os.environ['COLUMNS'] = '60' | |
| reset_string_io(out) | |
| prog_bar.update() | |
| assert len(out.getvalue()) == 60 | |
| def sleep_1s(num): | |
| time.sleep(1) | |
| return num | |
| def test_track_progress_list(): | |
| out = StringIO() | |
| ret = mmcv.track_progress(sleep_1s, [1, 2, 3], bar_width=3, file=out) | |
| assert out.getvalue() == ( | |
| '[ ] 0/3, elapsed: 0s, ETA:' | |
| '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' | |
| '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' | |
| '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') | |
| assert ret == [1, 2, 3] | |
| def test_track_progress_iterator(): | |
| out = StringIO() | |
| ret = mmcv.track_progress( | |
| sleep_1s, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) | |
| assert out.getvalue() == ( | |
| '[ ] 0/3, elapsed: 0s, ETA:' | |
| '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' | |
| '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' | |
| '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') | |
| assert ret == [1, 2, 3] | |
| def test_track_iter_progress(): | |
| out = StringIO() | |
| ret = [] | |
| for num in mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out): | |
| ret.append(sleep_1s(num)) | |
| assert out.getvalue() == ( | |
| '[ ] 0/3, elapsed: 0s, ETA:' | |
| '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' | |
| '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' | |
| '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') | |
| assert ret == [1, 2, 3] | |
| def test_track_enum_progress(): | |
| out = StringIO() | |
| ret = [] | |
| count = [] | |
| for i, num in enumerate( | |
| mmcv.track_iter_progress([1, 2, 3], bar_width=3, file=out)): | |
| ret.append(sleep_1s(num)) | |
| count.append(i) | |
| assert out.getvalue() == ( | |
| '[ ] 0/3, elapsed: 0s, ETA:' | |
| '\r[> ] 1/3, 1.0 task/s, elapsed: 1s, ETA: 2s' | |
| '\r[>> ] 2/3, 1.0 task/s, elapsed: 2s, ETA: 1s' | |
| '\r[>>>] 3/3, 1.0 task/s, elapsed: 3s, ETA: 0s\n') | |
| assert ret == [1, 2, 3] | |
| assert count == [0, 1, 2] | |
| def test_track_parallel_progress_list(): | |
| out = StringIO() | |
| results = mmcv.track_parallel_progress( | |
| sleep_1s, [1, 2, 3, 4], 2, bar_width=4, file=out) | |
| # The following cannot pass CI on Github Action | |
| # assert out.getvalue() == ( | |
| # '[ ] 0/4, elapsed: 0s, ETA:' | |
| # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' | |
| # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' | |
| # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' | |
| # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') | |
| assert results == [1, 2, 3, 4] | |
| def test_track_parallel_progress_iterator(): | |
| out = StringIO() | |
| results = mmcv.track_parallel_progress( | |
| sleep_1s, ((i for i in [1, 2, 3, 4]), 4), 2, bar_width=4, file=out) | |
| # The following cannot pass CI on Github Action | |
| # assert out.getvalue() == ( | |
| # '[ ] 0/4, elapsed: 0s, ETA:' | |
| # '\r[> ] 1/4, 1.0 task/s, elapsed: 1s, ETA: 3s' | |
| # '\r[>> ] 2/4, 2.0 task/s, elapsed: 1s, ETA: 1s' | |
| # '\r[>>> ] 3/4, 1.5 task/s, elapsed: 2s, ETA: 1s' | |
| # '\r[>>>>] 4/4, 2.0 task/s, elapsed: 2s, ETA: 0s\n') | |
| assert results == [1, 2, 3, 4] | |