Spaces:
Build error
Build error
| import argparse | |
| import json | |
| import os | |
| import tarfile | |
| import tempfile | |
| from typing import Dict, List | |
| from loguru import logger | |
| from tqdm import tqdm | |
| # fmt: off | |
| parser = argparse.ArgumentParser( | |
| description="""Pre-process RedCaps dataset for training VirTex models - make | |
| small shards of TAR files containing images and captions.""" | |
| ) | |
| parser.add_argument( | |
| "-a", "--annotations", required=True, help="Path to a RedCaps annotation file." | |
| ) | |
| parser.add_argument( | |
| "-i", "--images", default="datasets/redcaps/images", | |
| help="""Path to RedCaps image directory. This directory is expected to have | |
| subreddit specific sub-directories containing images.""", | |
| ) | |
| parser.add_argument( | |
| "-z", "--shard-size", type=int, default=1000, | |
| help="Maximum number of RedCaps instances in a single TAR file shard.", | |
| ) | |
| parser.add_argument( | |
| "-o", "--output-prefix", required=True, | |
| help="Path prefix for saving TAR file shards. For example, `/tmp/tarfiles` " | |
| "will save as `/tmp/tarfiles_000000.tar`, `/tmp/tarfiles_000001.tar`, ...", | |
| ) | |
| # fmt: on | |
| def main(_A: argparse.Namespace): | |
| r""" | |
| Make TAR files containing images and annotations from a single RedCaps | |
| annotations file. These TAR files are arranged in a way that | |
| `WebDataset <https://github.com/tmbdev/webdataset>`_ can understand. | |
| """ | |
| ANNOTATIONS: List[Dict] = json.load(open(_A.annotations))["annotations"] | |
| # Keep track of the current index of TAR file shard and dataset index. | |
| SHARD_INDEX: int = 0 | |
| DATASET_INDEX: int = 0 | |
| # Create TAR file handle for the initial shard. | |
| tar_handle = tarfile.open(f"{_A.output_prefix}_{SHARD_INDEX:0>d}.tar", "w") | |
| # Keep a count of submissions that were skipped because their image was | |
| # not downloaded (not present in image dir). | |
| SKIPPED: int = 0 | |
| for ann in tqdm(ANNOTATIONS): | |
| image_path = os.path.join( | |
| _A.images, ann["subreddit"], f"{ann['image_id']}.jpg" | |
| ) | |
| # Add current image in shard if it exists. | |
| if os.path.exists(image_path): | |
| tar_handle.add(image_path, arcname=f"{ann['image_id']}.jpg") | |
| # Save subreddit name and caption as a JSON file. | |
| subreddit_and_caption = { | |
| "subreddit": ann["subreddit"], "caption": ann["caption"] | |
| } | |
| tmpfile = tempfile.NamedTemporaryFile("w+") | |
| tmpfile.write(json.dumps(subreddit_and_caption)) | |
| tmpfile.seek(0) | |
| tar_handle.add(tmpfile.name, arcname=f"{ann['image_id']}.json") | |
| tmpfile.close() | |
| DATASET_INDEX += 1 | |
| # Create new shard if current shard is full. | |
| if DATASET_INDEX % _A.shard_size == 0 and DATASET_INDEX > 0: | |
| tar_handle.close() | |
| logger.success( | |
| f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar" | |
| ) | |
| SHARD_INDEX += 1 | |
| # Open new TAR file shard. | |
| tar_handle = tarfile.open( | |
| f"{_A.output_prefix}_{SHARD_INDEX:0>6d}.tar", "w" | |
| ) | |
| else: | |
| SKIPPED += 1 | |
| # Close the file handle to properly save it. | |
| tar_handle.close() | |
| logger.success(f"Saved shard: {_A.output_prefix}_{SHARD_INDEX:0>6d}.tar\n") | |
| logger.info(f"Skipped {SKIPPED} annotations due to missing images.") | |
| if __name__ == "__main__": | |
| _A = parser.parse_args() | |
| main(_A) | |