import json from pathlib import Path import torch from tqdm import tqdm DATASET_PATH = Path("your_dataset_path") if __name__ == "__main__": # "train" or "test" for stage in ["test"]: stage = DATASET_PATH / stage index = {} for chunk_path in tqdm( sorted(list(stage.iterdir())), desc=f"Indexing {stage.name}" ): if chunk_path.suffix == ".torch": chunk = torch.load(chunk_path) for example in chunk: index[example["key"]] = str(chunk_path.relative_to(stage)) with (stage / "index.json").open("w") as f: json.dump(index, f)