Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import pickle | |
| import pytest | |
| from omegaconf import OmegaConf | |
| from bytelatent.args import TrainArgs | |
| from bytelatent.constants import BLT_DATA | |
| def get_test_config(): | |
| if "BLT_INTERNAL" in os.environ: | |
| internal_dir = os.environ["BLT_INTERNAL"] | |
| else: | |
| internal_dir = "../internal-blt/configs" | |
| test_config = os.path.join(internal_dir, "tests.yaml") | |
| return test_config | |
| def test_first_batch_matches(): | |
| test_config_path = get_test_config() | |
| default_cfg = OmegaConf.create(TrainArgs().model_dump()) | |
| file_cfg = OmegaConf.load(test_config_path) | |
| merged_cfg = OmegaConf.merge(default_cfg, file_cfg) | |
| merged_cfg = OmegaConf.to_container(merged_cfg, resolve=True, throw_on_missing=True) | |
| train_args = TrainArgs.model_validate(merged_cfg) | |
| # MP doesn't work with async very well, but it doesn't change logic | |
| train_args.data.load_async = False | |
| # Test data created by pickling first batch in train loop then exiting | |
| with open(os.path.join(BLT_DATA, "fixtures", "first_batch_0.pickle"), "rb") as f: | |
| first_batch = pickle.load(f) | |
| # Emulate 1 node, 8 gpu training | |
| data_loader = train_args.data.build_from_rank(0, 8) | |
| batch_iterator = data_loader.create_iter() | |
| print("Getting first batch") | |
| batch = next(batch_iterator) | |
| assert (batch.x == first_batch.x).all() | |
| assert (batch.y == first_batch.y).all() | |
| assert (batch.mask == first_batch.mask).all() | |
| assert (batch.patch_lengths == first_batch.patch_lengths).all() | |
| assert batch.ngram_ids is None and first_batch.ngram_ids is None | |
| assert batch.is_final == False and batch.is_final == False | |