| import pytest |
|
|
| from open_clip_train.data import get_dataset_size |
|
|
| @pytest.mark.parametrize( |
| "shards,expected_size", |
| [ |
| ('/path/to/shard.tar', 1), |
| ('/path/to/shard_{000..000}.tar', 1), |
| ('/path/to/shard_{000..009}.tar', 10), |
| ('/path/to/shard_{000..009}_{000..009}.tar', 100), |
| ('/path/to/shard.tar::/path/to/other_shard_{000..009}.tar', 11), |
| ('/path/to/shard_{000..009}.tar::/path/to/other_shard_{000..009}.tar', 20), |
| (['/path/to/shard.tar'], 1), |
| (['/path/to/shard.tar', '/path/to/other_shard.tar'], 2), |
| ] |
| ) |
| def test_num_shards(shards, expected_size): |
| _, size = get_dataset_size(shards) |
| assert size == expected_size, f'Expected {expected_size} for {shards} but found {size} instead.' |
|
|