import dataclasses import time from openpi.training import config as c from openpi.training import data_loader as dl def main() -> None: for workers in [16, 32, 48]: cfg = dataclasses.replace( c.get_config("pi05_kinova_teddybear"), exp_name=f"bench_{workers}", num_workers=workers, ) t0 = time.time() loader = dl.create_data_loader(cfg, shuffle=True, num_batches=6, skip_norm_stats=True) loader_s = time.time() - t0 it = iter(loader) t1 = time.time() first_batch_s = None for i in range(6): _ = next(it) if i == 0: first_batch_s = time.time() - t1 total_s = time.time() - t1 print( { "workers": workers, "loader_s": round(loader_s, 3), "first_batch_s": round(first_batch_s or 0.0, 3), "avg_batch_s": round(total_s / 6, 3), }, flush=True, ) if __name__ == "__main__": main()